Skip to content

Commit

Permalink
fix gridencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Sep 5, 2022
1 parent f37c5ca commit 3c14ad5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions gridencoder/src/gridencoder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,12 @@ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const
template <typename scalar_t>
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
}

}
Expand Down Expand Up @@ -409,11 +410,12 @@ void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, con
template <typename scalar_t>
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
}
}

Expand Down

0 comments on commit 3c14ad5

Please sign in to comment.