Skip to content

Commit

Permalink
[LoopPartition] Fix a bug of LoopPartition in single point scenarioes (
Browse files Browse the repository at this point in the history
…apache#16104)

Fix a bug of LoopPartition in single point scenarioes.

Co-authored-by: lightzhan-intellif <[email protected]>
  • Loading branch information
lightzhan-intellif and lightzhan-intellif authored Dec 15, 2023
1 parent b3eec91 commit 870246a
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,12 +588,47 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
}
}

bool all_singlepoints_outside = true;

// Check all partitions to see if they are single points and outside `for_interval`
for (const auto& partition : finder.partitions) {
const auto& intset = partition.second;
// Only proceed if the interval set is a single point
if (intset.IsSinglePoint()) {
auto single_point = intset.PointValue();
// Check if the single point is outside the `for_interval`
bool is_inside = analyzer_.CanProve(single_point >= for_interval.min()) &&
analyzer_.CanProve(single_point <= for_interval.max());
if (is_inside) {
// If any single point is inside, this is an error condition
LOG(ERROR) << "unexpected case happened.";
all_singlepoints_outside = false;
break;
}
} else {
// If there is any intset that is not a single point, follow default logic
// For now, we set all_singlepoints_outside to false to indicate default logic was used
all_singlepoints_outside = false;
break;
}
}

if (all_singlepoints_outside) {
// If all single points are outside `for_interval`, return a nothing interval and false
return {IntSet::Nothing(), ExpressionSet(), false};
}

// we couldn't find an interval in which the conditions are
// provably true or false. Therefore, we can't partition the loop
// based on those conds
return {{}, {}, std::nullopt};
}();

if (middle_interval.IsNothing() && opt_cond_value == false) {
// Return loop directly as it can be simplified.
return stmt;
}

if (!opt_cond_value.has_value()) {
if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
analyzer_.CanProve(max - min > 0)) {
Expand Down
228 changes: 228 additions & 0 deletions tests/python/tir-transform/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import te
Expand Down Expand Up @@ -834,5 +835,232 @@ def after(
assert tvm.ir.structural_equal(mod["main"], after.with_attr("global_symbol", "main"))


@T.prim_func
def concat_func_single_point(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
) -> None:
for i0 in range(28):
for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
if i1 > 63:
T_concat[i0, i1] = placeholder[i0, i1 - 64]
elif i1 == 63:
T_concat[i0, i1] = placeholder_1[i0, i1 - 63]
else:
T_concat[i0, i1] = placeholder_2[i0, i1]


@T.prim_func
def expected_partitioned_concat_single_point(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
for i1 in range(63):
placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0]
for i1 in range(64):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]


@T.prim_func
def concat_func_start_point_equality(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
if i1 == 0:
# Special case for i1 == 0
T_concat[i0, i1] = placeholder_1[i0, 0]
elif i1 < 64:
# Normal case for i1 in [1, 63]
T_concat[i0, i1] = placeholder_2[i0, i1]
else:
# Case for i1 in [64, 127]
T_concat[i0, i1] = placeholder[i0, i1 - 64]


@T.prim_func
def concat_func_start_point_equality_expected(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 128] = placeholder_1_1[i0]
for i1 in range(63):
placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 128 + i1 + 1] = placeholder_2_1[i0 * 63 + i1 + 1]
for i1 in range(64):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]


@T.prim_func
def concat_func_end_point_equality(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}):
if i1 == 127:
# Explicit equality check for the end point i1 == 127
T_concat[i0, i1] = placeholder_1[i0, 0]
elif i1 >= 64:
# Case for i1 in [64, 126]
T_concat[i0, i1] = placeholder[i0, i1 - 64]
else:
# Case for i1 in [0, 63]
T_concat[i0, i1] = placeholder_2[i0, i1]


@T.prim_func
def concat_func_end_point_equality_expected(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 63), "int8"),
T_concat: T.Buffer((28, 128), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data)
for i1 in range(64):
placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1]
for i1 in range(63):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1]
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 128 + 127] = placeholder_1_1[i0]


@T.prim_func
def concat_func_edge_equalities(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 1), "int8"),
T_concat: T.Buffer((28, 66), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(
66, annotations={"pragma_loop_partition_hint": 1}
): # Loop from 0 to 65 inclusive
if i1 == 0:
# Handle equality at the start of the range: i1 == 0
T_concat[i0, i1] = placeholder_2[i0, 0]
elif i1 == 65:
# Handle equality at the end of the range: i1 == 65
T_concat[i0, i1] = placeholder_1[i0, 0]
else:
# Copying from placeholder (from 0 to 63)
T_concat[i0, i1] = placeholder[i0, i1 - 1]


@T.prim_func
def concat_func_edge_equalities_expected(
placeholder: T.Buffer((28, 64), "int8"),
placeholder_1: T.Buffer((28, 1), "int8"),
placeholder_2: T.Buffer((28, 1), "int8"),
T_concat: T.Buffer((28, 66), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((1848,), "int8", data=T_concat.data)
placeholder_2_1 = T.Buffer((28,), "int8", data=placeholder_2.data)
T_concat_1[i0 * 66] = placeholder_2_1[i0]
for i1 in range(64):
placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data)
T_concat_1[i0 * 66 + i1 + 1] = placeholder_3[i0 * 64 + i1]
placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data)
T_concat_1[i0 * 66 + 65] = placeholder_1_1[i0]


@T.prim_func
def concat_five_buffers_with_equalities(
buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0
buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63
buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64
buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128
buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129
T_concat: T.Buffer((28, 129), "int8"),
) -> None:
for i0 in range(28):
for i1 in range(130, annotations={"pragma_loop_partition_hint": 1}):
if i1 == 0:
T_concat[i0, i1] = buffer_a[i0, 0]
elif i1 == 64:
T_concat[i0, i1] = buffer_c[i0, 0]
elif i1 == 129:
T_concat[i0, i1] = buffer_e[i0, 0]
elif i1 < 64:
T_concat[i0, i1] = buffer_b[i0, i1 - 1]
else: # i1 > 64 and i1 < 128
T_concat[i0, i1] = buffer_d[i0, i1 - 65]


@T.prim_func
def concat_five_buffers_with_equalities_expected(
buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0
buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63
buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64
buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128
buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129
T_concat: T.Buffer((28, 129), "int8"),
):
for i0 in range(28):
T_concat_1 = T.Buffer((3612,), "int8", data=T_concat.data)
buffer_a_1 = T.Buffer((28,), "int8", data=buffer_a.data)
T_concat_1[i0 * 129] = buffer_a_1[i0]
for i1 in range(63):
buffer_b_1 = T.Buffer((1764,), "int8", data=buffer_b.data)
T_concat_1[i0 * 129 + i1 + 1] = buffer_b_1[i0 * 63 + i1]
buffer_c_1 = T.Buffer((28,), "int8", data=buffer_c.data)
T_concat_1[i0 * 129 + 64] = buffer_c_1[i0]
for i1 in range(64):
buffer_d_1 = T.Buffer((1764,), "int8", data=buffer_d.data)
T_concat_1[i0 * 129 + i1 + 65] = buffer_d_1[i0 * 63 + i1]
buffer_e_1 = T.Buffer((28,), "int8", data=buffer_e.data)
T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]


@pytest.mark.parametrize(
"origin,expected",
[
(concat_func_single_point, expected_partitioned_concat_single_point),
(concat_func_start_point_equality, concat_func_start_point_equality_expected),
(concat_func_end_point_equality, concat_func_end_point_equality_expected),
(concat_func_edge_equalities, concat_func_edge_equalities_expected),
(concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected),
],
)
def test_single_point_partition(origin, expected):
origin = origin.with_attr({"global_symbol": "main"})
expected = expected.with_attr({"global_symbol": "main"})
mod = partition_from_scheduled_tir(
origin,
{
"tir.LoopPartition": {
"partition_const_loop": True,
"unroll_loop_with_partition_hint_no_interval": True,
}
},
)
assert tvm.ir.structural_equal(mod["main"], expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 870246a

Please sign in to comment.