Skip to content

Commit

Permalink
[TIR] Fix plan buffer allocation location for loop carried dependenci…
Browse files Browse the repository at this point in the history
…es (apache#12757)

* Fix plan buffer allocation location for loop carried dependencies

* fix testcase region annotation issue

* fix typo in ut
  • Loading branch information
wrongtest-intellif authored Sep 26, 2022
1 parent 71f25b3 commit a61c1ad
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 15 deletions.
106 changes: 96 additions & 10 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,30 +99,113 @@ class LCADetector : public StmtExprVisitor {
}

ancestor_scopes_.push_back(current_scope);
loop_scope_map_.insert({op->loop_var.get(), current_scope});
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
loop_scope_map_.erase(op->loop_var.get());
}

void VisitStmt_(const BlockNode* op) final {
void VisitStmt_(const BlockRealizeNode* op) final {
const BlockNode* block = op->block.get();
int n = ancestor_scopes_.size();
for (const Buffer& buf : op->alloc_buffers) {
for (const Buffer& buf : block->alloc_buffers) {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, block, n);

ancestor_scopes_.push_back(current_scope);

// For each accessed buffer of the block, update the buffer's lca to
// the lowest inclusive stmt position, which should dominate all loops
// related to the accessed opaque block iter vars in buffer indices.
UpdateDominateScopeOfOpaqueIter(op);

// Update match_buffers
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
UpdateBufferLCA(match_buffer->source->buffer.get());
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
UpdateBufferLCA(match_buffer->source->buffer.get(), ancestor_scopes_.back());
match_buffers_.insert(match_buffer->buffer.get());
}

StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}

void UpdateDominateScopeOfOpaqueIter(const BlockRealizeNode* block_realize) {
// map opaque iter var to the scope which dominate all loop carried dependencies.
std::unordered_map<const VarNode*, const ScopeInfo*> itervar_to_dom_scope;

// function to collect `itervar_to_dom_scope`, the result scope for each block
// iter var should be above all loop scopes the opaque iter var binding relates to.
auto do_collect_itervar_scope = [this, &itervar_to_dom_scope](const IterVar& itervar,
const PrimExpr& binding) {
PostOrderVisit(binding, [this, &itervar_to_dom_scope, &itervar](const ObjectRef& obj) {
if (const VarNode* loop_var = obj.as<VarNode>()) {
auto it = loop_scope_map_.find(loop_var);
if (it == loop_scope_map_.end()) {
return;
}
const ScopeInfo* scope = it->second->parent_scope_info;
// find the highest loop scope the iter var binding has related to.
auto dom_scope_it = itervar_to_dom_scope.find(itervar->var.get());
if (dom_scope_it == itervar_to_dom_scope.end()) {
itervar_to_dom_scope.insert(dom_scope_it, {itervar->var.get(), scope});
} else if (scope->depth < dom_scope_it->second->depth) {
dom_scope_it->second = scope;
}
}
});
};

// function to update lca scope of the buffer with loop carried dependent buffer accesses.
// the result scope should be above all loop scopes the accessed opaque block iter vars
// relate to, which is record in `itervar_to_dom_scope`.
auto do_update = [this, &itervar_to_dom_scope](const BufferRegion& region) {
const Buffer& buffer = region->buffer;
const ScopeInfo* scope = ancestor_scopes_.back();

auto handle_itervar = [&itervar_to_dom_scope, &scope](const ObjectRef& obj) {
if (const VarNode* iter_var = obj.as<VarNode>()) {
auto dom_scope_it = itervar_to_dom_scope.find(iter_var);
if (dom_scope_it == itervar_to_dom_scope.end()) {
return;
}
// find the highest loop scope the accessed buffer index has
// loop carried dependencies to (via opaque iter var binding).
if (dom_scope_it->second->depth < scope->depth) {
scope = dom_scope_it->second;
}
}
};

// visit region min and max to find the lowest legal lca scope
for (const Range& range : region->region) {
PostOrderVisit(range->min, handle_itervar);
PostOrderVisit(range->min + range->extent - 1, handle_itervar);
}
UpdateBufferLCA(buffer.get(), scope);
};

// do collect and update
const Block& block = block_realize->block;
for (size_t i = 0; i < block_realize->iter_values.size(); ++i) {
const IterVar& iter_var = block->iter_vars[i];
if (iter_var->iter_type != IterVarType::kDataPar &&
iter_var->iter_type != IterVarType::kCommReduce) {
do_collect_itervar_scope(iter_var, block_realize->iter_values[i]);
}
}
if (!itervar_to_dom_scope.empty()) {
for (const auto& read : block->reads) {
do_update(read);
}
for (const auto& write : block->writes) {
do_update(write);
}
}
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const auto* iter = op->node.as<IterVarNode>();
Expand All @@ -136,17 +219,18 @@ class LCADetector : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) final {
UpdateBufferLCA(op->buffer.get());
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
UpdateBufferLCA(op->buffer.get());
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const BufferRealizeNode* op) final {
buffer_var_map_.emplace(op->buffer->data.get(), op->buffer.get());
UpdateBufferLCA(op->buffer.get(), ancestor_scopes_.back());
StmtExprVisitor::VisitStmt_(op);
}

Expand All @@ -165,16 +249,16 @@ class LCADetector : public StmtExprVisitor {
void VisitBufferVar(const VarNode* op) {
auto it = buffer_var_map_.find(op);
if (it != buffer_var_map_.end()) {
UpdateBufferLCA(it->second);
UpdateBufferLCA(it->second, ancestor_scopes_.back());
}
}

void UpdateBufferLCA(const BufferNode* buffer) {
void UpdateBufferLCA(const BufferNode* buffer, const ScopeInfo* scope) {
buffer_var_map_.emplace(buffer->data.get(), buffer);
if (match_buffers_.find(buffer) == match_buffers_.end()) {
// Ingore buffer created by block match_buffer
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
lca = LowestCommonAncestor(lca, scope);
}
}

Expand Down Expand Up @@ -229,6 +313,8 @@ class LCADetector : public StmtExprVisitor {
std::unordered_set<const BufferNode*> match_buffers_ = {};
/*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
std::vector<const ScopeInfo*> blockidx_scopes_ = {};
/*! \brief The map from loop var to the corresponding scope. */
std::unordered_map<const VarNode*, const ScopeInfo*> loop_scope_map_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te
from tvm.script import tir as T

Expand Down Expand Up @@ -242,9 +243,107 @@ def test_lower_te():
) # PlanAndUpdateBufferAllocationLocation should do nothing on TE


def test_loop_carried_dependency():
"""The buffer allocation should be above opaque iter var's loop scopes
such that buffer accesses with loop carried dependencies are covered."""

@T.prim_func
def before(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]):
C = T.alloc_buffer([8, 8, 8], dtype="int32")
for i in T.serial(8):
for j in T.serial(8):
for k in T.serial(8):
with T.block("b0"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] + 1
for k in T.serial(8):
with T.block("b1"):
vi, vk = T.axis.remap("SS", [i, k])
vj = T.axis.opaque(8, j)
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
)

@T.prim_func
def after(A: T.Buffer[(8, 8, 8), "int32"], B: T.Buffer[(8, 8, 8), "int32"]) -> None:
for i in T.serial(8):
with T.block():
T.reads(A[i, 0:8, 0:8])
T.writes(B[i, 0:8, 0:8])
C = T.alloc_buffer([8, 8, 8], dtype="int32")
for j in T.serial(8):
for k in T.serial(8):
with T.block("b0"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
C[vi, vj, vk] = A[vi, vj, vk] + 1
for k in T.serial(8):
with T.block("b1"):
vi, vk = T.axis.remap("SS", [i, k])
vj = T.axis.opaque(8, j)
B[vi, vj, vk] = C[vi, vj, vk] + T.if_then_else(
0 < vj, C[vi, vj - 1, vk], 0, dtype="int32"
)

_check(before, after)


def test_1D_cascade_op_rolling_buffer():
"""The intermediate buffer must be allocated above rolling buffer's rolling loop,
which is marked as opaque in consumer block's iter mappings."""

@T.prim_func
def before(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
B = T.alloc_buffer((4, 6), "int32")
for c in T.serial(4):
for i in T.serial(0, 2):
for j in T.serial(0, 6):
for k in T.serial(3):
with T.block("P1"):
T.where(i < 1 or j >= 2)
cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k])
if vk == 0:
B[cc, T.floormod(vi * 4 + vj, 6)] = 0
B[cc, T.floormod(vi * 4 + vj, 6)] = (
B[cc, T.floormod(vi * 4 + vj, 6)] + A[cc, vi * 4 + vj + vk]
)
for j in T.serial(0, 4):
for k in T.serial(3):
with T.block("P2"):
vi = T.axis.opaque(2, i)
cc, vj, vk = T.axis.remap("SSR", [c, j, k])
if vk == 0:
C[cc, vi * 4 + vj] = 0
C[cc, vi * 4 + vj] = (
C[cc, vi * 4 + vj] + B[cc, T.floormod(vi * 4 + vj + vk, 6)]
)

@T.prim_func
def after(A: T.Buffer[(4, 16), "int32"], C: T.Buffer[(4, 8), "int32"]):
for c in T.serial(4):
with T.block():
T.reads(A[c, 0:12], C[c, 0:8])
T.writes(C[c, 0:8])
B = T.alloc_buffer([4, 6], dtype="int32")
for i in T.serial(2):
for j, k in T.grid(6, 3):
with T.block("P1"):
T.where(i < 1 or j >= 2)
cc, vi, vj, vk = T.axis.remap("SSSR", [c, i, j, k])
if vk == 0:
B[cc, (vi * 4 + vj) % 6] = 0
B[cc, (vi * 4 + vj) % 6] = (
B[cc, (vi * 4 + vj) % 6] + A[cc, vi * 4 + vj + vk]
)
for j, k in T.grid(4, 3):
with T.block("P2"):
vi = T.axis.opaque(2, i)
cc, vj, vk = T.axis.remap("SSR", [c, j, k])
if vk == 0:
C[cc, vi * 4 + vj] = 0
C[cc, vi * 4 + vj] = C[cc, vi * 4 + vj] + B[cc, (vi * 4 + vj + vk) % 6]

_check(before, after)


if __name__ == "__main__":
test_elementwise()
test_locate_buffer_allocation()
test_match_buffer_allocation()
test_opaque_access()
test_lower_te()
tvm.testing.main()

0 comments on commit a61c1ad

Please sign in to comment.