Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Cherry-pick improvements to the vector distribution and update the
lambda interface. That will allow distributing loops with values coming
from above.
  • Loading branch information
ThomasRaoux authored Nov 2, 2022
1 parent 4221b2f commit 29bd703
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 36 deletions.
52 changes: 22 additions & 30 deletions compiler/src/iree/compiler/Codegen/Common/VectorReductionToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,52 +220,44 @@ class VectorReduceToGPUPass
llvm::dbgs() << "\n\n";
});

// 4. Distribute transfer write operations.
// 4. Distribute transfer write operations and propagate vector
// distribution.
{
auto distributionFn = [](vector::TransferWriteOp writeOp) {
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
int64_t vecRank = writeOp.getVectorType().getRank();
OpBuilder builder(writeOp.getContext());
auto map =
AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
return map;
};
RewritePatternSet patterns(ctx);
vector::populateDistributeTransferWriteOpPatterns(patterns,
distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After Step 4: Distribute transfer write ops ---\n";
funcOp.dump();
llvm::dbgs() << "\n\n";
});

// 5. Propagate vector distribution.
{
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(patterns);
auto getWarpSize = this->getWarpSize ? this->getWarpSize
: [](func::FuncOp) { return 32; };
auto groupReductionFn = [&](Location loc, OpBuilder &builder, Value input,
vector::CombiningKind kind, uint32_t size) {
return groupReduction(loc, builder, input, kind, size,
getWarpSize(funcOp));
};
auto distributionFn = [](Value val) {
AffineMap map = AffineMap::get(val.getContext());
auto vecType = val.getType().dyn_cast<VectorType>();
if (!vecType) return map;
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
int64_t vecRank = vecType.getRank();
OpBuilder builder(val.getContext());
map = AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
return map;
};
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(patterns,
distributionFn);
vector::populateDistributeReduction(patterns, groupReductionFn);
vector::populateDistributeTransferWriteOpPatterns(patterns,
distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After Step 5: Propagate distribution ---\n";
llvm::dbgs() << "\n--- After Step 4: Propagate distribution ---\n";
funcOp.dump();
llvm::dbgs() << "\n\n";
});

// 6. Lower the remaining WarpExecuteOnLane0 ops.
// 5. Lower the remaining WarpExecuteOnLane0 ops.
{
RewritePatternSet patterns(ctx);
vector::WarpExecuteOnLane0LoweringOptions options;
Expand All @@ -279,7 +271,7 @@ class VectorReduceToGPUPass
}

DEBUG_WITH_TYPE(DEBUG_TYPE, {
llvm::dbgs() << "\n--- After Step 6: Lower remaining ops ---\n";
llvm::dbgs() << "\n--- After Step 5: Lower remaining ops ---\n";
funcOp.dump();
llvm::dbgs() << "\n\n";
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,16 @@ static void populateMultiReductionLoweringPatterns(Operation *target,
patterns.add<InsertElementToBroadcast>(target->getContext(), benefit);
}

static AffineMap simpleDistributionFunction(vector::TransferWriteOp writeOp) {
static AffineMap simpleDistributionFunction(Value val) {
AffineMap map = AffineMap::get(val.getContext());
auto vecType = val.getType().dyn_cast<VectorType>();
if (!vecType) return map;
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
int64_t vecRank = writeOp.getVectorType().getRank();
OpBuilder builder(writeOp.getContext());
auto map = AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
int64_t vecRank = vecType.getRank();
OpBuilder builder(val.getContext());
map = AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
return map;
}

Expand All @@ -480,7 +483,8 @@ static void populatePropagateVectorDistribution(Operation *target,
RewritePatternSet &patterns,
PatternBenefit benefit) {
assert(target->hasTrait<OpTrait::IsIsolatedFromAbove>());
vector::populatePropagateWarpVectorDistributionPatterns(patterns, benefit);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, simpleDistributionFunction, benefit);
vector::populateDistributeReduction(patterns, warpReduction, benefit);
patterns.add<WarpOpLoad, HoistSharedMemoryAlloc>(target->getContext(),
benefit);
Expand Down
2 changes: 1 addition & 1 deletion third_party/llvm-project

0 comments on commit 29bd703

Please sign in to comment.