Skip to content

Commit

Permalink
[runtime asserts] deduplicate runtime asserts & CSE (#128599)
Browse files Browse the repository at this point in the history
This PR adds deduplication and CSE for runtime asserts. Existing size computation in the graph is CSE'd along with added runtime asserts, and redundant asserts are removed. Shape calls on intermediate tensors are also turned into compute on input sizes if possible, allowing intermediate tensors to be freed earlier. For example:
```
z = torch.cat([x, x], dim=0)  # 2*s0
w = z.repeat(y.shape[0])  # 2*s0*s1
_w = w.shape[0]
# something with _w ...

# turns into ->
s0 = x.shape[0]
s1 = y.shape[0]
_w0 = 2 * s0
_w = _w0 * s1
```

Additionally, constrain_range calls are deduplicated. Single-symbol bound checks for unbacked symbols (e.g. u0 >= 0, u0 <= 5) and sym_constrain_range.default calls are also removed, since they accumulate range info in the ShapeEnv, and are replaced with two _assert_scalar.default calls that check the min/max bounds. For example:
```
torch.sym_constrain_range_for_size(n, min=2, max=16)
torch.sym_constrain_range(n, min=4, max=20)
torch._check(n >= 0)
torch._check(n >= 3)
torch._check(n <= 14)

# turns into
torch.sym_constrain_range_for_size(n)
torch._check(n >= 4)
torch._check(n <= 14)
```

Pull Request resolved: pytorch/pytorch#128599
Approved by: https://github.com/ezyang
  • Loading branch information
pianpwk authored and pytorchmergebot committed Jul 6, 2024
1 parent 7c43f59 commit 0267b2d
Show file tree
Hide file tree
Showing 18 changed files with 593 additions and 323 deletions.
11 changes: 5 additions & 6 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,7 +2042,7 @@ def f(x):
if node.op == "call_function" and node.target == operator.getitem:
count += 1

self.assertEqual(count, 3)
self.assertEqual(count, 1)
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)

def test_dynamic_slicing_invalid(self):
Expand Down Expand Up @@ -3489,16 +3489,15 @@ def forward(self, x):
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sub = sym_size_int - 1
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 2)
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int_1); slice_2 = None
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None
sub_1 = sym_size_int - 2
slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int_1); slice_5 = None
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int); slice_5 = None
slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None
sub_2 = sym_size_int - 3; sym_size_int = None
sub_2 = sym_size_int - 3
slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int_1); slice_8 = sym_size_int_1 = None
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int); slice_8 = sym_size_int = None
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None
return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
)
Expand Down
8 changes: 3 additions & 5 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def f(x):
f,
default_args_generator((x,)),
ifdynstaticdefault(2, 3),
expected_opcount=ifdynstaticdefault(2, 3),
expected_opcount=2,
)

def test_wrap_pytree_args_nested(self):
Expand Down Expand Up @@ -356,7 +356,7 @@ def f(x, y):
f,
default_args_generator((x, y)),
ifdynstaticdefault(2, 3),
expected_opcount=ifdynstaticdefault(2, 3),
expected_opcount=2,
return_graph=True,
)
if torch._dynamo.config.assume_static_by_default:
Expand Down Expand Up @@ -387,10 +387,8 @@ class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
l_x_ = L_x_
size: "Sym(s0)" = l_x_.size(0)
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, size); wrap_body_0 = l_x_ = size = None
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, s0); wrap_body_0 = l_x_ = s0 = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
Expand Down
1 change: 1 addition & 0 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def fn(a):
"recompiles_verbose",
"graph_breaks",
"graph",
"graph_code",
"graph_sizes",
"ddp_graphs",
"perf_hints",
Expand Down
30 changes: 14 additions & 16 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,6 @@ def unpack4(a, b):
unpack4,
2,
expected_ops=5,
expected_ops_dynamic=ifdynstaticdefault(5, 7),
)

def test_unpack5(self):
Expand All @@ -388,7 +387,6 @@ def unpack5(a, b):
unpack5,
2,
expected_ops=5,
expected_ops_dynamic=ifdynstaticdefault(5, 7),
)

def test_matmul1(self):
Expand All @@ -412,7 +410,7 @@ def fn(x):
return x + y

torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 11)
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 9)
)

@torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
Expand Down Expand Up @@ -959,7 +957,7 @@ def fn(x):
return x + p

torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 10)
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 6)
)

def test_int_shape_inplace_binops(self):
Expand All @@ -983,7 +981,7 @@ def fn(x):
return x + y

torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 4)
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 2)
)

def test_int_int_comparisons(self):
Expand Down Expand Up @@ -1085,7 +1083,7 @@ def forward(self, x):
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(counts.op_count, """1""")
else:
self.assertExpectedInline(counts.op_count, """11""")
self.assertExpectedInline(counts.op_count, """9""")

def test_user_defined_binop(self):
class MyClass:
Expand All @@ -1112,7 +1110,7 @@ def fn(x, c):
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(counts.op_count, """1""")
else:
self.assertExpectedInline(counts.op_count, """4""")
self.assertExpectedInline(counts.op_count, """2""")

def test_user_defined_iter(self):
class Mod:
Expand Down Expand Up @@ -1415,14 +1413,14 @@ def _fn(a, b, func=func):
get_test_fn(func=min),
2,
expected_ops=1,
expected_ops_dynamic=ifdynstaticdefault(1, 14),
expected_ops_dynamic=ifdynstaticdefault(1, 10),
)
torch._dynamo.testing.standard_test(
self,
get_test_fn(func=max),
2,
expected_ops=1,
expected_ops_dynamic=ifdynstaticdefault(1, 17),
expected_ops_dynamic=ifdynstaticdefault(1, 5),
)

@torch._dynamo.config.patch(capture_scalar_outputs=True)
Expand Down Expand Up @@ -1755,7 +1753,7 @@ def fn(a):
fn=fn,
nargs=1,
expected_ops=3,
expected_ops_dynamic=ifdynstaticdefault(3, 6),
expected_ops_dynamic=ifdynstaticdefault(3, 4),
)

def test_pair(self):
Expand All @@ -1771,7 +1769,7 @@ def fn(a):
fn=fn,
nargs=1,
expected_ops=5,
expected_ops_dynamic=ifdynstaticdefault(5, 8),
expected_ops_dynamic=5,
)

@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
Expand Down Expand Up @@ -1947,7 +1945,7 @@ def fn(a):

# expect 3 ops post folding for dynamic case: size, index, add
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(1, 3)
self, fn, 1, expected_ops=1, expected_ops_dynamic=1
)

def test_tuple_iadd_with_shape(self):
Expand All @@ -1959,9 +1957,9 @@ def fn(a):
output += (2, 3)
return output

# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
# expect 4 add / subs for static
torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(4, 12)
self, fn, 1, expected_ops=4, expected_ops_dynamic=4
)

def test_list_iadd_with_shape(self):
Expand All @@ -1973,10 +1971,10 @@ def fn(a):
output += (a + a.shape[0], a - a.shape[0])
return output

# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
# expect 6 add / subs for static

torch._dynamo.testing.standard_test(
self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(6, 18)
self, fn, 1, expected_ops=6, expected_ops_dynamic=6
)

def test_list_iadd_side_effect(self):
Expand Down
18 changes: 7 additions & 11 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def test_do_paste_mask(self):
)
# (dynamic shapes, static shapes)
self.assertIn(cnt.frame_count, (5, 7))
self.assertIn(cnt.op_count, (104, 106, 127))
self.assertIn(cnt.op_count, (94, 106, 121))

def test_convert_boxes_to_pooler_format(self):
boxes1 = [
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def fn(boxes):
self.assertExpectedInline(cnt.op_count, """1""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """6""")
self.assertExpectedInline(cnt.op_count, """2""")

def _reformer(self, nopython):
input = torch.randn([1, 64, 256])
Expand Down Expand Up @@ -1238,13 +1238,13 @@ def test_longformer_chunk(self):
if torch._dynamo.config.assume_static_by_default:
if torch._dynamo.config.automatic_dynamic_shapes:
self.assertExpectedInline(cnt.frame_count, """2""")
self.assertExpectedInline(cnt.op_count, """14""")
self.assertExpectedInline(cnt.op_count, """8""")
else:
self.assertExpectedInline(cnt.frame_count, """2""")
self.assertExpectedInline(cnt.op_count, """4""")
else:
self.assertExpectedInline(cnt.frame_count, """2""")
self.assertExpectedInline(cnt.op_count, """35""")
self.assertExpectedInline(cnt.op_count, """21""")

def test_hf_t5_forward(self):
input = torch.randn([1, 2048, 512])
Expand Down Expand Up @@ -1609,12 +1609,8 @@ def fn(cfg):
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
self.assertEqual(opt_fn(cfg), 64)
# With unspec int, maximum computation is preserved
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """3""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """4""")
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """3""")

def test_reformer_sorting(self):
x = torch.zeros([1, 12, 4096], dtype=torch.int64)
Expand All @@ -1629,7 +1625,7 @@ def test_reformer_sorting(self):
self.assertExpectedInline(cnt.op_count, """14""")
else:
self.assertExpectedInline(cnt.frame_count, """1""")
self.assertExpectedInline(cnt.op_count, """27""")
self.assertExpectedInline(cnt.op_count, """16""")

def test_recursive_map(self):
# https://github.com/pytorch/torchdynamo/issues/132
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_subgraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def fn(a, b):
x = torch.add(unsupported(x, x), 1)
return a * x + len_(b)

self._common(fn, 2, ifdynstaticdefault(4, 5))
self._common(fn, 2, 4)

def test_restore_range(self):
def fn(a, b):
Expand Down Expand Up @@ -587,7 +587,7 @@ def fn(a, b):
b = b + x * i
return b

self._common(fn, 1, ifdynstaticdefault(2, 7))
self._common(fn, 1, ifdynstaticdefault(2, 3))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 0267b2d

Please sign in to comment.