-
Notifications
You must be signed in to change notification settings - Fork 322
Description
What we want to do is to enable FP8 quantization in PyTorch. Similar to INT8 quantization, this requires inserting quantize and dequantize operations into the computational graph. In order to reuse pattern matching logic of int8, we need register FP8 quant and dequant.
To address this, we attempted to register quant in #2379, but the PR was reverted in #2672 because it caused performance regression on H100 GPUs.
It will take a lot of effort to find the root cause of GPU regression.
Maybe we can register quant specifically for CPU, but this requires defining and registering a separate function for CPU.
@jerryzh168 @vkuzo Do you have some suggestions about it?
cc @Xia-Weiwen
I create following test to show the issue.
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "1"
os.environ["TORCHDYNAMO_PRINT_GUARD_FAILS"] = "1"
import torch
import torchao
dtype = torch.float
qtype = torch.float8_e4m3fn
def dequantize_per_tensor(
tensor: torch.Tensor,
scale: float,
output_dtype: torch.dtype
) -> torch.Tensor:
res = torchao.quantization.quant_primitives._dequantize_affine_float8(
tensor=tensor,
scale=torch.tensor([scale]),
output_dtype=torch.float
)
return res
def quantize_per_tensor(
tensor: torch.Tensor,
scale: float,
) -> torch.Tensor:
return torchao.quantization.quant_primitives._quantize_affine_float8(
tensor=tensor,
scale=torch.tensor([scale]),
float8_dtype=torch.float8_e4m3fn,
)
class FP8QDQLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.randn((out_features, in_features),).to(qtype)
self.weight_scale = 1.0
self.scale = 1.0
self.bias = None
def forward(self, input):
weight = dequantize_per_tensor(
self.weight.data,
self.weight_scale,
dtype,
)
q_input = quantize_per_tensor(
input,
self.scale,
)
dq_input = dequantize_per_tensor(
q_input,
self.scale,
dtype
)
out = torch.nn.functional.linear(dq_input, weight, self.bias)
return out
from torch._inductor import config as inductor_config
from torch._dynamo import config
config.error_on_recompile = True
#inductor_config.cpp_wrapper = True
inductor_config.max_autotune = False
inductor_config.freezing = True
inductor_config.aot_inductor.debug_compile = False
model = FP8QDQLinear(13, 16)
example_inputs = (torch.randn(128, 13),)
with torch.no_grad():
refe = model(*example_inputs)
test_eager = model(*example_inputs)
model = torch.compile(model)
model(*example_inputs)
test = model(*example_inputs)
Outputting log on freezing_patterns.py shows that the quant has been decomposed to clamp_min, clamp_max and convert_element_type.
# print(gm)
<lambda>()
def forward(self, arg1_1):
arg0_1 = self._frozen_param0
full_default = torch.ops.aten.full.default([1], 1.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8.default(arg0_1, full_default); arg0_1 = None
clamp_min = torch.ops.aten.clamp_min.default(arg1_1, -448.0); arg1_1 = None
clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 448.0); clamp_min = None
convert_element_type = torch.ops.prims.convert_element_type.default(clamp_max, torch.float8_e4m3fn); clamp_max = None
dequantize_affine_float8_1 = torch.ops.torchao.dequantize_affine_float8.default(convert_element_type, full_default); convert_element_type = full_default = None
permute = torch.ops.aten.permute.default(dequantize_affine_float8, [1, 0]); dequantize_affine_float8 = None
mm = torch.ops.aten.mm.default(dequantize_affine_float8_1, permute); dequantize_affine_float8_1 = permute = None
return (mm,)
For comparison, here are the results of int8. Quant will be used as a separate operator(torch.ops.quantized_decomposed.quantize_per_tensor.default).
def forward(self, arg4_1):
arg0_1 = self._frozen_param0
arg1_1 = self._frozen_param1
arg2_1 = self._frozen_param2
arg3_1 = self._frozen_param3
dequantize_per_channel = torch.ops.quantized_decomposed.dequantize_per_channel.default(arg3_1, arg1_1, arg2_1, 0, -128, 127, torch.int8); arg3_1 = arg1_1 = arg2_1 = None
quantize_per_tensor = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg4_1, 0.027873406186699867, 128, 0, 255, torch.uint8); arg4_1 = None
dequantize_per_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor, 0.027873406186699867, 128, 0, 255, torch.uint8); quantize_per_tensor = None
permute = torch.ops.aten.permute.default(dequantize_per_channel, [1, 0]); dequantize_per_channel = None
addmm = torch.ops.aten.addmm.default(arg0_1, dequantize_per_tensor, permute); arg0_1 = dequantize_per_tensor = permute = None
relu = torch.ops.aten.relu.default(addmm); addmm = None
return (relu,)