Skip to content

Commit

Permalink
Plumb dynamic shape support through for Vulkan and VMVX (iree-org#6917)
Browse files Browse the repository at this point in the history
This commit plumbs dynamic shape support through for both Vulkan and
VMVX. They rely on 1-D MemRef and running `FlattenMemRefSubspanPass`
in advance, instead of MemRef descriptors.

In order to enable dynamic shape support, we need to carry the
SSA values for dynamic dimensions down the CodeGen pipeline so that
we can linearize the index calculation in `FlattenMemRefSubspanPass`.
We have such information tightly associated with various ops at
the Flow level, but when outlining executables and materializing
HAL interface, the association is broken down. Instead, `tie_shape`
ops are used to carry such information. It's structurally difficult
to maintain and convert.

So this commit changes the `hal.interface.binding.subspan` to carry
the dynamic dimension SSA values by itself, like many other ops
in Flow/HAL. It's a natural change that simplifies lots of analysis
and transformation. For example, we don't need to maintain the two
step conversion at CPU side (first generating an undefined MemRef
descriptor when handling the `subspan` op, and then filling its
content when handling the `tie_shape` op). It also makes the
intervals of HAL more akin to Flow on this front.

Other changes are mostly natural based on that:

* `MaterializeInterfaces` picks up the information from `tie_shape`
  ops and attaches them to `subspan` ops.
* `FlattenBindingSubspan` reads the dynamic dimensions to perform
  index linearization.
* `ConvertToLLVM` now generates the full MemRef descriptor from
  `subspan` ops.
* A new pass is added to fold `memref.dim`/`tensor.dim` ops over
  shape carrying ops.

This puts IREE CodeGen dynamic shape support for Vulkan/VMvX in
a very nice state:

Because we run `FoldSubViewOpsPass` in advance, there won't be
intermediate MemRefs (coming from `subview` ops). So load/stores
directly take in HAL `subspan` ops. By definition in IREE we have
tightly packed buffers so all MemRefs coming from subspans should
have strides of the total element count in inner dimensions. So symbolic
strides in subspan ops' AffineMaps correspond to SSA values for
dimension sizes (or their products). Offsets are attached to subspan
ops as SSA values, but then they are "transferred" to load/store ops
during memref flattening, by being part of the index linearization
calculation.
  • Loading branch information
antiagainst authored Sep 2, 2021
1 parent 64e5225 commit bfd507f
Show file tree
Hide file tree
Showing 54 changed files with 895 additions and 425 deletions.
1 change: 0 additions & 1 deletion iree/compiler/Codegen/Common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ cc_library(
"LinalgBufferizePass.cpp",
"OptimizeVectorTransferPass.cpp",
"SetNumWorkgroupsPass.cpp",
"ShapeToLLVMConversion.cpp",
"VectorizeConv.cpp",
"VectorizeMMT4d.cpp",
],
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ iree_cc_library(
"LinalgBufferizePass.cpp"
"OptimizeVectorTransferPass.cpp"
"SetNumWorkgroupsPass.cpp"
"ShapeToLLVMConversion.cpp"
"VectorizeConv.cpp"
"VectorizeMMT4d.cpp"
DEPS
Expand Down
20 changes: 15 additions & 5 deletions iree/compiler/Codegen/Common/CleanupBufferAllocViewPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -51,6 +52,13 @@ struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> {

LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
// TODO(antigainst): enable dynamic shape support once they are needed.
auto reshapeSrcType = reshapeOp.src().getType().template cast<ShapedType>();
auto reshapeDstType = reshapeOp.getType().template cast<ShapedType>();
if (!reshapeSrcType.hasStaticShape() || !reshapeDstType.hasStaticShape()) {
return failure();
}

auto loadOp =
reshapeOp.src()
.template getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
Expand All @@ -66,16 +74,18 @@ struct FoldReshapeIntoInterfaceTensorLoad : OpRewritePattern<TensorReshapeOp> {
loadOp.source()
.template getDefiningOp<IREE::HAL::InterfaceBindingSubspanOp>();
if (!subspanOp) return failure();
assert(subspanOp.dynamic_dims().empty());

auto tensorAccess = subspanOp.getType()
.template cast<IREE::Flow::DispatchTensorType>()
.getAccess();
auto newSubspanType = IREE::Flow::DispatchTensorType::get(
subspanOp.getType()
.template cast<IREE::Flow::DispatchTensorType>()
.getAccess(),
reshapeOp.getResultType());
tensorAccess, reshapeOp.getResultType());

Value newSubspanOp = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp.getLoc(), newSubspanType, subspanOp.binding(),
subspanOp.byte_offset(), subspanOp.byte_length());
subspanOp.byte_offset(), subspanOp.byte_length(),
subspanOp.dynamic_dims());

rewriter.replaceOpWithNewOp<IREE::Flow::DispatchTensorLoadOp>(
reshapeOp, reshapeOp.getResultType(), newSubspanOp);
Expand Down
107 changes: 76 additions & 31 deletions iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,15 @@
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand All @@ -51,10 +55,13 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-flatten-memref-subspan"

namespace mlir {
namespace iree_compiler {

namespace {

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Expand All @@ -78,11 +85,6 @@ struct FlattenMemRefTypeConverter final : public TypeConverter {
// 1-D MemRef types are okay.
if (isRankZeroOrOneMemRef(type)) return type;

// We can only handle static strides and offsets for now; fail the rest.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(type, strides, offset))) return Type();

// Convert to a MemRef with unknown dimension. This is actually more akin
// to how IREE uses memref types: they are for representing a view from a
// byte buffer with potentially unknown total size, as transformation
Expand Down Expand Up @@ -168,39 +170,83 @@ struct FlattenBindingSubspan final
// layout maps.
if (!oldType || !oldType.getAffineMaps().empty()) return failure();

Value dynamicDim;
if (oldType.hasStaticShape()) {
// Because we always convert to 1-D dynamic memref, we still need to
// provide a "dynamic" dimension SSA value even if the old type is
// fully static.
dynamicDim = rewriter.create<ConstantIndexOp>(subspanOp.getLoc(),
oldType.getNumElements());
} else {
ArrayRef<int64_t> oldShape = oldType.getShape();
MLIRContext *context = rewriter.getContext();
Location loc = subspanOp.getLoc();

int dynamicDimIndex = 0;
SmallVector<Value, 4> dims;
AffineExpr sizeExpr = getAffineConstantExpr(1, context);
for (int i = 0; i < oldType.getRank(); ++i) {
sizeExpr = sizeExpr * getAffineSymbolExpr(i, context);
if (ShapedType::isDynamic(oldShape[i])) {
dims.push_back(subspanOp.dynamic_dims()[dynamicDimIndex++]);
} else {
dims.push_back(rewriter.create<ConstantIndexOp>(loc, oldShape[i]));
}
}
dynamicDim = makeComposedAffineApply(rewriter, loc, sizeExpr, dims);
}

Type newType = getTypeConverter()->convertType(oldType);

rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceBindingSubspanOp>(
subspanOp, newType, subspanOp.binding(), subspanOp.byte_offset(),
subspanOp.byte_length());
subspanOp.byte_length(), dynamicDim);
return success();
}
};

/// Generates IR to perform index linearization with the given `indices`
/// indexing into the given memref `sourceType`.
static Value linearizeIndices(MemRefType sourceType, ValueRange indices,
/// indexing into the given memref `sourceValue`.
static Value linearizeIndices(Value sourceValue, ValueRange indices,
Location loc, OpBuilder &builder) {
MemRefType sourceType = sourceValue.getType().cast<MemRefType>();
assert(sourceType.hasRank() && sourceType.getRank() != 0);
int64_t rank = sourceType.getRank();

// First try to get the strides from the MemRef type itself. This applies to
// cases where we have static shapes and only the leading dimension is
// dynamic.
if (AffineMap linearLayoutMap = getStridedLinearLayoutMap(sourceType)) {
// Dynamic strides/offset will create symbols. There should be none for the
// static case.
if (linearLayoutMap.getNumSymbols() == 0) {
return makeComposedAffineApply(builder, loc, linearLayoutMap, indices);
}
}

int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(sourceType, strides, offset)) ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
offset == MemRefType::getDynamicStrideOrOffset()) {
return nullptr;
// Then try to see if the source op carries the dynamic dimensions itself.
// If so we can still get the strides for dimensions to linearize.
Operation *sourceOp = sourceValue.getDefiningOp();
auto shapeCarryOp = dyn_cast<ShapeCarryingInterface>(sourceOp);
if (!shapeCarryOp) return nullptr;

Value shapeOp =
shapeCarryOp.buildResultValueRankedShape(sourceValue, builder);
SmallVector<Value, 4> dims;
dims.reserve(rank);
for (int i = 0; i < rank; ++i) {
dims.push_back(builder.create<Shape::RankedDimOp>(loc, shapeOp, i));
}

AffineExpr sym0, sym1, sym2;
bindSymbols(builder.getContext(), sym0, sym1, sym2);
MLIRContext *context = builder.getContext();
auto mulAddMap = AffineMap::get(0, 3, {sym0 * sym1 + sym2}, context);

Value linearIndex = builder.create<ConstantIndexOp>(loc, offset);
for (auto pair : llvm::zip(indices, strides)) {
Value stride = builder.create<ConstantIndexOp>(loc, std::get<1>(pair));
Value linearIndex = indices.front();
for (int i = 1; i < indices.size(); ++i) {
linearIndex = builder.create<AffineApplyOp>(
loc, mulAddMap, ValueRange{std::get<0>(pair), stride, linearIndex});
loc, mulAddMap, ValueRange{linearIndex, dims[i], indices[i]});
}
return linearIndex;
}
Expand All @@ -218,8 +264,8 @@ struct LinearizeLoadIndices final : public OpConversionPattern<memref::LoadOp> {
loadOp, "expected converted memref of rank <= 1");
}

Value linearIndex = linearizeIndices(
loadOp.getMemRefType(), loadOp.getIndices(), loadOp.getLoc(), rewriter);
Value linearIndex = linearizeIndices(loadOp.memref(), loadOp.getIndices(),
loadOp.getLoc(), rewriter);
if (!linearIndex) {
return loadOp.emitOpError() << "failed to linearize index";
}
Expand All @@ -244,9 +290,8 @@ struct LinearizeStoreIndices final
storeOp, "expected converted memref of rank <= 1");
}

Value linearIndex =
linearizeIndices(storeOp.getMemRefType(), storeOp.getIndices(),
storeOp.getLoc(), rewriter);
Value linearIndex = linearizeIndices(storeOp.memref(), storeOp.getIndices(),
storeOp.getLoc(), rewriter);
if (!linearIndex) {
return storeOp.emitOpError() << "failed to linearize index";
}
Expand Down Expand Up @@ -275,9 +320,9 @@ struct LinearizeTransferReadIndices final
return rewriter.notifyMatchFailure(
transferReadOp, "expected converted memref of rank <= 1");
}
Value linearIndex = linearizeIndices(
transferReadOp.getShapedType().cast<MemRefType>(),
transferReadOp.indices(), transferReadOp.getLoc(), rewriter);
Value linearIndex =
linearizeIndices(transferReadOp.source(), transferReadOp.indices(),
transferReadOp.getLoc(), rewriter);
if (!linearIndex) {
return transferReadOp.emitOpError() << "failed to linearize index";
}
Expand Down Expand Up @@ -308,9 +353,9 @@ struct LinearizeTransferWriteIndices final
return rewriter.notifyMatchFailure(
transferWriteOp, "expected converted memref of rank <= 1");
}
Value linearIndex = linearizeIndices(
transferWriteOp.getShapedType().cast<MemRefType>(),
transferWriteOp.indices(), transferWriteOp.getLoc(), rewriter);
Value linearIndex =
linearizeIndices(transferWriteOp.source(), transferWriteOp.indices(),
transferWriteOp.getLoc(), rewriter);
if (!linearIndex) {
return transferWriteOp.emitOpError() << "failed to linearize index";
}
Expand Down Expand Up @@ -397,7 +442,7 @@ struct FoldSubspanOffsetIntoLoadStore final : public OpRewritePattern<OpType> {
Value zero = rewriter.create<ConstantIndexOp>(op.memref().getLoc(), 0);
Value newSubspan = rewriter.create<IREE::HAL::InterfaceBindingSubspanOp>(
op.memref().getLoc(), subspanOp.getType(), subspanOp.binding(), zero,
subspanOp.byte_length());
subspanOp.byte_length(), subspanOp.dynamic_dims());

MLIRContext *context = rewriter.getContext();
AffineExpr sym0, sym1;
Expand Down Expand Up @@ -438,7 +483,7 @@ struct FlattenMemRefSubspanPass
FlattenMemRefSubspanPass(const FlattenMemRefSubspanPass &pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, memref::MemRefDialect>();
registry.insert<AffineDialect, memref::MemRefDialect, ShapeDialect>();
}

void runOnOperation() override {
Expand Down
2 changes: 1 addition & 1 deletion iree/compiler/Codegen/Common/LinalgBufferizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ void LinalgBufferizePass::runOnOperation() {
auto memRefType = getMemrefTypeForTensor(tensorType);
auto baseBuffer = b.create<IREE::HAL::InterfaceBindingSubspanOp>(
op->getLoc(), memRefType, op.binding(), op.byte_offset(),
op.byte_length());
op.byte_length(), op.dynamic_dims());
bvm.map(op, baseBuffer);
});

Expand Down
104 changes: 0 additions & 104 deletions iree/compiler/Codegen/Common/ShapeToLLVMConversion.cpp

This file was deleted.

Loading

0 comments on commit bfd507f

Please sign in to comment.