Skip to content

Commit

Permalink
Make cudnn work
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Sep 10, 2015
1 parent 3053f8c commit 208a198
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
6 changes: 2 additions & 4 deletions mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,8 @@ extern "C" {
#include <curand.h>
#endif

#if MSHADOW_USE_CUDNN
#ifdef __CUDACC__
#include <cudnn.h>
#endif
#if MSHADOW_USE_CUDNN == 1
#include <cudnn.h>
#endif

#if MSHADOW_USE_NVML
Expand Down
22 changes: 15 additions & 7 deletions mshadow/stream_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ struct Stream<gpu> {
cudaStream_t stream_;
/*! \brief cublas handle */
cublasHandle_t blas_handle_;
/*! \brief cudnn handle */
#if MSHADOW_USE_CUDNN == 1
cudnnHandle_t dnn_handle_;
#endif
/*! \brief cublas handle ownership */
HandleState blas_handle_ownership_;
/*! \brief cudnn handle ownership */
HandleState dnn_handle_ownership_;
#if MSHADOW_USE_CUDNN == 1
/*! \brief cudnn handle */
cudnnHandle_t dnn_handle_;
#endif

Stream(void) : stream_(0),
blas_handle_ownership_(NoHandle),
dnn_handle_ownership_(NoHandle) {}
Expand Down Expand Up @@ -97,7 +98,8 @@ struct Stream<gpu> {
blas_handle_ownership_ = OwnHandle;
utils::Check(err == CUBLAS_STATUS_SUCCESS, "Create cublas handle failed");
}
#if MSHADOW_USE_CUDNN && defined(__CUDACC__)
// #if MSHADOW_USE_CUDNN && defined(__CUDACC__)
#if MSHADOW_USE_CUDNN == 1
inline static cudnnHandle_t GetDnnHandle(Stream<gpu> *stream) {
if (stream == NULL) {
return 0;
Expand All @@ -109,7 +111,8 @@ struct Stream<gpu> {
}
#endif
inline void DestroyDnnHandle() {
#if MSHADOW_USE_CUDNN && defined(__CUDACC__)
// #if MSHADOW_USE_CUDNN && defined(__CUDACC__)
#if MSHADOW_USE_CUDNN == 1
if (dnn_handle_ownership_ == OwnHandle) {
cudnnStatus_t err = cudnnDestroy(dnn_handle_);
utils::Check(err == CUDNN_STATUS_SUCCESS,
Expand All @@ -118,11 +121,16 @@ struct Stream<gpu> {
#endif
}
inline void CreateDnnHandle() {
#if MSHADOW_USE_CUDNN && defined(__CUDACC__)
// #if MSHADOW_USE_CUDNN == 1 && defined(__CUDACC__)
#if MSHADOW_USE_CUDNN == 1
this->DestroyDnnHandle();
cudnnStatus_t err = cudnnCreate(&dnn_handle_);
utils::Check(err == CUDNN_STATUS_SUCCESS,
"Create cudnn handle failed");
err = cudnnSetStream(dnn_handle_, stream_);
utils::Check(err == CUDNN_STATUS_SUCCESS,
"Set cudnn handle stream failed");
this->dnn_handle_ownership_ = OwnHandle;
#endif
}
};
Expand Down

0 comments on commit 208a198

Please sign in to comment.