Skip to content

Commit

Permalink
Back out "Add _int_mm to expose cuBLAS int8@int8 -> int32 matmul (pyt…
Browse files Browse the repository at this point in the history
…orch#94339)" (pytorch#96885)

Summary:
Backing out  _int_mm to expose cuBLAS int8@int8 -> int32 matmul (pytorch#94339)

Test Plan: CI

Pull Request resolved: pytorch#96885
Approved by: https://github.com/drisspg
  • Loading branch information
cpuhrsch authored and pytorchmergebot committed Mar 16, 2023
1 parent 06054d7 commit 0a53c96
Show file tree
Hide file tree
Showing 14 changed files with 73 additions and 440 deletions.
164 changes: 53 additions & 111 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
};
} // namespace

template <typename Dtype, typename RDtype, typename BDtype>
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
Expand All @@ -630,11 +630,12 @@ void gemm_and_bias(
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const BDtype* bias,
RDtype* result_ptr,
const Dtype* bias,
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation,
bool use_heuristic) {
GEMMAndBiasActivationEpilogue activation) {
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = 0; // bias is added in epilogue

cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
Expand All @@ -653,19 +654,6 @@ void gemm_and_bias(
} else if (std::is_same<Dtype, at::BFloat16>::value) {
abcType = CUDA_R_16BF;
}
cudaDataType_t abType = abcType;
cudaDataType_t cType = abcType;
if (std::is_same<Dtype, int8_t>::value) {
abType = CUDA_R_8I;
cType = CUDA_R_32I;
computeType = CUBLAS_COMPUTE_32I;
scaleType = CUDA_R_32I;
bool valid_rdtype = std::is_same<RDtype, int32_t>::value;
TORCH_CHECK(valid_rdtype, "Expected int32_t for result Tensor if given int8_t mat1, mat2.");
} else {
bool valid_rdtype = std::is_same<RDtype, Dtype>::value;
TORCH_CHECK(valid_rdtype, "Expected result and input dtypes to match.");
}

CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
Expand All @@ -680,87 +668,64 @@ void gemm_and_bias(
CUBLASLT_MATMUL_DESC_TRANSB,
&transb,
sizeof(transb)));

cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
if (activation == GEMMAndBiasActivationEpilogue::BIAS) {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
if (activation == GEMMAndBiasActivationEpilogue::BIAS_RELU) {
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
}
if (activation == GEMMAndBiasActivationEpilogue::BIAS_GELU) {
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
#if CUDA_VERSION >= 11040
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#else
TORCH_CHECK(false, "CUBLASLT_EPILOGUE_GELU_BIAS is an unsupported feature for CUDA version ", CUDA_VERSION);
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
#endif
}
if (activation == GEMMAndBiasActivationEpilogue::NONE) {
TORCH_CHECK(bias == nullptr, "Expected bias to be a nullptr.");
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias,
sizeof(Dtype*)));
}
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
sizeof(epilogue)));
TORCH_CUDABLAS_CHECK(cublasLtMatmulDescSetAttribute(
computeDesc.descriptor(),
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias,
sizeof(Dtype*)));

CuBlasLtMatrixLayout Adesc(
abType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
abcType, transpose_mat1 ? k : m, transpose_mat1 ? m : k, mat1_ld);
CuBlasLtMatrixLayout Bdesc(
abType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
abcType, transpose_mat2 ? n : k, transpose_mat2 ? k : n, mat2_ld);
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);

CuBlasLtMatmulPreference preference;
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
void* workspace_data_ptr;
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
preference.descriptor(),
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize,
sizeof(workspaceSize)));

if (std::is_same<Dtype, int8_t>::value) {
workspaceSize = 0;
}
if (workspaceSize > 0) {
TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(
preference.descriptor(),
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize,
sizeof(workspaceSize)));

auto workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));
workspace_data_ptr = workspace.data_ptr();
}
auto workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte));

cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
if (use_heuristic) {
int returnedResult = 0;
auto heuristic_return_value = cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult);
TORCH_CUDABLAS_CHECK(heuristic_return_value);
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}

std::conditional_t<std::is_same<BDtype, std::nullptr_t>::value, float, at::opmath_type<Dtype>> beta_val = 0;
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
Expand All @@ -774,8 +739,8 @@ void gemm_and_bias(
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
use_heuristic ? &heuristicResult.algo : nullptr,
workspaceSize > 0 ? workspace_data_ptr : nullptr,
&heuristicResult.algo,
workspace.data_ptr(),
workspaceSize,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
Expand All @@ -798,10 +763,8 @@ void gemm_and_bias(
mat2_ld,
" result_ld ",
result_ld,
" abType ",
abType,
" cType ",
cType,
" abcType ",
abcType,
" computeType ",
computeType,
" scaleType ",
Expand All @@ -822,8 +785,7 @@ template void gemm_and_bias(
const double* bias,
double* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation,
bool use_heuristic);
GEMMAndBiasActivationEpilogue activation);

template void gemm_and_bias(
bool transpose_mat1,
Expand All @@ -839,8 +801,7 @@ template void gemm_and_bias(
const float* bias,
float* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation,
bool use_heuristic);
GEMMAndBiasActivationEpilogue activation);

template void gemm_and_bias(
bool transpose_mat1,
Expand All @@ -856,8 +817,7 @@ template void gemm_and_bias(
const at::Half* bias,
at::Half* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation,
bool use_heuristic);
GEMMAndBiasActivationEpilogue activation);

template void gemm_and_bias(
bool transpose_mat1,
Expand All @@ -873,25 +833,7 @@ template void gemm_and_bias(
const at::BFloat16* bias,
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation,
bool use_heuristic);

template void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
int64_t m,
int64_t n,
int64_t k,
at::opmath_type<int8_t> alpha_val,
const int8_t* mat1_ptr,
int64_t mat1_ld,
const int8_t* mat2_ptr,
int64_t mat2_ld,
const std::nullptr_t* bias,
int32_t* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation,
bool use_heuristic);
GEMMAndBiasActivationEpilogue activation);
#endif // !defined(USE_ROCM) && !defined(_MSC_VER)

template <>
Expand Down
16 changes: 7 additions & 9 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,14 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));

#if !defined(USE_ROCM) && !defined(_MSC_VER)
enum GEMMAndBiasActivationEpilogue {
NONE,
BIAS,
BIAS_RELU,
BIAS_GELU,
None,
RELU,
GELU,
};

// NOTE: GELU activation is not supported prior to CUDA 11.4 and will
// do nothing if passed in that case.
template <typename Dtype, typename RDtype, typename BDtype>
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
bool transpose_mat2,
Expand All @@ -90,11 +89,10 @@ void gemm_and_bias(
int64_t mat1_ld,
const Dtype* mat2_ptr,
int64_t mat2_ld,
const BDtype* bias,
RDtype* result_ptr,
const Dtype* bias,
Dtype* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::BIAS,
bool use_heuristic = true);
GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
#endif

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
Expand Down
Loading

0 comments on commit 0a53c96

Please sign in to comment.