Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Analysis] Validate global_symbol on non-Relax functions #17203

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ class WellFormedChecker : public relax::ExprVisitor,
WellFormedChecker(obj.as<IRModule>(), check_struct_info);

if (const auto* mod = obj.as<IRModuleNode>()) {
for (const auto& it : mod->functions) {
for (const auto& [gvar, base_func] : mod->functions) {
well_formed_checker.CheckGlobalVarAndGsymbolConsistency(gvar, base_func);
// visit relax.Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func);
if (auto opt = base_func.as<Function>()) {
Function func = opt.value();
well_formed_checker.VisitExpr(func);
}
}
Expand Down Expand Up @@ -133,7 +133,7 @@ class WellFormedChecker : public relax::ExprVisitor,
LOG(WARNING) << "This IR is not well formed: " << diag->message;
}

void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) {
void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, BaseFunc func) {
// the uniqueness of all global vars are ensured by IRModule->global_var_map_, so do not need
// to check again

Expand Down
5 changes: 3 additions & 2 deletions tests/python/tir-base/test_debug_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test line-level debug info for TIR"""

import tvm
import tvm.testing
from tvm import tir
Expand Down Expand Up @@ -104,7 +105,7 @@ def find_span(m):
class module_before:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")})
T.func_attr({"tir.noalias": True, "target": T.target("llvm")})
A = T.match_buffer(a, (8,), dtype="float32")
B = T.match_buffer(b, (8,), dtype="float32")
for i in range(8):
Expand All @@ -114,7 +115,7 @@ def main(a: T.handle, b: T.handle):

@T.prim_func
def subroutine(a_ptr: T.handle("float32"), b_ptr: T.handle("float32")):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
T.func_attr({"tir.noalias": True})
A = T.decl_buffer(1, "float32", data=a_ptr)
B = T.decl_buffer(1, "float32", data=b_ptr)
B[0] = A[1] + 1.0
Expand Down
4 changes: 1 addition & 3 deletions tests/python/tir-base/test_tir_host_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def main(
):
T.func_attr(
{
"global_symbol": "test",
"target": tvm.target.Target("llvm", host="llvm"),
"tir.noalias": True,
}
Expand All @@ -59,12 +58,11 @@ def test_host_func():
func = tvm.te.create_prim_func(
te_workload.matmul(729, 729, 729, in_dtype="float32", out_dtype="float32")
)
mod = tvm.ir.IRModule({"main": func})
mod = tvm.ir.IRModule({"main": func.with_attr("global_symbol", "main")})
target = tvm.target.Target("cuda")
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
"global_symbol": "test",
"tir.is_host_func": 1,
}
)
Expand Down
107 changes: 54 additions & 53 deletions tests/python/tir-base/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tvm import te, tir
from tvm import topi
from tvm.contrib import utils, clang
from tvm.script import tir as T
from tvm.script import ir as I, tir as T
import numpy as np
import ctypes
import math
Expand Down Expand Up @@ -187,59 +187,60 @@ def clz_np(x, dtype):
np.testing.assert_equal(b.numpy(), ref)


@tvm.script.ir_module
class Module:
@T.prim_func
def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
stride_2 = T.int32()
stride_3 = T.int32()
A_1 = T.match_buffer(
A,
[n],
strides=[stride],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
B_1 = T.match_buffer(
B,
[n],
strides=[stride_1],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
C_1 = T.match_buffer(
C,
[n],
strides=[stride_2],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
d_1 = T.match_buffer(
d,
[n],
strides=[stride_3],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
# body
for i in T.serial(0, n):
d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[(i * stride_2)]


def test_fma():
@I.ir_module
class Module:
@T.prim_func
def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None:
# function attr dict
T.func_attr({"tir.noalias": True})
n = T.int32()
stride = T.int32()
stride_1 = T.int32()
stride_2 = T.int32()
stride_3 = T.int32()
A_1 = T.match_buffer(
A,
[n],
strides=[stride],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
B_1 = T.match_buffer(
B,
[n],
strides=[stride_1],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
C_1 = T.match_buffer(
C,
[n],
strides=[stride_2],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
d_1 = T.match_buffer(
d,
[n],
strides=[stride_3],
elem_offset=0,
align=64,
offset_factor=1,
buffer_type="auto",
)
# body
for i in T.serial(0, n):
d_1[(i * stride_3)] = (A_1[(i * stride)] * B_1[(i * stride_1)]) + C_1[
(i * stride_2)
]

opt = tvm.transform.Sequential(
[
tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm")})
mod.kernel(A.data)

@T.prim_func
@T.prim_func(private=True)
def kernel(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda")})
A = T.decl_buffer(1, dtype="float32", data=A_data)
Expand All @@ -66,7 +66,6 @@ def kernel(A_data: T.handle("float32")):
"target": T.target("cuda"),
"calling_conv": 2,
"tir.kernel_launch_params": [],
"global_symbol": "kernel",
"tir.is_global_func": True,
}
)
Expand Down Expand Up @@ -99,7 +98,7 @@ def main(A: T.Buffer(1, "float32")):

@T.prim_func
def kernel(A_data: T.handle("float32")):
T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel_by_another_name"})
T.func_attr({"target": T.target("cuda")})
A = T.decl_buffer(1, dtype="float32", data=A_data)
A[0] = 0.0

Expand All @@ -111,7 +110,7 @@ class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("llvm")})
T.call_packed("kernel_by_another_name", A.data)
T.call_packed("kernel", A.data)

@T.prim_func
def kernel(A_data: T.handle("float32")):
Expand All @@ -120,7 +119,6 @@ def kernel(A_data: T.handle("float32")):
"target": T.target("cuda"),
"calling_conv": 2,
"tir.kernel_launch_params": [],
"global_symbol": "kernel_by_another_name",
"tir.is_global_func": True,
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MatmulBefore:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None:
# function attr dict
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
T.func_attr({"tir.noalias": True})
# body
# with T.block("root")
for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
Expand Down Expand Up @@ -69,7 +69,7 @@ class MatmulAfter:
@T.prim_func
def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None:
# function attr dict
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
T.func_attr({"tir.noalias": True})
# body
# with T.block("root")
for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def matmul(
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main", "tir.noalias": True})
T.func_attr({"tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
Expand All @@ -94,8 +94,7 @@ def matmul_gpu(
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main",
"target": T.target({"arch": "sm_86",
T.func_attr({"target": T.target({"arch": "sm_86",
"keys": ["cuda", "gpu"],
"kind": "cuda",
"max_num_threads": 1024,
Expand All @@ -118,8 +117,7 @@ def matmul_cpu(
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"global_symbol": "main",
"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}),
"tir.noalias": True})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
Expand All @@ -139,7 +137,7 @@ def matmul(
B: T.Buffer((32, 32), "float16"),
C: T.Buffer((32, 32), "float16"),
):
T.func_attr({"tir.is_scheduled": True, "global_symbol": "main", "tir.noalias": True})
T.func_attr({"tir.is_scheduled": True, "tir.noalias": True})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
Expand All @@ -160,7 +158,7 @@ def matmul(

@T.prim_func
def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")):
T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(32, 32, 32):
with T.block("C"):
Expand All @@ -173,7 +171,7 @@ def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"

@T.prim_func
def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")):
T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
# with T.block("root"):
for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/tir-usmp/test_tir_usmp_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6:
@T.prim_func
def run_model(input: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
T.func_attr({"runner_function": True})
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _assign_targets_to_primfuncs_irmodule(mod, target):
# These are test IRModules that contains varied topologies of operator graphs
# that includes a main TIR function that includes call to such operators.


# fmt: off
@tvm.script.ir_module
class LinearStructure:
Expand Down Expand Up @@ -163,7 +164,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6:
@T.prim_func
def run_model(input: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
T.func_attr({"runner_function": True})
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
Expand Down Expand Up @@ -238,7 +239,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol
@T.prim_func
def run_model(input: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
T.func_attr({"runner_function": True})
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
Expand Down Expand Up @@ -278,7 +279,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol
@T.prim_func
def run_model(input: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
T.func_attr({"runner_function": True})
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
Expand Down Expand Up @@ -618,7 +619,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_1(placehol
@T.prim_func
def run_model(input: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
T.func_attr({"runner_function": True})
# body
T.attr("default", "device_id", 0)
T.attr("default", "device_type", 1)
Expand Down Expand Up @@ -1334,7 +1335,7 @@ def tvmgen_default_fused_nn_softmax_add(placeholder_26: T.handle, placeholder_27
@T.prim_func
def run_model(data: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
T.func_attr({"runner_function": True})
data_buffer = T.match_buffer(data, [864], dtype="float32", align=16)
output_buffer = T.match_buffer(output, [864], dtype="float32", align=16)
# body
Expand Down
Loading