Skip to content

Commit

Permalink
[Codegen] Add consumer fusion (iree-org#18427)
Browse files Browse the repository at this point in the history
Adds consumer fusion to tile and distribute using for all pass.
  • Loading branch information
pashu123 authored Sep 11, 2024
1 parent 60843ec commit e2464dd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "tile-and-distribute-to-workgroups-using-forall-op"

namespace mlir::iree_compiler {

#define CEILDIV(a, b) ((a + b - 1) / b)
Expand Down Expand Up @@ -254,6 +256,50 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
// Pass implementation.
//===---------------------------------------------------------------------===//

static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {

auto addCandidateSlices =
[](Operation *fusedOp,
std::queue<tensor::ParallelInsertSliceOp> &candidates) {
for (auto *userOp : fusedOp->getResults().getUsers()) {
if (auto sliceOp =
llvm::dyn_cast<tensor::ParallelInsertSliceOp>(userOp)) {
candidates.push(sliceOp);
}
}
};

// Collect the candidate slices which can be potential consumers that can be
// fused.
std::queue<tensor::ParallelInsertSliceOp> candidates;
addCandidateSlices(tiledOp, candidates);

while (!candidates.empty()) {

// Traverse the slices in BFS fashion.
tensor::ParallelInsertSliceOp candidateSliceOp = candidates.front();
candidates.pop();

FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
if (failed(fusedResult)) {
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
<< candidateSliceOp << "\n");
continue;
}

// Replace the original consumer operation with the tiled implementation.
rewriter.replaceOp(fusedResult->origConsumerOperand->getOwner(),
fusedResult->tiledOps.front());

// The result of the fused conumers might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
candidates);
}
}

void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
auto funcOp = getOperation();
auto *context = &getContext();
Expand Down Expand Up @@ -292,6 +338,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {

// If the `tilableOp` is a `memref` op, then just tile the operation.
SmallVector<LoopLikeOpInterface> tilingLoops;
Operation *rootTiledOp = nullptr;
if (tilableOp->getNumResults() == 0) {
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCF(rewriter, tilableOp, tilingOptions);
Expand All @@ -313,6 +360,7 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
rewriter.replaceAllUsesWith(origValue, replacement);
}
std::swap(tileAndFuseResult->loops, tilingLoops);
rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
}
if (!tilingLoops.empty()) {
if (tilingLoops.size() != 1 || !isa<scf::ForallOp>(tilingLoops[0])) {
Expand All @@ -326,6 +374,10 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
forallOp.emitOpError("failed to drop unit dimensions");
return signalPassFailure();
}

if (rootTiledOp) {
fuseConsumers(rewriter, rootTiledOp);
}
}

// Cleanup patterns for tile and distribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,82 @@ func.func @generate_no_distribution(%arg0 : tensor<16xf16>) -> tensor<16xf16> {
}
// CHECK-LABEL: func @generate_no_distribution(
// CHECK-NOT: scf.forall

// -----

func.func @matmul_consumer_fusion_test(%arg0 : tensor<?x?xf16>,
%arg1 : tensor<?x?xf16>, %arg2: tensor<?xf16>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%cst0 = arith.constant 0.0 : f32
%M = tensor.dim %arg0, %c0 : tensor<?x?xf16>
%N = tensor.dim %arg1, %c1 : tensor<?x?xf16>
%K = tensor.dim %arg0, %c1 : tensor<?x?xf16>
%empty_lhs = tensor.empty(%M, %K) : tensor<?x?xf32>
%extf_lhs = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf16>) outs(%empty_lhs : tensor<?x?xf32>) {
^bb0(%b0 : f16, %b1 : f32) :
%0 = arith.extf %b0 : f16 to f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
%empty_rhs = tensor.empty(%K, %N) : tensor<?x?xf32>
%extf_rhs = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : tensor<?x?xf16>) outs(%empty_rhs : tensor<?x?xf32>) {
^bb0(%b0 : f16, %b1 : f32) :
%0 = arith.extf %b0 : f16 to f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
%empty = tensor.empty(%M, %N) : tensor<?x?xf32>
%fill = linalg.fill ins(%cst0 : f32) outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
%matmul = linalg.matmul
{lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 64]]>}
ins(%extf_lhs, %extf_rhs : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%empty_biasadd = tensor.empty(%M, %N) : tensor<?x?xf32>
%bias_add = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%matmul,%arg2 : tensor<?x?xf32>, tensor<?xf16>) outs(%empty_biasadd : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1: f16, %b2 : f32) :
%0 = arith.extf %b1 : f16 to f32
%1 = arith.addf %b0, %0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
%empty_relu = tensor.empty(%M, %N) : tensor<?x?xf32>
%relu = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%bias_add : tensor<?x?xf32>) outs(%empty_relu : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32) :
%0 = arith.maximumf %b0, %cst0 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
return %relu : tensor<?x?xf32>
}
// CHECK-LABEL: func @matmul_consumer_fusion_test(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf16>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf16>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?xf16>
// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
// CHECK: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
// CHECK: %[[LHS:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LHS_SLICE]] :
// CHECK: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
// CHECK: %[[RHS:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RHS_SLICE]] :
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: %[[BIASADD:.+]] = linalg.generic
// CHECK-SAME: ins(%[[MATMUL]]
// CHECK: %[[RELU:.+]] = linalg.generic
// CHECK-SAME: ins(%[[BIASADD]] :
// CHECK: scf.forall.in_parallel
// CHECK: tensor.parallel_insert_slice %[[RELU]]
// CHECK: return %[[RESULT]]

0 comments on commit e2464dd

Please sign in to comment.