Skip to content

Commit

Permalink
[ARITH] support floordiv in deduce bound (apache#13880)
Browse files Browse the repository at this point in the history
* support floordiv in deduce bound

* add rule for (x // -positive)

* leave todo for x // a == b
  • Loading branch information
wrongtest-intellif authored Feb 1, 2023
1 parent 206f085 commit c3fe08f
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 8 deletions.
61 changes: 54 additions & 7 deletions src/arith/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {

void VisitExprDefault_(const Object* op) final { success_ = false; }

SignType GetSignType(const PrimExpr& e) {
if (e.dtype().is_uint()) {
return kPositive;
}
return expr_map_[e].GetSignType();
}

void VisitExpr_(const VarNode* op) final {}

void VisitExpr_(const AddNode* op) final {
Expand All @@ -119,13 +126,7 @@ class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {
PrimExpr operand = left ? op->b : op->a;
PrimExpr target_var = left ? op->a : op->b;

SignType sign_operand;
if (operand.dtype().is_uint()) {
sign_operand = kPositive;
} else {
sign_operand = expr_map_[operand].GetSignType();
}

SignType sign_operand = GetSignType(operand);
if (sign_operand == SignType::kNegative) {
comp_op = ReverseOp(comp_op);
} else if (sign_operand == SignType::kUnknown) {
Expand Down Expand Up @@ -162,6 +163,52 @@ class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {
this->VisitExpr(left ? op->a : op->b);
}

void VisitExpr_(const FloorDivNode* op) final {
if (op->b.get() == path_[iter_]) {
// Skip cases where the var is divisor.
success_ = false;
return;
}
PrimExpr divisor = op->b;
if (analyzer_.CanProveEqual(divisor, 0)) {
// Skip zero divisor
success_ = false;
return;
}

SignType sign_operand = GetSignType(divisor);
if (sign_operand == SignType::kNegative) {
comp_op = ReverseOp(comp_op);
divisor = -divisor;
result_ = -result_;
} else if (sign_operand == SignType::kUnknown) {
// unable to get the sign of operand
success_ = false;
return;
}

if (comp_op == kGreater) {
// (x // 6 >= 4 --> x >= 4 * 6)
result_ = result_ * divisor;
} else if (comp_op == kEqual) {
// The bound is not single directional
// (x // 6 == 4 --> 30 > x >= 24)
// TODO(@wrongtest): support bidirectional bound
success_ = false;
return;
} else {
// (x // 6 <= 4 --> x <= 4 * 6 + 5)
result_ = result_ * divisor + divisor - 1;
}
if (sign_operand == SignType::kNegative) {
// (x // -6 >= 4 --> -((x + 6 - 1) // 6) >= 4
// --> (x + 6 - 1) // 6 <= -4
result_ = result_ - divisor + 1;
}

this->VisitExpr(op->a);
}

PrimExpr result_;
CompareOp comp_op{kGreater};
bool success_{true};
Expand Down
38 changes: 37 additions & 1 deletion tests/python/unittest/test_arith_deduce_bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def test_non_support(lhs):
res = tvm.arith.deduce_bound(a, lhs < 10, {}, {})
assert res.is_nothing()

test_non_support(tvm.tir.floordiv(a, 16))
test_non_support(tvm.tir.floormod(a, 16))
test_non_support(tvm.tir.Min(a, 16))
test_non_support(tvm.tir.Max(a, 16))
Expand All @@ -233,5 +232,42 @@ def test_non_support(lhs):
test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a]))


def test_deduce_floordiv():
def do_test(gen_expr, dom_map, expect_min, expect_max):
a = te.var("a")
expr = gen_expr(a)
res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map)
if isinstance(expect_min, str):
assert str(res.min_value) == expect_min
else:
tvm.testing.assert_prim_expr_equal(res.min_value, expect_min)
if isinstance(expect_max, str):
assert str(res.max_value) == expect_max
else:
tvm.testing.assert_prim_expr_equal(res.max_value, expect_max)

# test basic cases
do_test(lambda a: a // 8 > 3, {}, 32, "pos_inf")
do_test(lambda a: a // 8 >= 3, {}, 24, "pos_inf")
do_test(lambda a: a // 8 < 3, {}, "neg_inf", 23)
do_test(lambda a: a // 8 <= 3, {}, "neg_inf", 31)
do_test(lambda a: a // 8 == 3, {}, "pos_inf", "neg_inf")
do_test(lambda a: a // 8 > -3, {}, -16, "pos_inf")
do_test(lambda a: a // 8 >= -3, {}, -24, "pos_inf")
do_test(lambda a: a // -8 > 3, {}, "neg_inf", -32)
do_test(lambda a: a // -8 >= 3, {}, "neg_inf", -24)
do_test(lambda a: a // -8 < 3, {}, -23, "pos_inf")
do_test(lambda a: a // -8 <= 3, {}, -31, "pos_inf")
do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf")

# test nested cases
b = te.var("b")
bs = {b: tvm.arith.IntervalSet(2, 6)}
do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359)
do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367)
do_test(lambda a: b * 3 + a // 8 > 63, bs, 464, "pos_inf")
do_test(lambda a: b * 3 + a // 8 >= 63, bs, 456, "pos_inf")


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c3fe08f

Please sign in to comment.