Skip to content

Commit

Permalink
batch-wise backward
Browse files Browse the repository at this point in the history
// something is wrong..
// need debugging
  • Loading branch information
CharlesShang committed Dec 18, 2018
1 parent 618511d commit 4e4524c
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 128 deletions.
247 changes: 123 additions & 124 deletions src/cuda/dcn_v2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ extern THCState *state;
// [batch gemm]
// https://github.com/pytorch/pytorch/blob/master/aten/src/THC/generic/THCTensorMathBlas.cu

__global__ void createBatchGemmBuffer(const float ** input_b, float ** output_b,
float ** columns_b, const float ** ones_b,
const float ** weight_b, const float ** bias_b,
float * input, float * output,
float * columns, float * ones,
float * weight, float * bias,
const int input_stride, const int output_stride,
const int columns_stride, const int ones_stride,
const int num_batches)
__global__ void createBatchGemmBuffer(const float **input_b, float **output_b,
float **columns_b, const float **ones_b,
const float **weight_b, const float **bias_b,
float *input, float *output,
float *columns, float *ones,
float *weight, float *bias,
const int input_stride, const int output_stride,
const int columns_stride, const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches) {
if (idx < num_batches)
{
input_b[idx] = input + idx * input_stride;
output_b[idx] = output + idx * output_stride;
columns_b[idx] = columns + idx * columns_stride;
Expand Down Expand Up @@ -89,8 +90,8 @@ dcn_v2_cuda_forward(const at::Tensor &input,
auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());

// prepare for batch-wise computing, which is significantly faster than instance-wise computing
// when batch size is large.
// prepare for batch-wise computing, which is significantly faster than instance-wise computing
// when batch size is large.
// launch batch threads
int matrices_size = batch * sizeof(float *);
auto input_b = static_cast<const float **>(THCudaMalloc(state, matrices_size));
Expand All @@ -102,18 +103,18 @@ dcn_v2_cuda_forward(const at::Tensor &input,

const int block = 128;
const int grid = (batch + block - 1) / block;

createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
input_b, output_b,
columns_b, ones_b,
weight_b, bias_b,
input.data<scalar_t>(),
input.data<scalar_t>(),
output.data<scalar_t>(),
columns.data<scalar_t>(),
columns.data<scalar_t>(),
ones.data<scalar_t>(),
weight.data<scalar_t>(),
weight.data<scalar_t>(),
bias.data<scalar_t>(),
channels * width * height,
channels * width * height,
channels_out * width_out * height_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
Expand Down Expand Up @@ -148,15 +149,15 @@ dcn_v2_cuda_forward(const at::Tensor &input,
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THCudaBlas_SgemmBatched(state,
'n',
'n',
n,
m,
k,
THCudaBlas_SgemmBatched(state,
'n',
'n',
n,
m,
k,
1.0f,
(const float **)columns_b, n,
weight_b, k,
weight_b, k,
1.0f,
output_b, n,
batch);
Expand All @@ -171,25 +172,26 @@ dcn_v2_cuda_forward(const at::Tensor &input,
}

__global__ void createBatchGemmBufferBackward(
float ** grad_output_b,
float ** columns_b,
float ** ones_b,
float ** weight_b,
float ** grad_weight_b,
float ** grad_bias_b,
float * grad_output,
float * columns,
float * ones,
float * weight,
float * grad_weight,
float * grad_bias,
const int grad_output_stride,
const int columns_stride,
float **grad_output_b,
float **columns_b,
float **ones_b,
float **weight_b,
float **grad_weight_b,
float **grad_bias_b,
float *grad_output,
float *columns,
float *ones,
float *weight,
float *grad_weight,
float *grad_bias,
const int grad_output_stride,
const int columns_stride,
const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches) {
if (idx < num_batches)
{
grad_output_b[idx] = grad_output + idx * grad_output_stride;
columns_b[idx] = columns + idx * columns_stride;
ones_b[idx] = ones + idx * ones_stride;
Expand Down Expand Up @@ -268,106 +270,104 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,

const int block = 128;
const int grid = (batch + block - 1) / block;

createBatchGemmBufferBackward<<<grid, block, 0, THCState_getCurrentStream(state)>>>(grad_output_b,
columns_b,
ones_b,
weight_b,
grad_weight_b,
grad_bias_b,
grad_output.data<scalar_t>(),
columns.data<scalar_t>(),
ones.data<scalar_t>(),
weight.data<scalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
channels_out * height_out * width_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
batch);

createBatchGemmBufferBackward<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
grad_output_b,
columns_b,
ones_b,
weight_b,
grad_weight_b,
grad_bias_b,
grad_output.data<scalar_t>(),
columns.data<scalar_t>(),
ones.data<scalar_t>(),
weight.data<scalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
channels_out * height_out * width_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
batch);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THCudaBlas_SgemmBatched(
state,
'n',
't',
n,
m,
k,
1.0f,
(const float **)grad_output_b, n,
(const float **)weight_b, m,
0.0f,
columns_b, n,
batch);

THCudaBlas_SgemmBatched(state,
'n',
't',
n,
m,
k,
1.0f,
(const float **)grad_output_b, n,
(const float **)weight_b, m,
0.0f,
columns_b, n,
batch);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset.data<scalar_t>(),
grad_mask.data<scalar_t>());
columns.data<scalar_t>(),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset.data<scalar_t>(),
grad_mask.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input.data<scalar_t>());
columns.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
// gradient w.r.t. weight
THCudaBlas_SgemmBatched(
state,
't',
'n',
n_,
m_,
k_,
1.0f,
(const float **)columns_b, k_,
(const float **)grad_output_b, k_,
1.0f,
grad_weight_b, n_,
batch);
THCudaBlas_SgemmBatched(state,
't',
'n',
n_,
m_,
k_,
1.0f,
(const float **)columns_b, k_,
(const float **)grad_output_b, k_,
1.0f,
grad_weight_b, n_,
batch);

// gradient w.r.t. bias
THCudaBlas_SgemmBatched(
state,
't',
'n',
m_,
1,
k_,
1.0f,
(const float **)grad_output_b, k_,
(const float **)ones_b, k_,
1.0f,
grad_bias_b, m_,
batch);
THCudaBlas_SgemmBatched(state,
't',
'n',
m_,
1,
k_,
1.0f,
(const float **)grad_output_b, k_,
(const float **)ones_b, k_,
1.0f,
grad_bias_b, m_,
batch);

THCudaFree(state, grad_output_b);
THCudaFree(state, columns_b);
Expand All @@ -377,6 +377,5 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
THCudaFree(state, grad_bias_b);

return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
grad_input, grad_offset, grad_mask, grad_weight, grad_bias};
}
3 changes: 2 additions & 1 deletion src/cuda/dcn_v2_im2col_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const int height_col, const int width_col,
float *data_col)
{
// launch channels * batch_size * height_col * width_col cores
CUDA_KERNEL_LOOP(index, n)
{
// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
// launch channels * batch_size * height_col * width_col cores
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
Expand Down Expand Up @@ -270,6 +270,7 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const int height_col, const int width_col,
float *grad_offset, float *grad_mask)
{
// lanuch batch_size * (2 * kernel_h * kernel_w * deformable_group) * height_col * width_col cores
CUDA_KERNEL_LOOP(index, n)
{
float val = 0, mval = 0;
Expand Down
8 changes: 5 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def check_zero_offset():
print('Zero offset passed')
else:
print('Zero offset failed')
print(input)
print(output)

def check_gradient_dconv():

Expand Down Expand Up @@ -91,7 +93,7 @@ def check_gradient_dconv():

print('check_gradient_dconv: ',
gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias,
stride, padding, dilation, deformable_groups),
stride, padding, dilation, deformable_groups),
eps=1e-3, atol=1e-4, rtol=1e-2))


Expand Down Expand Up @@ -259,8 +261,8 @@ def example_mdpooling():
if inC == outC:
check_zero_offset()

# check_gradient_dpooling()
# check_gradient_dconv()
check_gradient_dpooling()
check_gradient_dconv()
# """
# ****** Note: backward is not reentrant error may not be a serious problem,
# ****** since the max error is less than 1e-7,
Expand Down

0 comments on commit 4e4524c

Please sign in to comment.