Skip to content

Commit 3eb6734

Browse files
author
Tristan Konolige
authored
[LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (apache#12364)
* [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite Vtcm allocations were being moved inside loops even if they were originally allocated outside of the loops. Normally PlanAndUpdateBufferAllocationLocation moves allocations as close to use as possible and then StorageRewrite moves them back out as far as possible. However, with Vtcm allocation, PlanAndUpdateBufferAllocationLocation would move the Vtcm allocation close to the compute, then LowerVtcm would convert the allocation to a LetStmt. StorageRewrite would not move this LetStmt as it only handles allocations. Moving LowerVtcmAlloc to after StorageRewrite ensures that the vtcm allocations are in their final spot before converting them to a LetStmt. * fix issues with tagging and storage rewrite
1 parent 779a7ad commit 3eb6734

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/driver/driver_api.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
204204
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
205205
pass_list.push_back(tir::transform::LowerOpaqueBlock());
206206
pass_list.push_back(tir::transform::FlattenBuffer());
207-
pass_list.push_back(tir::transform::LowerVtcmAlloc());
208207
pass_list.push_back(tir::transform::BF16Legalize());
209208
pass_list.push_back(tir::transform::NarrowDataType(32));
210209
pass_list.push_back(tir::transform::Simplify());
@@ -223,6 +222,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
223222
if (!disable_storage_rewrite) {
224223
pass_list.push_back(tir::transform::StorageRewrite());
225224
}
225+
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
226+
pass_list.push_back(tir::transform::LowerVtcmAlloc());
226227
pass_list.push_back(tir::transform::UnrollLoop());
227228

228229
// Add user-defined phase-2 passes

src/tir/transforms/storage_rewrite.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,10 @@ class StoragePlanRewriter : public StmtExprMutator {
583583
};
584584

585585
// Checks whether the storage_scope is especially tagged for a specific memory.
586+
// Special memory is all combined into a single allocation.
586587
bool IsSpecialTaggedMemory(const StorageScope& scope) {
587-
return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace";
588+
return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" &&
589+
scope.tag != ".vtcm";
588590
}
589591

590592
// Alllocate entry of node.
@@ -655,8 +657,6 @@ class StoragePlanRewriter : public StmtExprMutator {
655657

656658
if (e->allocs.size() == 1) {
657659
// simply use the original allocation.
658-
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
659-
make_const(DataType::Int(32), 1), e->allocs[0]->extents);
660660
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
661661
e->allocs[0]->condition, Evaluate(0));
662662
if (IsSpecialTaggedMemory(e->scope)) {

0 commit comments

Comments
 (0)