Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] int8 compute with PostTraining Quantization (W8A8) #2432

Open
ptc-hacharya opened this issue Jan 14, 2025 · 4 comments
Open
Labels
question Response providing clarification needed. Will not be assigned to a release. (type)

Comments

@ptc-hacharya
Copy link

ptc-hacharya commented Jan 14, 2025

❓Question

Im trying to understand the runtime behavior of W8A8 quantized networks on CoreML. I have written up a very simple model as follows :

class SimpleModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1)
            self.relu = nn.ReLU()

        def forward(self, x):
            x = self.conv(x)
            x = self.relu(x)
            return x

I convert this model to W8A8 format using post-training quantisation and export it to an .mlpackage via the following code :

def quantise(
    model: torch.nn.Module,
) -> torch.nn.Module:
    import tqdm
    from coremltools.optimize.torch.quantization import (
        LinearQuantizer,
        LinearQuantizerConfig,
        ModuleLinearQuantizerConfig,
    )

    # make dataloader for calibration
    dataloader = ...

    config = LinearQuantizerConfig(
        global_config=ModuleLinearQuantizerConfig(
            weight_dtype="qint8",
            activation_dtype="quint8",
            quantization_scheme="symmetric",
            milestones=[0, 1000, 1000, 0],
        )
    )
    quantizer = LinearQuantizer(model, config)

    example_inputs = next(iter(dataloader))
    quantizer.prepare(example_inputs=example_inputs, inplace=True)

    quantizer.step()

    # Do a forward pass through the model with calibration data
    for data in tqdm.tqdm(dataloader, desc="calibrating"):
        with torch.no_grad():
            model(data)

    quantized_model = quantizer.finalize()
    return quantized_model

def export_super_simple():
    import coremltools as ct
    import torch

    torch_model = SimpleModel()
    torch_model.eval()
    torch_model.cpu()

    # quantise model
    torch_model = quantise(torch_model)

    # trace model for export
    inputs = torch.rand(1, 1, 480, 640)
    input_shapes = [ct.TensorType(shape=inputs.shape)]
    outputs = [
        ct.TensorType(name="scores"),
    ]
    export_model = torch.jit.trace(torch_model, inputs, strict=True)
    _ = export_model(inputs) 

    # coreml export
    model = ct.convert(
        export_model,
        inputs=input_shapes,
        minimum_deployment_target=ct.target.iOS18,
        debug=True,
        outputs=outputs,
    )
    model.save(f"super_simple_model.mlpackage")  

From the graph of the exported coreml-model :
Image

  • the inputs and weights of the convolution operator are dequantised to fp-16
  • conv + relu run in fp-16
  • relu outputs are quantized again to int8

My questions are :

  • Is this the expected behavior to have dequant / quant applied at operator boundaries while actual compute runs in f16? From the documentation of coremltools (link) it looks like there's quite a bit of speed up that one can expect with W8A8 models. Do those w8a8 models run compute at fp-16? If so is the speedup coming from having lesser activation data between layers?
  • How do I get the exported model to run integer computation? that is instead of having these dequant -> quant operators before and after each op, I would like to see the network inputs quantized to int8 right at the beginning -> whole pipeline runs only on int8 tensors -> final network outputs dequantised back to f16. My goal is to see if a pure integer only compute gives me any inference speedup.
@ptc-hacharya ptc-hacharya added the question Response providing clarification needed. Will not be assigned to a release. (type) label Jan 14, 2025
@pulkital
Copy link
Collaborator

The compiler further lowers the CoreML model. It recognizes patterns like dequant -> conv -> relu -> quant and fuses them so that the hardware actually uses int8-int8 compute to execute the sequence of ops.

@NehalBhandari
Copy link
Collaborator

The quant / dequant layers are expected. These are inserted in the torch model during the “prepare” stage to simulate quantization effects during training, allowing the model to adapt to reduced precision, while still training in fp16. During training we also estimate scale and zero points for actual quantization.

The speedup for W8A8 models comes from int8-int8 compute on the Neural Engine (https://apple.github.io/coremltools/docs-guides/source/opt-quantization-perf.html#performance). The NE compiler automatically detects and fuses dequant -> op (s) -> quant patterns in the model to run int8-int8 compute wherever possible.

I believe your current quantized model should run integer only compute when run on NE.

@ptc-hacharya
Copy link
Author

ptc-hacharya commented Jan 15, 2025

Thanks for the quick response! okay, makes sense that the coreml compiler further optimizes out a dequant -> conv -> relu -> quant pattern and this finally runs as an integer compute.

I ran a profiler on the two networks :
f16
Image

w8a8
Image

I can see that the conv operator has gotten faster (164 micro -> 63 micro), but is it really the case that the dequant -> conv -> relu -> quant was replaced by conv8 -> relu8? If that were the case, I wouldnt have expected to see profiling entries for the dq, q entries at # 5, 16. which gives me the impression that they very much exist in the final compute graph.

My follow up question was, I see a huge profiling difference if I modify the code very slightly to accept and return (b,h,w,c) tensors:

def forward(self, x: Annotated[torch.tensor, "(b, h, w, c)"]) -> Annotated[torch.tensor, "(b, h, w, c)"]:
            x = torch.permute(x, (0, 3, 1, 2))  # convert to (b, c, h, w)  
            x = self.conv(x)
            x = self.relu(x)
            x = torch.permute(x, (0, 2, 3, 1)) # back to (b, h, w, c)
            return x

I find that the runtime of the conv operator (W8A8) shoots up to 544 micro (from 63 micro as seen above):

Image

I can't follow why this makes such a big impact, the conv operator still works on (b,c,h,w) tensors. is it still doing an int8 compute or is something going wrong? What's also quite strange is that my input has a single channel : so its (1,480,640,1) -> ((1,1,480,640)) and at least memory layout-wise those tensors should be the same. in general are there recommendations on the channel ordering of input tensors (or anything else input/output-wise) which plays a role in terms of inference speed on the neural-engine?

@dessatel
Copy link

Could be this, but due to "1" for channel in conv

https://machinelearning.apple.com/research/neural-engine-transformers

For example, if the last axis is used as a singleton one by the model implementation’s data format, it will be padded to 64 bytes, which results in 32 times the memory cost in 16-bit and 64 times the memory cost in 8-bit precision. Such an increase in buffer size will significantly reduce the chance of L2 cache residency and increase the chance of hitting DRAM. This is not desirable from a power and latency perspective.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Response providing clarification needed. Will not be assigned to a release. (type)
Projects
None yet
Development

No branches or pull requests

4 participants