Skip to content

Commit

Permalink
[Unity][Dlight] Matmul Rules (apache#15191)
Browse files Browse the repository at this point in the history
This PR introduces default schedule rules for matmul kernels. Note that
we skip GEMV-liked kernels as it would be a separate rule.
  • Loading branch information
Hzfengsy authored Jul 2, 2023
1 parent 741ca41 commit 918fc4e
Show file tree
Hide file tree
Showing 7 changed files with 670 additions and 16 deletions.
40 changes: 40 additions & 0 deletions python/tvm/dlight/base/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

from tvm import tir
from tvm._ffi import get_global_func
from tvm.target.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import BlockRV


class IterInfo:
Expand Down Expand Up @@ -146,3 +149,40 @@ def _iter_kind(i: tir.IterVar) -> str:
)
)
return blocks


def _assert_gpu_target(target: Target):
if "gpu" not in target.keys:
raise ValueError(f"Expect a GPU target, but got {target}")


def get_max_threads_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_threads_per_block = None
for name in ["max_threads_per_block", "max_num_threads"]:
if max_threads_per_block is None:
max_threads_per_block = target.attrs.get(name, None)
if max_threads_per_block is None:
max_threads_per_block = 64
return int(max_threads_per_block)


def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None)
if max_shared_memory_per_block is None:
raise ValueError(
f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually"
)
return int(max_shared_memory_per_block)


def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
try:
block = sch.mod[func_name].body.block
except:
raise ValueError(
f"The function body is expected to be the root block, but got:\n"
f"{sch.mod[func_name].body}"
)
return sch.get_block(block.name_hint)
2 changes: 1 addition & 1 deletion python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def transform_module( # pylint: disable=missing-function-docstring
target = Target.current(allow_none=False)
updated_functions = {}
for g_var, func in mod.functions.items():
if not _is_scheduled(func):
if isinstance(func, tir.PrimFunc) and not _is_scheduled(func):
sch = _apply_rules(func, target, self.rules, tunable=False)
if sch is not None:
assert len(sch) == 1
Expand Down
1 change: 1 addition & 0 deletions python/tvm/dlight/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .fallback import Fallback
from .decode_gemv import DecodeGEMV
from .reduction import Reduction
from .matmul import Matmul
14 changes: 2 additions & 12 deletions python/tvm/dlight/gpu/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,7 @@
from tvm import tir
from tvm.target import Target

from ..base import ScheduleRule, normalize_prim_func, try_inline


def _max_threads_per_block(target: Target) -> int:
max_threads_per_block = None
for name in ["max_threads_per_block", "max_num_threads"]:
if max_threads_per_block is None:
max_threads_per_block = target.attrs.get(name, None)
if max_threads_per_block is None:
max_threads_per_block = 64
return int(max_threads_per_block)
from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline


class Fallback(ScheduleRule):
Expand All @@ -46,7 +36,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
target: Target,
_: bool,
) -> tir.Schedule:
max_threads_per_block = _max_threads_per_block(target)
max_threads_per_block = analysis.get_max_threads_per_block(target)

sch = tir.Schedule(func)
block_infos = try_inline(sch, normalize_prim_func(sch))
Expand Down
Loading

0 comments on commit 918fc4e

Please sign in to comment.