Skip to content

Commit

Permalink
[DispatchCreation] CollapseDimensions patch (iree-org#18424)
Browse files Browse the repository at this point in the history
Fixes the case where parallel and reduction iterators (which are collapsable) are adjacent. They cannot be collapsed into each other in the producer because parallel and reduction dimensions are kept separate.
  • Loading branch information
IanWood1 authored Sep 30, 2024
1 parent a9c7ec1 commit f5dc573
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 46 deletions.
177 changes: 131 additions & 46 deletions compiler/src/iree/compiler/DispatchCreation/CollapseDimensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Utils.h"
Expand Down Expand Up @@ -280,10 +282,15 @@ class CollapseInfo {
// Debug print the current operation & reassociation indicies
void dump() const;

// Update `collapsableLoops` by taking the set intersection with
// `otherCollapsable` and update the reassociation indicies accordingly.
// Update CollapseInfo to ensure that all dimensions collapsable in `this` are
// also collapsable in `consumerInfo`. This means:
// 1. Any dimension not collapsable in `consumerInfo` should not be
// collapsable in `this`
// 2. For any pair of dimensions in `this`, if they are collapsable in
// `consumerInfo`, they must be collapsable into the same dimension in
// `consumerInfo` to be collapsable into the same dimension in `this`.
// Returns true if the operation modified the number of collapsable loops.
bool updateCollapseViaIntersect(const CollapsableLoopsSet &otherCollapsable);
bool updateFromConsumer(OpOperand *operand, const CollapseInfo &consumerInfo);

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
// reassociation indicies accordingly.
Expand All @@ -293,13 +300,18 @@ class CollapseInfo {
// Get `collapsableLoops` after applying the transformation provided by `map`.
// Note: doesn't modify `collapsableLoops`, the tranformation is applied to a
// copy.
FailureOr<CollapsableLoopsSet>
getTransformedCollapsableLoops(AffineMap map) const;
CollapsableLoopsSet getTransformedCollapsableLoops(AffineMap map) const;

// Clear internal data
void clear() {
// Get `reassociation` after applying the transformation provided by `map`.
SmallVector<ReassociationIndices>
getTransformedReassociation(AffineMap map) const;

// Clear internal data and returns if anything changed.
bool clear() {
bool isNotEmpty = reassociation.empty() || collapsableLoops.empty();
reassociation.clear();
collapsableLoops.clear();
return isNotEmpty;
}

const CollapsableLoopsSet &getCollapsibleLoops() const {
Expand Down Expand Up @@ -386,12 +398,8 @@ void CollapseInfo::updateReassociation() {
// map = affine_map<(d0, d1, d2) -> (d1, d2, d5)>
//
// Therefore, the collapsable loops with respect to the consumer is {1, 2, 5}.
FailureOr<CollapseInfo::CollapsableLoopsSet>
CollapseInfo::CollapsableLoopsSet
CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
if (!map) {
return failure();
}

CollapsableLoopsSet transformedLoops;
for (auto index : collapsableLoops) {
assert(index < map.getNumResults() && "index has no valid mapping");
Expand All @@ -405,19 +413,114 @@ CollapseInfo::getTransformedCollapsableLoops(AffineMap map) const {
return transformedLoops;
}

// Update `collapsableLoops` by taking the set intersection with
// `otherCollapsable` and update the reassociation indicies accordingly.
bool CollapseInfo::updateCollapseViaIntersect(
const CollapsableLoopsSet &otherCollapsable) {
CollapsableLoopsSet toRemove;
for (auto elem : collapsableLoops) {
if (!otherCollapsable.contains(elem)) {
toRemove.insert(elem);
SmallVector<ReassociationIndices>
CollapseInfo::getTransformedReassociation(AffineMap map) const {
SmallVector<ReassociationIndices> transformedReassociation(
reassociation.size());
for (const auto &[i, indicies] : llvm::enumerate(reassociation)) {
for (auto elem : indicies) {
auto dimExpr = dyn_cast<AffineDimExpr>(map.getResult(elem));
if (!dimExpr) {
break;
}
transformedReassociation[i].push_back(dimExpr.getPosition());
}
}
collapsableLoops.set_subtract(toRemove);
updateReassociation();
return toRemove.size();
return transformedReassociation;
}

bool CollapseInfo::updateFromConsumer(OpOperand *operand,
const CollapseInfo &consumerInfo) {
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(*operand);
if (failed(consumerToProducerMap)) {
return this->clear();
}

CollapsableLoopsSet consumerCollapsable =
consumerInfo.getTransformedCollapsableLoops(
consumerToProducerMap.value());

SmallVector<ReassociationIndices> consumerReassoc =
consumerInfo.getTransformedReassociation(consumerToProducerMap.value());

// Get a map from original index to the index it gets collapsed into
llvm::DenseMap<long, long> consumerCollapseMap;
for (const auto &[idx, indicies] : llvm::enumerate(consumerReassoc)) {
for (const auto elem : indicies) {
consumerCollapseMap[elem] = idx;
}
}

// Remove all collapsable loops in `producer` that are not collapsable in
// `consumer` (set intersect)
bool didChange = collapsableLoops.remove_if(
[&](long elem) -> bool { return !consumerCollapsable.contains(elem); });

// Now update the reassociation indicies given the updated `collapsableLoops`
// and `consumerCollapsableMap`.
// The idea is to reconstruct the reassociation indicies, and at each index:
// (1) If `index` IS NOT in `collapsableLoops`, split `indicies` and don't add
// `index` to either.
//
// (2) If `index` IS in `collapsableLoops` but `consumerCollapseMap` maps
// `index` to a different collapsed loop then the other indicies, split
// `indicies` and insert `index` into the new one.
//
// For example:
// producer reassociation = [[0, 1], [2, 3]]
// consumer reassociation = [0, 1, 2, 3]
// then, consumer reassociation gets updated to [[0, 1], [2, 3]] because
// [0, 1] and [2, 3] get collapsed into different loops
//
// (3) Otherwise, keep the index
constexpr long kUninitialized = -1;
SmallVector<ReassociationIndices> newReassociation;
for (ReassociationIndicesRef indicies : reassociation) {
// Track the loop index that `indicies` get collapsed into.
long collapseIntoIdx = kUninitialized;

// Holds dimensions that should be collapsed together
ReassociationIndices newIndicies;
for (int64_t index : indicies) {
if (!collapsableLoops.contains(index)) {
// (1) Because `index` isn't collapsable, the indicies in `newIndicies`
// are no longer adjacent to the upcoming indicies. If there is >1 index
// to collapse, add it to the new reassociation. Otherwise, discard it
// because there is no dimension to collapse with.
didChange = true;
if (newIndicies.size() > 1) {
newReassociation.push_back(std::move(newIndicies));
}
newIndicies.clear();
collapseIntoIdx = kUninitialized;
} else if (collapseIntoIdx == kUninitialized) {
// (2) First occurance of collapsable loop, set collapseIntoIdx.
collapseIntoIdx = consumerCollapseMap.at(index);
newIndicies.push_back(index);
} else if (consumerCollapseMap.at(index) != collapseIntoIdx) {
// (3) `index` is collapsable but not collapsable into the other loops.
// So, split them and look for other loops to collapse `index` into.
didChange = true;
if (newIndicies.size() > 1) {
newReassociation.push_back(std::move(newIndicies));
}
newIndicies.clear();
collapseIntoIdx = consumerCollapseMap[index];
newIndicies.push_back(index);
} else {
// (4) `index` is collapsable and can be collapsed into
// `collapseIntoIndex`.
newIndicies.push_back(index);
}
}

if (newIndicies.size() > 1) {
newReassociation.push_back(newIndicies);
}
}
reassociation = std::move(newReassociation);
return didChange;
}

// Update `collapsableLoops` by subtracting `uncollapsable` and update the
Expand Down Expand Up @@ -679,12 +782,10 @@ static bool updateConsumersFromProducers(
continue;
}

CollapseInfo &producerInfo = opMap.find(producerOp)->second;
FailureOr<CollapseInfo::CollapsableLoopsSet> producerCollapsable =
const CollapseInfo &producerInfo = opMap.at(producerOp);
CollapseInfo::CollapsableLoopsSet producerCollapsable =
producerInfo.getTransformedCollapsableLoops(mapping.value());
if (!failed(producerCollapsable)) {
producerUncollapsable.set_subtract(producerCollapsable.value());
}
producerUncollapsable.set_subtract(producerCollapsable);

didChange |=
consumerInfo.updateCollapseViaSubtract(producerUncollapsable);
Expand All @@ -707,7 +808,7 @@ static bool updateProducersFromConsumers(
for (auto op : llvm::reverse(slice)) {
auto genericConsumer = cast<linalg::GenericOp>(op);
assert(opMap.contains(genericConsumer));
const CollapseInfo &consumerInfo = opMap.find(genericConsumer)->second;
const CollapseInfo &consumerInfo = opMap.at(genericConsumer);

for (auto operand : genericConsumer.getDpsInputOperands()) {
auto definingOp = operand->get().getDefiningOp();
Expand All @@ -721,26 +822,10 @@ static bool updateProducersFromConsumers(

// Get a mapping from the consumer's iteration space to the producer's.
CollapseInfo &producerInfo = opMap.find(genericProducer)->second;
FailureOr<AffineMap> consumerToProducerMap =
getConsumerLoopToProducerLoopsMap(*operand);
if (failed(consumerToProducerMap)) {
didChange |= !producerInfo.getCollapsibleLoops().empty();
producerInfo.clear();
continue;
}

// Use the map to get the consumer's collapsable loops in terms of the
// producer.
auto consumerCollapsable = consumerInfo.getTransformedCollapsableLoops(
consumerToProducerMap.value());
if (failed(consumerCollapsable)) {
producerInfo.clear();
continue;
}
// Only loops collapsable in both the consumer and producer may be
// collapsed.
didChange |=
producerInfo.updateCollapseViaIntersect(consumerCollapsable.value());
didChange |= producerInfo.updateFromConsumer(operand, consumerInfo);
}
}
return didChange;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,44 @@ util.func public @propagate_uncollapsable(%arg0: tensor<2x320x128x128xf32>) -> t
// CHECK-SAME: ins(%[[VAL2]], %[[VAL1]] : tensor<2x320x128x128xf32>, tensor<2x320x128x128xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<2x320x128x128xf32>)
// CHECK: flow.return %[[VAL3]]

// -----

util.func public @dequant_contraction(%arg0: tensor<2x32xf32>, %arg1: tensor<2x32x10x16384xf16>) -> tensor<2x32xf32> {
%0 = flow.dispatch.region -> (tensor<2x32xf32>) {
%1 = tensor.empty() : tensor<2x32xf32>
%cst = arith.constant 0.000000e+00 : f32
%2 = tensor.empty() : tensor<2x32x10x16384xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<2x32x10x16384xf16>) outs(%2 : tensor<2x32x10x16384xf32>) {
^bb0(%in: f16, %out: f32):
%6 = arith.extf %in : f16 to f32
linalg.yield %6 : f32
} -> tensor<2x32x10x16384xf32>
%4 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32>
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%3, %arg0 : tensor<2x32x10x16384xf32>, tensor<2x32xf32>) outs(%4 : tensor<2x32xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%6 = arith.subf %in, %in_0 : f32
%7 = arith.mulf %6, %6 : f32
%8 = arith.addf %7, %out : f32
linalg.yield %8 : f32
} -> tensor<2x32xf32>
flow.return %5 : tensor<2x32xf32>
}
util.return %0 : tensor<2x32xf32>
}

// CHECK-LABEL: util.func public @dequant_contraction
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x32xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<2x32x10x16384xf16>
// CHECK-DAG: %[[COLLAPSED_ARG0:.+]] = tensor.collapse_shape %[[ARG0]]
// CHECK-DAG: %[[COLLAPSED_ARG1:.+]] = tensor.collapse_shape %[[ARG1]]
// CHECK: flow.dispatch.region
// CHECK: %[[VAL0:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[COLLAPSED_ARG1]] : tensor<64x163840xf16>)
// CHECK-SAME: outs(%{{.*}} : tensor<64x163840xf32>)
// CHECK: %[[VAL1:.*]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "reduction"]
// CHECK-SAME: ins(%[[VAL0]], %[[COLLAPSED_ARG0]] : tensor<64x163840xf32>, tensor<64xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<64xf32>)
// CHECK: flow.return %[[VAL1]]

0 comments on commit f5dc573

Please sign in to comment.