Skip to content

[CPU][FP8][Inductor] How to support fp8 quant for inductor on CPU #2896

@shiyang-weng

Description

@shiyang-weng

What we want to do is to enable FP8 quantization in PyTorch. Similar to INT8 quantization, this requires inserting quantize and dequantize operations into the computational graph. In order to reuse pattern matching logic of int8, we need register FP8 quant and dequant.

To address this, we attempted to register quant in #2379, but the PR was reverted in #2672 because it caused performance regression on H100 GPUs.

It will take a lot of effort to find the root cause of GPU regression.
Maybe we can register quant specifically for CPU, but this requires defining and registering a separate function for CPU.
@jerryzh168 @vkuzo Do you have some suggestions about it?
cc @Xia-Weiwen

I create following test to show the issue.

import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TORCHINDUCTOR_FREEZING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "1"
os.environ["TORCHDYNAMO_PRINT_GUARD_FAILS"] = "1"

import torch
import torchao

dtype = torch.float
qtype = torch.float8_e4m3fn

def dequantize_per_tensor(
    tensor: torch.Tensor,
    scale: float,
    output_dtype: torch.dtype
) -> torch.Tensor:
    res = torchao.quantization.quant_primitives._dequantize_affine_float8(
        tensor=tensor,
        scale=torch.tensor([scale]),
        output_dtype=torch.float
    )
    return res

def quantize_per_tensor(
    tensor: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    return torchao.quantization.quant_primitives._quantize_affine_float8(
        tensor=tensor,
        scale=torch.tensor([scale]),
        float8_dtype=torch.float8_e4m3fn,
    )


class FP8QDQLinear(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = torch.randn((out_features, in_features),).to(qtype)
        self.weight_scale = 1.0
        self.scale = 1.0
        self.bias = None

    def forward(self, input):
        weight = dequantize_per_tensor(
            self.weight.data,
            self.weight_scale,
            dtype,
        )
        q_input = quantize_per_tensor(
            input,
            self.scale,
        )

        dq_input = dequantize_per_tensor(
            q_input,
            self.scale,
            dtype
        )
        out = torch.nn.functional.linear(dq_input, weight, self.bias)

        return out

from torch._inductor import config as inductor_config
from torch._dynamo import config

config.error_on_recompile = True
#inductor_config.cpp_wrapper = True
inductor_config.max_autotune = False
inductor_config.freezing = True

inductor_config.aot_inductor.debug_compile = False


model = FP8QDQLinear(13, 16)
example_inputs = (torch.randn(128, 13),)

with torch.no_grad():
    refe = model(*example_inputs)
    test_eager = model(*example_inputs)
    model = torch.compile(model)
    model(*example_inputs)
    test = model(*example_inputs)

Outputting log on freezing_patterns.py shows that the quant has been decomposed to clamp_min, clamp_max and convert_element_type.

# print(gm)
<lambda>()



def forward(self, arg1_1):
    arg0_1 = self._frozen_param0
    full_default = torch.ops.aten.full.default([1], 1.0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    dequantize_affine_float8 = torch.ops.torchao.dequantize_affine_float8.default(arg0_1, full_default);  arg0_1 = None
    clamp_min = torch.ops.aten.clamp_min.default(arg1_1, -448.0);  arg1_1 = None
    clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 448.0);  clamp_min = None
    convert_element_type = torch.ops.prims.convert_element_type.default(clamp_max, torch.float8_e4m3fn);  clamp_max = None
    dequantize_affine_float8_1 = torch.ops.torchao.dequantize_affine_float8.default(convert_element_type, full_default);  convert_element_type = full_default = None
    permute = torch.ops.aten.permute.default(dequantize_affine_float8, [1, 0]);  dequantize_affine_float8 = None
    mm = torch.ops.aten.mm.default(dequantize_affine_float8_1, permute);  dequantize_affine_float8_1 = permute = None
    return (mm,)

For comparison, here are the results of int8. Quant will be used as a separate operator(torch.ops.quantized_decomposed.quantize_per_tensor.default).

def forward(self, arg4_1):
    arg0_1 = self._frozen_param0
    arg1_1 = self._frozen_param1
    arg2_1 = self._frozen_param2
    arg3_1 = self._frozen_param3
    dequantize_per_channel = torch.ops.quantized_decomposed.dequantize_per_channel.default(arg3_1, arg1_1, arg2_1, 0, -128, 127, torch.int8);  arg3_1 = arg1_1 = arg2_1 = None
    quantize_per_tensor = torch.ops.quantized_decomposed.quantize_per_tensor.default(arg4_1, 0.027873406186699867, 128, 0, 255, torch.uint8);  arg4_1 = None
    dequantize_per_tensor = torch.ops.quantized_decomposed.dequantize_per_tensor.default(quantize_per_tensor, 0.027873406186699867, 128, 0, 255, torch.uint8);  quantize_per_tensor = None
    permute = torch.ops.aten.permute.default(dequantize_per_channel, [1, 0]);  dequantize_per_channel = None
    addmm = torch.ops.aten.addmm.default(arg0_1, dequantize_per_tensor, permute);  arg0_1 = dequantize_per_tensor = permute = None
    relu = torch.ops.aten.relu.default(addmm);  addmm = None
    return (relu,)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions