Skip to content

Commit

Permalink
Add Inductor config for default stride behavior (#135238)
Browse files Browse the repository at this point in the history
By default, Inductor is allowed to manipulate the layout
(strides+storage offset) of input tensors to custom operators.

We want to change it so that the default is that Inductor should respect
the stride order of input tensors to custom operators.

This PR adds a config to toggle the behavior, in the next PR up we'll
change the default. We also make the following changes:
- We add a new operator Tag (flexible_layout), which means that
inductor is allowed to manipulate the layout. When we flip the default,
users can specify they want the old behavior by using this tag.

This is a reland of pytorch/pytorch#126986,
which was previously reverted due to silent incorrectness. We've since
fixed the silent incorrectness
(pytorch/pytorch#133639)

Test Plan:
- new test

Pull Request resolved: pytorch/pytorch#135238
Approved by: https://github.com/albanD
  • Loading branch information
zou3519 authored and pytorchmergebot committed Sep 6, 2024
1 parent 3a9e33d commit ad29a2c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 5 deletions.
9 changes: 9 additions & 0 deletions aten/src/ATen/native/tags.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@
desc: |
This tag indicates that the operator should be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor.
Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
multiple are assigned then we assume the most restrictive one.
- tag: flexible_layout
desc: |
This tag indicates that the custom operator can accept inputs with varying
strides/storage_offset and that when compiled, Inductor is allowed to change
the strides/storage_offset of inputs to the custom operator.
Only one of {needs_fixed_stride_order, flexible_layout} can apply; if
multiple are assigned then we assume the most restrictive one.
# NOTE [Core ATen Ops]
- tag: core
Expand Down
20 changes: 20 additions & 0 deletions test/test_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3428,6 +3428,26 @@ def vmap(info, in_dims, w, x=2, *, y=3, z):
self.assertTrue(called)
self.assertEqual(result, w * 2 * 3 * 42)

def test_layout_constraint_tags(self):
needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order
flexible_layout = torch._C.Tag.flexible_layout
# (tags, the result of the tag inference)
tests = [
({needs_fixed_stride_order}, needs_fixed_stride_order),
({flexible_layout}, flexible_layout),
# If no tags are provided, then the following is the default
(set(), flexible_layout),
# If multiple tags are provided, then we use the most constrained tag.
({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order),
]
from torch._inductor.lowering import get_layout_constraint_tag

for tags, expected in tests:
with torch.library._scoped_library("mylib", "FRAGMENT") as m:
m.define("foobar(Tensor x) -> Tensor", tags=tags)
result = get_layout_constraint_tag(torch.ops.mylib.foobar.default)
self.assertEqual(result, expected)

@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_vmap(self):
for mode in ["function", "qualname", "opoverload", "c_opdef"]:
Expand Down
7 changes: 7 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def autotune_remote_cache_default() -> Optional[bool]:
# sleep in inductor for testing
sleep_sec_TESTING_ONLY: Optional[int] = None

# The default layout constraint for custom operators.
# This must be the name of one of the layout constraint tags
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
# If the custom op does not have a layout constraint tag already
# then we assume the following applies.
custom_op_default_layout_constraint = "flexible_layout"

# use cpp wrapper instead of python wrapper
cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"

Expand Down
29 changes: 24 additions & 5 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,30 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A
_maybe_layout_constraints[fn] = None
return None
# We lazily register tag-based layout constraints.
if torch._C.Tag.needs_fixed_stride_order in fn.tags:
_maybe_layout_constraints[fn] = constrain_to_fx_strides
return _maybe_layout_constraints[fn]
_maybe_layout_constraints[fn] = None
return None

def handle_layout_constraint_tag(tag):
if tag is torch._C.Tag.needs_fixed_stride_order:
_maybe_layout_constraints[fn] = constrain_to_fx_strides
return _maybe_layout_constraints[fn]
elif tag is torch._C.Tag.flexible_layout:
_maybe_layout_constraints[fn] = None
return None
else:
raise AssertionError(f"Unknown layout constraint tag: {tag}")

tag = get_layout_constraint_tag(fn)
return handle_layout_constraint_tag(tag)


def get_layout_constraint_tag(fn):
tags_by_priority = [
torch._C.Tag.needs_fixed_stride_order,
torch._C.Tag.flexible_layout,
]
for tag in tags_by_priority:
if tag in fn.tags:
return tag
return getattr(torch._C.Tag, config.custom_op_default_layout_constraint)


def assert_nyi(cond, msg):
Expand Down

0 comments on commit ad29a2c

Please sign in to comment.