Skip to content

Commit

Permalink
remove redundancies on rank (shorter sweeter)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewrgarcia committed Jul 7, 2023
1 parent 334b033 commit 3401ead
Showing 1 changed file with 12 additions and 27 deletions.
39 changes: 12 additions & 27 deletions non_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def non_local_block(ip, intermediate_dim=None, compression=2,
a tensor of same shape as input
"""
channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
ip_shape = K.int_shape(ip)
input_shape = K.int_shape(ip)

if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']:
raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`')
Expand All @@ -40,28 +40,18 @@ def non_local_block(ip, intermediate_dim=None, compression=2,
dim1, dim2, dim3 = None, None, None

# check rank and calculate the input shape
if len(ip_shape) == 3: # temporal / time series data
rank = 3
batchsize, dim1, channels = ip_shape

elif len(ip_shape) == 4: # spatial / image data
rank = 4

if channel_dim == 1:
batchsize, channels, dim1, dim2 = ip_shape
else:
batchsize, dim1, dim2, channels = ip_shape
rank = len(input_shape)
if rank not in [3,4,5]:
raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial) or 5 (spatio-temporal)')

elif len(ip_shape) == 5: # spatio-temporal / Video or Voxel data
rank = 5
elif rank == 3:
batchsize, dims, channels = input_shape

else:
if channel_dim == 1:
batchsize, channels, dim1, dim2, dim3 = ip_shape
batchsize, channels, *dims = input_shape
else:
batchsize, dim1, dim2, dim3, channels = ip_shape

else:
raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial) or 5 (spatio-temporal)')
batchsize, *dims, channels = input_shape

# verify correct intermediate dimension specified
if intermediate_dim is None:
Expand Down Expand Up @@ -130,17 +120,12 @@ def non_local_block(ip, intermediate_dim=None, compression=2,

# reshape to input tensor format
if rank == 3:
y = Reshape((dim1, intermediate_dim))(y)
elif rank == 4:
if channel_dim == -1:
y = Reshape((dim1, dim2, intermediate_dim))(y)
else:
y = Reshape((intermediate_dim, dim1, dim2))(y)
y = Reshape((dims, intermediate_dim))(y)
else:
if channel_dim == -1:
y = Reshape((dim1, dim2, dim3, intermediate_dim))(y)
y = Reshape((*dims, intermediate_dim))(y)
else:
y = Reshape((intermediate_dim, dim1, dim2, dim3))(y)
y = Reshape((intermediate_dim, *dims))(y)

# project filters
y = _convND(y, rank, channels)
Expand Down

0 comments on commit 3401ead

Please sign in to comment.