Skip to content

Commit

Permalink
Speedup via batch-wise computing
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesShang committed Dec 18, 2018
1 parent ae58bad commit 618511d
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 85 deletions.
231 changes: 154 additions & 77 deletions src/cuda/dcn_v2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ dcn_v2_cuda_forward(const at::Tensor &input,
output_b, n_,
batch);

// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to swap axis
// auto columns_transpose = columns.transpose(0, 1).contiguous();
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
input.data<scalar_t>(),
offset.data<scalar_t>(),
Expand Down Expand Up @@ -173,6 +170,37 @@ dcn_v2_cuda_forward(const at::Tensor &input,
return output;
}

__global__ void createBatchGemmBufferBackward(
float ** grad_output_b,
float ** columns_b,
float ** ones_b,
float ** weight_b,
float ** grad_weight_b,
float ** grad_bias_b,
float * grad_output,
float * columns,
float * ones,
float * weight,
float * grad_weight,
float * grad_bias,
const int grad_output_stride,
const int columns_stride,
const int ones_stride,
const int num_batches)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_batches) {
grad_output_b[idx] = grad_output + idx * grad_output_stride;
columns_b[idx] = columns + idx * columns_stride;
ones_b[idx] = ones + idx * ones_stride;

// share weights and bias within a Mini-Batch
weight_b[idx] = weight;
grad_weight_b[idx] = grad_weight;
grad_bias_b[idx] = grad_bias;
}
}

std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
Expand Down Expand Up @@ -214,8 +242,8 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto ones = at::ones({batch, height_out, width_out}, input.options());
auto columns = at::empty({batch, channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());

auto grad_input = at::zeros_like(input);
Expand All @@ -226,78 +254,127 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,

using scalar_t = float;

for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;

THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f,
grad_output_n.data<scalar_t>(), n,
weight.data<scalar_t>(), m, 0.0f,
columns.data<scalar_t>(), n);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data<scalar_t>(),
grad_mask_n.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());

long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;

THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f,
columns.data<scalar_t>(), k_,
grad_output_n.data<scalar_t>(), k_, 1.0f,
grad_weight.data<scalar_t>(), n_);

// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
THCudaBlas_Sgemv(state,
't',
k_, m_, 1.0f,
grad_output_n.data<scalar_t>(), k_,
ones.data<scalar_t>(), 1, 1.0f,
grad_bias.data<scalar_t>(), 1);
}
// prepare for batch-wise computing, which is significantly faster than instance-wise computing
// when batch size is large.
// launch batch threads
int matrices_size = batch * sizeof(float *);

auto grad_output_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto columns_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto ones_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto weight_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto grad_weight_b = static_cast<float **>(THCudaMalloc(state, matrices_size));
auto grad_bias_b = static_cast<float **>(THCudaMalloc(state, matrices_size));

const int block = 128;
const int grid = (batch + block - 1) / block;

createBatchGemmBufferBackward<<<grid, block, 0, THCState_getCurrentStream(state)>>>(grad_output_b,
columns_b,
ones_b,
weight_b,
grad_weight_b,
grad_bias_b,
grad_output.data<scalar_t>(),
columns.data<scalar_t>(),
ones.data<scalar_t>(),
weight.data<scalar_t>(),
grad_weight.data<scalar_t>(),
grad_bias.data<scalar_t>(),
channels_out * height_out * width_out,
channels * kernel_h * kernel_w * height_out * width_out,
height_out * width_out,
batch);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;
THCudaBlas_SgemmBatched(
state,
'n',
't',
n,
m,
k,
1.0f,
(const float **)grad_output_b, n,
(const float **)weight_b, m,
0.0f,
columns_b, n,
batch);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset.data<scalar_t>(),
grad_mask.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
columns.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
batch, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());
long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;
// gradient w.r.t. weight
THCudaBlas_SgemmBatched(
state,
't',
'n',
n_,
m_,
k_,
1.0f,
(const float **)columns_b, k_,
(const float **)grad_output_b, k_,
1.0f,
grad_weight_b, n_,
batch);

// gradient w.r.t. bias
THCudaBlas_SgemmBatched(
state,
't',
'n',
m_,
1,
k_,
1.0f,
(const float **)grad_output_b, k_,
(const float **)ones_b, k_,
1.0f,
grad_bias_b, m_,
batch);

THCudaFree(state, grad_output_b);
THCudaFree(state, columns_b);
THCudaFree(state, ones_b);
THCudaFree(state, weight_b);
THCudaFree(state, grad_weight_b);
THCudaFree(state, grad_bias_b);

return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
Expand Down
28 changes: 20 additions & 8 deletions src/cuda/dcn_v2_im2col_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
{
CUDA_KERNEL_LOOP(index, n)
{
// NOTE(CharlesShang): different from Dai Jifeng's MXNet implementation, col_buffer is of shape (c*kw*kh, N, oh, ow)
// here columns is of shape (N, c*kw*kh, oh * ow), need to adapt axis
// launch channels * batch_size * height_col * width_col cores
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
Expand Down Expand Up @@ -202,18 +205,23 @@ __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const int height_col, const int width_col,
float *grad_im)
{
// launch (batch_size * channels * kernel_h * kernel_w * height_col * width_col) cores
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int j = (index / width_col / height_col) % kernel_w;
// const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int i = (index / width_col / height_col / kernel_w) % kernel_h;
// const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
const int c = (index / width_col / height_col / kernel_w / kernel_h) % channels;
// compute the start and end of the output

const int deformable_group_index = c / channel_per_deformable_group;

int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
// int b = (index / width_col / height_col) % batch_size;
int b = (index / width_col / height_col / channels / kernel_h / kernel_w) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;

Expand Down Expand Up @@ -274,7 +282,8 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
// const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * width_col * height_col;
const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
Expand All @@ -283,11 +292,14 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,

for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
// const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int col_pos = (((col_c + b * channels * kernel_h * kernel_w) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;

int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
// int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int j = col_c % kernel_w;
// int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int i = (col_c / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
Expand Down

0 comments on commit 618511d

Please sign in to comment.