Skip to content

Commit

Permalink
msvc_fixes (#17201)
Browse files Browse the repository at this point in the history
Summary:
Fixing MSVC errors

```
  D:\pytorch-scripts\caffe2_builders\v141\pytorch\aten\src\THC/THCReduce.cuh(144): error C4002: too many actual paramet
ers for macro 'C10_LAUNCH_BOUNDS_1' [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2\caffe2_gpu.vcxp
roj]
  D:\pytorch-scripts\caffe2_builders\v141\pytorch\aten\src\THC/THCReduce.cuh(259): error C4002: too many actual paramet
ers for macro 'C10_LAUNCH_BOUNDS_1' [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2\caffe2_gpu.vcxp
roj]
  D:/pytorch-scripts/caffe2_builders/v141/pytorch/aten/src/THCUNN/SpatialDilatedMaxPooling.cu(51): error C4002: too man
y actual parameters for macro 'C10_LAUNCH_BOUNDS_1' [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2
\caffe2_gpu.vcxproj]
```

on variadic C10_LAUNCH_BOUNDS as well as Debug linking issues with at::Half in pool_op_cudnn.cc like this one

```
pool_op_cudnn.obj : error LNK2019: unresolved external symbol "public: bool __cdecl caffe2::MaxPoolFunctor<class caff
e2::CUDAContext>::GlobalPoolingBackward<struct c10::Half,2>(int,int,int,struct c10::Half const *,struct c10::Half const
 ,struct c10::Half const ,struct c10::Half ,class caffe2::CUDAContext )const " (??$GlobalPoolingBackward@UHalf@c10@
@$01@?$MaxPoolFunctor@VCUDAContext@caffe2@@caffe2@QEBA_NHHHPEBUHalf@c10@00PEAU23@PEAVCUDAContext@1@Z) referenced in
 function "public: bool __cdecl caffe2::`anonymous namespace'::CuDNNMaxPoolFunctor::GlobalPoolingBackward<struct c10::H
alf,2>(int,int,int,struct c10::Half const ,struct c10::Half const ,struct c10::Half const ,struct c10::Half ,class
caffe2::CUDAContext *)const " (??$GlobalPoolingBackward@UHalf@c10@@$01@CuDNNMaxPoolFunctor@?A0xb936404a@caffe2@QEBA_NH
HHPEBUHalf@c10@00PEAU34@PEAVCUDAContext@2@Z) [D:\pytorch-scripts\caffe2_builders\v141\pytorch\build\Debug\caffe2\caff
e2_gpu.vcxproj]
```
Pull Request resolved: pytorch/pytorch#17201

Differential Revision: D14165732

Pulled By: ezyang

fbshipit-source-id: 875fd9a5b2db6f83fc483f6d750d2c011260eb8b
  • Loading branch information
ArutyunovG authored and facebook-github-bot committed Mar 1, 2019
1 parent 06c8aa7 commit 2336f0b
Show file tree
Hide file tree
Showing 21 changed files with 41 additions and 84 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/cuda/CUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ template <typename Op,
int ADims,
int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void kernelPointwiseApply1(detail::TensorInfo<scalar, IndexType> a,
IndexType totalElements, const Op op) {
Expand Down Expand Up @@ -357,7 +357,7 @@ template <typename Op,
int ADims, int BDims,
int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernelPointwiseApply2(detail::TensorInfo<scalar1, IndexType> a,
Expand Down Expand Up @@ -466,7 +466,7 @@ template <typename Op,
int ADims, int BDims, int CDims,
int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernelPointwiseApply3(detail::TensorInfo<scalar1, IndexType> a,
Expand Down Expand Up @@ -589,7 +589,7 @@ template <typename Op,
int ADims, int BDims, int CDims, int DDims,
int step>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernelPointwiseApply4(detail::TensorInfo<scalar1, IndexType> a,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ template <
typename IndexType,
int ADims>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(256, 8)
C10_LAUNCH_BOUNDS_2(256, 8)
#endif
__global__ void
fused_dropout_kernel(cuda::detail::TensorInfo<scalar_t, IndexType> a,
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cuda/GridSampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ namespace {
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_2d_kernel(
const int nthreads,
TensorInfo<scalar_t, int> input,
Expand Down Expand Up @@ -228,7 +228,7 @@ namespace {
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_3d_kernel(
const int nthreads,
TensorInfo<scalar_t, int> input,
Expand Down Expand Up @@ -392,7 +392,7 @@ namespace {
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_2d_backward_kernel(
const int nthreads,
TensorInfo<scalar_t, int> grad_output,
Expand Down Expand Up @@ -547,7 +547,7 @@ namespace {
}

template <typename scalar_t>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void grid_sampler_3d_backward_kernel(
const int nthreads,
TensorInfo<scalar_t, int> grad_output,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static constexpr int launch_bound2 = 4;
namespace at { namespace native {

template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS(nt, launch_bound2)
C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
__global__ void elementwise_kernel(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cuda/LossCTC.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ __device__ static inline int64_t get_target_prime(const target_t* __restrict__ t
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_log_alpha_gpu_kernel(scalar_t* __restrict__ log_alpha_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
Expand Down Expand Up @@ -260,7 +260,7 @@ std::tuple<Tensor, Tensor> ctc_loss_gpu_template(const Tensor& log_probs, const
// alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
template<typename scalar_t, typename target_t>
__global__ void
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
Expand Down Expand Up @@ -366,7 +366,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
Expand Down Expand Up @@ -418,7 +418,7 @@ ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_da
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/cuda/RNN.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ namespace kernel {

template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(512, 4)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void lstm_cell_forward(
TensorInfo<scalar_t, index_type> input,
Expand Down Expand Up @@ -169,7 +169,7 @@ __global__ void lstm_cell_forward(

template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(512, 4)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void lstm_cell_backward(
TensorInfo<scalar_t, index_type> storage,
Expand Down Expand Up @@ -234,7 +234,7 @@ __global__ void lstm_cell_backward(

template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(512, 4)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void gru_cell_forward(
TensorInfo<scalar_t, index_type> Input,
Expand Down Expand Up @@ -304,7 +304,7 @@ __global__ void gru_cell_forward(

template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(512, 4)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void gru_cell_backward(
TensorInfo<scalar_t, index_type> gradInInput,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ struct ReduceConfig {
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);

template<int nt, typename R>
C10_LAUNCH_BOUNDS(nt, 4)
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void reduce_kernel(R reduction) {
reduction.run();
}
Expand Down Expand Up @@ -410,7 +410,7 @@ struct ReduceOp {

return is_last_block_done_shared;
}

template <bool can_acc>
C10_DEVICE arg_t accumulate_in_output(
out_scalar_t* out, arg_t value,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/TensorTransformations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ constexpr uint32_t AT_APPLY_BLOCKS_PER_SM = 4;

template <typename scalar_t, typename IndexType>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
C10_LAUNCH_BOUNDS_2(AT_APPLY_THREADS_PER_BLOCK, AT_APPLY_BLOCKS_PER_SM)
#endif
__global__ void
kernel_pointwise_flip_apply2(const cuda::detail::TensorInfo<scalar_t, IndexType> in_tensor_info,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ template
typename FinalizeOp,
int ADims, int BDims>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(512, 4)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void kernelReduceNoncontigDim_shared
(TensorInfo<T, IndexType> out,
Expand Down Expand Up @@ -256,7 +256,7 @@ template <typename T,
typename FinalizeOp,
int ADims, int BDims>
#if __CUDA_ARCH__ >= 350 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(512, 4)
C10_LAUNCH_BOUNDS_2(512, 4)
#endif
__global__ void
kernelReduceNoncontigDim(TensorInfo<T, IndexType> out,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCReduceAll.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ template <typename T,
int ADims>
__global__ void
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS(THC_REDUCE_ALL_BLOCK_SIZE)
C10_LAUNCH_BOUNDS_1(THC_REDUCE_ALL_BLOCK_SIZE)
#endif
kernelReduceAll(TensorInfo<T, IndexType> in,
IndexType totalElements,
Expand Down Expand Up @@ -299,7 +299,7 @@ bool THC_reduceAll(THCState* state,

/*
Only instantiates the all 1D special case and the fallback all nD case for
large (64-bit indexed) tensors to reduce compilation time.
large (64-bit indexed) tensors to reduce compilation time.
*/
if (inInfo.dims == 1) {
HANDLE_IN_CASE(uint64_t, 1);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCSortUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ __device__ inline void bitonicSortKeys(K keys[Power2SortSize],
template <typename K, typename V,
int KeyDims, int ValueDims,
typename Comparator, typename IndexType, int Power2SortSize>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void
bitonicSortKVInPlace(TensorInfo<K, IndexType> keys,
IndexType keySlices,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THC/THCTensorTopK.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ __device__ void radixSelect(DataType* data,
}

template <typename T, typename IndexType, int Dim, bool Order>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void gatherTopK(TensorInfo<T, IndexType> input,
IndexType inputSliceSize,
IndexType outputSliceSize, // aka `k`
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THCUNN/MultiLabelMarginCriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

template <typename Dtype, typename Acctype>
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
#endif
__global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output,
Dtype *input,
Expand Down Expand Up @@ -82,7 +82,7 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output

template <typename Dtype, typename Acctype>
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS(MULTILABELMARGIN_THREADS)
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
#endif
__global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput,
Dtype *gradOutput,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THCUNN/SpatialClassNLLCriterion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ __global__ void SpatialClassNLLCriterion_updateGradInput_no_reduce_kernel(

template <typename T, typename AccumT>
#if defined(__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
#endif
__global__ void cunn_SpatialClassNLLCriterion_updateOutput_kernel(
T *output,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/THCUNN/SpatialCrossMapLRN.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
template <typename Dtype, typename Acctype>
__global__ void
#if __CUDA_ARCH__ >= 320 || defined __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
#endif
LRNFillScale(const int nthreads, const Dtype* const in,
const int num, const int channels, const int height,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THCUNN/SpatialDilatedMaxPooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ const int BACKWARD_THREADS = 256;

template <typename Dtype, typename AccType>
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS(BACKWARD_THREADS, 4)
C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 4)
#else
C10_LAUNCH_BOUNDS(BACKWARD_THREADS, 8)
C10_LAUNCH_BOUNDS_2(BACKWARD_THREADS, 8)
#endif
__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int64_t* top_mask, const int num, const int channels,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THCUNN/VolumetricConvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// Borrowed from Theano
// Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter, Nicolas Ballas
template <typename Dtype>
__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
__global__ void C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
im3d2col_kernel(const int64_t n, const Dtype* data_im,
const int64_t height, const int64_t width, const int64_t depth,
const int64_t kernel_h, const int64_t kernel_w, const int64_t kernel_d,
Expand Down Expand Up @@ -88,7 +88,7 @@ void im3d2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
}

template <typename Dtype, typename Acctype>
__global__ void C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
__global__ void C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS) // ensure that at least 1 block can be resident
col2im3d_kernel(const int64_t n, const Dtype* data_col,
const int64_t height, const int64_t width, const int64_t depth,
const int64_t channels,
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <c10/macros/Macros.h>

template<typename Dtype, typename Acctype>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
const THCDeviceTensor<Dtype, 5> data1, THCDeviceTensor<Dtype, 5> data2) {
Expand Down Expand Up @@ -81,7 +81,7 @@ __global__ void caffe_gpu_interp2_kernel(const int n,

// Backward (adjoint) operation 1 <- 2 (accumulates)
template <typename Dtype, typename Acctype>
C10_LAUNCH_BOUNDS(1024)
C10_LAUNCH_BOUNDS_1(1024)
__global__ void caffe_gpu_interp2_kernel_backward(const int n,
const Acctype rdepth, const Acctype rheight, const Acctype rwidth, const bool align_corners,
THCDeviceTensor<Dtype, 5> data1, const THCDeviceTensor<Dtype, 5> data2){
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THCUNN/im2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// Kernel for fast unfold+copy
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
template <typename Dtype>
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void im2col_kernel(const int64_t n, const Dtype* data_im,
const int64_t height, const int64_t width,
const int64_t ksize_h, const int64_t ksize_w,
Expand Down Expand Up @@ -60,7 +60,7 @@ void im2col(cudaStream_t stream, const Dtype* data_im, const int64_t channels,
}

template <typename Dtype, typename Acctype>
C10_LAUNCH_BOUNDS(CUDA_NUM_THREADS)
C10_LAUNCH_BOUNDS_1(CUDA_NUM_THREADS)
__global__ void col2im_kernel(const int64_t n, const Dtype* data_col,
const int64_t height, const int64_t width, const int64_t channels,
const int64_t kernel_h, const int64_t kernel_w,
Expand Down
3 changes: 0 additions & 3 deletions c10/macros/Macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,9 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256;
#define C10_MAX_THREADS_PER_BLOCK(val) (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) : CUDA_THREADS_PER_BLOCK_FALLBACK)
#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) ((((threads_per_block)*(blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) ? (blocks_per_sm) : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / (threads_per_block))))
// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__
// https://stackoverflow.com/a/8814003 snippet to have macro with an optional argument
#define C10_LAUNCH_BOUNDS_0 __launch_bounds__(256, 4) // default launch bounds that should give good occupancy and versatility across all architectures.
#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))))
#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm))))
#define C10_LAUNCH_BOUNDS_X(x,max_threads_per_block,min_blocks_per_sm,FUNC, ...) FUNC
#define C10_LAUNCH_BOUNDS(...) C10_LAUNCH_BOUNDS_X(,##__VA_ARGS__, C10_LAUNCH_BOUNDS_2(__VA_ARGS__), C10_LAUNCH_BOUNDS_1(__VA_ARGS__), C10_LAUNCH_BOUNDS_0(__VA_ARGS__))
#else
#define C10_HOST_DEVICE
#define C10_HOST
Expand Down
Loading

0 comments on commit 2336f0b

Please sign in to comment.