Skip to content

Commit

Permalink
[RUNTIME][CLML] Fix for CLML ops and enable more test case (apache#15896
Browse files Browse the repository at this point in the history
)

* [RUNTIME][CLML] Fix for few clml ops

Fixed the dense operator and enhance clml network testcase

* [RUNTIME][CLML] Fix for dense layer and float16

Fixed the dense layer issue in network level and improved
converage of dense layer with clml
Fixed float16 crash error.

* Update comment for dense pattern

* fix in clml test cases

* Enable more test cases and few fixes

* Fix the import error

* Fix the import error

* Fix in batchnorm testcase

* Restructure clml test case and enable vm executor

* Fix the import error in clml test network

* Fix the test failure for vm tests

* Update clml.py
  • Loading branch information
krishnaraj36 authored Dec 20, 2023
1 parent 759ee12 commit 3a57a40
Show file tree
Hide file tree
Showing 8 changed files with 1,332 additions and 764 deletions.
118 changes: 89 additions & 29 deletions python/tvm/relay/op/contrib/clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""CLML Library supported operators."""
import json
from string import Template
import numpy as np
import tvm

from tvm import relay
Expand All @@ -27,7 +28,7 @@
from tvm.relay.build_module import bind_params_by_name
from tvm.relay import function as _function
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.expr import Call, TupleGetItem
from tvm.relay.expr import Call, TupleGetItem, Var, Constant

from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple
from .register import register_pattern_table
Expand Down Expand Up @@ -81,34 +82,61 @@ def transform_function(
return RemoveDropout().visit(func)


class BroadcastInputs(ExprMutator):
class OptimizeBatchnorm(ExprMutator):
"""
Binary operators need broadcasting for CLML.
Fuse Conv+Batchnorm and constant folder to generate Conv+Add.
"""

def visit_call(self, call):
if call.op.name in ["add", "subtract", "multiply", "divide", "maximum", "minimum"]:
new_fn = self.visit(call.op)
call_shape = call.checked_type.shape
lhs = call.args[0]
rhs = call.args[1]
lhs_shape = lhs.checked_type.shape
rhs_shape = rhs.checked_type.shape
if list(call_shape) != list(lhs_shape):
lhs = relay.broadcast_to(self.visit(lhs), call_shape)
if list(call_shape) != list(rhs_shape):
rhs = relay.broadcast_to(self.visit(rhs), call_shape)
args = [lhs, rhs]
return Call(new_fn, args, call.attrs)
return super().visit_call(call)
def visit_call(self, call) -> relay.expr.Expr:
new_args = []
for arg in call.args:
if (
not isinstance(arg, (Var, Constant))
and isinstance(arg, tvm.relay.TupleGetItem)
and arg.tuple_value.op.name == "nn.batch_norm"
and (not isinstance(arg.tuple_value.args[0], (Var, Constant)))
and arg.tuple_value.args[0].op.name == "nn.conv2d"
):
ep = arg.tuple_value.attrs["epsilon"]
wt = arg.tuple_value.args[1].data.numpy()
bs = arg.tuple_value.args[2].data.numpy()
mn = arg.tuple_value.args[3].data.numpy()
vr = arg.tuple_value.args[4].data.numpy() + ep
dino = np.sqrt(vr)
wt = wt / dino
bs = bs - mn * wt
conv_op = arg.tuple_value.args[0]
conv_args = list(conv_op.args)
wt_conv = conv_args[1].data.numpy()
if conv_op.attrs["kernel_layout"] == "OIHW":
wt = wt.reshape(wt.shape[0], 1, 1, 1)
elif conv_op.attrs["kernel_layout"] == "IOHW":
wt = wt.reshape(1, wt.shape[0], 1, 1)
else:
raise ValueError("Unsupported Conv2d kernel layout")
wt_conv = wt_conv * wt
conv_args[1] = relay.const(tvm.nd.array(wt_conv))
bs_args = relay.const(tvm.nd.array(bs.reshape(-1, bs.shape[0], 1, 1)))
conv_out = Call(
arg.tuple_value.args[0].op, conv_args, arg.tuple_value.args[0].attrs
)
mod = tvm.relay.add(conv_out, bs_args)
new_args.append(mod)
else:
new_args.append(arg)

call = Call(call.op, new_args, call.attrs)
args = [self.visit(arg) for arg in call.args]

return Call(call.op, args, call.attrs)


@transform.function_pass(opt_level=0)
class BinaryOpBroadcaster:
class OptimizeBatchnormPass:
def transform_function(
self, func: relay.function.Function, mod: tvm.IRModule, _: tvm.transform.PassContext
) -> relay.function.Function:
return BroadcastInputs().visit(func)
return OptimizeBatchnorm().visit(func)


def partition_for_clml(mod, params=None, **opts):
Expand All @@ -134,8 +162,8 @@ def partition_for_clml(mod, params=None, **opts):
[
transform.InferType(),
RemoveDropoutPass(),
BinaryOpBroadcaster(),
transform.FoldConstant(),
OptimizeBatchnormPass(),
transform.MergeComposite(clml_pattern_table()),
transform.AnnotateTarget("clml", False),
transform.MergeCompilerRegions(),
Expand Down Expand Up @@ -289,8 +317,15 @@ def concat_pattern():

return pattern

def dense_pattern():
"""Create a dense pattern."""
def dense1d_pattern():
"""Create a dense pattern for 1d vector to matrix multiple."""
pattern = is_op("nn.dense")(wildcard(), is_constant())
pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
return pattern

def dense2d_pattern():
"""Create a dense pattern for 2d matrix to matrix multiple."""
pattern = is_op("nn.dense")(wildcard(), is_constant())
return pattern

Expand Down Expand Up @@ -377,6 +412,9 @@ def check_binary_op(extract):
if len(call.args[1].checked_type.shape) == 0:
return False

if tuple(call.args[0].checked_type.shape) != tuple(call.args[1].checked_type.shape):
return False

for arg in call.args:
# Avoid any operators with dtype Int64
if arg.checked_type.dtype == "int64":
Expand Down Expand Up @@ -436,11 +474,33 @@ def check_batch_matmul_op(extract):
return False
return True

def check_dense1d_op(extract):
call = extract
# Only support single Matmul
if call.args[0].checked_type.shape[0] > 1:
return False
if not (call.op.name in ["nn.bias_add", "add"] and call.args[0].op.name == "nn.dense"):
return False
return True

def check_reshape(extract):
call = extract
call_shape = call.checked_type.shape
# Only support batch dim = 1
if call_shape[0] > 1:
return False
# Checking buffer indexing limit
for shape in call_shape:
if shape > 32768:
return False
return True

return [
("clml.pad_conv2d", pad_conv_pattern(), check_conv),
("clml.conv2d", conv_pattern(), check_conv),
("clml.conv2d_transpose", conv_transpose_pattern(), check_conv_transpose),
("clml.dense", dense_pattern(), check_default_op),
("clml.dense1d", dense1d_pattern(), check_dense1d_op),
("clml.dense2d", dense2d_pattern(), check_default_op),
("clml.pad", pad_pattern(), check_pad_op),
("clml.concat", concat_pattern(), check_concat_op),
("clml.batch_norm", batch_norm_pattern(), check_default_op),
Expand All @@ -451,7 +511,7 @@ def check_batch_matmul_op(extract):
("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op),
("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op),
("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op),
# ("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
("clml.reshape", is_op("reshape")(wildcard()), check_reshape),
("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op),
("clml.max_pool2d", is_op("nn.max_pool2d")(wildcard()), check_default_op),
("clml.global_avg_pool2d", is_op("nn.global_avg_pool2d")(wildcard()), check_default_op),
Expand Down Expand Up @@ -807,7 +867,7 @@ def make_output_tensor(
elif activation == "relu6":
activation = "CL_ACTIVATION_RELU6"
else:
RuntimeError("Unknown activation:" + activation)
raise RuntimeError("Unknown activation:" + activation)
has_bias = bool((node["inputs"] == 3) or (node["inputs"] == 7))
has_bn = bool((node["inputs"] == 6) or (node["inputs"] == 7))
input_tensor = get_tensor_from_map(node["inputs"][0][0])
Expand Down Expand Up @@ -907,8 +967,8 @@ def make_output_tensor(
)
)
elif node["name"] == "nn.batch_norm":
bn_attrs = tuple(node["attrs"]["batchnorm"][0][0])
axis = bn_attrs[0]
bn_attrs = tuple(node["attrs"]["axis"])
axis = int(bn_attrs[0][0])
bn_shape = [1, 1, 1, 1]
bn_node = self.nodes[node["inputs"][0][0]]
bn_shape[axis] = bn_node["attrs"]["shape"][0][0]
Expand Down Expand Up @@ -1094,7 +1154,7 @@ def make_output_tensor(
)
)
else:
RuntimeError("Unsupported Op:" + node["name"])
raise RuntimeError("Unsupported Op:" + node["name"])
self.clml_code.append(
self.MapInsert.substitute(nid=node_out_name, tensor_desc=node_out_name)
)
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/clml/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
json_node = CreateCompositeConvJSONNode(cn);
} else if (name == "clml.batch_norm") {
json_node = CreateBatchNormJSONNode(cn);
} else if (name == "clml.dense") {
} else if (name == "clml.dense1d" || name == "clml.dense2d") {
json_node = CreateDenseJSONNode(cn);
} else if (name == "clml.pad") {
json_node = CreatePadJSONNode(cn);
Expand Down
Loading

0 comments on commit 3a57a40

Please sign in to comment.