Skip to content

Commit

Permalink
[TIR] Preserve loop annotation after loop partitioning (apache#13292)
Browse files Browse the repository at this point in the history
Preserve loop annotations when the loop is get partitioned. Also we bind the loop region info to the analyzer for some cases some partition condition could not get solved due to unknown (but trivial) loop region.
  • Loading branch information
wrongtest-intellif authored Nov 5, 2022
1 parent 1e79364 commit 732e34f
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 20 deletions.
6 changes: 4 additions & 2 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class LoopPartitioner : public StmtMutator {
}

Stmt VisitStmt_(const ForNode* op) final {
analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true);
auto fs = GetRef<Stmt>(op);
if (selector.candidates.count(fs)) {
Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
Expand Down Expand Up @@ -697,12 +698,13 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
const ForNode* for_node = static_cast<const ForNode*>(node);
ICHECK(for_node);
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
!no_unroll_loop_with_extent_one_) {
!no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else {
ICHECK(for_node->kind != ForKind::kThreadBinding);
return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body);
return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body,
for_node->thread_binding, for_node->annotations);
}
}

Expand Down
128 changes: 110 additions & 18 deletions tests/python/unittest/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,17 @@ def test_explicit_partition_hint():
assert tvm.ir.structural_equal(mod["main"], partitioned_concat)


def partition_from_scheduled_tir(prim_func, pass_cfg):
with tvm.transform.PassContext(config=pass_cfg):
mod = IRModule.from_expr(prim_func)
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
return mod


@T.prim_func
def partitioned_concat_3(
placeholder: T.Buffer[(50176,), "int8"],
Expand Down Expand Up @@ -609,13 +620,9 @@ def concat_func_3(


def test_condition_mutually_exclusive():
mod = IRModule.from_expr(concat_func_3)
with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = partition_from_scheduled_tir(
concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}}
)
assert tvm.ir.structural_equal(mod["main"], partitioned_concat_3)


Expand Down Expand Up @@ -650,23 +657,108 @@ def partitioned_main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) ->
if ax2 < 5 and ax3 < 3:
B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 + 219]

mod = tvm.ir.module.IRModule.from_expr(main)
with tvm.transform.PassContext(
config={
mod = partition_from_scheduled_tir(
main,
{
"tir.LoopPartition": {
"partition_const_loop": True,
"unroll_loop_with_partition_hint_no_interval": True,
}
}
):
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
mod = tvm.tir.transform.FlattenBuffer()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
},
)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
assert tvm.ir.structural_equal(mod["main"], partitioned_main)


def test_loop_partition_keep_loop_annotations():
@T.prim_func
def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None:
for i in T.serial(
160,
annotations={"pragma_loop_partition_hint": True, "key": "value"},
):
if i < 10:
B[i] = A[i] + 1
elif 10 <= i and i < 150:
B[i] = A[i] + 2
else:
B[i] = A[i] + 3

@T.prim_func
def after(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None:
T.preflattened_buffer(A, [160], dtype="int32", data=A.data)
T.preflattened_buffer(B, [160], dtype="int32", data=B.data)
for i in T.serial(10, annotations={"key": "value"}):
B[i] = A[i] + 1
for i in T.serial(140, annotations={"key": "value"}):
B[i + 10] = A[i + 10] + 2
for i in T.serial(10, annotations={"key": "value"}):
B[i + 150] = A[i + 150] + 3

mod = partition_from_scheduled_tir(
before,
{
"tir.LoopPartition": {
"partition_const_loop": True,
}
},
)
assert tvm.ir.structural_equal(mod["main"], after)


def test_loop_partition_with_unit_loop_in_condition():
@T.prim_func
def before(
placeholder: T.Buffer[(50176,), "int8"],
placeholder_1: T.Buffer[(25088,), "int8"],
placeholder_2: T.Buffer[(25088,), "int8"],
T_concat: T.Buffer[(100352,), "int8"],
) -> None:
for k in range(1, annotations={"preserve_unit_loop": True}):
for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
for i2, i3 in T.grid(28, 28):
if 96 <= k * 128 + i1:
T_concat[k * i1 * 784 + i2 * 28 + i3] = placeholder_2[
i1 * 784 + i2 * 28 + i3 - 75264
]
if 64 <= k * 128 + i1 and k * 128 + i1 < 96:
T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[
i1 * 784 + i2 * 28 + i3 - 50176
]
if k * 128 + i1 < 64:
T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]

@T.prim_func
def after(
placeholder: T.Buffer[50176, "int8"],
placeholder_1: T.Buffer[25088, "int8"],
placeholder_2: T.Buffer[25088, "int8"],
T_concat: T.Buffer[100352, "int8"],
) -> None:
T.preflattened_buffer(placeholder, [50176], dtype="int8", data=placeholder.data)
T.preflattened_buffer(placeholder_1, [25088], dtype="int8", data=placeholder_1.data)
T.preflattened_buffer(placeholder_2, [25088], dtype="int8", data=placeholder_2.data)
T.preflattened_buffer(T_concat, [100352], dtype="int8", data=T_concat.data)
for _ in T.serial(1, annotations={"preserve_unit_loop": True}):
for i1, i2, i3 in T.grid(64, 28, 28):
T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
for i1, i2, i3 in T.grid(32, 28, 28):
T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3]
for i1, i2, i3 in T.grid(32, 28, 28):
T_concat[i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3]

mod = partition_from_scheduled_tir(
before,
{
"tir.LoopPartition": {
"partition_const_loop": True,
}
},
)
assert tvm.ir.structural_equal(mod["main"], after)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 732e34f

Please sign in to comment.