Skip to content

Commit

Permalink
remove amax_ptr from scaled_gemm (#135421)
Browse files Browse the repository at this point in the history
amax was removed from _scaled_mm by #128683. Remove it from the internal at::cuda::blas::scaled_gemm, as well.  This allows hipBLASLt to find additional solutions rather than forcing amax to be used and then discarding the result.
Pull Request resolved: pytorch/pytorch#135421
Approved by: https://github.com/drisspg, https://github.com/eqy
  • Loading branch information
jeffdaily authored and pytorchmergebot committed Sep 9, 2024
1 parent b4feec9 commit 39a6179
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 25 deletions.
9 changes: 2 additions & 7 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,6 @@ void scaled_gemm(
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum) {
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
const auto computeType = CUBLAS_COMPUTE_32F;
Expand All @@ -1421,13 +1420,9 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)
// Amax support in ROCm as of 6.2
if (isFloat8Type(result_dtype)) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
#endif
#ifndef USE_ROCM
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
#endif
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ void scaled_gemm(
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum);

#define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/cuda/tunable/TunableGemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->amax_ptr,
params->use_fast_accum);
return OK;
}
Expand Down
18 changes: 2 additions & 16 deletions aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,9 +964,9 @@ ScalingType get_scaling_type(

} // namespace

// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
// Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, neither scale_result is applied nor amax is computed.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
// Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major
// - Only works if matrices sizes are divisible by 32
Expand Down Expand Up @@ -1068,9 +1068,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");

// Some scaled_gemms require an amax to populate lets create one here
Tensor amax = at::empty({0}, mat1.options().dtype(ScalarType::Float));

#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
Expand Down Expand Up @@ -1126,7 +1123,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.amax_ptr = amax.data_ptr();
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
Expand All @@ -1150,11 +1146,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
else
#endif
{
#if defined(USE_ROCM) && ROCM_VERSION >= 60200
// hipBlasLT requires scaleD to be set to something in order to use AMAX
auto dummy_options = TensorOptions().dtype(kFloat).device(kCUDA);
auto dummy_scale = at::ones(1, dummy_options);
#endif
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
Expand All @@ -1172,14 +1163,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
#if defined(USE_ROCM) && ROCM_VERSION >= 60200
scale_result ? scale_result->data_ptr() : dummy_scale.data_ptr(),
#else
scale_result ? scale_result->data_ptr() : nullptr,
#endif
args.result_ld,
out_dtype_,
amax.data_ptr(),
use_fast_accum);
}

Expand Down

0 comments on commit 39a6179

Please sign in to comment.