Skip to content

Commit

Permalink
Refactor cub namespace handling (pytorch#66219)
Browse files Browse the repository at this point in the history
Summary:
This PR is to update PyTorch with the following cub changes:
- Starting cub 1.13.1, cub requires users to define `CUB_NS_QUALIFIER` if `CUB_NS_PREFIX` is also defined. Besides that, a new mechanism `CUB_WRAPPED_NAMESPACE` is added.

And I do the following change to PyTorch:
- Starting CUDA 11.5, define `CUB_WRAPPED_NAMESPACE` globally as an nvcc flag.
- Fix caffe2 failures caused by the above change.
- Add a `aten/src/ATen/cuda/cub_definitions.cuh` that defines helper macros about feature availability.

Pull Request resolved: pytorch#66219

Reviewed By: bdhirsh

Differential Revision: D31626931

Pulled By: ngimel

fbshipit-source-id: 97ebf5ef671ade8bf46d0860edc317f22660f26d
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Oct 25, 2021
1 parent 700b39a commit b8dfb45
Show file tree
Hide file tree
Showing 38 changed files with 148 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
- name: Ensure no direct cub include
if: always()
run: |
(! git --no-pager grep -I -no $'#include <cub/' -- ./aten ':(exclude)aten/src/ATen/cuda/cub.cuh' || (echo "The above files have direct cub include; please include ATen/cuda/cub.cuh instead and wrap your cub calls in at::native namespace if necessary"; false))
(! git --no-pager grep -I -no $'#include <cub/' -- ./aten ':(exclude)aten/src/ATen/cuda/cub*.cuh' || (echo "The above files have direct cub include; please include ATen/cuda/cub.cuh instead and wrap your cub calls in at::native namespace if necessary"; false))
- name: Ensure no raw cuda api calls
if: always()
run: |
Expand Down
100 changes: 55 additions & 45 deletions aten/src/ATen/cuda/cub.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,28 @@
#include <iterator>
#include <limits>

#include <ATen/cuda/cub_definitions.cuh>

#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()

#include <cub/cub.cuh>

#else

// include cub in a safe manner, see:
// https://github.com/pytorch/pytorch/pull/55292
#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#define CUB_NS_PREFIX namespace at { namespace cuda { namespace detail {
#define CUB_NS_POSTFIX }}}
#undef CUB_NS_QUALIFIER
#define CUB_NS_PREFIX namespace at_cuda_detail {
#define CUB_NS_POSTFIX }
#define CUB_NS_QUALIFIER ::at_cuda_detail::cub
#include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
#undef CUB_NS_QUALIFIER

#endif

#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDACachingAllocator.h>
Expand All @@ -33,16 +46,40 @@
#define NO_ROCM(x)
#else
#define NO_ROCM(x) x
#endif

namespace at { namespace native {
#if !defined(USE_ROCM) && !CUB_SUPPORTS_NV_BFLOAT16()

namespace at_cuda_detail {
// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16

template <>
struct cub::FpLimits<c10::BFloat16>
{
static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
unsigned short max_word = 0x7F7F;
return reinterpret_cast<c10::BFloat16&>(max_word);
}

static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
unsigned short lowest_word = 0xFF7F;
return reinterpret_cast<c10::BFloat16&>(lowest_word);
}
};

namespace cub = at::cuda::detail::cub;
template <> struct cub::NumericTraits<c10::BFloat16>: cub::BaseTraits<cub::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
}
#endif

#if !defined(USE_ROCM)
namespace at { namespace native {
namespace cub = ::at_cuda_detail::cub;
}}
#endif

namespace at {
namespace cuda {
namespace cub {

namespace detail {

Expand All @@ -55,44 +92,17 @@ struct cuda_type<c10::Half> {
using type = __half;
};

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11050
// cub sort support for __nv_bfloat16 is added to cub 1.13 in
// https://github.com/NVIDIA/cub/pull/306 and according to
// https://github.com/NVIDIA/cub#releases, 1.13 is included in
// CUDA Toolkit 11.5
#if CUB_SUPPORTS_NV_BFLOAT16()

// waiting for https://github.com/NVIDIA/cub/pull/306 to land on CUDA
template<>
struct cuda_type<c10::BFloat16> {
using type = __nv_bfloat16;
};

#elif !defined(USE_ROCM)

// backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16

template <>
struct cub::FpLimits<c10::BFloat16>
{
static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
unsigned short max_word = 0x7F7F;
return reinterpret_cast<c10::BFloat16&>(max_word);
}

static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
unsigned short lowest_word = 0xFF7F;
return reinterpret_cast<c10::BFloat16&>(lowest_word);
}
};

template <> struct cub::NumericTraits<c10::BFloat16>: cub::BaseTraits<cub::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};

#endif

} // namespace detail

namespace cub {

inline int get_num_bits(uint64_t max_key) {
int num_bits = 1;
while (max_key > 1) {
Expand All @@ -115,11 +125,11 @@ static inline void radix_sort_keys(
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortKeysDescending,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
keys_in_, keys_out_, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortKeys,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
keys_in_, keys_out_, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
Expand Down Expand Up @@ -147,11 +157,11 @@ static inline void radix_sort_pairs(
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortPairsDescending,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
keys_in_, keys_out_, values_in, values_out, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceRadixSort::SortPairs,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
keys_in_, keys_out_, values_in, values_out, n,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
}
Expand Down Expand Up @@ -183,12 +193,12 @@ static inline void segmented_sort_pairs(
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);

if (descending) {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
keys_in_, keys_out_, values_in, values_out,
num_elements, num_segments, begin_offsets, end_offsets,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
} else {
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSegmentedRadixSort::SortPairs,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
keys_in_, keys_out_, values_in, values_out,
num_elements, num_segments, begin_offsets, end_offsets,
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
Expand Down Expand Up @@ -240,7 +250,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
int size_cub = std::min<int64_t>(num_items, max_cub_size);
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceScan::InclusiveScan,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input,
output,
scan_op,
Expand All @@ -260,7 +270,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
first_elem_ptr,
scan_op);
C10_CUDA_KERNEL_LAUNCH_CHECK();
using ArgIndexInputIterator = NO_ROCM(detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using ArgIndexInputIterator = NO_ROCM(at_cuda_detail)::cub::ArgIndexInputIterator<InputIteratorT>;
using tuple = typename ArgIndexInputIterator::value_type;
auto input_iter_transform = [=] __device__ (const tuple &x)->input_t {
if (x.key == 0) {
Expand All @@ -269,9 +279,9 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
return x.value;
}
};
auto input_ = NO_ROCM(detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
auto input_ = NO_ROCM(at_cuda_detail)::cub::TransformInputIterator<input_t, decltype(input_iter_transform), ArgIndexInputIterator>(
ArgIndexInputIterator(input + i), input_iter_transform);
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceScan::InclusiveScan,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
Expand All @@ -287,7 +297,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
int size_cub = std::min<int64_t>(num_items, max_cub_size);
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceScan::ExclusiveScan,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
input,
output,
scan_op,
Expand All @@ -309,7 +319,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto input_ = impl::chained_iterator<InitValueT, InputIteratorT>{
input + i, first_elem_ptr};
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceScan::InclusiveScan,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
input_,
output + i,
scan_op,
Expand All @@ -322,7 +332,7 @@ template<typename InputIteratorT , typename OutputIteratorT , typename NumSelect
inline void unique(InputIteratorT input, OutputIteratorT output, NumSelectedIteratorT num_selected_out, int64_t num_items) {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub unique does not support more than INT_MAX elements");
CUB_WRAPPER(NO_ROCM(detail)::cub::DeviceSelect::Unique,
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
}

Expand Down
29 changes: 29 additions & 0 deletions aten/src/ATen/cuda/cub_definitions.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#if !defined(USE_ROCM)
#include <cuda.h> // for CUDA_VERSION
#endif

#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include <cub/version.cuh>
#else
#define CUB_VERSION 0
#endif

// cub sort support for __nv_bfloat16 is added to cub 1.13 in:
// https://github.com/NVIDIA/cub/pull/306
#if CUB_VERSION >= 101300
#define CUB_SUPPORTS_NV_BFLOAT16() true
#else
#define CUB_SUPPORTS_NV_BFLOAT16() false
#endif

// cub sort support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// https://github.com/NVIDIA/cub/pull/326
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
// starting from CUDA 11.5
#if defined(CUB_WRAPPED_NAMESPACE) || defined(THRUST_CUB_WRAPPED_NAMESPACE)
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() true
#else
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif
1 change: 1 addition & 0 deletions caffe2/core/context_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"
#include "caffe2/utils/string_utils.h"
#include "caffe2/utils/cub_namespace.cuh"

C10_DEFINE_string(
caffe2_cuda_memory_pool,
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/accuracy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "caffe2/utils/GpuAtomics.cuh"
#include "caffe2/utils/math.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>

namespace caffe2 {
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/affine_channel_op.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "caffe2/operators/affine_channel_op.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>

#include "caffe2/core/context_gpu.h"
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/arg_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include <limits>

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>

#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/batch_moments_op.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "caffe2/operators/batch_moments_op.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>

#include "caffe2/core/context_gpu.h"
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/batch_sparse_to_dense_op.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "caffe2/operators/batch_sparse_to_dense_op.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/device/device_scan.cuh>

#include "caffe2/core/context_gpu.h"
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/boolean_mask_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/boolean_mask_ops.h"

#include <cub/cub.cuh>
#include "caffe2/utils/cub_namespace.cuh"

namespace caffe2 {

Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/cross_entropy_op.h"
#include "caffe2/operators/operator_fallback_gpu.h"
#include "caffe2/utils/cub_namespace.cuh"

namespace caffe2 {

Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/distance_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "caffe2/operators/distance_op.h"
#include "caffe2/utils/conversions.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>

namespace caffe2 {
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/elementwise_div_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <algorithm>
#include <functional>

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/elementwise_ops_utils.h"
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/elementwise_linear_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/operator_fallback_gpu.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>

namespace caffe2 {
Expand Down
2 changes: 1 addition & 1 deletion caffe2/operators/elementwise_mul_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <algorithm>
#include <functional>

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/elementwise_ops_utils.h"
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/elementwise_ops.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "caffe2/operators/elementwise_ops.h"

#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_load.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/device/device_reduce.cuh>
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/find_op.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cub/block/block_reduce.cuh>
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/find_op.h"
#include "caffe2/utils/cub_namespace.cuh"

namespace caffe2 {

Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/generate_proposals_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "caffe2/operators/generate_proposals_op_util_boxes.h" // BBOX_XFORM_CLIP_DEFAULT
#include "caffe2/operators/generate_proposals_op_util_nms.h"
#include "caffe2/operators/generate_proposals_op_util_nms_gpu.h"
#include "caffe2/utils/cub_namespace.cuh"

#if defined(USE_ROCM)
#include <cfloat>
Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/normalize_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/normalize_l1_op.h"
#include "caffe2/operators/normalize_op.h"
#include "caffe2/utils/cub_namespace.cuh"

namespace caffe2 {

Expand Down
1 change: 1 addition & 0 deletions caffe2/operators/one_hot_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/one_hot_ops.h"
#include "caffe2/utils/cub_namespace.cuh"

namespace caffe2 {

Expand Down
Loading

0 comments on commit b8dfb45

Please sign in to comment.