Skip to content

Commit

Permalink
[quant][pt2] Fix no conv bias in convert QAT (pytorch#103298)
Browse files Browse the repository at this point in the history
Summary:
Previously, the QAT pattern for conv + bn with no conv
bias was not actually replaced in convert. This commit adds an
extra pattern in the convert path for this case and the numerics
now match FX's.

Test Plan: python test/test_quantization.py TestQuantizePT2E.test_prepare_qat_conv_bn_fusion_no_conv_bias

Reviewed By: jerryzh168

Differential Revision: D46382819

Pull Request resolved: pytorch#103298
Approved by: https://github.com/jerryzh168
  • Loading branch information
andrewor14 authored and pytorchmergebot committed Jun 16, 2023
1 parent a52b6f0 commit dad29f9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 36 deletions.
15 changes: 7 additions & 8 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,12 +1400,11 @@ def forward(self, x):
M(), example_inputs, is_per_channel=True
)

def test_prepare_qat_conv_bn_fusion_no_conv_bias(self):
def test_qat_conv_bn_fusion_no_conv_bias(self):
class M2(torch.nn.Module):
"""
Mixed conv + BN with and without conv bias.
"""

def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3, bias=False)
Expand All @@ -1423,25 +1422,25 @@ def forward(self, x):
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
example_inputs = (torch.randn(3, 3, 5, 5),)
self._verify_symmetric_qnnpack_qat_graph(
m1, example_inputs, is_per_channel=False, has_relu=False, has_bias=False
m1, example_inputs, is_per_channel=False, has_relu=False, has_bias=False,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
self._verify_symmetric_qnnpack_qat_graph(
m1, example_inputs, is_per_channel=True, has_relu=False, has_bias=False
m1, example_inputs, is_per_channel=True, has_relu=False, has_bias=False,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
self._verify_symmetric_qnnpack_qat_numerics(
m1, example_inputs, is_per_channel=False
m1, example_inputs, is_per_channel=False, verify_convert=True,
)
m1 = TestHelperModules.ConvWithBNRelu(relu=False, bias=False)
self._verify_symmetric_qnnpack_qat_numerics(
m1, example_inputs, is_per_channel=True
m1, example_inputs, is_per_channel=True, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M2(), example_inputs, is_per_channel=False
M2(), example_inputs, is_per_channel=False, verify_convert=True,
)
self._verify_symmetric_qnnpack_qat_numerics(
M2(), example_inputs, is_per_channel=True
M2(), example_inputs, is_per_channel=True, verify_convert=True,
)

def test_prepare_qat_conv_bn_relu_fusion(self):
Expand Down
1 change: 1 addition & 0 deletions test/quantization/pt2e/test_quantize_pt2e_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def forward(self, x):
self.assertEqual(ref_result, inductor_res, atol=5e-2, rtol=5e-2)

@skipIfNoX86
@unittest.skip("Fails due to small numerics mismatch, reenable this with the new API in the future")
def test_inductor_qconv_lowering(self):
dim_to_module = {
1: nn.Conv1d,
Expand Down
69 changes: 44 additions & 25 deletions torch/ao/quantization/_pt2e/qat_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import itertools
import operator
from typing import Any, Callable, List, Tuple
from typing import Any, Callable, Dict, List, Tuple

import torch
from torch.fx import Graph, GraphModule, Node
Expand All @@ -25,14 +25,32 @@
_quantized_conv2d_bn_pattern_example_inputs = (
torch.randn(1, 1, 3, 3), # x
torch.randn(1, 1, 1, 1), # conv_weight
torch.randn(1), # conv_bias
torch.randn(1), # bn_weight
torch.randn(1), # bn_bias
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
_weight_scale = torch.tensor([1], dtype=torch.float)
_weight_zero_point = torch.tensor([0], dtype=torch.int)

def _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(
is_per_channel: bool,
has_bias: bool,
) -> Dict[str, Any]:
"""
Optional example inputs for both `_quantized_qat_conv2d_bn_pattern`
and `_folded_quantized_qat_conv2d_bn_pattern`, expressed as kwargs.
Note that weight_scale and weight_zero_point are only used when
`is_per_channel` is True. This is because for per tensor quantization,
scale and zero point are hard coded into quantize/dequantize ops
in the pattern.
"""
kwargs = {}
if is_per_channel:
kwargs["weight_scale"] = torch.tensor([1], dtype=torch.float)
kwargs["weight_zero_point"] = torch.tensor([0], dtype=torch.int)
if has_bias:
kwargs["conv_bias"] = torch.randn(1)
return kwargs

def _conv2d_bn_pattern(
x: torch.Tensor,
Expand All @@ -47,6 +65,7 @@ def _conv2d_bn_pattern(
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True)
return x

# TODO: merge this with the `no_conv_bias` case
def _qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
Expand Down Expand Up @@ -152,7 +171,7 @@ def _input_output_quantized_filter(
return _input_output_quantized_filter


def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool, has_bias: bool):
"""
Return the quantized version of QAT conv + BN pattern.
This is based on `nniqat.ConvBn2d._forward_approximate`,
Expand All @@ -169,7 +188,6 @@ def _get_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
def _quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
Expand All @@ -183,7 +201,6 @@ def _quantized_qat_conv2d_bn_pattern(
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(conv_bias, dtype=x.dtype)
if is_per_channel:
scaled_weight = torch.ops.quantized_decomposed.quantize_per_channel(
scaled_weight, kwargs['weight_scale'], kwargs['weight_zero_point'], per_channel_axis,
Expand All @@ -200,16 +217,21 @@ def _quantized_qat_conv2d_bn_pattern(
scaled_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
scaled_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
)
x = F.conv2d(x, scaled_weight, zero_bias)
if has_bias:
zero_bias = torch.zeros_like(kwargs["conv_bias"], dtype=x.dtype)
x = F.conv2d(x, scaled_weight, zero_bias)
else:
x = F.conv2d(x, scaled_weight, None)
x = x / scale_factor.reshape(bias_shape)
x = x + conv_bias.reshape(bias_shape)
if has_bias:
x = x + kwargs["conv_bias"].reshape(bias_shape)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
if has_relu:
x = F.relu(x)
return x
return _quantized_qat_conv2d_bn_pattern

def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool):
def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu: bool, has_bias: bool):
"""
Quantized QAT conv - bn pattern with bn weights being folded into conv.
"""
Expand All @@ -222,7 +244,6 @@ def _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel: bool, has_relu:
def _folded_quantized_qat_conv2d_bn_pattern(
x: torch.Tensor,
conv_weight: torch.Tensor,
conv_bias: torch.Tensor,
bn_weight: torch.Tensor,
bn_bias: torch.Tensor,
bn_running_mean: torch.Tensor,
Expand All @@ -245,7 +266,10 @@ def _folded_quantized_qat_conv2d_bn_pattern(
conv_weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
conv_weight, 1.0, int(0), weight_quant_min, weight_quant_max, torch.int8,
)
x = F.conv2d(x, conv_weight, conv_bias)
if has_bias:
x = F.conv2d(x, conv_weight, kwargs["conv_bias"])
else:
x = F.conv2d(x, conv_weight, None)
x = F.batch_norm(x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True, eps=bn_eps)
if has_relu:
x = F.relu(x)
Expand Down Expand Up @@ -478,20 +502,15 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
replacement_options = itertools.product(
[True, False], # is_per_channel
[True, False], # has_relu
[True, False], # has_bias
)
for is_per_channel, has_relu in replacement_options:
for is_per_channel, has_relu, has_bias in replacement_options:
example_inputs = _quantized_conv2d_bn_pattern_example_inputs
kwargs_args = {}
# Note that weight_scale and weight_zero_point are only used when is_per_channel is True
# This is because for per tensor quantization, scale and zero point are hard coded
# into quantize/dequantize ops in the pattern.
if is_per_channel:
kwargs_args['weight_scale'] = _weight_scale
kwargs_args['weight_zero_point'] = _weight_zero_point
match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu)
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs_args)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu)
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs_args)
kwargs = _get_quantized_conv2d_bn_pattern_example_inputs_kwargs(is_per_channel, has_bias)
match_pattern = _get_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu, has_bias)
match_pattern = _get_aten_graph_module(match_pattern, example_inputs, **kwargs)
replacement_pattern = _get_folded_quantized_qat_conv2d_bn_pattern(is_per_channel, has_relu, has_bias)
replacement_pattern = _get_aten_graph_module(replacement_pattern, example_inputs, **kwargs)
replacements.extend(
replace_pattern_with_filters(
m,
Expand Down Expand Up @@ -526,7 +545,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
assert isinstance(conv_weight, Node)
assert conv_weight.op == "get_attr"
conv_bias = conv_node.args[2]
assert isinstance(conv_bias, Node)
assert conv_bias is None or isinstance(conv_bias, Node)

(weight_q_node, weight_dq_node) = _get_fused_convbn_q_dq_nodes(r.replacements)
original_weight_q_node = None
Expand Down
7 changes: 4 additions & 3 deletions torch/ao/quantization/_pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
_is_activation_post_process_node,
)
import operator
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple


def _get_tensor_constant_from_node(node, m):
Expand All @@ -32,7 +32,7 @@ def _get_all_arguments(orig_args, orig_kwargs, args_schema):
def _fold_bn_weights_into_conv_node(
conv_node: Node,
conv_weight_node: Node,
conv_bias_node: Node,
conv_bias_node: Optional[Node],
bn_node: Node,
m: GraphModule
) -> None:
Expand Down Expand Up @@ -63,7 +63,8 @@ def _fold_bn_weights_into_conv_node(
conv_args = list(conv_node.args)
# calling data since the fused_weight and fused_bias are nn.Parameter
weight_attr_name = conv_weight_node.target
setattr(m, weight_attr_name, fused_weight) # type: ignore[arg-type]
assert isinstance(weight_attr_name, str)
setattr(m, weight_attr_name, fused_weight)
if conv_bias_node is not None:
bias_attr_name = conv_bias_node.target
else:
Expand Down

0 comments on commit dad29f9

Please sign in to comment.