Skip to content

Commit

Permalink
Disable cudagraphs by default when dynamic shape is enabled. (pytorch…
Browse files Browse the repository at this point in the history
…#104448)

Disable cudagraphs when dynamic shape is enabled (via torch.compile(dynamic=True)).
Otherwise, Inductor recompiles for each new shape, which doesn't seem to be very reasonable.

Pull Request resolved: pytorch#104448
Approved by: https://github.com/jansel, https://github.com/ezyang
  • Loading branch information
ipiszy authored and pytorchmergebot committed Jul 11, 2023
1 parent 3279f06 commit e940d5d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
6 changes: 6 additions & 0 deletions test/inductor/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def test_api_options(self):
self.assertEqual(max_autotune_opts["max_autotune"], True)
self.assertEqual(max_autotune_opts["triton.cudagraphs"], True)

max_autotune_opts = torch._inductor.list_mode_options(
"max-autotune", dynamic=True
)
self.assertEqual(max_autotune_opts["max_autotune"], True)
self.assertEqual(max_autotune_opts["triton.cudagraphs"], False)

max_autotune_no_cudagraphs_opts = torch._inductor.list_mode_options(
"max-autotune-no-cudagraphs"
)
Expand Down
9 changes: 8 additions & 1 deletion torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,14 @@ def apply_mode(self, mode: Optional[str]):
pass
elif mode in ("reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"):
from torch._inductor import list_mode_options
self.apply_options(list_mode_options(mode))
if mode == "reduce-overhead" and self.dynamic:
raise RuntimeError(
"mode=reduce-overhead cannot be used together with dynamic=True! "
"reduce-overhead enables cudagraph. dynamic=True forces recompiliation "
"for each new shape, which defeats the purpose of cudagraph. "
"Please only enable one of them."
)
self.apply_options(list_mode_options(mode, self.dynamic))
else:
raise RuntimeError(
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
Expand Down
11 changes: 8 additions & 3 deletions torch/_inductor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ def aot_compile(
return lib_path


def list_mode_options(mode: str = None) -> Dict[str, Any]:
def list_mode_options(mode: str = None, dynamic: bool = None) -> Dict[str, Any]:
r"""Returns a dictionary describing the optimizations that each of the available
modes passed to `torch.compile()` performs.
Args:
mode (str, optional): The mode to return the optimizations for.
If None, returns optimizations for all modes
dynamic (bool, optional): Whether dynamic shape is enabled.
When dynamic_shape is enabled, cuda graph will be disabled.
Example::
>>> torch._inductor.list_mode_options()
Expand All @@ -76,10 +78,13 @@ def list_mode_options(mode: str = None) -> Dict[str, Any]:
"max-autotune-no-cudagraphs": {
"max_autotune": True,
},
# enable both cuda-graphs and max-autotune
# enable max-autotune
# enable cudagraphs when dynamic is not set
# otherwise, if both cudagraphs and dynamic are enabled, Inductor
# recompiles for each new shape
"max-autotune": {
"max_autotune": True,
"triton.cudagraphs": True,
"triton.cudagraphs": (dynamic is not True),
},
}
return mode_options[mode] if mode else mode_options
Expand Down

0 comments on commit e940d5d

Please sign in to comment.