Skip to content

Commit

Permalink
Instance-wise backward
Browse files Browse the repository at this point in the history
Failed to adapting backward to batch-wise. So the best option is to roll back to former version. Now the FP is batch-wise, BP not.
  • Loading branch information
CharlesShang committed Dec 18, 2018
1 parent 4e4524c commit e07ee3b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 140 deletions.
198 changes: 76 additions & 122 deletions src/cuda/dcn_v2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

auto ones = at::ones({batch, height_out, width_out}, input.options());
auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());

auto grad_input = at::zeros_like(input);
Expand All @@ -256,126 +256,80 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,

using scalar_t = float;

// prepare for batch-wise computing, which is significantly faster than instance-wise computing
// when batch size is large.
// launch batch threads
int matrices_size = batch * sizeof(float *);

auto grad_output_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto columns_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto ones_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto weight_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto grad_weight_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto grad_bias_b = static_cast<float **>(THCudaMalloc(state, matrices_size));

const int block = 128;
const int grid = (batch + block - 1) / block;

createBatchGemmBufferBackward<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
grad_output_b,
columns_b,
ones_b,
weight_b,
grad_weight_b,
grad_bias_b,
grad_output.data<scalar_t>(),
columns.data<scalar_t>(),
ones.data<scalar_t>(),
weight.data<scalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
channels_out * height_out * width_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
batch);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THCudaBlas_SgemmBatched(state,
'n',
't',
n,
m,
k,
1.0f,
(const float **)grad_output_b, n,
(const float **)weight_b, m,
0.0f,
columns_b, n,
batch);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset.data<scalar_t>(),
grad_mask.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
// gradient w.r.t. weight
THCudaBlas_SgemmBatched(state,
't',
'n',
n_,
m_,
k_,
1.0f,
(const float **)columns_b, k_,
(const float **)grad_output_b, k_,
1.0f,
grad_weight_b, n_,
batch);

// gradient w.r.t. bias
THCudaBlas_SgemmBatched(state,
't',
'n',
m_,
1,
k_,
1.0f,
(const float **)grad_output_b, k_,
(const float **)ones_b, k_,
1.0f,
grad_bias_b, m_,
batch);

THCudaFree(state, grad_output_b);
THCudaFree(state, columns_b);
THCudaFree(state, ones_b);
THCudaFree(state, weight_b);
THCudaFree(state, grad_weight_b);
THCudaFree(state, grad_bias_b);
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;

THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
grad_output_n.data<scalar_t>(), n,
weight.data<scalar_t>(), m, 0.0f,
columns.data<scalar_t>(), n);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data<scalar_t>(),
grad_mask_n.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());

long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;

THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
columns.data<scalar_t>(), k_,
grad_output_n.data<scalar_t>(), k_, 1.0f,
grad_weight.data<scalar_t>(), n_);

// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
THCudaBlas_Sgemv(state,
't',
k_, m_, 1.0f,
grad_output_n.data<scalar_t>(), k_,
ones.data<scalar_t>(), 1, 1.0f,
grad_bias.data<scalar_t>(), 1);
}

return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias};
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
27 changes: 9 additions & 18 deletions src/cuda/dcn_v2_im2col_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
{
// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis

// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
Expand Down Expand Up @@ -205,23 +206,18 @@ __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const int height_col, const int width_col,
float *grad_im)
{
// launch (batch_size * channels * kernel_h * kernel_w * height_col * width_col) cores
CUDA_KERNEL_LOOP(index, n)
{
// const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int j = (index / width_col / height_col) % kernel_w;
// const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int i = (index / width_col / height_col / kernel_w) % kernel_h;
// const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
const int c = (index / width_col / height_col / kernel_w / kernel_h) % channels;
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output

const int deformable_group_index = c / channel_per_deformable_group;

int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
// int b = (index / width_col / height_col) % batch_size;
int b = (index / width_col / height_col / channels / kernel_h / kernel_w) % batch_size;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;

Expand Down Expand Up @@ -270,7 +266,6 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const int height_col, const int width_col,
float *grad_offset, float *grad_mask)
{
// lanuch batch_size * (2 * kernel_h * kernel_w * deformable_group) * height_col * width_col cores
CUDA_KERNEL_LOOP(index, n)
{
float val = 0, mval = 0;
Expand All @@ -283,8 +278,7 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
// const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * width_col * height_col;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
Expand All @@ -293,14 +287,11 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,

for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
// const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int col_pos = (((col_c + b * channels * kernel_h * kernel_w) * height_col) + h) * width_col + w;
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;

// int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int j = col_c % kernel_w;
// int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int i = (col_c / kernel_w) % kernel_h;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
Expand Down

0 comments on commit e07ee3b

Please sign in to comment.