Skip to content

Commit

Permalink
[JIT] Support Cython jit and make cython a default execution backend (#…
Browse files Browse the repository at this point in the history
…102)

* [Feature] Add CTypes JIT kernel support for dynamic shapes and multi-stream execution

- Enhance CtypesKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in CTypes backend
- Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py
- Add symbolic shape utility function in tilelang.language
- Update profiler to improve flexibility in benchmark selection

* Remove redundant thread binding in GEMM kernel implementations

- Remove unnecessary `thread_binding` line in GEMM kernel functions
- Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py`
- Enhance code readability by removing redundant thread binding annotation

* Fix indentation in int4 GEMM kernel test file

- Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py`
- Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls
- Improve code formatting for better readability

* [Feature] Add Cython JIT kernel support for dynamic shapes and multi-stream execution

- Implement CythonKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in Cython backend
- Create comprehensive test suite for Cython GEMM kernel in test_tilelang_jit_gemm_cython.py
- Update JITKernel to include "cython" as a valid execution backend
- Add Cython-specific wrapper and library generation modules
- Update .gitignore to exclude Cython cache directory
- Modify setup.py to include Cython source files in package data

* lint fix

* [Refactor] Replace JITKernel with compile() function for kernel compilation

- Add new `compile()` function in tilelang/jit/__init__.py as a wrapper for JITKernel
- Update multiple test files and examples to use `tilelang.compile()` instead of `tilelang.JITKernel()`
- Modify kernel adapters to support optional kernel-only source retrieval
- Update `__init__.py` to import the new `compile()` function
- Improve kernel source retrieval for different execution backends

* lint fix

* remove debug print

* Add C/C++ compiler utility module and update Cython JIT kernel support

- Introduce new `tilelang/contrib/cc.py` module with cross-platform C/C++ compiler utilities
- Add functions to detect and retrieve system C/C++ compilers
- Implement cross-compilation and shared library creation support
- Update Cython JIT kernel to validate C++ compiler availability
- Modify Cython adapter to use detected C++ compiler for library generation

* Refactor float8 dtype mapping in tensor utility module

- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray

* Refactor float8 dtype mapping in tensor utility module

- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray

* revert

* Enhance Cython JIT adapter with Cython compiler detection

- Add `get_cython_compiler()` function to dynamically locate Cython executable
- Update Cython adapter to use detected Cython compiler instead of hardcoded command
- Raise an exception if no Cython compiler is found
- Update requirements.txt to specify minimum PyTorch version (>=2.2.0)

* Fix Cython kernel wrapper stream handling and type annotations

- Update stream parameter type to int64_t for better compatibility
- Directly use torch.cuda.current_stream().cuda_stream instead of casting
- Improve type safety and precision in Cython kernel wrapper
  • Loading branch information
LeiWang1999 authored Feb 21, 2025
1 parent 9285e38 commit 2e53bd0
Show file tree
Hide file tree
Showing 26 changed files with 1,740 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ tilelang/lib

# tox
.tox/

# cython
tilelang/jit/adapter/cython/.cycache
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func = matmul(1024, 1024, 1024, 128, 128, 32)
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch
Expand Down
2 changes: 1 addition & 1 deletion docs/deeplearning_operators/matmul.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
func = matmul(1024, 1024, 1024, 128, 128, 32)

# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

import torch

Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def main(
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ attrs
cloudpickle
ml_dtypes
psutil
torch
torch>=2.2.0
19 changes: 17 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"


package_data = {
"tilelang": ["py.typed"],
"tilelang": ["py.typed", "*pyx"],
}

LLVM_VERSION = "10.0.1"
Expand Down Expand Up @@ -227,7 +227,22 @@ def run(self):
ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}")
print(f"Build temp directory: {build_temp_dir}")

# copy cython files
CYTHON_SRC = [
"tilelang/jit/adapter/cython/cython_wrapper.pyx",
]
for item in CYTHON_SRC:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# copy the tl_templates
TILELANG_SRC = [
"src/tl_templates",
]
Expand Down
10 changes: 5 additions & 5 deletions testing/python/debug/test_tilelang_debug_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def program(Q: T.Buffer((M, N), dtype)):
shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf)

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -34,7 +34,7 @@ def program(Q: T.Buffer((M, N), dtype)):
if bx == 0 and by == 0 and bz == 0:
T.print(shared_buf)

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -53,7 +53,7 @@ def program(Q: T.Buffer((M, N), dtype)):
if tid == 0:
T.print(bx + by + bz)

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -72,7 +72,7 @@ def program(Q: T.Buffer((M, N), dtype)):
for i, j in T.Parallel(M, N):
T.print(register_buf[i, j])

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -91,7 +91,7 @@ def program(Q: T.Buffer((M, N), dtype)):
if tid == 0:
T.print(bx + by + bz, msg="hello world")

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand Down
2 changes: 1 addition & 1 deletion testing/python/issue/test_tilelang_issue_96.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main(

def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):
func = matmul(N, N, N, block_M, block_N, block_K)
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

torch.manual_seed(0)
a = torch.randn(N, N, device="cuda", dtype=torch.float16)
Expand Down
4 changes: 2 additions & 2 deletions testing/python/jit/test_tilelang_jit_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")

kernel_source = matmul_kernel.get_kernel_source()

Expand Down Expand Up @@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down
2 changes: 1 addition & 1 deletion testing/python/jit/test_tilelang_jit_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def run_gemm_jit_kernel(
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down
10 changes: 5 additions & 5 deletions testing/python/jit/test_tilelang_jit_gemm_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")

kernel_source = matmul_kernel.get_kernel_source()

Expand Down Expand Up @@ -195,7 +195,7 @@ def run_gemm_jit_kernel(
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down Expand Up @@ -263,7 +263,7 @@ def run_ctypes_kernel_do_bench(M,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")

profiler = matmul_kernel.get_profiler()

Expand Down Expand Up @@ -312,7 +312,7 @@ def run_ctypes_kernel_multi_stream(M,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")

tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down Expand Up @@ -364,7 +364,7 @@ def run_ctypes_dynamic_shape(M,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
Expand Down
Loading

0 comments on commit 2e53bd0

Please sign in to comment.