From 731af42d1b851258746919d590d8ade0a1077e63 Mon Sep 17 00:00:00 2001 From: Christopher Sidebottom Date: Mon, 9 May 2022 09:23:44 +0100 Subject: [PATCH] [CMSIS-NN] Increase partitioning accuracy for pooling (#11229) This ensures that CMSIS-NN is only used when the batch size and layout are correct for the library calls. --- python/tvm/relay/op/contrib/cmsisnn.py | 38 +++++++++- .../contrib/test_cmsisnn/test_pooling.py | 73 +++++++++++++------ 2 files changed, 84 insertions(+), 27 deletions(-) diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py index e39fa034c571..1a06867e5485 100644 --- a/python/tvm/relay/op/contrib/cmsisnn.py +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -31,6 +31,12 @@ def enabled(): return "cmsis-nn" in Target.list_kinds() +def _find_last(pattern): + if hasattr(pattern, "args"): + return _find_last(pattern.args[0]) + return pattern + + def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts): """Partition the graph greedily offloading supported operators on Cortex-M using CMSIS-NN @@ -199,9 +205,20 @@ def qnn_avg_pool2d_pattern(): def check_qnn_avg_pool2d(pattern): """Check if avg pool2d is supported by CMSIS-NN.""" - in_cast = pattern - out_cast = in_cast.args[0].args[0] - return in_cast.checked_type.dtype == "int8" and out_cast.checked_type.dtype == "int32" + output = pattern + input_var = _find_last(pattern) + + if str(pattern.op.name) == "clip": + pooling = pattern.args[0].args[0] + else: + pooling = pattern.args[0] + + return ( + pooling.attrs.layout == "NHWC" + and bool(input_var.checked_type.shape[0] == 1) + and input_var.checked_type.dtype == "int8" + and output.checked_type.dtype == "int8" + ) def qnn_max_pool2d_pattern(): """Matches max pool2d with optional Relu""" @@ -211,7 +228,20 @@ def qnn_max_pool2d_pattern(): def check_qnn_max_pool2d(pattern): """Check if max pool2d is supported by CMSIS-NN.""" - return True + output = pattern + input_var = _find_last(pattern) + + if str(pattern.op.name) == "clip": + pooling = pattern.args[0] + else: + pooling = pattern + + return ( + pooling.attrs.layout == "NHWC" + and bool(input_var.checked_type.shape[0] == 1) + and input_var.checked_type.dtype == "int8" + and output.checked_type.dtype == "int8" + ) def binary_op_pattern(op): """Matches QNN binary operation""" diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index 732fd9bb82ec..cca1288ac2a0 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -44,8 +44,19 @@ ) -def make_model(pool_op, shape, pool_size, strides, padding, dtype, scale, zero_point, relu_type): - """Return a model and any parameters it may have""" +def make_model( + pool_op, + shape=(1, 28, 28, 12), + pool_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype="int8", + scale=1, + zero_point=-33, + relu_type="RELU", + layout="NHWC", +): + """Return a model and any parameters it may have, all parameters are defaulted to known good values""" op = relay.var("input", shape=shape, dtype=dtype) pad_ = (0, 0, 0, 0) if padding == "SAME": @@ -60,7 +71,7 @@ def make_model(pool_op, shape, pool_size, strides, padding, dtype, scale, zero_p if pool_op == relay.nn.avg_pool2d: op = relay.cast(op, "int32") op = pool_op( - op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout="NHWC" + op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout ) if pool_op == relay.nn.avg_pool2d: op = relay.cast(op, dtype) @@ -68,12 +79,13 @@ def make_model(pool_op, shape, pool_size, strides, padding, dtype, scale, zero_p return op +@tvm.testing.requires_corstone300 @tvm.testing.requires_cmsisnn @pytest.mark.parametrize("in_shape", [(1, 28, 28, 12), (1, 64, 100, 4)]) @pytest.mark.parametrize( "pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), "VALID")] ) -@pytest.mark.parametrize("relu_type", ["RELU"]) +@pytest.mark.parametrize("relu_type", ["NONE", "RELU"]) @pytest.mark.parametrize("pool_type", [relay.nn.max_pool2d, relay.nn.avg_pool2d]) @pytest.mark.parametrize("zero_point, scale", [(-34, 0.0256)]) def test_op_int8( @@ -93,15 +105,14 @@ def test_op_int8( dtype = "int8" model = make_model( - pool_type, - in_shape, - pool_size, - strides, - padding, - dtype, - scale, - zero_point, - relu_type, + pool_op=pool_type, + shape=in_shape, + pool_size=pool_size, + strides=strides, + padding=padding, + scale=scale, + zero_point=zero_point, + relu_type=relu_type, ) orig_mod = make_module(model) @@ -132,17 +143,21 @@ def test_op_int8( @tvm.testing.requires_cmsisnn -def test_invalid_parameters(): +@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d]) +def test_invalid_datatype(op): + model = make_model(pool_op=op, dtype="int64") + + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + assert_no_external_function(cmsisnn_mod) + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d]) +def test_invalid_batch_size(op): model = make_model( - pool_op=relay.nn.avg_pool2d, - shape=(1, 28, 28, 12), - pool_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype="uint8", - scale=1, - zero_point=-33, - relu_type="RELU", + pool_op=op, + shape=(2, 28, 28, 12), ) orig_mod = make_module(model) @@ -150,5 +165,17 @@ def test_invalid_parameters(): assert_no_external_function(cmsisnn_mod) +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d]) +def test_invalid_layout(op): + model = make_model(pool_op=op, layout="NCHW") + + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + assert_no_external_function(cmsisnn_mod) + + if __name__ == "__main__": + import sys + sys.exit(pytest.main([__file__] + sys.argv[1:]))