Skip to content

Commit

Permalink
[dynamo] Eagerly install guards (pytorch#111415)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#111415
Approved by: https://github.com/voznesenskym
ghstack dependencies: pytorch#111306
  • Loading branch information
jansel authored and pytorchmergebot committed Nov 7, 2023
1 parent 2964682 commit 9664190
Show file tree
Hide file tree
Showing 30 changed files with 333 additions and 622 deletions.
4 changes: 3 additions & 1 deletion test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3292,7 +3292,9 @@ def forward(self, pred, x):
cos = l_x_.cos(); l_x_ = None
return pytree.tree_unflatten([cos], self._out_spec)
"""
true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
true_guard_code = [
"cast_symbool_to_symint_guardless(L['pred']) == 1",
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
Expand Down
44 changes: 20 additions & 24 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,25 +297,25 @@ def my_args_generator(t):
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor, L_z_ : torch.Tensor):
l_x_ = L_x_
l_y_ = L_y_
l_z_ = L_z_
def forward(self, L_d_x_ : torch.Tensor, L_d_y_0_ : torch.Tensor, L_d_y_1_2_ : torch.Tensor):
l_d_x_ = L_d_x_
l_d_y_0_ = L_d_y_0_
l_d_y_1_2_ = L_d_y_1_2_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, l_y_, l_z_); wrap_body_0 = l_x_ = l_y_ = l_z_ = None
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
getitem = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_, l_z_):
sin = l_x_.sin(); l_x_ = None
cos = l_y_.cos(); l_y_ = None
def forward(self, l_d_x_, l_d_y_0_, l_d_y_1_2_):
sin = l_d_x_.sin(); l_d_x_ = None
cos = l_d_y_0_.cos(); l_d_y_0_ = None
add = sin + cos; sin = cos = None
sin_1 = l_z_.sin(); l_z_ = None
sin_1 = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
sub = add - sin_1; add = sin_1 = None
return (sub,)
""",
""", # NOQA: B950
)

def test_wrap_pytree_args_with_symint_constant(self):
Expand Down Expand Up @@ -3005,9 +3005,9 @@ def fn(x):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_x_ = L_x_
def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
child = L_y_
l_x_ = L_x_
_check_randomness_arg = torch._functorch.vmap._check_randomness_arg('error')
_check_randomness_arg_1 = torch._functorch.vmap._check_randomness_arg('error')
Expand Down Expand Up @@ -3269,16 +3269,14 @@ def wrapper_fn(x, in_dims):
return torch.func.vmap(torch.sum, in_dims)(x)

x = torch.randn(3, 3, 3, 3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
cnt = CompileCounter()
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
# Third invocation of `opt` makes `in_dims` as SymInt.
actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"torch.func.vmap: in_dims is not an int or tuple variable.": 2},
)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)

def test_vmap_multiple_invocation_out_dims(self):
counters.clear()
Expand All @@ -3287,16 +3285,14 @@ def wrapper_fn(x, out_dims):
return torch.func.vmap(lambda x: torch.sum(x, 0), out_dims=out_dims)(x)

x = torch.randn(3, 3, 3, 3)
opt = torch.compile(wrapper_fn, backend="eager", fullgraph=False, dynamic=True)
cnt = CompileCounter()
opt = torch.compile(wrapper_fn, backend=cnt, fullgraph=False, dynamic=True)
expected = wrapper_fn(x, 0), wrapper_fn(x, 1), wrapper_fn(x, 2)
# Third invocation of `opt` makes `in_dims` as SymInt.
actual = opt(x, 0), opt(x, 1), opt(x, 2)
self.assertEqual(expected, actual)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"torch.func.vmap: out_dims is not an int or tuple variable.": 2},
)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)

def test_vmap_new_tensor_in_body(self):
def fn(x):
Expand Down
2 changes: 1 addition & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ def fn(x, a, b):
args = [torch.randn(10), 4096, np.int64(8)]
correct = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, dynamic=True)(fn)
opt_fn = torch._dynamo.optimize(cnts, dynamic=True, nopython=True)(fn)
self.assertTrue(same(opt_fn(*args), correct))
self.assertTrue(same(opt_fn(*args), correct))
self.assertEqual(cnts.frame_count, 1)
Expand Down
85 changes: 5 additions & 80 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,9 +814,8 @@ def fn(self, tensor):
class ReproTests(torch._dynamo.test_case.TestCase):
def test_do_paste_mask(self):
torch._dynamo.utils.counters.clear()
opt__do_paste_mask = torch._dynamo.optimize(
torch._dynamo.testing.CompileCounter()
)(_do_paste_mask)
cnt = torch._dynamo.testing.CompileCounter()
opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
opt__do_paste_mask(
torch.randn(1, 1, 28, 28),
torch.tensor([[0.0, 1, 2, 4]]) * 1,
Expand Down Expand Up @@ -852,12 +851,9 @@ def test_do_paste_mask(self):
640,
False,
)

self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3)
self.assertEqual(
torch._dynamo.utils.counters["frames"]["total"],
torch._dynamo.utils.counters["frames"]["ok"] + 1,
)
# (dynamic shapes, static shapes)
self.assertIn(cnt.frame_count, (5, 7))
self.assertIn(cnt.op_count, (106, 127))

def test_convert_boxes_to_pooler_format(self):
boxes1 = [
Expand Down Expand Up @@ -2451,77 +2447,6 @@ def f(x, y):
self.assertEqual(f(x, x), opt_f(x, x))
self.assertEqual(f(x, y), opt_f(x, y))

def test_reformer_remove_unused_args(self):
# This test case is very interesting. First, let's describe
# the bug this is testing for. The bug we fixed is twofold:
#
# - We prune GraphArgs that aren't used in the output graph.
# However, sometimes it is possible for those GraphArgs to be
# utilized in shape guards (you could imagine this happening if
# dynamo poked some shape variables without recording them in the
# graph.) If we prune those GraphArgs, we get a
# "s1 not in ..." error as we can no longer codegen the
# requested guards.
#
# - But in practice, Dynamo usually traces size accesses into the
# graph, preventing the GraphArg from getting pruned. So how
# come we were running into this in practice with hf_Reformer?
# The answer is checkpointing!
#
# This brings us to the following test case. Here's what it does:
#
# 1. It traces some operations, and then checkpoints before inlining
# the function call to g
#
# 2. g traces some more operations (triggering the shape guard
# to be created), but then it graph breaks
#
# 3. Because you can't graph break in an inlining function, we roll
# back to the outer checkpoint ("undoing" the operation that
# induced the shape guard) and then immediately generate a
# subgraph at that point.
#
# If we failed to checkpoint the ShapeEnv, it can still have guards
# from the aborted speculation, which we will then still attempt to
# codegen.
#
# There's an additional nuance: suppose x is used but y is not.
# If you create a guard like y == x * 2, you will accidentally avoid
# the "s1 not in ..." error, as y will get substituted with x * 2,
# but x is still a GraphArg (it's used) and you don't end up with
# the error. This is why we must show y + y == x, not vice versa.
# Similarly, it is also why we must not do a simple guard like x == y
#
# Can we actually demonstrate that checkpointing the ShapeEnv is
# necessary? It's not so easy to induce this case. Dynamo is very
# eager about adding locals to GraphArgs; any local that is in scope,
# even if it isn't used, is added to GraphArgs (see also
# https://github.com/pytorch/torchdynamo/issues/1925 ). So long
# as Dynamo eagerly guards in this way, we have an invariant that
# all locals are guaranteed to show up in GraphArgs before the
# inlining function call, in which case we will always have enough
# information to codegen our guards so long as we don't prune the
# unused GraphArgs away (and indeed, the direct fix for this bug
# was to make sure we use original GraphArgs). Non locals,
# conversely, typically are static, and so won't have guards allocated
# for them. That being said, there may still be a way to trigger
# this error.

def g(x, y):
r = torch.cat((y, y)) + x
print("foo")
return r

def f(x, y):
x = x * 3
return g(x, y)

opt_f = torch._dynamo.optimize("aot_eager")(f)

x = torch.randn(4)
y = torch.randn(2)
self.assertEqual(f(x, y), opt_f(x, y))

def test_swin_base_tensor_attr(self):
class Foo(torch.nn.Module):
def __init__(self):
Expand Down
9 changes: 1 addition & 8 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)

@torch.compile(backend="eager", fullgraph=True)
@torch.compile(backend="eager")
def fn(x):
return x.sigmoid()

Expand Down Expand Up @@ -819,13 +819,6 @@ def binary(nt1, nt2):
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True)

def test_binary_recompiles_due_to_duck_sizing(self):
# Even though the input is unused, we still guard due to duck sizing
nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 3), None)
nt2, _ = self._get_jagged_tensor(((2, 3, 4), 3), offsets)
nt3, _ = self._get_jagged_tensor(((2, 3, 4), 3), None)
self._check_recompiles(lambda nt1, nt2: nt1.sin(), (nt1, nt2), (nt1, nt3), True)

# TODO: cannot parametrize this test class with device for some reason
def _test_autograd(self, backend):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64)
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_subgraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ def fn(x, c1, c2, c3):
opt_fn(v1, a, b, c)

# checking here we don't create 2^n graphs
self.assertEqual(cnt.frame_count, 12)
self.assertEqual(cnt.op_count, 16)
self.assertEqual(cnt.frame_count, 7)
self.assertEqual(cnt.op_count, 10)

def test_resume_with_no_grad1(self):
def fn(a, b):
Expand Down
4 changes: 3 additions & 1 deletion test/test_pruning_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from hypothesis import given
import numpy as np
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()

Expand Down Expand Up @@ -56,6 +56,7 @@ def get_reference_result(embedding_weights, mask, indices_type):
self.assertEqual(pt_compressed_indices_map.dtype, indices_type)


@skipIfTorchDynamo()
@given(
embedding_rows=st.integers(1, 100),
embedding_dims=st.integers(1, 100),
Expand All @@ -67,6 +68,7 @@ def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, we
self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype)


@skipIfTorchDynamo()
@given(
embedding_rows=st.integers(1, 100),
embedding_dims=st.integers(1, 100),
Expand Down
2 changes: 2 additions & 0 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2530,6 +2530,7 @@ def run_test(m, n, k, nnz, train):
run_test(4, 5, 4, 10, False)
run_test(4, 4, 4, 16, True)

@skipIfTorchDynamo()
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01})
Expand Down Expand Up @@ -2894,6 +2895,7 @@ def run_test(shape, nnz, index_type):
run_test(shape, max(shape), index_dtype)
run_test(shape, shape[0] * shape[1], index_dtype)

@skipIfTorchDynamo()
@skipMeta
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@all_sparse_compressed_layouts()
Expand Down
2 changes: 0 additions & 2 deletions torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def __call__(self, value, allow_cache=True):
self.clear_tos()
return

self.tx.output.guards.update(value.guards)

assert isinstance(value, VariableTracker)
output = self._output
graph_outputs = self.graph_outputs
Expand Down
18 changes: 17 additions & 1 deletion torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def BACKEND_MATCH(self, guard: Guard):
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
)
code = [
f"___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id})"
f"(___skip_backend_check() or ___current_backend() == ___lookup_backend({backend_id}))"
]
self._produce_guard_code(guard, code)

Expand Down Expand Up @@ -1366,3 +1366,19 @@ def make_dupe_guard(obj_source, dupe_source):
# However, this should always be a sound guard to add here.
return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
return None


def install_guard(*guards, skip=0):
"""
Add dynamo guards to the current tracing context.
Args:
guards: guard(s) to add
skip: number of stack frames to ignore for debug stack trace
"""
from torch._guards import TracingContext

add = TracingContext.get().guards_context.dynamo_guards.add
for guard in guards:
assert isinstance(guard, Guard)
add(guard, skip=skip + 1)
18 changes: 9 additions & 9 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
unimplemented,
unimplemented_with_warning,
)
from .guards import GuardBuilder
from .guards import GuardBuilder, install_guard
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
Expand Down Expand Up @@ -561,7 +561,11 @@ def restore_graphstate(self, state: OutputGraphState):
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
removed_nodes = 0
for node in reversed(list(self.graph.nodes)):
if node.meta["creation_timestamp"] > self.timestamp:
if (
node.meta["creation_timestamp"] > self.timestamp
# placeholders here may have been lazily added by existing objects
and node.op != "placeholder"
):
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
Expand Down Expand Up @@ -670,7 +674,6 @@ def register_attr_or_module(
return variables.UnspecializedNNModuleVariable(target, **options)

options = dict(options)
options["guards"] = set(options.get("guards", []))
assert "source" in options
source = options["source"]
assert not isinstance(source, ParamBufferSource)
Expand All @@ -692,10 +695,10 @@ def register_attr_or_module(
tracer = self.root_tracer

if not is_constant_source(source):
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))

if get_static_address_type(target) == "guarded":
options["guards"].add(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH))

def wrap_name(module_key):
assert self.param_name_to_source is not None
Expand All @@ -711,7 +714,7 @@ def wrap_name(module_key):
elif isinstance(target, torch.nn.Module):
assert isinstance(target, torch.nn.Module)

options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
install_guard(source.make_guard(GuardBuilder.NN_MODULE))

def wrap_name(module_key):
return NNModuleVariable(type(target), module_key, **options)
Expand Down Expand Up @@ -1005,9 +1008,6 @@ def compile_and_call_fx_graph(self, tx, rv, root):

assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
for output in rv:
self.guards.update(output.guards)

self.create_node(
"output",
"output",
Expand Down
Loading

0 comments on commit 9664190

Please sign in to comment.