Skip to content

Commit

Permalink
Revert D22517785: [pytorch][PR] Enable TF32 support for cuBLAS
Browse files Browse the repository at this point in the history
Test Plan: revert-hammer

Differential Revision:
D22517785 (pytorch/pytorch@288ece8)

Original commit changeset: 87334c893561

fbshipit-source-id: 0a0674f49c1bcfc98f7f88af5a8c7de93b76e458
  • Loading branch information
mrshenli authored and facebook-github-bot committed Jul 15, 2020
1 parent 8548a21 commit 3a63a93
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 248 deletions.
8 changes: 0 additions & 8 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,6 @@ void Context::setBenchmarkCuDNN(bool b) {
benchmark_cudnn = b;
}

bool Context::allowTF32CuBLAS() const {
return allow_tf32_cublas;
}

void Context::setAllowTF32CuBLAS(bool b) {
allow_tf32_cublas = b;
}

bool Context::hasMKL() const {
#if AT_MKL_ENABLED()
return true;
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ class CAFFE2_API Context {
bool deterministic() const;
void setDeterministic(bool);
void alertNotDeterministic(c10::string_view const& caller);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
const std::vector<at::QEngine>& supportedQEngines() const;
Expand Down Expand Up @@ -138,7 +136,6 @@ class CAFFE2_API Context {
bool deterministic_cudnn = false;
bool _deterministic = false;
bool benchmark_cudnn = false;
bool allow_tf32_cublas = true;
bool enabled_mkldnn = true;
#ifdef C10_MOBILE
bool release_original_weights = true;
Expand Down
8 changes: 0 additions & 8 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 5) {
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle,
opa,
Expand All @@ -258,11 +254,7 @@ void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
} else {
TORCH_CUDABLAS_CHECK(cublasSgemmEx(
handle,
Expand Down
10 changes: 0 additions & 10 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
#if CUDA_VERSION >= 11000
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
if (at::globalContext().allowTF32CuBLAS()) {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
} else {
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
#endif
return handle;
}

Expand Down
18 changes: 1 addition & 17 deletions aten/src/ATen/native/cuda/MiscUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,10 @@ struct MAGMAQueue {
// Constructor
explicit MAGMAQueue(int64_t device_id) {
auto& context = at::globalContext();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000
// Magma operations is numerically sensitive, so TF32 should be off
// regardless of the global flag.
TORCH_CUDABLAS_CHECK(cublasGetMathMode(handle, &original_math_mode));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
magma_queue_create_from_cuda(
device_id,
at::cuda::getCurrentCUDAStream(),
handle,
at::cuda::getCurrentCUDABlasHandle(),
at::cuda::getCurrentCUDASparseHandle(),
&magma_queue_);
}
Expand All @@ -45,20 +38,11 @@ struct MAGMAQueue {

// Destructor
~MAGMAQueue() {
#if CUDA_VERSION >= 11000
// We've manually set the math mode to CUBLAS_DEFAULT_MATH, now we
// should restore the original math mode back
cublasHandle_t handle = magma_queue_get_cublas_handle(magma_queue_);
cublasSetMathMode(handle, original_math_mode);
#endif
magma_queue_destroy(magma_queue_);
}

private:
magma_queue_t magma_queue_;
#if CUDA_VERSION >= 11000
cublasMath_t original_math_mode;
#endif
};

static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
Expand Down
8 changes: 0 additions & 8 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,22 +185,14 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i
(int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
0, 0));
#else
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#endif // CUDA_VERSION < 11000
THCublasCheck(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB,
(void*)&fBeta, c, CUDA_R_16F, (int)ldc, strideC,
(int)batchCount, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
#if defined(CUDA_VERSION) && CUDA_VERSION < 11000
// On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH
// manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required.
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif // CUDA_VERSION < 11000
#endif // __HIP_PLATFORM_HCC__
}
#endif // CUDA_VERSION or __HIP_PLATFORM_HCC__
Expand Down
65 changes: 0 additions & 65 deletions docs/source/notes/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,71 +54,6 @@ Below you can find a small example showcasing this::
f = torch.randn(2).cuda(cuda2)
# d.device, e.device, and f.device are all device(type='cuda', index=2)

.. _tf32_on_ampere:

TensorFloat-32(TF32) on Ampere devices
--------------------------------------

Starting in PyTorch 1.7, there is a new flag called `allow_tf32` which defaults to true.
This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores,
available on new NVIDIA GPUs since Ampere, internally to compute matmul (matrix multiplies
and batched matrix multiplies) and convolutions.

TF32 tensor cores are designed to achieve better performance on matmul and convolutions on
`torch.float32` tensors by truncating input data to have 10 bits of mantissa, and accumulating
results with FP32 precision, maintaining FP32 dynamic range.

matmul and convolutions are controlled separately, and their corresponding flag can be accessed at:

.. code:: python
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = True
# The allow_tf32 flag for convolutions is not implemented yet
To get an idea of the precision and speed, see the example code below:

.. code:: python
a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda')
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7277
a = a_full.float()
b = b_full.float()
# Do matmul at TF32 mode.
ab_tf32 = a @ b # takes 0.016s on GA100
error = (ab_tf32 - ab_full).abs().max() # 0.1747
relative_error = error / mean # 0.0022
# Do matmul with TF32 disabled.
torch.backends.cuda.matmul.allow_tf32 = False
ab_fp32 = a @ b # takes 0.11s on GA100
error = (ab_fp32 - ab_full).abs().max() # 0.0031
relative_error = error / mean # 0.000039
From the above example, we can see that with TF32 enabled, the speed is ~7x faster, relative error
compared to double precision is approximately 2 orders of magnitude larger. If the full FP32 precision
is needed, users can disable TF32 by:

.. code:: python
torch.backends.cuda.matmul.allow_tf32 = False
# disabling of TF32 for cuDNN is not implemented yet
For more information about TF32, see:

- `TensorFloat-32`_
- `CUDA 11`_
- `Ampere architecture`_

.. _TensorFloat-32: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
.. _CUDA 11: https://devblogs.nvidia.com/cuda-11-features-revealed/
.. _Ampere architecture: https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/

Asynchronous execution
----------------------

Expand Down
7 changes: 0 additions & 7 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,6 @@ def test_serialization_array_with_storage(self):
q_copy[1].fill_(10)
self.assertTrue(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

def test_allow_tf32_get_set(self):
orig = torch.backends.cuda.matmul.allow_tf32
self.assertEqual(torch._C._get_cublas_allow_tf32(), orig)
torch.backends.cuda.matmul.allow_tf32 = not orig
self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
torch.backends.cuda.matmul.allow_tf32 = orig

def test_type_conversions(self):
x = torch.randn(5, 5)
self.assertIsInstance(x.float(), torch.FloatTensor)
Expand Down
32 changes: 11 additions & 21 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from typing import Dict, List, Tuple, Union
import torch.backends.quantized
import torch.testing._internal.data
from torch.testing._internal.common_cuda import tf32_on_and_off


# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
Expand Down Expand Up @@ -8441,7 +8440,6 @@ def dims_full_for_fn():
r1 = fntorch(t0_full, t1, t2)
self.assertEqual(r0, r1)

@tf32_on_and_off(0.001)
def test_broadcast_batched_matmul(self, device):
n_dim = random.randint(1, 8)
m_dim = random.randint(1, 8)
Expand Down Expand Up @@ -10431,7 +10429,6 @@ def check_norm(a, b, expected_norm, gels_result):

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@tf32_on_and_off(0.001)
def test_qr(self, device):
def run_test(tensor_dims, some):
A = torch.randn(*tensor_dims, device=device)
Expand Down Expand Up @@ -11511,7 +11508,6 @@ def test_cdist_norm_batch(self, device):
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)

@tf32_on_and_off(0.005)
def test_cdist_large(self, device):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(1000, 10, device=device)
Expand All @@ -11521,7 +11517,6 @@ def test_cdist_large(self, device):
self.assertEqual(expected, actual)

@slowTest
@tf32_on_and_off(0.01)
def test_cdist_large_batch(self, device):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 1000, 10, device=device)
Expand All @@ -11530,7 +11525,6 @@ def test_cdist_large_batch(self, device):
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)

@tf32_on_and_off(0.005)
def test_cdist_non_contiguous(self, device):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(5, 7, device=device).transpose(-1, -2)
Expand All @@ -11557,7 +11551,6 @@ def test_cdist_non_contiguous(self, device):
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)

@tf32_on_and_off()
def test_cdist_non_contiguous_batch(self, device):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 2, 5, 7, device=device).transpose(-1, -2)
Expand Down Expand Up @@ -12394,7 +12387,6 @@ def test_empty_tensor_props(self, device):
self.assertEqual(x.stride(), y.stride())

@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@tf32_on_and_off(0.005)
def test_tensordot(self, device):
a = torch.arange(60., device=device).reshape(3, 4, 5)
b = torch.arange(24., device=device).reshape(4, 3, 2)
Expand Down Expand Up @@ -16478,7 +16470,6 @@ def test_addmm(self, device):
@dtypes(torch.float, torch.double)
@dtypesIfCUDA(*([torch.float, torch.double] +
([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
@tf32_on_and_off(0.005)
def test_addmm_sizes(self, device, dtype):
for m in [0, 1, 25]:
for n in [0, 1, 10]:
Expand Down Expand Up @@ -16928,7 +16919,6 @@ def test_remainder_edge_cases(self, device, dtype):
@onlyOnCPUAndCUDA
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
@dtypesIfCUDA(torch.float32, torch.float64)
@tf32_on_and_off(0.01)
def test_mm(self, device, dtype):
def _test_mm(n, m, p, dtype, genf):
# helper function
Expand Down Expand Up @@ -17974,7 +17964,6 @@ def test_pickle_gradscaler(self, device):
self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0)

@onlyCUDA
@tf32_on_and_off(0.005)
def test_mv_stride_0(self, device):
# Reference: https://github.com/pytorch/pytorch/issues/38315
mat = torch.randn(2, 2, device=device)
Expand Down Expand Up @@ -18930,6 +18919,8 @@ def test_split_view(self, device):

_float_types_no_half = [torch.float, torch.double]

_complex_types = [torch.cfloat, torch.cdouble]

# _float_types2 adds bfloat16 type to _float_types only on ROCm. Should eventually be unified
# with _float_types when bfloat16 bringup is complete on all platforms
_float_types2 = _float_types + [torch.bfloat16] if TEST_WITH_ROCM else _float_types
Expand Down Expand Up @@ -19104,13 +19095,13 @@ def inner(self, device, dtype):
('pow', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d).abs()],
1e-1, 1e-1, 1e-5, _float_types2),
('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
1e-1, 1e-1, 1e-4, _float_types2),
('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
[_wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
[_wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2),
('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
Expand All @@ -19135,26 +19126,25 @@ def inner(self, device, dtype):
1e-1, 1e-5, _types2, _cpu_types, True,
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
1e-1, 1e-1, 1e-4, _float_types2),
('addmm', 'scalar', _medium_2d,
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
[_wrap_maybe_warns("This overload of addmm_? is deprecated")]),
('addmm', 'two_scalars', _medium_2d,
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
[_wrap_maybe_warns("This overload of addmm_? is deprecated")]),
('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types,
True, [tf32_on_and_off(0.005)]),
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm),
('addmv', 'scalar', _medium_1d,
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
('addmv', 'two_scalars', _medium_1d,
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2),
('addr', 'scalar', _medium_2d,
Expand Down
25 changes: 13 additions & 12 deletions torch/backends/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,16 @@ def __setattr__(self, name, value):
return super(cuFFTPlanCacheManager, self).__setattr__(name, value)


class cuBLASModule:
def __getattr__(self, name):
assert name == "allow_tf32", "Unknown attribute " + name
return torch._C._get_cublas_allow_tf32()

def __setattr__(self, name, value):
assert name == "allow_tf32", "Unknown attribute " + name
return torch._C._set_cublas_allow_tf32(value)


cufft_plan_cache = cuFFTPlanCacheManager()
matmul = cuBLASModule()
class CUDAModule(object):
def __init__(self, m):
self.__dict__ = m.__dict__
# You have to retain the old module, otherwise it will
# get GC'ed and a lot of things will break. See:
# https://stackoverflow.com/questions/47540722/how-do-i-use-the-sys-modules-replacement-trick-in-init-py-on-python-2
self.__old_mod = m

cufft_plan_cache = cuFFTPlanCacheManager()

# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = CUDAModule(sys.modules[__name__])
Loading

0 comments on commit 3a63a93

Please sign in to comment.