Skip to content

Commit

Permalink
Merge pull request CharlesShang#52 from palver7/master
Browse files Browse the repository at this point in the history
modified DCNv2 to work on CPU and GPU
  • Loading branch information
CharlesShang authored Apr 1, 2020
2 parents d7897dc + 455dd87 commit 2805f07
Show file tree
Hide file tree
Showing 11 changed files with 1,482 additions and 136 deletions.
65 changes: 6 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,65 +1,12 @@
## Deformable Convolutional Networks V2 with Pytorch 1.0

## CPU and GPU version

### Build
```bash
./make.sh # build
python test.py # run examples and gradient check
```

### An Example
- deformable conv
```python
from dcn_v2 import DCN
input = torch.randn(2, 64, 128, 128).cuda()
# wrap all things (offset and mask) in DCN
dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda()
output = dcn(input)
print(output.shape)
```
- deformable roi pooling
```python
from dcn_v2 import DCNPooling
input = torch.randn(2, 32, 64, 64).cuda()
batch_inds = torch.randint(2, (20, 1)).cuda().float()
x = torch.randint(256, (20, 1)).cuda().float()
y = torch.randint(256, (20, 1)).cuda().float()
w = torch.randint(64, (20, 1)).cuda().float()
h = torch.randint(64, (20, 1)).cuda().float()
rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1)

# mdformable pooling (V2)
# wrap all things (offset and mask) in DCNPooling
dpooling = DCNPooling(spatial_scale=1.0 / 4,
pooled_size=7,
output_dim=32,
no_trans=False,
group_size=1,
trans_std=0.1).cuda()

dout = dpooling(input, rois)
```
### Note
Now the master branch is for pytorch 1.0 (new ATen API), you can switch back to pytorch 0.4 with,
```bash
git checkout pytorch_0.4
python testcpu.py # run examples and gradient check on cpu
python testcuda.py # run examples and gradient check on gpu
```

### Known Issues:

- [x] Gradient check w.r.t offset (solved)
- [ ] Backward is not reentrant (minor)

This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).

<s>I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
non-differential points? </s>

Update: all gradient check passes with double precision.

Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for
float `<1e-15` for double),
so it may not be a serious problem (?)

Please post an issue or PR if you have any comments.

This repository contains the Pytorch DCNv2 implementation by Charles Shang. It has been modified so that it can run on both CPU and GPU.

9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@

requirements = ["torch", "torchvision"]


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))


os.environ["CC"] = "g++"
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {"cxx": []}
define_macros = []


if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
Expand All @@ -38,7 +41,9 @@ def get_extensions():
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
raise NotImplementedError('Cuda is not availabel')
#raise NotImplementedError('Cuda is not available')
pass


sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
Expand Down
273 changes: 216 additions & 57 deletions src/cpu/dcn_v2_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,74 +1,233 @@
#include <vector>
#include "cpu/dcn_v2_im2col_cpu.h"

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
//#include <ATen/cuda/CUDAContext.h>

#include <TH/TH.h>
//#include <THC/THCAtomics.cuh>
//#include <THC/THCDeviceUtils.cuh>

//extern THCState *state;

// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
// modified from the CUDA version for CPU use by Daniel K. Suhendro

at::Tensor
dcn_v2_cpu_forward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
AT_ERROR("Not implement on cpu");
}

std::vector<at::Tensor>
dcn_v2_cpu_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int dilation_h,
const int dilation_w,
const int deformable_group)
{
AT_ERROR("Not implement on cpu");
}
// THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask));
/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/

std::tuple<at::Tensor, at::Tensor>
dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
{
AT_ERROR("Not implement on cpu");
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);

const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);

// printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h);
// printf("Channels: %d %d\n", channels, channels_kernel);
// printf("Channels: %d %d\n", channels_out, channels_kernel);

AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);

AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);

const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());

using scalar_t = float;
for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto output_n = output.select(0, b);

// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
long m_ = channels_out;
long n_ = height_out * width_out;
long k_ = 1;
THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
ones.contiguous().data<scalar_t>(), k_,
bias.contiguous().data<scalar_t>(), k_, 0.0f,
output_n.data<scalar_t>(), n_);

modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group,
columns.data<scalar_t>());

//(k * m) x (m * n)
// Y = WC
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THFloatBlas_gemm('n', 'n', n, m, k, 1.0f,
columns.data<scalar_t>(), n,
weight.data<scalar_t>(), k, 1.0f,
output_n.data<scalar_t>(), n);
}
return output;
}

std::tuple<at::Tensor, at::Tensor>
dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad,
const at::Tensor &input,
const at::Tensor &bbox,
const at::Tensor &trans,
const at::Tensor &top_count,
const int no_trans,
const float spatial_scale,
const int output_dim,
const int group_size,
const int pooled_size,
const int part_size,
const int sample_per_part,
const float trans_std)
std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
const at::Tensor &offset,
const at::Tensor &mask,
const at::Tensor &grad_output,
int kernel_h, int kernel_w,
int stride_h, int stride_w,
int pad_h, int pad_w,
int dilation_h, int dilation_w,
int deformable_group)
{
AT_ERROR("Not implement on cpu");

THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous");
THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous");

/*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor");
AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor");
AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor");
AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/

const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);

const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);

AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w,
"Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_);

AT_ASSERTM(channels == channels_kernel,
"Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel);

const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());

auto grad_input = at::zeros_like(input);
auto grad_weight = at::zeros_like(weight);
auto grad_bias = at::zeros_like(bias);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);

using scalar_t = float;

for (int b = 0; b < batch; b++)
{
auto input_n = input.select(0, b);
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto grad_output_n = grad_output.select(0, b);
auto grad_input_n = grad_input.select(0, b);
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;

THFloatBlas_gemm('n', 't', n, m, k, 1.0f,
grad_output_n.data<scalar_t>(), n,
weight.data<scalar_t>(), m, 0.0f,
columns.data<scalar_t>(), n);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cpu(columns.data<scalar_t>(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data<scalar_t>(),
grad_mask_n.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cpu(columns.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());

long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;

THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
columns.data<scalar_t>(), k_,
grad_output_n.data<scalar_t>(), k_, 1.0f,
grad_weight.data<scalar_t>(), n_);

// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
THFloatBlas_gemv('t', k_, m_, 1.0f,
grad_output_n.data<scalar_t>(), k_,
ones.data<scalar_t>(), 1, 1.0f,
grad_bias.data<scalar_t>(), 1);
}

return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
Loading

0 comments on commit 2805f07

Please sign in to comment.