Skip to content

Commit

Permalink
[CMSIS-NN] Increase partitioning accuracy for pooling (apache#11229)
Browse files Browse the repository at this point in the history
This ensures that CMSIS-NN is only used when the batch size and layout are correct for the library calls.
  • Loading branch information
Mousius authored May 9, 2022
1 parent e854c0a commit 731af42
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 27 deletions.
38 changes: 34 additions & 4 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def enabled():
return "cmsis-nn" in Target.list_kinds()


def _find_last(pattern):
if hasattr(pattern, "args"):
return _find_last(pattern.args[0])
return pattern


def partition_for_cmsisnn(mod, params=None, mod_name="default", **opts):
"""Partition the graph greedily offloading supported
operators on Cortex-M using CMSIS-NN
Expand Down Expand Up @@ -199,9 +205,20 @@ def qnn_avg_pool2d_pattern():

def check_qnn_avg_pool2d(pattern):
"""Check if avg pool2d is supported by CMSIS-NN."""
in_cast = pattern
out_cast = in_cast.args[0].args[0]
return in_cast.checked_type.dtype == "int8" and out_cast.checked_type.dtype == "int32"
output = pattern
input_var = _find_last(pattern)

if str(pattern.op.name) == "clip":
pooling = pattern.args[0].args[0]
else:
pooling = pattern.args[0]

return (
pooling.attrs.layout == "NHWC"
and bool(input_var.checked_type.shape[0] == 1)
and input_var.checked_type.dtype == "int8"
and output.checked_type.dtype == "int8"
)

def qnn_max_pool2d_pattern():
"""Matches max pool2d with optional Relu"""
Expand All @@ -211,7 +228,20 @@ def qnn_max_pool2d_pattern():

def check_qnn_max_pool2d(pattern):
"""Check if max pool2d is supported by CMSIS-NN."""
return True
output = pattern
input_var = _find_last(pattern)

if str(pattern.op.name) == "clip":
pooling = pattern.args[0]
else:
pooling = pattern

return (
pooling.attrs.layout == "NHWC"
and bool(input_var.checked_type.shape[0] == 1)
and input_var.checked_type.dtype == "int8"
and output.checked_type.dtype == "int8"
)

def binary_op_pattern(op):
"""Matches QNN binary operation"""
Expand Down
73 changes: 50 additions & 23 deletions tests/python/contrib/test_cmsisnn/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,19 @@
)


def make_model(pool_op, shape, pool_size, strides, padding, dtype, scale, zero_point, relu_type):
"""Return a model and any parameters it may have"""
def make_model(
pool_op,
shape=(1, 28, 28, 12),
pool_size=(3, 3),
strides=(2, 2),
padding="VALID",
dtype="int8",
scale=1,
zero_point=-33,
relu_type="RELU",
layout="NHWC",
):
"""Return a model and any parameters it may have, all parameters are defaulted to known good values"""
op = relay.var("input", shape=shape, dtype=dtype)
pad_ = (0, 0, 0, 0)
if padding == "SAME":
Expand All @@ -60,20 +71,21 @@ def make_model(pool_op, shape, pool_size, strides, padding, dtype, scale, zero_p
if pool_op == relay.nn.avg_pool2d:
op = relay.cast(op, "int32")
op = pool_op(
op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout="NHWC"
op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout
)
if pool_op == relay.nn.avg_pool2d:
op = relay.cast(op, dtype)
op = make_qnn_relu(op, relu_type, scale, zero_point, dtype)
return op


@tvm.testing.requires_corstone300
@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("in_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
@pytest.mark.parametrize(
"pool_size, strides, padding", [((3, 3), (2, 2), "SAME"), ((2, 2), (1, 1), "VALID")]
)
@pytest.mark.parametrize("relu_type", ["RELU"])
@pytest.mark.parametrize("relu_type", ["NONE", "RELU"])
@pytest.mark.parametrize("pool_type", [relay.nn.max_pool2d, relay.nn.avg_pool2d])
@pytest.mark.parametrize("zero_point, scale", [(-34, 0.0256)])
def test_op_int8(
Expand All @@ -93,15 +105,14 @@ def test_op_int8(
dtype = "int8"

model = make_model(
pool_type,
in_shape,
pool_size,
strides,
padding,
dtype,
scale,
zero_point,
relu_type,
pool_op=pool_type,
shape=in_shape,
pool_size=pool_size,
strides=strides,
padding=padding,
scale=scale,
zero_point=zero_point,
relu_type=relu_type,
)
orig_mod = make_module(model)

Expand Down Expand Up @@ -132,23 +143,39 @@ def test_op_int8(


@tvm.testing.requires_cmsisnn
def test_invalid_parameters():
@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d])
def test_invalid_datatype(op):
model = make_model(pool_op=op, dtype="int64")

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
assert_no_external_function(cmsisnn_mod)


@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d])
def test_invalid_batch_size(op):
model = make_model(
pool_op=relay.nn.avg_pool2d,
shape=(1, 28, 28, 12),
pool_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype="uint8",
scale=1,
zero_point=-33,
relu_type="RELU",
pool_op=op,
shape=(2, 28, 28, 12),
)

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
assert_no_external_function(cmsisnn_mod)


@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("op", [relay.nn.avg_pool2d, relay.nn.max_pool2d])
def test_invalid_layout(op):
model = make_model(pool_op=op, layout="NCHW")

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
assert_no_external_function(cmsisnn_mod)


if __name__ == "__main__":
import sys

sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 731af42

Please sign in to comment.