Skip to content

Commit

Permalink
override shfl methods for torch.half
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 28, 2021
1 parent 66bcc36 commit 6fca568
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
7 changes: 1 addition & 6 deletions csrc/cuda/segment_coo_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <type_traits>

#include "reducer.cuh"
#include "utils.cuh"
Expand All @@ -26,10 +25,6 @@ segment_coo_kernel(const scalar_t *src_data,
int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1];

using cuda_scalar_t =
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
scalar_t>::type;

if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
Expand All @@ -41,7 +36,7 @@ segment_coo_kernel(const scalar_t *src_data,
#pragma unroll
for (int i = 1; i < 32; i *= 2) {
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, (cuda_scalar_t)val, i);
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
Expand Down
7 changes: 1 addition & 6 deletions csrc/cuda/segment_csr_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ segment_csr_kernel(const scalar_t *src_data,
int row_idx = thread_idx / TB;
int lane_idx = thread_idx & (TB - 1);

using cuda_scalar_t =
typename std::conditional<std::is_same<scalar_t, at::Half>::value, __half,
scalar_t>::type;

if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int64_t row_start = __ldg(indptr_info.data + offset);
Expand All @@ -52,8 +48,7 @@ segment_csr_kernel(const scalar_t *src_data,
if (REDUCE == MIN || REDUCE == MAX)
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
Reducer<scalar_t, REDUCE>::update(
&val, __shfl_down_sync(FULL_MASK, (cuda_scalar_t)val, i), &arg,
arg_tmp);
&val, __shfl_down_sync(FULL_MASK, val, i), &arg, arg_tmp);
}

if (lane_idx == 0) {
Expand Down
12 changes: 12 additions & 0 deletions csrc/cuda/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,15 @@
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")

__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_up_sync(mask, (__half)var, delta);
}

__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, (__half)var, delta);
}

0 comments on commit 6fca568

Please sign in to comment.