Skip to content

Commit

Permalink
[LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (apache#…
Browse files Browse the repository at this point in the history
…12364)

* [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite

Vtcm allocations were being moved inside loops even if they were
originally allocated outside of the loops. Normally
PlanAndUpdateBufferAllocationLocation moves allocations as close to use
as possible and then StorageRewrite moves them back out as far as
possible. However, with Vtcm allocation,
PlanAndUpdateBufferAllocationLocation would move the Vtcm allocation
close to the compute, then LowerVtcm would convert the allocation to a
LetStmt. StorageRewrite would not move this LetStmt as it only handles
allocations. Moving LowerVtcmAlloc to after StorageRewrite ensures that
the vtcm allocations are in their final spot before converting them to a
LetStmt.

* fix issues with tagging and storage rewrite
  • Loading branch information
Tristan Konolige authored Aug 12, 2022
1 parent 779a7ad commit 3eb6734
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand All @@ -223,6 +222,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
if (!disable_storage_rewrite) {
pass_list.push_back(tir::transform::StorageRewrite());
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,10 @@ class StoragePlanRewriter : public StmtExprMutator {
};

// Checks whether the storage_scope is especially tagged for a specific memory.
// Special memory is all combined into a single allocation.
bool IsSpecialTaggedMemory(const StorageScope& scope) {
return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace";
return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" &&
scope.tag != ".vtcm";
}

// Alllocate entry of node.
Expand Down Expand Up @@ -655,8 +657,6 @@ class StoragePlanRewriter : public StmtExprMutator {

if (e->allocs.size() == 1) {
// simply use the original allocation.
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), e->allocs[0]->extents);
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
e->allocs[0]->condition, Evaluate(0));
if (IsSpecialTaggedMemory(e->scope)) {
Expand Down

0 comments on commit 3eb6734

Please sign in to comment.