Skip to content

Commit

Permalink
[Bugfix] Wrap cub with CUB_NS_PREFIX and remove dependency on Thrust …
Browse files Browse the repository at this point in the history
…to linking issues with Torch 1.8 (dmlc#2758)

* Wrap cub with prefixes and remove thrust

* Using counting iterator

Co-authored-by: Zihao Ye <[email protected]>
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2021
1 parent 9aac93f commit 0ff7127
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/array/cuda/array_cumsum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
* \brief Array cumsum GPU implementation
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {
using runtime::NDArray;
Expand Down
73 changes: 49 additions & 24 deletions src/array/cuda/array_nonzero.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,71 @@
* \file array/cpu/array_nonzero.cc
* \brief Array nonzero CPU implementation
*/
#include <thrust/iterator/counting_iterator.h>
#include <thrust/copy.h>
#include <thrust/functional.h>
#include <thrust/device_vector.h>

#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {

template <typename IdType>
struct IsNonZero {
__device__ bool operator() (const IdType val) {
return val != 0;
struct IsNonZeroIndex {
explicit IsNonZeroIndex(const IdType * array) : array_(array) {
}

__device__ bool operator() (const int64_t index) {
return array_[index] != 0;
}

const IdType * array_;
};

template <DLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);

const int64_t len = array->shape[0];
IdArray ret = NewIdArray(len, array->ctx, 64);
thrust::device_ptr<IdType> in_data(array.Ptr<IdType>());
thrust::device_ptr<int64_t> out_data(ret.Ptr<int64_t>());
// TODO(minjie): should take control of the memory allocator.
// See PyTorch's implementation here:
// https://github.com/pytorch/pytorch/blob/1f7557d173c8e9066ed9542ada8f4a09314a7e17/
// aten/src/THC/generic/THCTensorMath.cu#L104
auto startiter = thrust::make_counting_iterator<int64_t>(0);
auto enditer = startiter + len;
auto indices_end = thrust::copy_if(thrust::cuda::par.on(thr_entry->stream),
startiter,
enditer,
in_data,
out_data,
IsNonZero<IdType>());
const int64_t num_nonzeros = indices_end - out_data;
IdArray ret = NewIdArray(len, ctx, 64);

cudaStream_t stream = 0;

const IdType * const in_data = static_cast<const IdType*>(array->data);
int64_t * const out_data = static_cast<int64_t*>(ret->data);

IsNonZeroIndex<IdType> comp(in_data);
cub::CountingInputIterator<int64_t> counter(0);

// room for cub to output on GPU
int64_t * d_num_nonzeros = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t)));

size_t temp_size = 0;
cub::DeviceSelect::If(nullptr, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
void * temp = device->AllocWorkspace(ctx, temp_size);
cub::DeviceSelect::If(temp, temp_size, counter, out_data,
d_num_nonzeros, len, comp, stream);
device->FreeWorkspace(ctx, temp);

// copy number of selected elements from GPU to CPU
int64_t num_nonzeros;
device->CopyDataFromTo(
d_num_nonzeros, 0,
&num_nonzeros, 0,
sizeof(num_nonzeros),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
stream);
device->FreeWorkspace(ctx, d_num_nonzeros);
device->StreamSync(ctx, stream);

// truncate array to size
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
}

Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/array_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
* \brief Array sort GPU implementation
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {
using runtime::NDArray;
Expand Down
2 changes: 1 addition & 1 deletion src/array/cuda/csr_sort.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
* \brief Sort CSR index
*/
#include <dgl/array.h>
#include <cub/cub.cuh>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "./dgl_cub.cuh"

namespace dgl {

Expand Down
17 changes: 17 additions & 0 deletions src/array/cuda/dgl_cub.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*!
* Copyright (c) 2021 by Contributors
* \file cuda_common.h
* \brief Wrapper to place cub in dgl namespace.
*/

#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_
#define DGL_ARRAY_CUDA_DGL_CUB_CUH_

// include cub in a safe manner
#define CUB_NS_PREFIX namespace dgl {
#define CUB_NS_POSTFIX }
#include "cub/cub.cuh"
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX

#endif
2 changes: 1 addition & 1 deletion src/array/cuda/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

#include "./utils.h"
#include <cub/cub.cuh>
#include "./dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"

namespace dgl {
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/cuda/cuda_hashtable.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
* \brief Device level functions for within cuda kernels.
*/

#include <cub/cub.cuh>
#include <cassert>

#include "cuda_hashtable.cuh"
#include "../../kernel/cuda/atomic.cuh"
#include "../../array/cuda/dgl_cub.cuh"

using namespace dgl::kernel::cuda;

Expand Down

0 comments on commit 0ff7127

Please sign in to comment.