Skip to content

Commit

Permalink
[Fix] Revise unit test of correlation (open-mmlab#1368)
Browse files Browse the repository at this point in the history
* [Fix] Revise unit test of correlation

* rename

* lint

* lint

* lint

* lint
  • Loading branch information
MeowZheng authored Sep 25, 2021
1 parent 9d4571e commit 745aa73
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 322 deletions.
192 changes: 75 additions & 117 deletions mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <vector>

#include <iostream>
#include <vector>

using namespace torch;

Expand All @@ -28,17 +29,10 @@ using namespace torch;
#define THREADS_BACKWARD 16

template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
const TensorAcc4R rInput2,
TensorAcc5R output,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH,
int dilation_patchW,
int dH, int dW)
{
__global__ void correlation_forward_cuda_kernel(
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) {
const int iH = rInput1.size(1);
const int iW = rInput1.size(2);
const int C = rInput1.size(3);
Expand All @@ -56,42 +50,35 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,

__shared__ scalar_t prod_sum[THREADS_FORWARD];

for (int ph = 0; ph < patchH; ++ph)
{
for (int ph = 0; ph < patchH; ++ph) {
int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw)
{
for (int pw = 0; pw < patchW; ++pw) {
int pw_dilated = pw * dilation_patchW - patchRadW;
prod_sum[thread] = 0;
for (int i = 0; i < kH; ++i)
{
for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
if WITHIN_BOUNDS (i1, i2, iH, iH)
{
for (int j = 0; j < kW; ++j)
{
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if WITHIN_BOUNDS (j1, j2, iW, iW)
{
for (int c = thread; c < C; c += THREADS_FORWARD)
{
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum[thread] += v1 * v2;
}
if
WITHIN_BOUNDS(i1, i2, iH, iH) {
for (int j = 0; j < kW; ++j) {
int j1 = start_j + j * dilationW;
int j2 = j1 + pw_dilated;
if
WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += THREADS_FORWARD) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum[thread] += v1 * v2;
}
}
}
}
}
}
// accumulate
__syncthreads();
if (thread == 0)
{
if (thread == 0) {
scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_FORWARD; ++index)
{
for (int index = 0; index < THREADS_FORWARD; ++index) {
reduce_sum += prod_sum[index];
}
output[n][ph][pw][h][w] = reduce_sum;
Expand All @@ -101,18 +88,12 @@ __global__ void correlation_forward_cuda_kernel(const TensorAcc4R rInput1,
}

template <typename scalar_t>
__global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_output,
const TensorAcc4R input2,
TensorAcc4R grad_input1,
const int kH, const int kW,
const int patchH, const int patchW,
const int padH, const int padW,
const int dilationH, const int dilationW,
const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW,
const int batch)
{

__global__ void correlation_backward_cuda_kernel_input1(
const TensorAcc5R grad_output, const TensorAcc4R input2,
TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
const int patchW, const int padH, const int padW, const int dilationH,
const int dilationW, const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW, const int batch) {
const int iH = input2.size(2);
const int iW = input2.size(3);

Expand All @@ -137,29 +118,23 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;

for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD)
{
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD)
{
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW))
{
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t val = input2[n][c][i1][j1];
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH)
{
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3)
continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
{
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3)
continue;
if WITHIN_BOUNDS (i2, j2, H, W)
{
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
}
if (j2 * dW != w_3) continue;
if
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
}
}
}
Expand All @@ -168,13 +143,10 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o

__syncthreads();

if (ph_off == 0 && pw_off == 0)
{
if (ph_off == 0 && pw_off == 0) {
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph)
{
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
{
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
reduce_sum += prod_sum[ph][pw];
}
}
Expand All @@ -183,17 +155,11 @@ __global__ void correlation_backward_cuda_kernel_input1(const TensorAcc5R grad_o
}

template <typename scalar_t>
__global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_output,
const TensorAcc4R input1,
TensorAcc4R grad_input2,
int kH, int kW,
int patchH, int patchW,
int padH, int padW,
int dilationH, int dilationW,
int dilation_patchH, int dilation_patchW,
int dH, int dW,
int batch)
{
__global__ void correlation_backward_cuda_kernel_input2(
const TensorAcc5R grad_output, const TensorAcc4R input1,
TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW, int batch) {
const int iH = input1.size(2);
const int iW = input1.size(3);

Expand All @@ -216,50 +182,42 @@ __global__ void correlation_backward_cuda_kernel_input2(const TensorAcc5R grad_o
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;

for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD)
{
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD)
{
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if WITHIN_BOUNDS (i1, j1, iH, iW)
{
scalar_t val = input1[n][c][i1][j1];

const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;

for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH)
{
int i2 = (h_3) / dH;
if (i2 * dH != h_3)
continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW)
{
int j2 = (w_3) / dW;
if (j2 * dW != w_3)
continue;
if WITHIN_BOUNDS (i2, j2, H, W)
{
prod_sum[ph_off][pw_off] += grad_output[n][ph][pw][i2][j2] * val;
if
WITHIN_BOUNDS(i1, j1, iH, iW) {
scalar_t val = input1[n][c][i1][j1];

const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;

for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
}
}
}
}
}
}

__syncthreads();

if (ph_off == 0 && pw_off == 0)
{
if (ph_off == 0 && pw_off == 0) {
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph)
{
for (int pw = 0; pw < THREADS_BACKWARD; ++pw)
{
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
reduce_sum += prod_sum[ph][pw];
}
}
Expand Down
Loading

0 comments on commit 745aa73

Please sign in to comment.