Skip to content

Commit

Permalink
[AOTI] Introduce DeferredCudaGridLine for cuda cpp wrapper (pytorch#1…
Browse files Browse the repository at this point in the history
…29268)

Summary: Similar to pytorch#129135, use DeferredCudaGridLine to create a deferred grid computation line when generating cpp wrapper.

Differential Revision: [D61800622](https://our.internmc.facebook.com/intern/diff/D61800622)
Pull Request resolved: pytorch#129268
Approved by: https://github.com/angelayi
  • Loading branch information
desertfire authored and pytorchmergebot committed Aug 27, 2024
1 parent 5fd670e commit a4b44dd
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 62 deletions.
13 changes: 9 additions & 4 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self):

def generate_kernel_call(
self,
name,
kernel_name: str,
call_args,
grid=None,
device_index=None,
Expand All @@ -81,6 +81,7 @@ def generate_kernel_call(
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
"""
Expand All @@ -94,14 +95,18 @@ def generate_kernel_call(
"""
if cuda:
return super().generate_kernel_call(
name,
kernel_name,
call_args,
grid,
device_index,
cuda,
triton,
arg_types,
raw_args,
grid_fn,
triton_meta,
autotune_configs,
grid_extra_kwargs,
)
else:
if config.abi_compatible:
Expand All @@ -119,9 +124,9 @@ def generate_kernel_call(
else:
# arg is a scalar
new_args.append(arg)
self.writeline(self.wrap_kernel_call(name, new_args))
self.writeline(self.wrap_kernel_call(kernel_name, new_args))
else:
self.writeline(self.wrap_kernel_call(name, call_args))
self.writeline(self.wrap_kernel_call(kernel_name, call_args))

def write_constant(self, name, hashed):
# include a hash so our code cache gives different constants different files
Expand Down
191 changes: 135 additions & 56 deletions torch/_inductor/codegen/cpp_wrapper_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
from .cpp_utils import DTYPE_TO_CPP
from .cpp_utils import cexpr, DTYPE_TO_CPP
from .cpp_wrapper_cpu import CppWrapperCpu
from .wrapper import SymbolicCallArg

Expand Down Expand Up @@ -61,6 +61,98 @@ def _new_line(self, line):
return DeferredCudaKernelLine(self.kernel_name, line, self.keys)


class DeferredCudaDefaultGrid:
"""
A marker to
"""

def __init__(
self,
kernel_name: str,
grid,
grid_callable: Optional[Callable[..., Any]] = None,
**grid_extra_kwargs,
):
self.kernel_name = kernel_name
self.grid = grid
self.grid_callable = grid_callable
self.grid_extra_kwargs = grid_extra_kwargs

def __call__(self):
grid = self.grid
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_callable = self.grid_callable or default_grid
if not self.grid_extra_kwargs:
grid_fn = grid_callable(*grid)
else:
grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)

params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
block_cfg = {
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
}
return grid_fn(block_cfg)


class DeferredCudaGridLine(DeferredLineBase):
"""
When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
to be tuned and stored as cubin files, so use a deferred line to backfill those information
"""

def __init__(
self,
kernel_name: str,
grid_var: str,
grid,
autotune_configs,
):
super().__init__("")
self.kernel_name = kernel_name
self.grid_var = grid_var
self.grid = grid
self.autotune_configs = autotune_configs

def __call__(self):
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"

if self.autotune_configs is not None:
# This indicates the Triton kernel is a user-defined one.
grid = None
if len(self.grid) == 1:
grid = self.grid[0]
else:
for i, c in enumerate(self.autotune_configs):
if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
grid = self.grid[i]
break
assert grid is not None
elif isinstance(self.grid, DeferredCudaDefaultGrid):
grid = self.grid()
else:
grid = self.grid

assert len(grid) != 0, "Grid can't be empty"
grid_args_str = ", ".join(
[cexpr(V.graph.sizevars.simplify(item)) for item in grid]
)
return f"Grid {self.grid_var} = Grid({grid_args_str});"

def _new_line(self, line):
return DeferredCudaGridLine(
self.kernel_name, self.grid_var, self.grid, self.autotune_configs
)


class CppWrapperCuda(CppWrapperCpu):
"""
Generates cpp wrapper for running on GPU and calls CUDA kernels
Expand Down Expand Up @@ -116,28 +208,20 @@ def generate(self, is_inference):
return super().generate(is_inference)

def generate_user_defined_triton_kernel(
self, kernel_name, raw_args, grid, configs, triton_meta, constexprs
self,
kernel_name: str,
raw_args: List[Any],
grid: List[Any],
configs,
triton_meta,
constexprs,
):
# in C++ wrapper, we don't pass constexpr args, as they don't
# get added as parameters to the PTX code compiled from the
# user-defined Triton kernel (only non-constexpr args do)
raw_args = [
raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
]

assert len(grid) != 0
if len(grid) == 1:
grid_decision = grid[0]
else:
meta = CudaKernelParamCache.get(kernel_name)
assert meta is not None
grid_decision = None
for i, c in enumerate(configs):
if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()):
grid_decision = grid[i]
break
assert grid_decision is not None

args = [self.val_to_arg_str(v) for v in raw_args]
arg_types = [
arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
Expand All @@ -147,10 +231,12 @@ def generate_user_defined_triton_kernel(
kernel_name,
args,
arg_types=arg_types,
grid=grid_decision,
raw_args=raw_args,
grid=grid,
cuda=True,
triton=True,
triton_meta=triton_meta,
autotune_configs=configs,
)

@functools.lru_cache(None) # noqa: B019
Expand Down Expand Up @@ -228,39 +314,27 @@ def generate_args_decl(self, call_args, arg_types):

def generate_default_grid(
self,
name: str,
kernel_name: str,
grid: List[Any],
cuda: bool = True,
grid_callable: Optional[Callable[..., Any]] = None,
**grid_extra_kwargs,
):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics.
function from triton_heuristics. Because its computation needs
to read kernel config after autotune, it is done in a deferred way
using DeferredCudaDefaultGrid.
"""
if not cuda:
return grid
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_callable = grid_callable or default_grid
if not grid_extra_kwargs:
grid_fn = grid_callable(*grid)
else:
grid_fn = grid_callable(*grid, **grid_extra_kwargs)
params = CudaKernelParamCache.get(name)
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
block_cfg = {
"XBLOCK": params["x_block"],
"YBLOCK": params["y_block"],
"ZBLOCK": params["z_block"],
}
return grid_fn(block_cfg)
return DeferredCudaDefaultGrid(
kernel_name, grid, grid_callable, **grid_extra_kwargs
)

def generate_kernel_call(
self,
kernel_name,
kernel_name: str,
call_args,
grid=None,
device_index=None,
Expand All @@ -270,6 +344,7 @@ def generate_kernel_call(
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
assert arg_types is not None and len(call_args) == len(
Expand All @@ -279,7 +354,18 @@ def generate_kernel_call(
if not cuda:
# Even in CppWrapperCuda, we may see cpp kernels
return super().generate_kernel_call(
kernel_name, call_args, grid, device_index, cuda, triton, arg_types
kernel_name,
call_args,
grid,
device_index,
cuda,
triton,
arg_types,
raw_args,
grid_fn,
triton_meta,
autotune_configs,
grid_extra_kwargs,
)

device_index, call_args = self.prepare_triton_kernel_call(
Expand Down Expand Up @@ -307,33 +393,26 @@ def generate_kernel_call(
if V.graph.aot_mode
else self.write_get_raw_stream(device_index, V.graph)
)
grid_name = f"{kernel_name}_grid_{next(self.grid_id)}"
assert isinstance(
grid, (list, tuple)
), f"expected grid to be a list or tuple but got: {grid=}"

grid = [V.graph.sizevars.simplify(item) for item in grid]
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
grid_args = [self.expr_printer(item) for item in grid]
grid_args_str = ", ".join(grid_args)
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")

if grid_uses_symbolic_shapes:
self.writeline(f"if ({grid_name}.is_non_zero()) {{")

grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
self.writeline(
DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs)
)

kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
self.writeline(f"if ({grid_var}.is_non_zero()) {{")
self.writeline(
DeferredCudaKernelLine(
kernel_name,
r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
kernel_var_name,
f"{grid_name}.grid_x",
f"{grid_name}.grid_y",
f"{grid_name}.grid_z",
f"{grid_var}.grid_x",
f"{grid_var}.grid_y",
f"{grid_var}.grid_z",
kernel_args_var,
stream,
),
("num_warps", "shared_mem"),
),
)
if grid_uses_symbolic_shapes:
self.writeline("}")
self.writeline("}")
11 changes: 9 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,13 @@ def generate_extern_kernel_out(
self.writeline(f"{kernel}({', '.join(args)})")

def generate_user_defined_triton_kernel(
self, kernel_name, raw_args, grid, configs, triton_meta, constexprs
self,
kernel_name: str,
raw_args: List[Any],
grid: List[Any],
configs,
triton_meta,
constexprs,
):
grid_fn, code = user_defined_kernel_grid_fn_code(
kernel_name, configs, grid, wrapper=self
Expand Down Expand Up @@ -1541,7 +1547,7 @@ def generate_save_uncompiled_kernels(self):

def generate_default_grid(
self,
name: str,
kernel_name: str,
grid: List[Any],
cuda: bool = True,
grid_callable: Optional[Callable[..., Any]] = None,
Expand Down Expand Up @@ -1632,6 +1638,7 @@ def generate_kernel_call(
raw_args=None,
grid_fn: str = "grid",
triton_meta=None,
autotune_configs=None,
grid_extra_kwargs="",
):
"""
Expand Down

0 comments on commit a4b44dd

Please sign in to comment.