Skip to content

Commit

Permalink
CUDA BFloat16 and other improvements on abs (pytorch#44804)
Browse files Browse the repository at this point in the history
Summary:
Not sure if ROCm supports `std::abs` today, let's see the CI

Pull Request resolved: pytorch#44804

Reviewed By: mruberry

Differential Revision: D23748837

Pulled By: ngimel

fbshipit-source-id: ccf4e63279f3e5927a85d8d8f70ba4b8c334156b
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Sep 17, 2020
1 parent ba6534a commit 34331b0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 24 deletions.
25 changes: 2 additions & 23 deletions aten/src/ATen/native/cuda/AbsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,10 @@

namespace at { namespace native {

// We manually overload abs because std::abs does not work with thrust::complex types and ROCm.
template<typename scalar_t>
__host__ __device__ static inline scalar_t abs_wrapper(scalar_t v) {
return ::abs(v);
}

template<typename T>
__host__ __device__ static inline c10::complex<T> abs_wrapper(c10::complex<T> v) {
return std::abs(v);
}

__host__ __device__ static inline uint8_t abs_wrapper(uint8_t v) {
return v;
}

__host__ __device__ static inline bool abs_wrapper(bool v) {
return v;
}

void abs_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, iter.dtype(), "abs_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "abs_cuda", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return abs_wrapper(a);
});
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return std::abs(a);
});
});
}
Expand Down
3 changes: 2 additions & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20071,7 +20071,8 @@ def inner(self, device, dtype):
1e-5, 1e-5, 3e-4, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]),
('eig', 'with_eigvec', _new_t((10, 10)), lambda t, d: [True],
1e-5, 1e-5, 1e-5, _float_types_no_half, _cpu_types, False, [skipCUDAIfNoMagma]),
('abs', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types2, [torch.bfloat16]),
('abs', '', _small_3d, lambda t, d: [], 1e-5, 1e-5, 1e-5,
torch.testing.get_all_dtypes(include_complex=False, include_bool=False), [torch.bfloat16]),
('sign', '', _small_3d, lambda t, d: []),
('log', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types2, [torch.bfloat16]),
('log10', '', _small_3d, lambda t, d: [], 1e-2, 1e-2, 1e-5, _float_types2, [torch.bfloat16]),
Expand Down

0 comments on commit 34331b0

Please sign in to comment.