Skip to content

Commit

Permalink
do not raise when flatten_fn_with_keys not found when suggesting fixe…
Browse files Browse the repository at this point in the history
…s (#135518)

Test Plan: added test

Differential Revision: D62395371

Pull Request resolved: pytorch/pytorch#135518
Approved by: https://github.com/zhxchen17
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Sep 10, 2024
1 parent 1d9feff commit 6546c61
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
30 changes: 30 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,36 @@ def forward(self, xs, y):
strict=strict,
)

class Box:
def __init__(self, content):
self.content = content

from torch.utils._pytree import register_pytree_node

register_pytree_node(
Box,
lambda box: ([box.content], None), # flatten_fn
lambda contents, _context: Box(*contents), # unflatten_fn
flatten_with_keys_fn=None, # unflatten_fn
serialized_type_name="test_no_suggested_fixes_for_data_dependent_errors.Box",
)

class cf_stacklist_udd(torch.nn.Module):
def forward(self, xs, y):
box = Box(y.item())
# box.content is not a local, so we can't suggest a fix
return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze()

with self.assertRaisesRegex(
error_type,
"Could not guard on data-dependent expression u0 < 0",
):
export(
cf_stacklist_udd(),
([torch.ones(5) * i for i in range(10)], torch.tensor(2)),
strict=strict,
)

def test_tolist(self):
class M(torch.nn.Module):
def forward(self, x):
Expand Down
11 changes: 10 additions & 1 deletion torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5588,8 +5588,17 @@ def _suggest_fixes_for_data_dependent_error_non_strict(e):
# map symbol names reachable via frame locals to their source-level names
src_map = defaultdict(list)
for var, val in frame.f_locals.items():
try:
tree_leaves_with_path = pytree.tree_leaves_with_path(val)
except ValueError:
log.warning(
"pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}",
type(val),
var,
)
continue
# figure out how to access any symbol inside `val` through `var`
for path, leaf in pytree.tree_leaves_with_path(val):
for path, leaf in tree_leaves_with_path:
name = var + pytree.keystr(path)
if isinstance(leaf, torch.SymInt):
src_map[str(leaf.node.expr)].append(name)
Expand Down

0 comments on commit 6546c61

Please sign in to comment.