Skip to content

Commit

Permalink
[Hexagon][UnitTest] Disable flaky quantization test (apache#16337)
Browse files Browse the repository at this point in the history
* [Hexagon][UnitTest] Disable flaky quantization test

The `test_pass_fq2i_avg_pool2d.py::test_avgpool_conv2d` test is
sensitive to rounding errors, and failed about a third of the time (42
/ 100 tests).  This was first noticed as CI failures in unrelated
PRs (e.g. https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-hexagon/detail/PR-16184/6/tests).
This commit marks the flaky portions of the test with
`pytest.mark.xfail`, to avoid causing breaking CI for other PRs.

To minimize the extent of the disabled test cases, this commit breaks
up each of the unit tests.  Where previously a single test performed
both hardware/simulation tests and relay graph comparisons, these are
now done in separate test functions.  The hardware/simulation tests
use `tvm.testing.assert_allclose` and
have a tolerance of `1e-02`, while the graph-comparison tests use
`tvm.ir.structural_equal`, and require identical floating-point
values.  Only the graph-comparison test is disabled here.

The other two test cases in `test_pass_fq2i_avg_pool2d.py` do not show
this same sensitivity, with no failures seen in 100 executions.

* Disable pylint for pytest fixture names
  • Loading branch information
Lunderberg authored Jan 3, 2024
1 parent eb15d04 commit 42b4f21
Showing 1 changed file with 69 additions and 46 deletions.
115 changes: 69 additions & 46 deletions tests/python/contrib/test_hexagon/test_pass_fq2i_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,24 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=redefined-outer-name

""" Tests for avg_pool2d fake quantization to integer """

import numpy as np
import pytest

import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relay
from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET
from .infrastructure import quantize_np, build_module, run_module


def compare_graphs(expr, ref_expr):
"""Compares the given graph with the expected graph"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
ref_mod = tvm.IRModule.from_expr(ref_expr)
ref_mod = tvm.relay.transform.InferType()(ref_mod)
assert tvm.ir.structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True)


def compare_fq_to_int(hexagon_session, expr, inputs):
"""Compares the float module output with the integer module output"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

mod = build_module(
mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)
mod_int = build_module(
mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)

hexagon_mod = hexagon_session.get_executor_from_factory(mod)
result = run_module(hexagon_mod, inputs)

hexagon_mod = hexagon_session.get_executor_from_factory(mod_int)
result_int = run_module(hexagon_mod, inputs)

tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02)
from .infrastructure import quantize_np, build_module, run_module


@tvm.testing.requires_hexagon
def test_avgpool_conv2d(hexagon_session: Session):
def _make_avgpool_conv2d():
"""Test case with avg_pool2d followed by a conv2d"""
dtype = "int8"
shape_x = [1, 2, 9, 9]
Expand Down Expand Up @@ -112,8 +83,6 @@ def test_avgpool_conv2d(hexagon_session: Session):
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
args = {"input": input_quant, "weight": weight_quant}

compare_fq_to_int(hexagon_session, expr, args)

# Expected graph
op0 = relay.qnn.op.avg_pool2d(
inp,
Expand Down Expand Up @@ -148,11 +117,11 @@ def test_avgpool_conv2d(hexagon_session: Session):
out_dtype="int8",
)
ref_expr = relay.qnn.op.dequantize(op2, out_sc, out_zp)
compare_graphs(expr, ref_expr)

return expr, args, ref_expr

@tvm.testing.requires_hexagon
def test_avgpool_avgpool(hexagon_session: Session):

def _make_avgpool_avgpool():
"""Test case with avg_pool2d followed by an avg_pool2d"""
dtype = "uint8"
shape_x = [1, 2, 9, 9]
Expand Down Expand Up @@ -197,7 +166,6 @@ def test_avgpool_avgpool(hexagon_session: Session):
expr = relay.qnn.op.quantize(op2, out_sc, out_zp, out_dtype=dtype)
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
args = {"input": input_quant}
compare_fq_to_int(hexagon_session, expr, args)

# Expected graph
op0 = relay.qnn.op.avg_pool2d(
Expand Down Expand Up @@ -227,12 +195,11 @@ def test_avgpool_avgpool(hexagon_session: Session):
count_include_pad=False,
)
ref_expr = relay.qnn.op.dequantize(op1, out_sc, out_zp)
compare_graphs(expr, ref_expr)

return expr, args, ref_expr

@tvm.testing.requires_hexagon
def test_avgpool(hexagon_session: Session):
"""Test case of a regular avg_pool2d"""

def _make_avgpool():
dtype = "int8"
shape_x = [1, 2, 9, 9]
kernel = [3, 3]
Expand Down Expand Up @@ -266,7 +233,6 @@ def test_avgpool(hexagon_session: Session):
expr = relay.qnn.op.quantize(op1, out_sc, out_zp, out_dtype=dtype)
expr = relay.qnn.op.dequantize(expr, out_sc, out_zp)
args = {"input": input_quant}
compare_fq_to_int(hexagon_session, expr, args)

# Expected graph
op = relay.qnn.op.avg_pool2d(
Expand All @@ -283,6 +249,63 @@ def test_avgpool(hexagon_session: Session):
count_include_pad=False,
)
ref_expr = relay.qnn.op.dequantize(op, out_sc, out_zp)

return expr, args, ref_expr


def compare_graphs(expr, ref_expr):
"""Compares the given graph with the expected graph"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
ref_mod = tvm.IRModule.from_expr(ref_expr)
ref_mod = tvm.relay.transform.InferType()(ref_mod)
tvm.ir.assert_structural_equal(mod_int["main"], ref_mod["main"], map_free_vars=True)


def compare_fq_to_int(hexagon_session, expr, inputs):
"""Compares the float module output with the integer module output"""
mod = tvm.IRModule.from_expr(expr)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger()(mod)
assert not tvm.ir.structural_equal(mod, mod_int)

mod = build_module(
mod, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)
mod_int = build_module(
mod_int, tvm.target.Target(HEXAGON_AOT_LLVM_TARGET, host=HEXAGON_AOT_LLVM_TARGET)
)

hexagon_mod = hexagon_session.get_executor_from_factory(mod)
result = run_module(hexagon_mod, inputs)

hexagon_mod = hexagon_session.get_executor_from_factory(mod_int)
result_int = run_module(hexagon_mod, inputs)

tvm.testing.assert_allclose(result, result_int, rtol=1e-02, atol=1e-02)


avgpool_test_case = tvm.testing.parameter(
_make_avgpool,
_make_avgpool_avgpool,
pytest.param(
_make_avgpool_conv2d,
marks=pytest.mark.xfail(
reason="Rounding differences causing mismatch of Constant, difference around 10^-7"
),
),
)


@tvm.testing.requires_hexagon
def test_execution(hexagon_session: Session, avgpool_test_case):
expr, args, _ = avgpool_test_case()
compare_fq_to_int(hexagon_session, expr, args)


def test_quantization(avgpool_test_case):
expr, _, ref_expr = avgpool_test_case()
compare_graphs(expr, ref_expr)


Expand Down

0 comments on commit 42b4f21

Please sign in to comment.