Skip to content

Commit

Permalink
[quant][core][gpu][feature] Implemented quantized cuda gelu
Browse files Browse the repository at this point in the history
Summary:
Support for quantized cuda gelu has been provided by using
`dequantize -> fp32 cuda gelu kernel -> quantize`. Mathematically, this
is not equivalent to doing int8 gelu, so we have opted for this approach
for now. It might be possible to write a variant of the int8 gelu that's
equivalent to `dequantize -> fp32 cuda gelu kernel -> quantize`, which
can be a topic for future work.

Test function `test_qgelu` was amended to test gelu for quantized cuda
backends.

Test Plan:
```
python test/test_quantization.py -k test_qgelu
```

Pull Request resolved: pytorch#77212

Approved by: https://github.com/jerryzh168
  • Loading branch information
dzdang authored and pytorchmergebot committed May 24, 2022
1 parent b7bb34d commit 2aad28a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ filegroup(
"aten/src/ATen/native/cudnn/*.cpp",
"aten/src/ATen/native/miopen/*.cpp",
"aten/src/ATen/native/nested/cuda/*.cpp",
"aten/src/ATen/native/quantized/cuda/*.cpp",
"aten/src/ATen/native/quantized/cudnn/*.cpp",
"aten/src/ATen/native/sparse/cuda/*.cpp",
"aten/src/ATen/native/transformers/cuda/*.cpp",
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3987,6 +3987,7 @@
dispatch:
MkldnnCPU: mkldnn_gelu
QuantizedCPU: gelu_quantized_cpu
QuantizedCUDA: gelu_quantized_cuda
NestedTensorCPU, NestedTensorCUDA: NestedTensor_gelu

- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/native/quantized/cuda/Activation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <c10/util/Exception.h>
#include <ATen/ATen.h>

namespace at {
namespace native {

// this kernel is currently implemented with dequantize -> fp32 gelu -> quantize, which is not equivalent to int8 gelu
// It might be possible to write a variant of the int8 gelu that's equivalent to dequantize -> fp32 cuda gelu kernel -> quantize,
// which can be a topic for future work.
Tensor gelu_quantized_cuda(const Tensor& qx, c10::string_view approximate) {
(void)approximate; // suppress unused variable lint warning
if (qx.numel() == 0) {
return Tensor{};
}
auto x_fp32 = at::dequantize(qx);
auto result_fp32 = at::gelu(x_fp32);
return at::quantize_per_tensor(result_fp32, qx.q_scale(), qx.q_zero_point(), qx.scalar_type());
}

} // namespace at::native
} // namespace at
27 changes: 15 additions & 12 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
qengine_is_onednn,
)
from torch.ao.quantization import PerChannelMinMaxObserver
from torch.testing._internal.common_cuda import TEST_CUDNN
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDA
import torch.backends.xnnpack

from typing import Optional
Expand Down Expand Up @@ -447,24 +447,27 @@ def test_qgelu(self):
memory_formats = (torch.channels_last, torch.contiguous_format)
approximation = ['none', 'tanh']
test_cases = itertools.product(shapes, dtypes, memory_formats, approximation)
devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"]
for shape, dtype, memory_format, approximate in test_cases:
if memory_format == torch.channels_last and len(shape) != 4:
continue

X, scale, zero_point, torch_type = \
torch.randn(*shape), 0.1, 0, dtype
X = X.to(memory_format=memory_format)
for device in devices:
X = X.to(device=device)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
dqX = qX.dequantize()

qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
dqX = qX.dequantize()

op = torch.nn.functional.gelu
dqY = op(dqX, approximate=approximate)
qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
dtype=torch_type)
qY_hat = op(qX)
self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
msg="F.gelu failed ({} vs {})".format(qY, qY_hat))
op = torch.nn.functional.gelu
dqY = op(dqX, approximate=approximate)
qY = torch.quantize_per_tensor(dqY, scale=scale, zero_point=zero_point,
dtype=torch_type)
qY_hat = op(qX)
self.assertEqual(qY.dequantize(), qY_hat.dequantize(),
msg="F.gelu failed ({} vs {})".format(qY, qY_hat))

"""Tests the correctness of the quantized::qlayer_norm op."""
@skipIfNoFBGEMM
Expand Down

0 comments on commit 2aad28a

Please sign in to comment.