Skip to content

Commit

Permalink
add volatile operator to half
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Mar 18, 2016
1 parent b0381d7 commit 6f63450
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 6 deletions.
4 changes: 4 additions & 0 deletions make/mshadow.mk
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ MSHADOW_NVCCFLAGS += --std=c++11
else
MSHADOW_CFLAGS+= -DMSHADOW_DIST_PS=0
endif

# Set MSDHADOW_USE_PASCAL to one to enable nvidia pascal gpu features.
# Like cublasHgemm
MSHADOW_CFLAGS += -DMSDHADOW_USE_PASCAL=0
52 changes: 52 additions & 0 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,58 @@ struct BLASEngine<cpu, double> {
// CuBLAS redirect code
#if MSHADOW_USE_CUDA
// All CuBLAS goes to here, use legacy API: not threadsafe
template<>
struct BLASEngine<gpu, half::half_t> {
inline static cublasOperation_t GetT(bool t) {
return t ? CUBLAS_OP_T : CUBLAS_OP_N;
}
inline static void SetStream(Stream<gpu> *stream) {
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
Stream<gpu>::GetStream(stream));
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail";
}
inline static void gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, half::half_t alpha,
const half::half_t *A, int lda,
const half::half_t *B, int ldb, half::half_t beta,
half::half_t *C, int ldc) {
#if MSHADOW_USE_PASCAL == 1
cublasStatus_t err = cublasHgemm(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
A, lda, B, ldb, &beta, C, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas Hgemm fail";
#else
float alpha_f = float(alpha);
float beta_f = float(beta);
cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha_f,
A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF, ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
#endif // MSHADOW_USE_PASCAL == 1
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, half::half_t alpha,
const half::half_t *A, int lda,
const half::half_t *X, int incX, half::half_t beta,
half::half_t *Y, int incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<gpu> *stream,
int m, int n, half::half_t alpha,
const half::half_t *X, int incX,
const half::half_t *Y, int incY, half::half_t *A, int lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<gpu> *stream,
int n,
const half::half_t* X, int incX,
const half::half_t* Y, int incY,
half::half_t *ret) {
LOG(FATAL) << "Not implmented!";
}
};

template<>
struct BLASEngine<gpu, float> {
inline static cublasOperation_t GetT(bool t) {
Expand Down
4 changes: 2 additions & 2 deletions mshadow/extension/reduceto1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inline ReduceTo1DExp<SrcExp, DType, red::sum,
ExpInfo<SrcExp>::kDim - dimkeep>
sumall_except_dim(const Exp<SrcExp, DType, etype> &exp) {
return ReduceTo1DExp<SrcExp, DType, red::sum,
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), 1);
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
}
/*!
* \brief reduce over all dimensions, except dimkeep
Expand All @@ -58,7 +58,7 @@ inline ReduceTo1DExp<SrcExp, DType, Reducer,
ExpInfo<SrcExp>::kDim - dimkeep>
reduce_except_dim(const Exp<SrcExp, DType, etype> &exp) {
return ReduceTo1DExp<SrcExp, DType, Reducer,
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), 1);
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
}
/*!
* \brief a expression that sum over rows of a matrix
Expand Down
63 changes: 59 additions & 4 deletions mshadow/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,29 @@ namespace half {

#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (T a) { \
MSHADOW_XINLINE half_t operator AOP (const T& a) { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
} \
template<typename T> \
MSHADOW_XINLINE volatile half_t operator AOP (const volatile T& a) volatile { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
}

#if MSHADOW_CUDA_HALF
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
}
#else
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(half2float(half_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(half2float(half_)); /* NOLINT(*)*/ \
}
#endif // MSHADOW_CUDA_HALF

Expand Down Expand Up @@ -91,7 +101,17 @@ class half_t {
}

template<typename T>
MSHADOW_XINLINE half_t operator=(T a) {
MSHADOW_XINLINE half_t operator=(const T& a) {
return *this = half_t(a); /* NOLINT(*)*/
}

template<typename T>
MSHADOW_XINLINE half_t operator=(const volatile T& a) volatile {
return *this = half_t(a); /* NOLINT(*)*/
}

template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) volatile {
return *this = half_t(a); /* NOLINT(*)*/
}

Expand Down Expand Up @@ -132,7 +152,24 @@ class half_t {
#endif // MSHADOW_CUDA_HALF
};

MSHADOW_XINLINE uint16_t float2half(const float value) const {
MSHADOW_XINLINE uint16_t float2half(const float& value) const {
Bits v, s;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
}

MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile {
Bits v, s;
v.f = value;
uint32_t sign = v.si & signN;
Expand All @@ -149,7 +186,25 @@ class half_t {
return v.ui | sign;
}

MSHADOW_XINLINE float half2float(const uint16_t value) const {
MSHADOW_XINLINE float half2float(const uint16_t& value) const {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}

MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
Expand Down

0 comments on commit 6f63450

Please sign in to comment.