Skip to content

Commit

Permalink
add the example symbol of modulated deformable RoIpooling layer
Browse files Browse the repository at this point in the history
  • Loading branch information
ancientmooner committed Dec 1, 2018
1 parent 1c1ca4a commit 1595b18
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion DCNv2_op/example_symbol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import mx

"""
demo symbol of using modulated deformable convolution
"""
def modulated_deformable_conv(data, name, num_filter, stride, lr_mult=1):
def modulated_deformable_conv(data, name, num_filter, stride, lr_mult=0.1):
weight_var = mx.sym.Variable(name=name+'_conv2_offset_weight', init=mx.init.Zero(), lr_mult=lr_mult)
bias_var = mx.sym.Variable(name=name+'_conv2_offset_bias', init=mx.init.Zero(), lr_mult=lr_mult)
conv2_offset = mx.symbol.Convolution(name=name + '_conv2_offset', data=data, num_filter=27,
Expand All @@ -14,3 +16,49 @@ def modulated_deformable_conv(data, name, num_filter, stride, lr_mult=1):
num_filter=num_filter, pad=(1, 1), kernel=(3, 3), stride=stride,
num_deformable_group=1, no_bias=True)
return conv2

"""
demo symbol of using modulated deformable RoI pooling
"""
def modulated_deformable_roi_pool(data, rois, spatial_scale, imfeat_dim=256, deform_fc_dim=1024, roi_size=7, trans_std=0.1):
roi_align = mx.contrib.sym.DeformablePSROIPooling(name='roi_align',
data=data,
rois=rois,
group_size=1,
pooled_size=roi_size,
sample_per_part=2,
no_trans=True,
part_size=roi_size,
output_dim=imfeat_dim,
spatial_scale=spatial_scale)

feat_deform = mx.symbol.FullyConnected(name='fc_deform_1', data=roi_align, num_hidden=deform_fc_dim)
feat_deform = mx.sym.Activation(data=feat_deform, act_type='relu', name='fc_deform_1_relu')

feat_deform = mx.symbol.FullyConnected(name='fc_deform_2', data=feat_deform, num_hidden=deform_fc_dim)
feat_deform = mx.sym.Activation(data=feat_deform, act_type='relu', name='fc_deform_2_relu')

feat_deform = mx.symbol.FullyConnected(name='fc_deform_3', data=feat_deform, num_hidden=roi_size * roi_size * 3)

roi_offset = mx.sym.slice_axis(feat_deform, axis=1, begin=0, end=roi_size * roi_size * 2)
roi_offset = mx.sym.reshape(roi_offset, shape=(-1, 2, roi_size, roi_size))

roi_mask = mx.sym.slice_axis(feat_deform, axis=1, begin=roi_size * roi_size * 2, end=None)
roi_mask_sigmoid = mx.sym.Activation(roi_mask, act_type='sigmoid')
roi_mask_sigmoid = mx.sym.reshape(roi_mask_sigmoid, shape=(-1, 1, roi_size, roi_size))

deform_roi_pool = mx.contrib.sym.DeformablePSROIPooling(name='deform_roi_pool',
data=data,
rois=rois,
trans=roi_offset,
group_size=1,
pooled_size=roi_size,
sample_per_part=2,
no_trans=False,
part_size=roi_size,
output_dim=imfeat_dim,
spatial_scale=spatial_scale,
trans_std=trans_std)

modulated_deform_roi_pool = mx.sym.broadcast_mul(deform_roi_pool, roi_mask_sigmoid)
return modulated_deform_roi_pool

0 comments on commit 1595b18

Please sign in to comment.