Skip to content

Commit

Permalink
[DispatchCreation] Extend multi-use producer fusion (iree-org#18551)
Browse files Browse the repository at this point in the history
Fuse even in cases where the most dominant op isn't fusable, but other operations would be legal to fuse. Do this by moving the fusable consumer and all transitive defs before all other consumers (if legal).

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Oct 16, 2024
1 parent c6056d1 commit 206b60c
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 74 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ jobs:
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-unet 1545 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
Expand All @@ -241,7 +241,7 @@ jobs:
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-unet 1545 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
Expand Down Expand Up @@ -107,25 +108,6 @@ static bool isEmptyFillContractionDAGRootOp(
return true;
}

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the `slice` of the operation does not contain any op
/// from the group.
static bool isHorizontalToGroup(Operation *op,
const llvm::SetVector<Operation *> &currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

/// Get user of operation that is a truncate operation.
static std::optional<linalg::GenericOp>
getTruncateOp(Operation *op,
Expand All @@ -149,8 +131,8 @@ getTruncateOp(Operation *op,
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
return std::nullopt;
}
if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
seedTruncateOp.value())) {
if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
dominanceInfo, seedTruncateOp.value())) {
return std::nullopt;
}
}
Expand Down Expand Up @@ -226,7 +208,8 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
seedOp)) {
return false;
}
return true;
Expand Down Expand Up @@ -346,40 +329,6 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
}

/// During horizontal fusion, there might be operands of the fused operations
/// whose definitions are interspersed between the fused operations. For groups
/// chosen to fuse horizontally, such operations can be moved before the
/// seed contraction operation (where the fused operation is generated).
template <typename T>
static LogicalResult
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
Operation *insertionPoint, DominanceInfo &dominanceInfo,
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
BackwardSliceOptions options;
llvm::DenseSet<Operation *> ignoreOperationsSet;
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
}

mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

/// On finding this pattern
/// ```
/// %0 = linalg.matmul ins(%arg0, %arg1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -45,25 +49,55 @@ static llvm::cl::opt<int64_t> clLinalgMaxConstantFoldElements(
llvm::cl::desc("Maximum number of elements to try to constant fold."),
llvm::cl::init(0));

static Operation *getMostDominantUse(Operation *op,
const DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
auto it = llvm::find_if(uses, [&](OpOperand &source) {
Operation *sourceOp = source.getOwner();

return llvm::all_of(uses, [&](OpOperand &target) {
Operation *targetOp = target.getOwner();
return dominanceInfo.dominates(sourceOp, targetOp);
});
});
if (it != uses.end()) {
return it->getOwner();
}
return nullptr;
}

/// Check if any of the use dominates all other uses of the operation.
static std::optional<OpOperand *> getFusableUse(Operation *op,
DominanceInfo &dominanceInfo) {
static Operation *getFusableUse(Operation *op,
const DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
Operation *fusableUse = nullptr;
for (OpOperand &source : uses) {
Operation *sourceOp = source.getOwner();
bool dominatesAllUsers = true;
for (OpOperand &target : uses) {

bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
Operation *targetOp = target.getOwner();
if (!dominanceInfo.dominates(sourceOp, targetOp)) {
dominatesAllUsers = false;
break;
}
}
if (dominatesAllUsers) {
return &source;
return !isa<linalg::GenericOp>(targetOp) ||
dominanceInfo.dominates(sourceOp, targetOp);
});
if (dominatesAllFusableOps) {
fusableUse = sourceOp;
break;
}
}
return std::nullopt;
Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
if (!fusableUse || !mostDominantOp) {
return nullptr;
}

// If `fusableUse` dominates all other users, there's nothing else to do.
if (fusableUse == mostDominantOp) {
return fusableUse;
}

SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
? fusableUse
: nullptr;
}

static OpOperand *getFirstUseInConsumer(Operation *producer,
Expand Down Expand Up @@ -91,6 +125,7 @@ static SmallVector<OpOperand *> getAllUsesInConsumer(Operation *producer,
/// using elementwise fusion.
static LogicalResult doMultiUseFusion(Operation *rootOp,
llvm::SetVector<Operation *> &fusableOps,
const DominanceInfo &dominanceInfo,
RewriterBase &rewriter) {
assert(rootOp && "root op cant be null");

Expand All @@ -112,11 +147,20 @@ static LogicalResult doMultiUseFusion(Operation *rootOp,
Operation *consumerOp = rootOp;
OpBuilder::InsertionGuard g(rewriter);
for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
// Fuse all uses from producer -> consumer. It has been checked
// before that all uses are fusable.
while (OpOperand *fusedOperand =
getFirstUseInConsumer(producerOp, consumerOp)) {
rewriter.setInsertionPoint(consumerOp);

if (consumerOp != mostDominantUser &&
failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
mostDominantUser, dominanceInfo))) {
return rewriter.notifyMatchFailure(consumerOp,
"failed to move operand defs");
}
rewriter.moveOpBefore(consumerOp, mostDominantUser);
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
linalg::fuseElementwiseOps(rewriter, fusedOperand);
if (failed(fusionResult)) {
Expand Down Expand Up @@ -190,9 +234,8 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
}

// 6. Check that the `genericOp` dominates all uses of `producer`.
std::optional<OpOperand *> fusableUse =
getFusableUse(producer, dominanceInfo);
if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
Operation *fusableUse = getFusableUse(producer, dominanceInfo);
if (!fusableUse || fusableUse != genericOp) {
continue;
}

Expand Down Expand Up @@ -232,7 +275,8 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,

IRRewriter rewriter(context);
for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
if (failed(
doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
return funcOp->emitOpError("failed multi use fusion");
}
}
Expand Down
33 changes: 33 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler::DispatchCreation {

Expand Down Expand Up @@ -97,4 +101,33 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
return true;
}

bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
assert(dominanceInfo.properlyDominates(seedOp, op) &&
op->getParentRegion() == seedOp->getParentRegion());
BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);

// `getBackwardSlice` doesnt track uses from within an ops region, so make
// sure there are no values defined above.
for (Operation *sliceOp : slice) {
bool usesValuesFromAbove = false;
mlir::visitUsedValuesDefinedAbove(
sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; });
if (usesValuesFromAbove) {
return false;
}
}

return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

} // namespace mlir::iree_compiler::DispatchCreation
44 changes: 44 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"

namespace mlir::iree_compiler::DispatchCreation {
Expand All @@ -19,4 +23,44 @@ namespace mlir::iree_compiler::DispatchCreation {
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool fuseMultiReduction);

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the program slice of the operation (from op back to seedOp)
/// does not contain any op from the group.
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo, Operation *seedOp);

/// Moves the operands and transitive defs for each op in `operations` directly
/// after `insertionPoint`. Note: this does not check if it is legal to move the
/// operands.
template <typename T>
static LogicalResult
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
Operation *insertionPoint, const DominanceInfo &dominanceInfo,
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
BackwardSliceOptions options;
llvm::DenseSet<Operation *> ignoreOperationsSet;
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
assert(insertionPoint->getBlock() == op->getBlock());
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
}

mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

} // namespace mlir::iree_compiler::DispatchCreation
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,28 @@ util.func public @math_sin() {
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0,
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1,

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 3.000000e+00 : f32
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%8 = arith.addf %arg2, %cst : f32
linalg.yield %8 : f32
} -> tensor<5x5xf32>
// expected-note @below {{prior use here}}
%collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%8 = arith.subf %arg2, %cst_0 : f32
linalg.yield %8 : f32
} -> tensor<5x5xf32>
util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32>
}
// CHECK-LABEL: util.func public @fuse_by_moving_consumer
// CHECK: linalg.generic
// CHECK-NOT: linalg.generic

0 comments on commit 206b60c

Please sign in to comment.