Skip to content

Commit

Permalink
[AOT] Introduce checks for return values from operators (apache#10424)
Browse files Browse the repository at this point in the history
This matches the lowering of `call_cpacked` which checks only for an
operator return of `0` in the main flow:

https://github.com/apache/tvm/blob/bd14a4d36e0d364ef9bd34b2ee96cc09ce64d4b3/src/target/source/codegen_c_host.cc#L207-L231

This replaces:
```c
(void)tvmgen_default_fused_add(x_buffer_var, y_buffer_var, output_buffer_var);
```
with:
```c
if (tvmgen_default_fused_add(x_buffer_var, y_buffer_var, output_buffer_var) != 0 ) return -1;
```

when AOT generates the C output.
  • Loading branch information
Mousius authored Mar 9, 2022
1 parent f9d3918 commit 060d9d2
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 31 deletions.
8 changes: 4 additions & 4 deletions apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cm
return 0;
}

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) { return 0; }
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) { return 0; }
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) { return 0; }
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) { return 0; }
14 changes: 14 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,20 @@ TVM_DLL const Op& tvm_call_cpacked();
*/
TVM_DLL const Op& tvm_call_trace_packed();

/*!
* \brief Checks the return value of another call is correct or returns a given value.
*
* \note This is meant to serve a specific case for AOT code generator whilst this
* cannot be fully represented in TIR.
*
* Type tvm_check_return(expected, return_unexpected, nested_call) {
* if (nested_call() != expected) {
* return return_unexpected;
* }
* }
*/
TVM_DLL const Op& tvm_check_return();

/*!
* \brief See pesudo code
* Mark the content as thread local context, can get optimized
Expand Down
32 changes: 25 additions & 7 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
* \param num the number to convert
* \return PrimExpr representing num
*/
inline PrimExpr ConstInt32(size_t num) {
inline PrimExpr ConstInt32(int32_t num) {
ICHECK_LE(num, std::numeric_limits<int>::max());
return tir::make_const(DataType::Int(32), static_cast<int>(num));
}
Expand Down Expand Up @@ -333,6 +333,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
args->insert(args->end(), sids.begin(), sids.end());
}

/*
* Wraps a call_extern with a tvm_check_return annotation if required otherwise
* returns the passed Call
*/
tir::Call AddCheckReturn(tir::Call existing_call) {
if (use_unpacked_api_) {
Array<PrimExpr> args = {ConstInt32(0), ConstInt32(-1), existing_call};
return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args);
}

return existing_call;
}

/*!
* brief Create a function call
* \param call_lowered_props The lowered function and the arguments to call it with
Expand All @@ -343,6 +356,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::string func_name = call_lowered_props.lowered_func->name_hint;
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;

// Pack the inputs
for (const Expr& arg : call_lowered_props.arguments) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
Expand Down Expand Up @@ -394,7 +408,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
tir::Var context = device_contexts_.Get(global_var).value();
args.push_back(context);

tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
tir::Evaluate func_call(
AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args)));
create_func_call_stmts.push_back(tir::SeqStmt({
GenerateDeviceHook(context, "Open"),
func_call,
Expand All @@ -407,7 +422,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
create_func_call_stmts.push_back(func_call);
} else {
// call_extern calling convention without context
tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern, args));
tir::Evaluate func_call(
AddCheckReturn(tvm::tir::Call(DataType::Int(32), calling_pattern, args)));
create_func_call_stmts.push_back(func_call);
}

Expand Down Expand Up @@ -482,8 +498,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Array<String> sections = {"Device", device_name, hook};
String device_hook_name = ToCFunctionStyle(PrefixName(sections));

tir::Evaluate device_hook(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook_name), context}));
tir::Evaluate device_hook(
AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook_name), context})));
device_hooks.push_back(device_hook);
}
return tir::SeqStmt(device_hooks);
Expand All @@ -503,8 +520,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Array<String> sections = {"Device", device_name, hook};
String device_hook = ToCFunctionStyle(PrefixName(sections));

return tir::Evaluate(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook), context}));
return tir::Evaluate(
AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
{tvm::tir::StringImm(device_hook), context})));
}

/*!
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ int32_t TVMEthosULaunch(tvm_device_ethos_u_t* context, void* cms_data, size_t cm
return 0;
}

int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) {}
int32_t TVMDeviceEthosUActivate(tvm_device_ethos_u_t* context) { return 0; }
int32_t TVMDeviceEthosUOpen(tvm_device_ethos_u_t* context) { return 0; }
int32_t TVMDeviceEthosUClose(tvm_device_ethos_u_t* context) { return 0; }
int32_t TVMDeviceEthosUDeactivate(tvm_device_ethos_u_t* context) { return 0; }
12 changes: 10 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (auto* ptr_op = op->op.as<OpNode>()) {
auto call_op = GetRef<Op>(ptr_op);

if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
if (op->op.same_as(builtin::tvm_check_return())) {
const CallNode* call = op->args[2].as<CallNode>();
os << "if (";
VisitExpr_(call, os);
os << " != ";
PrintExpr(op->args[0], os);
os << " ) return ";
PrintExpr(op->args[1], os);
} else if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) {
ICHECK_GE(op->args.size(), 1U);
auto func = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os);
Expand Down Expand Up @@ -971,7 +979,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) {
std::string vid = this->PrintExpr(op->value);
if (vid != "") {
this->PrintIndent();
this->stream << "(void)" << vid << ";\n";
this->stream << vid << ";\n";
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked)
TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_check_return)
.set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
Expand Down
40 changes: 26 additions & 14 deletions tests/python/relay/aot/test_c_device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,37 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
# Activate Device
assert (
str(main_func.body[0])
== "tir.call_extern(" + '"TVMDeviceEthosUActivate",' + " device_context_ethos_u)\n"
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUActivate",'
+ " device_context_ethos_u))\n"
)
# Open Device
assert (
str(main_func.body[1][0][0][0])
== "tir.call_extern(" + '"TVMDeviceEthosUOpen",' + " device_context_ethos_u)\n"
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUOpen",'
+ " device_context_ethos_u))\n"
)
# Device Call
assert (
str(main_func.body[1][0][0][1])
== 'tir.call_extern("tvmgen_default_ethos_u_main_0", x_int8_buffer_var, output_buffer_var, device_context_ethos_u)\n'
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"tvmgen_default_ethos_u_main_0",'
+ " x_int8_buffer_var, output_buffer_var, device_context_ethos_u))\n"
)
# Close Device
assert (
str(main_func.body[1][0][0][2])
== "tir.call_extern(" + '"TVMDeviceEthosUClose",' + " device_context_ethos_u)\n"
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUClose",'
+ " device_context_ethos_u))\n"
)
# Deactivate Device
assert (
str(str(main_func.body[2]))
== "tir.call_extern(" + '"TVMDeviceEthosUDeactivate",' + " device_context_ethos_u)\n"
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUDeactivate",'
+ " device_context_ethos_u))\n"
)


Expand All @@ -171,18 +181,18 @@ def test_device_api_hooks_packed_api(device_api_main_func):
# Activate Device
assert (
str(main_func.body[0][0].value)
== "@tir.call_extern("
== "@tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUActivate",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
+ " dtype=int32))"
)
# Open Device
assert (
str(main_func.body[1].body.body[0][0][0].value)
== "@tir.call_extern("
== "@tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUOpen",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
+ " dtype=int32))"
)
# Device Call
assert (
Expand All @@ -196,18 +206,18 @@ def test_device_api_hooks_packed_api(device_api_main_func):
# Close Device
assert (
str(main_func.body[1].body.body[0][0][2].value)
== "@tir.call_extern("
== "@tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUClose",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
+ " dtype=int32))"
)
# Deactivate Device
assert (
str(main_func.body[2][0].value)
== "@tir.call_extern("
== "@tir.tvm_check_return(0, -1, tir.call_extern("
+ '"TVMDeviceEthosUDeactivate",'
+ " device_context_ethos_u: handle,"
+ " dtype=int32)"
+ " dtype=int32))"
)


Expand All @@ -217,7 +227,9 @@ def test_without_device_api_unpacked_api(non_device_api_main_func):
main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True)
assert (
str(main_func.body)
== 'tir.call_extern("tvmgen_default_fused_multiply", x_buffer_var, y_buffer_var, output_buffer_var)\n'
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"tvmgen_default_fused_multiply",'
+ " x_buffer_var, y_buffer_var, output_buffer_var))\n"
)


Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,5 +920,36 @@ def test_workspace_calculation_cmsis_nn():
assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 9904


def test_aot_codegen_checks_returns():
"""This test checks whether AoT lowering creates calls that check the return value correctly"""
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
func = relay.Function([x, y], z)

compiled_test_mods = compile_models(
models=AOTTestModel(module=IRModule.from_expr(func), inputs=None, outputs=None),
interface_api="c",
use_unpacked_api=True,
)
source = compiled_test_mods[0].executor_factory.lib.imported_modules[0].get_source()

main_ir_module = compiled_test_mods[0].executor_factory.lowered_ir_mods.items()[0][1]
main_func = main_ir_module["__tvm_main__"]

# Check operator call is wrapped properly
assert (
str(main_func.body[1])
== "tir.tvm_check_return(0, -1, tir.call_extern("
+ '"tvmgen_default_fused_add",'
+ " x_buffer_var, y_buffer_var, output_buffer_var))\n"
)
# TODO(Mousius) - Create a better place for C codegen tests
assert (
"if (tvmgen_default_fused_add(x_buffer_var, y_buffer_var, output_buffer_var) != 0 ) return -1;"
in source
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 060d9d2

Please sign in to comment.