Skip to content

Commit

Permalink
Create BasicConvLSTMCell.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Bei-Jin authored Aug 20, 2018
1 parent 39d382e commit 6f44c4b
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions code/BasicConvLSTMCell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Convolutional LSTM cell implementation from https://github.com/loliverhennigh
"""
import tensorflow as tf

class BasicConvLSTMCell(object):
"""Basic Conv LSTM recurrent network cell.
"""

def __init__(self, shape, filter_size, num_features, forget_bias=1.0,
input_size=None, state_is_tuple=False, activation=tf.nn.tanh):
"""Initialize the basic Conv LSTM cell.
Args:
shape: int tuple thats the height and width of the cell
filter_size: int tuple thats the height and width of the filter
num_features: int thats the depth of the cell
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self.shape = shape
self.filter_size = filter_size
self.num_features = num_features
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation

@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)

@property
def output_size(self):
return self._num_units

def __call__(self, inputs, state, scope=None, reuse=False):
"""Long short-term memory cell (LSTM)."""
# "BasicLSTMCell"
with tf.variable_scope(scope or type(self).__name__, reuse=reuse):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = tf.split(axis=3, num_or_size_splits=2, value=state)
concat = _conv_linear([inputs, h], self.filter_size,
self.num_features * 4, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=concat)

new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * tf.nn.sigmoid(o)

if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = tf.concat(axis=3, values=[new_c, new_h])
return new_h, new_state

def _conv_linear(args, filter_size, num_features, bias,
bias_start=0.0, scope=None, reuse=False):
"""convolution:
Args:
args: a 4D Tensor or a list of 4D, batch x n, Tensors.
filter_size: int tuple of filter height and width.
num_features: int, number of features.
bias_start: starting value to initialize the bias; 0 by default.
scope: VariableScope for the created subgraph; defaults to "Linear".
reuse: For reusing already existing weights
Returns:
A 4D Tensor with shape [batch h w num_features]
Raises:
ValueError: if some of the arguments has unspecified or wrong shape.
"""

# Calculate the total size of arguments on dimension 1.
total_arg_size_depth = 0
shapes = [a.get_shape().as_list() for a in args]
for shape in shapes:
if len(shape) != 4:
raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes))
if not shape[3]:
raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes))
else:
total_arg_size_depth += shape[3]

dtype = [a.dtype for a in args][0]

# Now the computation.
with tf.variable_scope(scope or "Conv", reuse=reuse):
matrix = tf.get_variable(
"Matrix", [filter_size[0], filter_size[1],
total_arg_size_depth, num_features], dtype=dtype)
if len(args) == 1:
res = tf.nn.conv2d(args[0], matrix, strides=[1, 1, 1, 1], padding='SAME')
else:
res = tf.nn.conv2d(tf.concat(axis=3, values=args), matrix,
strides=[1, 1, 1, 1], padding='SAME')
if not bias:
return res
bias_term = tf.get_variable(
"Bias", [num_features],
dtype=dtype, initializer=tf.constant_initializer(bias_start,
dtype=dtype)
)
return res + bias_term

0 comments on commit 6f44c4b

Please sign in to comment.