Skip to content

Commit

Permalink
[Wrap] Use a ctypes-based kernel wrapper instead of dlpack for runtim…
Browse files Browse the repository at this point in the history
…e efficiency (#95)

* bump version into v0.1.0

* [Enhancement] Add custom develop command for editable installs and update .gitignore

* [Documentation] Update README to include system dependencies installation instructions

* [Build] Update setup.py to support library file copying for both release and develop modes

* [Build] Refactor library file copying logic in setup.py

* [Documentation] Remove unnecessary install section header in Installation.md

* [Build] Add tox configuration and local distribution script for multi-Python version support

* [Build] Improve git submodule update function with better error handling

* [Build] Update LLVM configuration path in ROCm installation script

* [Build] Add .tox/ to .gitignore for tox testing environment

* [Build] Add support for TVM prebuild path configuration in CMakeLists.txt

* [Cleanup] Remove unused TVM runtime error codes header

* [Cleanup] Fix TVM grid constant type reference in CUDA module

* [Cleanup] Remove unused customized_code function from IR module

* [Feature] Add TileLang thread synchronization and storage access analysis passes

* [Build] Reorder DLL search path directories for more flexible library loading

* [Refactor] Improve thread synchronization and library path handling

- Rename ThreadSync and TileLangThreadSync functions in C++ code
- Update Python docstring for ThreadSync with more detailed description
- Reorder library path detection in tilelang environment setup
- Minor comment and code cleanup in CUDA and warp specialization modules

* [Refactor] Improve thread synchronization code style and formatting

- Standardize pointer type spacing in storage_access.h and storage_access.cc
- Update whitespace and indentation in thread_storage_sync.cc
- Reorder include statements in thread_partial_sync.cc
- Minor code formatting improvements across thread synchronization files

* [Refactor] Fix global function registration for ThreadSync

- Correct global function registration to use ThreadSync instead of TileLangThreadSync
- Update TVM global registration to match recent refactoring efforts

* [Refactor] Simplify ThreadSync global function registration

- Remove unnecessary whitespace in global function registration
- Compact the TVM global registration line for ThreadSync

* [Feature] Add WebGPU code generation support in TileLang

- Implement WebGPU code generator (codegen_webgpu.cc and codegen_webgpu.h)
- Add WebGPU target support in lower.py and target.py
- Update CMakeLists.txt to include WebGPU codegen source files
- Introduce WebGPU-specific code generation for WGSL shader language

* [Refactor] Improve WebGPU code generation formatting and readability

- Enhance code formatting in codegen_webgpu.cc and codegen_webgpu.h
- Standardize pointer type spacing and indentation
- Improve line breaks and reduce line length for better readability
- Minor code style improvements in WebGPU code generation

* [Test] Add WebGPU matrix multiplication code generation test

- Implement test_webgpu_codegen.py for WebGPU matrix multiplication
- Add assert_gemm_codegen function to validate WebGPU code generation
- Include basic matrix multiplication kernel test case

* Update README with WebGPU codegen support announcement

* Support multi version pypi package build via tox

* Add support for CPU device backend with C code generation

- Introduce `is_cpu_device_backend` function to detect CPU backend with C code generation
- Modify `lower` function to handle special case of CPU device backend
- Update host and device call filtering for CPU backend
- Add conditional source code generation for C host target
- Extend JITKernel to support optional target_host parameter

* lint fix

* Enhance JIT kernel adapters with CTypes and Torch C++ backends

- Add CtypesKernelAdapter with dynamic library generation and kernel wrapping
- Implement TorchCPPKernelAdapter for CUDA kernel compilation
- Refactor BaseKernelAdapter to support more flexible initialization
- Improve error handling and argument processing in kernel adapters
- Update adapter initialization to support various execution backends

* Refactor and clean up code style in JIT CTypes adapter modules

- Apply consistent code formatting and whitespace in CTypes adapter files
- Remove unused imports and improve import organization
- Enhance readability of code in adapter, libgen, and wrapper modules
- Add missing whitespace and improve line breaks
- Minor linting and code style improvements across CTypes adapter files

* Add test for TileLang JIT GEMM with CTypes backend

- Implement comprehensive test for matrix multiplication using CTypes execution backend
- Create test functions for GEMM with float16 data type
- Add kernel source verification with custom callback
- Implement reference implementation using PyTorch for result validation
- Support various matrix multiplication configurations (transposition, block sizes)

* test fix

* Update TileLang JIT callback registration with override parameter

- Modify tilelang_callback_cuda_postproc to use @tvm.register_func(override=True)
- Ensure proper function registration with ability to replace existing implementations
  • Loading branch information
LeiWang1999 authored Feb 19, 2025
1 parent e5f5ca6 commit fca18c4
Show file tree
Hide file tree
Showing 13 changed files with 1,075 additions and 55 deletions.
2 changes: 1 addition & 1 deletion testing/python/jit/test_tilelang_jit_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def run_gemm(

stramp = "&*(XS)"

@tvm.register_func
@tvm.register_func(override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
Expand Down
239 changes: 239 additions & 0 deletions testing/python/jit/test_tilelang_jit_gemm_ctypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from tilelang import tvm as tvm
import tilelang.testing
import tilelang
import torch


def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tilelang.language as T

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)

stramp = "&*(XS)"

@tvm.register_func(override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")

kernel_source = matmul_kernel.get_kernel_source()

assert stramp in kernel_source, f"Expected {stramp} in the kernel source"


def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)


def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tilelang.language as T

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()

if trans_A:
A = A.T
if trans_B:
B = B.T

def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C

ref_C = ref_program(A, B)
C = matmul_kernel(A, B)

tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


def test_gemm_jit_kernel():
run_gemm_jit_kernel(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)


if __name__ == "__main__":
tilelang.testing.main()
61 changes: 45 additions & 16 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tilelang as tl
import os
import os.path as osp
from typing import Union, Optional
from typing import Union, Optional, Callable
from tilelang import tvm as tvm
from tvm import tir, relay
from tvm.ir import CallingConv
Expand All @@ -14,21 +14,36 @@
from tilelang.utils.target import determine_target


def is_device_call(func: tir.PrimFunc):
def is_cpu_device_backend(target: Target):
return target.kind.name == "c"


def has_device_kernel_launch(attrs) -> bool:
"""Check if the attributes indicate a device kernel launch."""
return bool(attrs and "calling_conv" in attrs and
attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)


def is_device_call_c_device(func: tir.PrimFunc):
attrs = func.attrs

# consider c source as a device call
if "target" in attrs:
target = attrs["target"]
if target.kind.name == "c":
return True
# Check if it's a C target
if "target" in attrs and attrs["target"].kind.name == "c":
return True

return has_device_kernel_launch(attrs)

return bool(func.attrs and "calling_conv" in func.attrs and
func.attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)

def is_device_call(func: tir.PrimFunc):
return has_device_kernel_launch(func.attrs)


def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return is_device_call_c_device if is_device_c else is_device_call

def is_host_call(func: tir.PrimFunc):
return not is_device_call(func)

def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
return lambda func: not get_device_call(is_device_c)(func)


@tvm.register_func("tilelang_callback_cuda_compile", override=True)
Expand Down Expand Up @@ -134,6 +149,9 @@ def lower(
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)

_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))

mod = tir.transform.BindTarget(target)(mod)

mod = tl.transform.FrontendLegalize()(mod)
Expand Down Expand Up @@ -196,7 +214,7 @@ def lower(

mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
host_mod = tir.transform.Filter(is_host_call)(mod)
host_mod = tir.transform.Filter(_is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
Expand All @@ -209,11 +227,14 @@ def lower(
if target_host.kind.name == "llvm":
host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
if is_cpu_device_backend(target):
host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
else:
host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
else:
raise ValueError("Target host is not supported")
raise ValueError(f"Target host {target_host.kind.name} is not supported")

device_mod = tir.transform.Filter(is_device_call)(mod)
device_mod = tir.transform.Filter(_is_device_call)(mod)
device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tir.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
Expand All @@ -231,10 +252,18 @@ def lower(
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError("Target is not supported")
raise ValueError(f"Target {target.kind.name} is not supported")

host_mod.import_module(device_mod)

if target_host.kind.name == "c":
# cpu host should be recompiled
# TODO(lei): this is a hack to make the C host backend work
temp_dir = tvm.contrib.utils.tempdir()
tmp_lib_path = temp_dir.relpath("tmp.so")
host_mod.export_library(tmp_lib_path)
host_mod = tvm.runtime.load_module(tmp_lib_path)

if runtime_only is True:
return host_mod
else:
Expand Down
3 changes: 2 additions & 1 deletion tilelang/jit/adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@

from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .torch_cpp import TorchCPPKernelAdapter # noqa: F401
from .torchcpp import TorchCPPKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
Loading

0 comments on commit fca18c4

Please sign in to comment.