Skip to content

Commit

Permalink
simplify code to help MSVC 19.10 and lower
Browse files Browse the repository at this point in the history
  • Loading branch information
YashasSamaga committed Dec 30, 2019
1 parent 7b12cbd commit 48eecaf
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
42 changes: 23 additions & 19 deletions modules/dnn/src/cuda/grid_stride_range.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,29 @@
namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace device {

namespace detail {
template <int> __device__ auto getGridDim()->decltype(dim3::x);
template <> inline __device__ auto getGridDim<0>()->decltype(dim3::x) { return gridDim.x; }
template <> inline __device__ auto getGridDim<1>()->decltype(dim3::x) { return gridDim.y; }
template <> inline __device__ auto getGridDim<2>()->decltype(dim3::x) { return gridDim.z; }

template <int> __device__ auto getBlockDim()->decltype(dim3::x);
template <> inline __device__ auto getBlockDim<0>()->decltype(dim3::x) { return blockDim.x; }
template <> inline __device__ auto getBlockDim<1>()->decltype(dim3::x) { return blockDim.y; }
template <> inline __device__ auto getBlockDim<2>()->decltype(dim3::x) { return blockDim.z; }

template <int> __device__ auto getBlockIdx()->decltype(uint3::x);
template <> inline __device__ auto getBlockIdx<0>()->decltype(uint3::x) { return blockIdx.x; }
template <> inline __device__ auto getBlockIdx<1>()->decltype(uint3::x) { return blockIdx.y; }
template <> inline __device__ auto getBlockIdx<2>()->decltype(uint3::x) { return blockIdx.z; }

template <int> __device__ auto getThreadIdx()->decltype(uint3::x);
template <> inline __device__ auto getThreadIdx<0>()->decltype(uint3::x) { return threadIdx.x; }
template <> inline __device__ auto getThreadIdx<1>()->decltype(uint3::x) { return threadIdx.y; }
template <> inline __device__ auto getThreadIdx<2>()->decltype(uint3::x) { return threadIdx.z; }
using dim3_member_type = decltype(dim3::x);

template <int> __device__ dim3_member_type getGridDim();
template <> inline __device__ dim3_member_type getGridDim<0>() { return gridDim.x; }
template <> inline __device__ dim3_member_type getGridDim<1>() { return gridDim.y; }
template <> inline __device__ dim3_member_type getGridDim<2>() { return gridDim.z; }

template <int> __device__ dim3_member_type getBlockDim();
template <> inline __device__ dim3_member_type getBlockDim<0>() { return blockDim.x; }
template <> inline __device__ dim3_member_type getBlockDim<1>() { return blockDim.y; }
template <> inline __device__ dim3_member_type getBlockDim<2>() { return blockDim.z; }

using uint3_member_type = decltype(uint3::x);

template <int> __device__ uint3_member_type getBlockIdx();
template <> inline __device__ uint3_member_type getBlockIdx<0>() { return blockIdx.x; }
template <> inline __device__ uint3_member_type getBlockIdx<1>() { return blockIdx.y; }
template <> inline __device__ uint3_member_type getBlockIdx<2>() { return blockIdx.z; }

template <int> __device__ uint3_member_type getThreadIdx();
template <> inline __device__ uint3_member_type getThreadIdx<0>() { return threadIdx.x; }
template <> inline __device__ uint3_member_type getThreadIdx<1>() { return threadIdx.y; }
template <> inline __device__ uint3_member_type getThreadIdx<2>() { return threadIdx.z; }
}

template <int dim, class index_type = device::index_type, class size_type = device::size_type>
Expand Down
7 changes: 4 additions & 3 deletions modules/dnn/src/cuda4dnn/csl/cudnn/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
}

/** get_data_type<T> returns the equivalent cudnn enumeration constant for type T */
template <class> auto get_data_type()->decltype(CUDNN_DATA_FLOAT);
template <> inline auto get_data_type<half>()->decltype(CUDNN_DATA_HALF) { return CUDNN_DATA_HALF; }
template <> inline auto get_data_type<float>()->decltype(CUDNN_DATA_FLOAT) { return CUDNN_DATA_FLOAT; }
using cudnn_data_enum_type = decltype(CUDNN_DATA_FLOAT);
template <class> cudnn_data_enum_type get_data_type();
template <> inline cudnn_data_enum_type get_data_type<half>() { return CUDNN_DATA_HALF; }
template <> inline cudnn_data_enum_type get_data_type<float>() { return CUDNN_DATA_FLOAT; }
}

/** @brief noncopyable cuDNN smart handle
Expand Down

0 comments on commit 48eecaf

Please sign in to comment.