Skip to content

Commit

Permalink
Workaround for cublas bug for 45724 (pytorch#46001)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#45724

Pull Request resolved: pytorch#46001

Reviewed By: mruberry

Differential Revision: D24184058

Pulled By: ngimel

fbshipit-source-id: 7d2bab3206ddbc10a7cae3efd9b5e253f38400a9
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Oct 8, 2020
1 parent 8d14b50 commit b2bff9e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
54 changes: 52 additions & 2 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,56 @@ void THCudaBlas_Dgemm(THCState *state, char transa, char transb, int64_t m, int6
at::cuda::blas::gemm<double>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}

#ifndef __HIP_PLATFORM_HCC__
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11200
#define cublasGemmStridedBatchedExFix cublasGemmStridedBatchedEx
#else
// Workaround for https://github.com/pytorch/pytorch/issues/45724
cublasStatus_t cublasGemmStridedBatchedExFix(cublasHandle_t &handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const void *alpha,
const void *A,
cudaDataType Atype,
int lda,
long long int strideA,
const void *B,
cudaDataType Btype,
int ldb,
long long int strideB,
const void *beta,
void *C,
cudaDataType Ctype,
int ldc,
long long int strideC,
int64_t batchCount,
cudaDataType computeType,
cublasGemmAlgo_t algo)
{
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major != 7) {
return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, strideA, B, Btype, ldb, strideB, beta, C, Ctype, ldc, strideC, batchCount, computeType, algo);
}
cublasStatus_t result;
constexpr int64_t split = 63 * 1024;
for(int64_t i = 0; i < batchCount; i += split) {
int64_t count = std::min<int64_t>(split, batchCount - i);
result = cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, alpha,
(char *)A + i * strideA * 2, Atype, lda, strideA,
(char *)B + i * strideB * 2, Btype, ldb, strideB,
beta,
(char *)C + i * strideC * 2, Ctype, ldc, strideC,
(int)count, computeType, algo);
THCublasCheck(result);
}
return result;
}
#endif
#endif

void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
at::Half alpha, const at::Half *a, int64_t lda, int64_t strideA, const at::Half *b, int64_t ldb, int64_t strideB,
at::Half beta, at::Half *c, int64_t ldc, int64_t strideC, int64_t batchCount)
Expand Down Expand Up @@ -167,7 +217,7 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
THCublasCheck(cublasGemmStridedBatchedEx(handle,
THCublasCheck(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB,
Expand Down Expand Up @@ -207,7 +257,7 @@ void THCudaBlas_BgemmStridedBatched(THCState *state, char transa, char transb, i
if (prop->major < 8) {
TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU");
}
THCublasCheck(cublasGemmStridedBatchedEx(handle,
THCublasCheck(cublasGemmStridedBatchedExFix(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16BF, (int)lda, strideA,
b, CUDA_R_16BF, (int)ldb, strideB,
Expand Down
10 changes: 10 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16814,6 +16814,16 @@ def test_addmm_sizes(self, device, dtype):
m2 = torch.randn(k, m, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2)

@onlyCUDA
def test_matmul_45724(self, device):
# https://github.com/pytorch/pytorch/issues/45724
a = torch.rand(65537, 22, 64).cuda().half()
b = torch.rand(65537, 64, 22).cuda().half()
c = torch.full((65537, 22, 22), math.nan, dtype=torch.half, device='cuda')
cpu_result = torch.matmul(a.cpu().float(), b.cpu().float()).cuda().half()
torch.matmul(a, b, out=c)
self.assertEqual(c, cpu_result)

def _test_dot_vdot_vs_numpy(self, device, dtype, torch_fn, np_fn):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y):
y_np = y.cpu().numpy()
Expand Down

0 comments on commit b2bff9e

Please sign in to comment.