Skip to content

Commit

Permalink
adapt to torchsparse
Browse files Browse the repository at this point in the history
  • Loading branch information
hua0x522 committed Mar 30, 2024
1 parent fcfabb1 commit 9f6cb5d
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
16 changes: 10 additions & 6 deletions torchsparse/backend/convolution/flash_conv_sort_s2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,14 @@ __device__ void store_C_s2(uint32_t* reg_C, half* C, int* reorder_loc, int M, in
int col = shm_col + blockIdx.y * 64;
int row_8 = reorder_loc[row + 8];
row = reorder_loc[row];
C[row * N + col] = __float2half(*(float*)&reg_C[m * 16 + n * 4]);
C[row * N + col + 1] = __float2half(*(float*)&reg_C[m * 16 + n * 4 + 1]);
C[row_8 * N + col] = __float2half(*(float*)&reg_C[m * 16 + n * 4 + 2]);
C[row_8 * N + col + 1] = __float2half(*(float*)&reg_C[m * 16 + n * 4 + 3]);
if (row < M) {
C[row * N + col] = __float2half(*(float*)&reg_C[m * 16 + n * 4]);
C[row * N + col + 1] = __float2half(*(float*)&reg_C[m * 16 + n * 4 + 1]);
}
if (row_8 < M) {
C[row_8 * N + col] = __float2half(*(float*)&reg_C[m * 16 + n * 4 + 2]);
C[row_8 * N + col + 1] = __float2half(*(float*)&reg_C[m * 16 + n * 4 + 3]);
}
}
}
}
Expand Down Expand Up @@ -157,10 +161,10 @@ __global__ void flash_conv_sort_s2_kernel(half* inputs, half* weights, int* reor
}

torch::Tensor flash_conv_sort_s2_cuda(torch::Tensor inputs, torch::Tensor weights, torch::Tensor reorder_map,
torch::Tensor reduced_mask, torch::Tensor reorder_loc) {
torch::Tensor reduced_mask, torch::Tensor reorder_loc, int num_out_feats) {
int c_in = weights.size(1);
int c_out = weights.size(2);
int n_points = reorder_map.size(0);
int n_points = num_out_feats;
int kernel_size = reorder_map.size(1);

auto options = torch::TensorOptions().dtype(inputs.dtype()).device(inputs.device());
Expand Down
3 changes: 2 additions & 1 deletion torchsparse/backend/convolution/flash_conv_sort_s2_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ at::Tensor flash_conv_sort_s2_cuda(torch::Tensor inputs,
torch::Tensor weights,
torch::Tensor reorder_map,
torch::Tensor reduced_mask,
torch::Tensor reorder_loc);
torch::Tensor reorder_loc,
int num_out_feats);
8 changes: 5 additions & 3 deletions torchsparse/backend/convolution/flash_conv_sort_s3_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,24 @@ __global__ void flash_conv_sort_s3_kernel(half* inputs, half* weights, int* reor
int M = n_points;
int N = c_out;
int K = kernel_size * c_in;
__shared__ half shm_A[2 * 64 * 64];
__shared__ half shm_B[2 * 32 * 64];
__shared__ half shm_A[3 * 64 * 64];
__shared__ half shm_B[3 * 32 * 64];

uint32_t reg_A[4 * 4];
uint32_t reg_B[4 * 2];
uint32_t reg_C[4 * 4 * 4] = {0};

pipe_load_s3(shm_A, shm_B, inputs, weights, reorder_map, kernel_size, c_in, N, 0, 0);
__pipeline_commit();
pipe_load_s3(shm_A, shm_B, inputs, weights, reorder_map, kernel_size, c_in, N, 1, 1);
__pipeline_commit();
int idx0 = 0;
int idx1 = 1;
int loc0 = 0;
int loc1 = 1;
int loc2;

for (int ko = 1; ko < K / 32; ko++) {
for (int ko = 2; ko < K / 32; ko++) {
bool flag = reduced_mask[blockIdx.x] & (1 << (ko * 32 / c_in));
if (flag) {
loc2 = (loc1 + 1) % 3;
Expand Down
4 changes: 2 additions & 2 deletions torchsparse/backend/pybind_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reduce_bitmask_cuda", &reduce_bitmask_cuda);
m.def("flash_conv_cuda", &flash_conv_cuda);
m.def("flash_conv_sort_cuda", &flash_conv_sort_cuda);
m.def("flash_conv_sort_s2_cuda", &flash_conv_sort_cuda);
m.def("flash_conv_sort_s3_cuda", &flash_conv_sort_cuda);
m.def("flash_conv_sort_s2_cuda", &flash_conv_sort_s2_cuda);
m.def("flash_conv_sort_s3_cuda", &flash_conv_sort_s3_cuda);
}

0 comments on commit 9f6cb5d

Please sign in to comment.