Skip to content

Commit

Permalink
Make it possible to use TF32 accumulation in F32 matmuls. (huggingfac…
Browse files Browse the repository at this point in the history
…e#2178)

* Allow the use of tf32 accumulation in matmul.

* Better timings.

* Dummy versions for use when cuda is not enabled.
  • Loading branch information
LaurentMazare authored May 11, 2024
1 parent d9bc5ec commit 9cff7bc
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 30 deletions.
42 changes: 18 additions & 24 deletions candle-core/examples/cuda_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,26 @@ extern crate accelerate_src;
extern crate intel_mkl_src;

use anyhow::Result;
use candle_core::{Device, Module, Tensor};

use candle_core::quantized::{QMatMul, QTensor};
use candle_core::{Device, Tensor};

fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");

let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");

let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}");
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
Ok(())
}
67 changes: 61 additions & 6 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1615,12 +1615,8 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
Expand Down Expand Up @@ -1817,6 +1813,20 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);

/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}

/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(b: bool) {
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}

/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
Expand All @@ -1842,6 +1852,51 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}

unsafe fn gemm_strided_batched_f32(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f32>,
a: &cudarc::driver::CudaView<f32>,
b: &cudarc::driver::CudaView<f32>,
c: &mut CudaSlice<f32>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;

let compute_type = if gemm_reduced_precision_f32() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
let beta = &cfg.gemm.beta as *const f32 as *const _;

cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldb,
cfg.stride_b,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}

unsafe fn gemm_strided_batched_f16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f16>,
Expand Down
10 changes: 10 additions & 0 deletions candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,13 @@ pub fn gemm_reduced_precision_bf16() -> bool {
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}

/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
true
}

/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(_b: bool) {}

0 comments on commit 9cff7bc

Please sign in to comment.