Skip to content

Commit

Permalink
Fix mod semantics for Z3Ops. (pytorch#104827)
Browse files Browse the repository at this point in the history
Python `mod` semantics is not the same as the mathematical modulus operation. According to
the Python reference: `a = floor(a / b) * b + a % r`.

In other words: `a % b = a - floor(a / b) * b`.

This PR fixes the old implementation which used SMT-LIB2 semantics for `mod`. In short, it
only worked with integers and had the following guarantee: `0 <= a % b < b`.

In summary, the changes are:
- `a % b = a - floordiv(a, b) * b`
- `a` and `b` can be both integer or real
- The result will be real if any of the arguments is real. Otherwise, it will be integer

Pull Request resolved: pytorch#104827
Approved by: https://github.com/lezcano
  • Loading branch information
ysiraichi authored and pytorchmergebot committed Jul 10, 2023
1 parent 951b9a6 commit d5dbe77
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
14 changes: 11 additions & 3 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5942,7 +5942,6 @@ def test_sympy_to_z3_translation(self):
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mod,
operator.mul,
operator.pow,
)
Expand All @@ -5968,9 +5967,18 @@ def test_sympy_to_z3_translation(self):
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(s2 % (s0 / s1), z2 % z3.ToInt(z3.ToReal(z0) * (z1**-1))),
(s2 % (s0**3), z2 % z3.ToInt(z0**3)),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(s0 % s1, z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
s2 % (s0 / s1),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
s2 % (s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]

toZ3 = SympyToZ3(validator)
Expand Down
5 changes: 3 additions & 2 deletions torch/fx/experimental/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
return z3.If(a < b, a, b) # type: ignore[return-value]

# Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q
# It should work with both integer and reals.
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef:
self.validator.add_assertion(q != 0) # type: ignore[arg-type]
return Z3Ops.to_int(p) % Z3Ops.to_int(q)
return p - self.floordiv(p, q) * q

def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef:
# Z3 can't handle complex numbers very well.
Expand Down

0 comments on commit d5dbe77

Please sign in to comment.