Skip to content

Commit

Permalink
Add docs to the function
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed May 31, 2018
1 parent 5bd9c2e commit 33c4c26
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from non_local import non_local_block

ip = Input(shape=(##)) # this can be a rank 3, 4 or 5 tensor shape
x = ConvND(...) # as againm can be Conv1D, Conv2D or Conv3D
x = non_local_block(x, computation_compression=2, mode='embedded')
x = non_local_block(x, compression=2, mode='embedded')
...

```
Expand Down
95 changes: 65 additions & 30 deletions non_local.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
from keras.layers import Activation, Reshape, Lambda, concatenate, dot, add
from keras.layers import Activation, Reshape, Lambda, dot, add
from keras.layers import Conv1D, Conv2D, Conv3D
from keras.layers import MaxPool1D
from keras import backend as K


def non_local_block(ip, computation_compression=2, mode='embedded'):
def non_local_block(ip, intermediate_dim=None, compression=2,
mode='embedded', add_residual=True):
"""
Adds a Non-Local block for self attention to the input tensor.
Input tensor can be or rank 3 (temporal), 4 (spatial) or 5 (spatio-temporal).
Arguments:
ip: input tensor
intermediate_dim: The dimension of the intermediate representation. Can be
`None` or a positive integer greater than 0. If `None`, computes the
intermediate dimension as half of the input channel dimension.
compression: Compresses the intermediate representation during
the dot products to reduce memory consumption. Default is set to 2,
which states halve the time/space/spatio-time dimension for the
intermediate step. Set to 1 to prevent computation compression.
mode: Mode of operation. Can be one of `embedded`, `gaussian`, `dot` or
`concatenate`.
add_residual: Boolean value to decide if the residual connection should be
added or not. Default is True for ResNets, and False for Self Attention.
Returns:
a tensor of same shape as input
"""
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
ip_shape = K.int_shape(ip)

Expand All @@ -13,19 +35,20 @@ def non_local_block(ip, computation_compression=2, mode='embedded'):

dim1, dim2, dim3 = None, None, None

if len(ip_shape) == 3: # time series data
# check rank and calculate the input shape
if len(ip_shape) == 3: # temporal / time series data
rank = 3
batchsize, dim1, channels = ip_shape

elif len(ip_shape) == 4: # image data
elif len(ip_shape) == 4: # spatial / image data
rank = 4

if channel_dim == 1:
batchsize, channels, dim1, dim2 = ip_shape
else:
batchsize, dim1, dim2, channels = ip_shape

elif len(ip_shape) == 5: # Video / Voxel data
elif len(ip_shape) == 5: # spatio-temporal / Video or Voxel data
rank = 5

if channel_dim == 1:
Expand All @@ -36,6 +59,19 @@ def non_local_block(ip, computation_compression=2, mode='embedded'):
else:
raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial) or 5 (spatio-temporal)')

# verify correct intermediate dimension specified
if intermediate_dim is None:
intermediate_dim = channels // 2

if intermediate_dim < 1:
intermediate_dim = 1

else:
intermediate_dim = int(intermediate_dim)

if intermediate_dim < 1:
raise ValueError('`intermediate_dim` must be either `None` or positive integer greater than 1.')

if mode == 'gaussian': # Gaussian instantiation
x1 = Reshape((-1, channels))(ip) # xi
x2 = Reshape((-1, channels))(ip) # xj
Expand All @@ -44,37 +80,35 @@ def non_local_block(ip, computation_compression=2, mode='embedded'):

elif mode == 'dot': # Dot instantiation
# theta path
theta = _convND(ip, rank, channels // 2)
theta = Reshape((-1, channels // 2))(theta)
theta = _convND(ip, rank, intermediate_dim)
theta = Reshape((-1, intermediate_dim))(theta)

# phi path
phi = _convND(ip, rank, channels // 2)
phi = Reshape((-1, channels // 2))(phi)
phi = _convND(ip, rank, intermediate_dim)
phi = Reshape((-1, intermediate_dim))(phi)

f = dot([theta, phi], axes=2)

# scale the values to make it size invariant
if batchsize is not None:
f = Lambda(lambda z: 1./ batchsize * z)(f)
else:
f = Lambda(lambda z: 1. / 128 * z)(f)
size = K.int_shape(f)

# scale the values to make it size invariant
f = Lambda(lambda z: (1. / float(size[-1])) * z)(f)

elif mode == 'concatenate': # Concatenation instantiation
raise NotImplemented('Concatenation mode has not been implemented yet')
raise NotImplementedError('Concatenate model has not been implemented yet')

else: # Embedded Gaussian instantiation
# theta path
theta = _convND(ip, rank, channels // 2)
theta = Reshape((-1, channels // 2))(theta)
theta = _convND(ip, rank, intermediate_dim)
theta = Reshape((-1, intermediate_dim))(theta)

# phi path
phi = _convND(ip, rank, channels // 2)
phi = Reshape((-1, channels // 2))(phi)
phi = _convND(ip, rank, intermediate_dim)
phi = Reshape((-1, intermediate_dim))(phi)

if computation_compression > 1:
if compression > 1:
# shielded computation
phi = MaxPool1D(computation_compression)(phi)
phi = MaxPool1D(compression)(phi)

f = dot([theta, phi], axes=2)
f = Activation('softmax')(f)
Expand All @@ -83,34 +117,35 @@ def non_local_block(ip, computation_compression=2, mode='embedded'):
g = _convND(ip, rank, channels // 2)
g = Reshape((-1, channels // 2))(g)

if computation_compression > 1 and mode == 'embedded':
if compression > 1 and mode == 'embedded':
# shielded computation
g = MaxPool1D(computation_compression)(g)
g = MaxPool1D(compression)(g)

# compute output path
y = dot([f, g], axes=[2, 1])

# reshape to input tensor format
if rank == 3:
y = Reshape((dim1, channels // 2))(y)
y = Reshape((dim1, intermediate_dim))(y)
elif rank == 4:
if channel_dim == -1:
y = Reshape((dim1, dim2, channels // 2))(y)
y = Reshape((dim1, dim2, intermediate_dim))(y)
else:
y = Reshape((channels // 2, dim1, dim2))(y)
y = Reshape((intermediate_dim, dim1, dim2))(y)
else:
if channel_dim == -1:
y = Reshape((dim1, dim2, dim3, channels // 2))(y)
y = Reshape((dim1, dim2, dim3, intermediate_dim))(y)
else:
y = Reshape((channels // 2, dim1, dim2, dim3))(y)
y = Reshape((intermediate_dim, dim1, dim2, dim3))(y)

# project filters
y = _convND(y, rank, channels)

# residual connection
residual = add([ip, y])
if add_residual:
y = add([ip, y])

return residual
return y


def _convND(ip, rank, channels):
Expand Down
4 changes: 2 additions & 2 deletions nonlocal_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def f(x):
# Non Local Blook
if filters >= 256:
print("Filters : ", filters, "Adding Non Local Blocks")
x = non_local_block(x, mode='embedded', computation_compression=True)
x = non_local_block(x, mode='embedded', compression=2)

return x

Expand Down Expand Up @@ -457,5 +457,5 @@ def NonLocalResNet152(input_shape, classes):


if __name__ == '__main__':
model = NonLocalResNet18((32, 32, 3), classes=10)
model = NonLocalResNet18((128, 160, 3), classes=10)
model.summary()

0 comments on commit 33c4c26

Please sign in to comment.