Skip to content

Commit 3152fea

Browse files
lezcanopytorchmergebot
authored andcommitted
Assert that we can compute the bounds for guards using rational numbers (pytorch#105139)
This makes sure that the bounds are always correct, as we're not losing precision Pull Request resolved: pytorch#105139 Approved by: https://github.com/ezyang
1 parent 34c91a7 commit 3152fea

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

torch/fx/experimental/symbolic_shapes.py

+13
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,16 @@ def eval_guards(gm, *args, ignore_static=True):
471471
def bind_symbols(gm, *args):
472472
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
473473

474+
def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges):
475+
"""
476+
We assert that the bounds are either Boolean, or not finite, or can be computed
477+
in exact prevision via rational arithmetic.
478+
The only exception to this is the rare case when the user calls `sqrt(s0)`
479+
sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still)
480+
"""
481+
assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr)
482+
assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr)
483+
474484
class DimDynamic(Enum):
475485
"""
476486
Controls how to perform symbol allocation for a dimension. It is always
@@ -2956,6 +2966,8 @@ def replace(expr, repl):
29562966

29572967
# Check if the range can solve it statically
29582968
out = bound_sympy(new_expr, new_range_env)
2969+
_assert_bound_is_rational(new_expr, out)
2970+
29592971
if out.is_singleton():
29602972
return out.lower
29612973

@@ -3415,6 +3427,7 @@ def simplify_until(expr: sympy.Expr, max_iterations: int = 10) -> sympy.Expr:
34153427
lower, upper = vr.lower, vr.upper
34163428

34173429
rhs_vr = bound_sympy(expr.rhs, self.var_to_range)
3430+
_assert_bound_is_rational(expr.rhs, rhs_vr)
34183431
lower_guard, upper_guard = self.var_to_guards.get(symbol, (None, None))
34193432

34203433
# Let's suppose that we have a preexisting range for x [0, 100].

0 commit comments

Comments
 (0)