Skip to content

Commit

Permalink
Merge pull request dmlc#103 from piiswrong/master
Browse files Browse the repository at this point in the history
FP16 compatibility
  • Loading branch information
tqchen committed Mar 20, 2016
2 parents 62ca0b7 + db0fb96 commit bbee465
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 13 deletions.
4 changes: 4 additions & 0 deletions make/mshadow.mk
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ MSHADOW_NVCCFLAGS += --std=c++11
else
MSHADOW_CFLAGS+= -DMSHADOW_DIST_PS=0
endif

# Set MSDHADOW_USE_PASCAL to one to enable nvidia pascal gpu features.
# Like cublasHgemm
MSHADOW_CFLAGS += -DMSDHADOW_USE_PASCAL=0
20 changes: 20 additions & 0 deletions mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,26 @@ struct DataType;
template<>
struct DataType<float> {
static const int kFlag = kFloat32;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1)
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_FLOAT;
typedef float ScaleType;
#endif
};
template<>
struct DataType<double> {
static const int kFlag = kFloat64;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1)
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_DOUBLE;
typedef double ScaleType;
#endif
};
template<>
struct DataType<half::half_t> {
static const int kFlag = kFloat16;
#if (MSHADOW_USE_CUDA && MSHADOW_USE_CUDNN == 1)
static const cudnnDataType_t kCudnnFlag = CUDNN_DATA_HALF;
typedef float ScaleType;
#endif
};
template<>
struct DataType<uint8_t> {
Expand Down Expand Up @@ -458,7 +470,11 @@ struct maximum {
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
using namespace std;
#ifdef __CUDACC__
dst = ::max(dst, src);
#else
dst = max(dst, src);
#endif // __CUDACC__
}
/*!
* \brief calculate gradient of redres with respect to redsrc,
Expand All @@ -482,7 +498,11 @@ struct minimum {
template<typename DType>
MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*)
using namespace std;
#ifdef __CUDACC__
dst = ::min(dst, src);
#else
dst = min(dst, src);
#endif // __CUDACC__
}
/*!
* \brief calculate gradient of redres with respect to redsrc,
Expand Down
53 changes: 53 additions & 0 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,59 @@ struct BLASEngine<cpu, double> {
// CuBLAS redirect code
#if MSHADOW_USE_CUDA
// All CuBLAS goes to here, use legacy API: not threadsafe
template<>
struct BLASEngine<gpu, half::half_t> {
inline static cublasOperation_t GetT(bool t) {
return t ? CUBLAS_OP_T : CUBLAS_OP_N;
}
inline static void SetStream(Stream<gpu> *stream) {
cublasStatus_t err = cublasSetStream(Stream<gpu>::GetBlasHandle(stream),
Stream<gpu>::GetStream(stream));
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas set stream fail";
}
inline static void gemm(Stream<gpu> *stream,
bool transa, bool transb,
int m, int n, int k, half::half_t alpha,
const half::half_t *A, int lda,
const half::half_t *B, int ldb, half::half_t beta,
half::half_t *C, int ldc) {
#if MSHADOW_USE_PASCAL == 1
cublasStatus_t err = cublasHgemm(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
A, lda, B, ldb, &beta, C, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas Hgemm fail";
#else
float alpha_f = float(alpha); // NOLINT(*)
float beta_f = float(beta); // NOLINT(*)
cublasStatus_t err = cublasSgemmEx(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha_f,
A, CUBLAS_DATA_HALF, lda, B, CUBLAS_DATA_HALF,
ldb, &beta_f, C, CUBLAS_DATA_HALF, ldc);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas SgemmEx fail";
#endif // MSHADOW_USE_PASCAL == 1
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, half::half_t alpha,
const half::half_t *A, int lda,
const half::half_t *X, int incX, half::half_t beta,
half::half_t *Y, int incY) {
LOG(FATAL) << "Not implmented!";
}
inline static void ger(Stream<gpu> *stream,
int m, int n, half::half_t alpha,
const half::half_t *X, int incX,
const half::half_t *Y, int incY, half::half_t *A, int lda) {
LOG(FATAL) << "Not implmented!";
}
inline static void dot(Stream<gpu> *stream,
int n,
const half::half_t* X, int incX,
const half::half_t* Y, int incY,
half::half_t *ret) {
LOG(FATAL) << "Not implmented!";
}
};

template<>
struct BLASEngine<gpu, float> {
inline static cublasOperation_t GetT(bool t) {
Expand Down
4 changes: 2 additions & 2 deletions mshadow/extension/reduceto1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ inline ReduceTo1DExp<SrcExp, DType, red::sum,
ExpInfo<SrcExp>::kDim - dimkeep>
sumall_except_dim(const Exp<SrcExp, DType, etype> &exp) {
return ReduceTo1DExp<SrcExp, DType, red::sum,
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), 1);
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
}
/*!
* \brief reduce over all dimensions, except dimkeep
Expand All @@ -58,7 +58,7 @@ inline ReduceTo1DExp<SrcExp, DType, Reducer,
ExpInfo<SrcExp>::kDim - dimkeep>
reduce_except_dim(const Exp<SrcExp, DType, etype> &exp) {
return ReduceTo1DExp<SrcExp, DType, Reducer,
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), 1);
ExpInfo<SrcExp>::kDim - dimkeep>(exp.self(), DType(1));
}
/*!
* \brief a expression that sum over rows of a matrix
Expand Down
2 changes: 1 addition & 1 deletion mshadow/extension/unpack_patch2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ struct Plan<UnpackPatchToColXExp<SrcExp, DType, srcdim>, DType> {
if (x < i_width_ && y < i_height_) {
return src_.Eval((n * i_channel_ + c) * i_height_ + y, x);
} else {
return 0.0f;
return DType(0.0f);
}
}

Expand Down
76 changes: 69 additions & 7 deletions mshadow/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,29 @@ namespace half {

#define MSHADOW_HALF_ASSIGNOP(AOP, OP) \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (T a) { \
MSHADOW_XINLINE half_t operator AOP (const T& a) { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
} \
template<typename T> \
MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \
return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \
}

#if MSHADOW_CUDA_HALF
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \
}
#else
#define MSHADOW_HALF_CONVERSIONOP(T) \
MSHADOW_XINLINE operator T() const { \
return T(half2float(half_)); /* NOLINT(*)*/ \
} \
MSHADOW_XINLINE operator T() const volatile { \
return T(half2float(half_)); /* NOLINT(*)*/ \
}
#endif // MSHADOW_CUDA_HALF

Expand All @@ -67,15 +77,12 @@ class half_t {
}
#endif // MSHADOW_CUDA_HALF

MSHADOW_XINLINE explicit half_t(const float& value) { constructor(value); }
MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }

MSHADOW_HALF_CONVERSIONOP(float)
MSHADOW_HALF_CONVERSIONOP(double)
MSHADOW_HALF_CONVERSIONOP(uint8_t)
MSHADOW_HALF_CONVERSIONOP(int32_t)

MSHADOW_HALF_ASSIGNOP(+=, +)
MSHADOW_HALF_ASSIGNOP(-=, -)
Expand All @@ -90,6 +97,26 @@ class half_t {
return half_t(-float(*this)); // NOLINT(*)
}

MSHADOW_XINLINE half_t operator=(const half_t& a) {
half_ = a.half_;
return a;
}

template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) {
return *this = half_t(a); /* NOLINT(*)*/
}

MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
half_ = a.half_;
return a;
}

template<typename T>
MSHADOW_XINLINE half_t operator=(const T& a) volatile {
return *this = half_t(a); /* NOLINT(*)*/
}

private:
union Bits {
float f;
Expand Down Expand Up @@ -127,7 +154,24 @@ class half_t {
#endif // MSHADOW_CUDA_HALF
};

MSHADOW_XINLINE uint16_t float2half(const float value) const {
MSHADOW_XINLINE uint16_t float2half(const float& value) const {
Bits v, s;
v.f = value;
uint32_t sign = v.si & signN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
return v.ui | sign;
}

MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*)
Bits v, s;
v.f = value;
uint32_t sign = v.si & signN;
Expand All @@ -144,7 +188,25 @@ class half_t {
return v.ui | sign;
}

MSHADOW_XINLINE float half2float(const uint16_t value) const {
MSHADOW_XINLINE float half2float(const uint16_t& value) const {
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
v.si ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
Bits s;
s.si = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
return v.f;
}

MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*)
Bits v;
v.ui = value;
int32_t sign = v.si & signC;
Expand Down
6 changes: 3 additions & 3 deletions mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ inline void Softmax(Tensor<cpu, 1, DType> dst,
for (index_t x = 1; x < dst.size(0); ++x) {
if (mmax < energy[x]) mmax = energy[x];
}
DType sum = 0.0f;
DType sum = DType(0.0f);
for (index_t x = 0; x < dst.size(0); ++x) {
dst[x] = std::exp(energy[x] - mmax);
sum += dst[x];
Expand Down Expand Up @@ -314,7 +314,7 @@ inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const index_t k = static_cast<int>(label[y][n]);
if (k == static_cast<int>(ignore_label)) {
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] = 0.0f;
dst[y][x][n] = DType(0.0f);
}
} else {
for (index_t x = 0; x < dst.size(1); ++x) {
Expand Down Expand Up @@ -348,7 +348,7 @@ inline void Softmax(Tensor<cpu, 3, DType> dst,
for (index_t x = 1; x < dst.size(1); ++x) {
if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
}
DType sum = 0.0f;
DType sum = DType(0.0f);
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
sum += dst[y][x][n];
Expand Down

0 comments on commit bbee465

Please sign in to comment.