Skip to content

Commit

Permalink
add DCNv2 ops
Browse files Browse the repository at this point in the history
  • Loading branch information
ancientmooner committed Dec 1, 2018
1 parent f4e1637 commit 65de16f
Show file tree
Hide file tree
Showing 13 changed files with 2,708 additions and 2 deletions.
39 changes: 39 additions & 0 deletions DCNv2_op/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# DCNv2 operators

## Introduction

This folder provides the operators used in [DCNv2](https://arxiv.org/abs/1811.11168):

```
@article{DCNv2_2018,
title={Deformable ConvNets v2: More Deformable, Better Results},
author={Xizhou Zhu and Han Hu and Stephen Lin and Jifeng Dai},
journal={arXiv:1811.11168},
year={2018}
}
```

There are two operators in this folder. The first one is an updated deformable convolution operator. We have two major modifications:

* The new operator simutaneously processes multiple images in one computation loop, rather than processing one image in one loop as in the old operator.

Both the old and new operators use the following computation pipeline (illustrated by a 3x3 deformable convolution with input data of NxCxHxW and output data of NxC'xHxW):

for i in range(N/S):
step 1 (slicing): slicing the input data at the batch dimension from i*S to (i+1)*S, input (NxCxHxW) -> sliced input (SxCxHxW)
step 2 (deformable im2col): sliced input (SxCxHxW)+sliced offset (Sx18xHxW) -> column (Cx9xSxHxW)
step 3 (MatMul&reshape): weight matrix (C'x 9C) * column (9CxSHW) -> temp sliced output (C'xSxHxW) -> sliced output (SxC'xHxW)
step 4 (Merge): merge sliced output to form the whole output data (NxC'xHxW)
end

In the old operator, S is fixed as 1. In the new operator, S can be set by a new *im2col_step* parameter its default value is min(N, 64). The new operator is significantly faster than the old one when the image batch size is large (e.g. 32 as a usual practice in ImageNet classification).

* The boundary processing scheme is modified.

In the old operator, the pixel with any one coordinate (x or y) in (-inf, 0) is set as 0. The pixel with one coordinate in [H(W)-1,H(W)] is bilinear sampled assuming the pixel value on H(W) is the same as on H(W)-1. The pixel with any one coordinate in (H(W), inf) is set as 0.

In the new operator, the input image is firstly padded with zeros and then bilinear sampling is performed on all range of locations.

The new boundary scheme has little influence on tasks with large output feature map size (e.g. object detection), but can lead to better accuracy on tasks with small output feature map size (e.g. 7x7 as a usual practice in ImageNet classification).

The second operator is a new modulated deformable convolution operator introduced in the DCNv2 paper. Please see [example_symbol.py](https://github.com/msracver/Deformable-ConvNets/blob/master/DCNv2_op/example_symbol.py) for an example usage.
523 changes: 523 additions & 0 deletions DCNv2_op/deformable_convolution-inl.h

Large diffs are not rendered by default.

89 changes: 89 additions & 0 deletions DCNv2_op/deformable_convolution.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*!
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_convolution.cc
* \brief
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
*/

#include "./deformable_convolution-inl.h"

namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(DeformableConvolutionParam);

template<>
Operator* CreateOp<cpu>(DeformableConvolutionParam param, int dtype,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
Context ctx) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new DeformableConvolutionOp<cpu, DType>(param);
})
return op;
}

// DO_BIND_DISPATCH comes from operator_common.h
Operator *DeformableConvolutionProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], in_shape, &out_shape, ctx);
}

MXNET_REGISTER_OP_PROPERTY(_contrib_DeformableConvolution, DeformableConvolutionProp)
.describe(R"code(Compute 2-D deformable convolution on 4-D input.
The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
For 2-D deformable convolution, the shapes are
- **data**: *(batch_size, channel, height, width)*
- **offset**: *(batch_size, num_deformable_group * kernel[0] * kernel[1] * 2, height, width)*
- **weight**: *(num_filter, channel, kernel[0], kernel[1])*
- **bias**: *(num_filter,)*
- **out**: *(batch_size, num_filter, out_height, out_width)*.
Define::
f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1
then we have::
out_height=f(height, kernel[0], pad[0], stride[0], dilate[0])
out_width=f(width, kernel[1], pad[1], stride[1], dilate[1])
If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
The default data ``layout`` is *NCHW*, namely *(batch_size, channle, height,
width)*.
If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data``
evenly into *g* parts along the channel axis, and also evenly split ``weight``
along the first dimension. Next compute the convolution on the *i*-th part of
the data with the *i*-th weight part. The output is obtained by concating all
the *g* results.
If ``num_deformable_group`` is larger than 1, denoted by *dg*, then split the
input ``offset`` evenly into *dg* parts along the channel axis, and also evenly
split ``out`` evenly into *dg* parts along the channel axis. Next compute the
deformable convolution, apply the *i*-th part of the offset part on the *i*-th
out.
Both ``weight`` and ``bias`` are learnable parameters.
)code" ADD_FILELINE)
.add_argument("data", "NDArray-or-Symbol", "Input data to the DeformableConvolutionOp.")
.add_argument("offset", "NDArray-or-Symbol", "Input offset to the DeformableConvolutionOp.")
.add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
.add_argument("bias", "NDArray-or-Symbol", "Bias parameter.")
.add_arguments(DeformableConvolutionParam::__FIELDS__());

} // namespace op
} // namespace mxnet
28 changes: 28 additions & 0 deletions DCNv2_op/deformable_convolution.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*!
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file deformable_convolution.cu
* \brief
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu
*/

#include "./deformable_convolution-inl.h"
#include <vector>

namespace mxnet {
namespace op {
template<>
Operator* CreateOp<gpu>(DeformableConvolutionParam param, int dtype,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
Context ctx) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new DeformableConvolutionOp<gpu, DType>(param);
})
return op;
}

} // namespace op
} // namespace mxnet

16 changes: 16 additions & 0 deletions DCNv2_op/example_symbol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
demo symbol of using modulated deformable convolution
"""
def modulated_deformable_conv(data, name, num_filter, stride, lr_mult=1):
weight_var = mx.sym.Variable(name=name+'_conv2_offset_weight', init=mx.init.Zero(), lr_mult=lr_mult)
bias_var = mx.sym.Variable(name=name+'_conv2_offset_bias', init=mx.init.Zero(), lr_mult=lr_mult)
conv2_offset = mx.symbol.Convolution(name=name + '_conv2_offset', data=data, num_filter=27,
pad=(1, 1), kernel=(3, 3), stride=stride, weight=weight_var, bias=bias_var, lr_mult=lr_mult)
conv2_offset_t = mx.sym.slice_axis(conv2_offset, axis=1, begin=0, end=18)
conv2_mask = mx.sym.slice_axis(conv2_offset, axis=1, begin=18, end=None)
conv2_mask = 2 * mx.sym.Activation(conv2_mask, act_type='sigmoid')

conv2 = mx.contrib.symbol.ModulatedDeformableConvolution(name=name + '_conv2', data=act1, offset=conv2_offset_t, mask=conv2_mask,
num_filter=num_filter, pad=(1, 1), kernel=(3, 3), stride=stride,
num_deformable_group=1, no_bias=True)
return conv2
Loading

0 comments on commit 65de16f

Please sign in to comment.