Skip to content

Commit

Permalink
Move memref->buffer conversion out of VM conversion. (iree-org#10015)
Browse files Browse the repository at this point in the history
* Move memref->buffer conversion out of VM conversion.

It is now done as part of VMVX conversion.
  • Loading branch information
stellaraccident authored Aug 6, 2022
1 parent 2906c1d commit b9a39d7
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 66 deletions.
9 changes: 6 additions & 3 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeSupport.h"
Expand Down Expand Up @@ -308,8 +309,9 @@ static bool isValueUsableForOp(Value value, Block *block,
if (value.getDefiningOp()->isBeforeInBlock(&*insertionPoint)) {
return true;
}
} else if (definingBlock->isEntryBlock()) {
// Entry block always dominates - fast path for constants.
} else if (definingBlock->isEntryBlock() &&
llvm::isa<FunctionOpInterface>(definingBlock->getParentOp())) {
// Function entry block always dominates - fast path for constants.
return true;
} else {
// See if block the value is defined in dominates the forOp block.
Expand Down Expand Up @@ -351,8 +353,9 @@ Value SizeAwareTypeInterface::findSizeValue(Value resourceValue, Block *block,
use.getOwner())) {
auto sizeValue = sizeAwareOp.getOperandSize(use.getOperandNumber());
if (sizeValue) {
if (isValueUsableForOp(sizeValue, block, insertionPoint))
if (isValueUsableForOp(sizeValue, block, insertionPoint)) {
return sizeValue;
}
}
}
if (auto tiedOp =
Expand Down
11 changes: 1 addition & 10 deletions compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <tuple>

#include "iree/compiler/Dialect/Util/Conversion/ConversionPatterns.h"
#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/VM/Conversion/ConversionDialectInterface.h"
Expand Down Expand Up @@ -78,7 +77,7 @@ class ConversionPass
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<IREE::Util::UtilDialect, IREE::VM::VMDialect,
func::FuncDialect, mlir::arith::ArithmeticDialect,
math::MathDialect, AffineDialect, memref::MemRefDialect>();
math::MathDialect, AffineDialect>();
}

void runOnOperation() override {
Expand Down Expand Up @@ -126,14 +125,6 @@ class ConversionPass
populateMathToVMPatterns(context, typeConverter, patterns);
populateAffineToStdConversionPatterns(patterns);

// MemRef to Util (to VM) is an A->B->C lowering. We must instruct it
// specifically on what the correct C buffer type is.
auto utilBufferType =
typeConverter.convertType(IREE::Util::BufferType::get(&getContext()));
assert(utilBufferType);
populateMemRefToUtilPatterns(context, conversionTarget, typeConverter,
patterns, utilBufferType);

conversionTarget
.addIllegalDialect<func::FuncDialect, mlir::arith::ArithmeticDialect>();
conversionTarget.addIllegalDialect<AffineDialect>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Dialect/VMVX/IR",
"//compiler/src/iree/compiler/Dialect/VMVX/IR:VMVXDialect",
"//compiler/src/iree/compiler/Utils",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_cc_library(
iree::compiler::Dialect::Util::IR
iree::compiler::Dialect::VMVX::IR
iree::compiler::Dialect::VMVX::IR::VMVXDialect
iree::compiler::Utils
PUBLIC
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "iree/compiler/Dialect/VMVX/IR/VMVXDialect.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXOps.h"
#include "iree/compiler/Dialect/VMVX/IR/VMVXTypes.h"
#include "iree/compiler/Utils/IndexSet.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -71,16 +72,13 @@ LogicalResult updateHALToVMVXEntryFuncOp(func::FuncOp funcOp,
return funcOp.emitError() << "exported functions must have no I/O";
}

auto i8Type = IntegerType::get(funcOp.getContext(), 8);
auto i32Type = IntegerType::get(funcOp.getContext(), 32);
auto memRefI8Type = MemRefType::get({-1}, i8Type);
auto memRefI32Type = MemRefType::get({-1}, i32Type);
auto bindingsType = IREE::Util::ListType::get(memRefI8Type);
auto bufferType = IREE::Util::BufferType::get(funcOp.getContext());
auto bindingsType = IREE::Util::ListType::get(bufferType); // of i8
auto indexType = IndexType::get(funcOp.getContext());
auto newType = FunctionType::get(funcOp.getContext(),
{
/*local_memory=*/memRefI8Type,
/*constants=*/memRefI32Type,
/*local_memory=*/bufferType, // of i8
/*constants=*/bufferType, // of i32
/*bindings=*/bindingsType,
/*workgroup_id_x=*/indexType,
/*workgroup_id_y=*/indexType,
Expand Down Expand Up @@ -177,17 +175,18 @@ struct ConvertHALInterfaceConstantLoadOp
auto constantsArg = op->getParentOfType<mlir::func::FuncOp>().getArgument(
kEntryArgConstants);
assert(constantsArg && "entry point not conforming to requirements");
auto constantType =
constantsArg.getType().cast<MemRefType>().getElementType();

auto constantsSize =
rewriter.create<IREE::Util::BufferSizeOp>(op.getLoc(), constantsArg);
auto resultType = getTypeConverter()->convertType(op.getResult().getType());

auto constantIndex = rewriter.createOrFold<arith::ConstantIndexOp>(
op.getLoc(), op.getIndex().getZExtValue());
auto loadedValue = rewriter.createOrFold<memref::LoadOp>(
op.getLoc(), constantType, constantsArg, ValueRange{constantIndex});
rewriter.replaceOpWithNewOp<mlir::UnrealizedConversionCastOp>(
op, resultType, loadedValue);
auto elementSize =
rewriter.createOrFold<IREE::Util::SizeOfOp>(op.getLoc(), resultType);
auto byteOffset = rewriter.createOrFold<arith::MulIOp>(
op.getLoc(), elementSize, constantIndex);
rewriter.replaceOpWithNewOp<IREE::Util::BufferLoadOp>(
op, resultType, constantsArg, constantsSize, byteOffset);
return success();
}
};
Expand All @@ -211,39 +210,46 @@ struct ConvertHALInterfaceBindingSubspanOp
return op.emitOpError() << "sparse binding sets not yet implemented";
}

IndexSet indexSet(op.getLoc(), rewriter);
auto bindingType =
bindingsArg.getType().cast<IREE::Util::ListType>().getElementType();
auto memrefValue = rewriter
.create<IREE::Util::ListGetOp>(
op.getLoc(), bindingType, bindingsArg,
rewriter.createOrFold<arith::ConstantIndexOp>(
op.getLoc(), op.getBinding().getZExtValue()))
.getResult();
auto sourceBuffer =
rewriter
.create<IREE::Util::ListGetOp>(
op.getLoc(), bindingType, bindingsArg,
rewriter.createOrFold<arith::ConstantIndexOp>(
op.getLoc(), op.getBinding().getZExtValue()))
.getResult();

if (op.getByteOffset() && !matchPattern(op.getByteOffset(), m_Zero())) {
auto memrefType = op.getResult().getType().cast<MemRefType>();
Value elementCount;
if (memrefType.isDynamicDim(0)) {
elementCount = op.getDynamicDims().front();
} else {
elementCount = rewriter.createOrFold<arith::ConstantIndexOp>(
op.getLoc(), memrefType.getDimSize(0));
// Offsetted binding: replace with a BufferSpan.
Value sourceSize = rewriter.createOrFold<IREE::Util::BufferSizeOp>(
op.getLoc(), sourceBuffer);

// Compute the dest size by multiplying the element size by all extents
// (static and dynamic).
auto memRefType = op.getResult().getType().cast<MemRefType>();
Value destSize = rewriter.createOrFold<IREE::Util::SizeOfOp>(
op.getLoc(), memRefType.getElementType());
auto dynamicExtentIt = adaptor.getDynamicDims().begin();
for (int i = 0; i < memRefType.getRank(); ++i) {
Value extent;
if (memRefType.isDynamicDim(i)) {
extent = *dynamicExtentIt;
dynamicExtentIt++;
} else {
extent = indexSet.get(memRefType.getDimSize(i));
}
destSize =
rewriter.createOrFold<arith::MulIOp>(op.getLoc(), destSize, extent);
}
auto byteLength = rewriter.createOrFold<arith::MulIOp>(
op.getLoc(),
rewriter.createOrFold<arith::ConstantIndexOp>(
op.getLoc(), memrefType.getElementTypeBitWidth()),
elementCount);
memrefValue = rewriter.createOrFold<memref::SubViewOp>(
op.getLoc(), memrefValue, ArrayRef<OpFoldResult>{op.getByteOffset()},
ArrayRef<OpFoldResult>{byteLength},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});

rewriter.replaceOpWithNewOp<IREE::Util::BufferSubspanOp>(
op, sourceBuffer, sourceSize, adaptor.getByteOffset(), destSize);
} else {
// Zero offset. Just return the source buffer.
rewriter.replaceOp(op, sourceBuffer);
}
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op,
getTypeConverter()
->convertType(op.getResult().getType())
.cast<MemRefType>(),
memrefValue);
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// RUN: iree-opt --split-input-file --iree-vmvx-conversion --canonicalize %s | FileCheck %s

// CHECK: memref.global "private" constant @__constant_5xi32 : memref<5xi32> = dense<[1, 2, 3, 4, 5]>
// CHECK: util.global private @__constant_5xi32 : !util.buffer
// CHECK: util.initializer {
// CHECK: %[[CST:.*]] = util.buffer.constant : !util.buffer = dense<[1, 2, 3, 4, 5]> : tensor<5xi32>
// CHECK: util.global.store %[[CST]], @__constant_5xi32
memref.global "private" constant @__constant_5xi32 : memref<5xi32> = dense<[1, 2, 3, 4, 5]>

// CHECK-LABEL: func.func @entry(
// CHECK-SAME: %[[SCRATCHPAD:.+]]: memref<?xi8>,
// CHECK-SAME: %[[CONSTANTS:.+]]: memref<?xi32>,
// CHECK-SAME: %[[BINDINGS:.+]]: !util.list<memref<?xi8>>,
// CHECK-SAME: %[[SCRATCHPAD:[a-z0-9]+]]: !util.buffer,
// CHECK-SAME: %[[CONSTANTS:[a-z0-9]+]]: !util.buffer,
// CHECK-SAME: %[[BINDINGS:[a-z0-9]+]]: !util.list<!util.buffer>,
// CHECK-SAME: %[[WORKGROUP_X:[a-z0-9]+]]: index,
// CHECK-SAME: %[[WORKGROUP_Y:[a-z0-9]+]]: index,
// CHECK-SAME: %[[WORKGROUP_Z:[a-z0-9]+]]: index,
Expand All @@ -22,11 +25,9 @@ func.func @entry() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.get_global @__constant_5xi32 : memref<5xi32>
// CHECK: %[[BINDING0_RAW:.+]] = util.list.get %[[BINDINGS]][%c0] : !util.list<memref<?xi8>>
// CHECK-NEXT: %[[BINDING0:.+]] = builtin.unrealized_conversion_cast %[[BINDING0_RAW]] : memref<?xi8> to memref<5xf32>
// CHECK: %[[BINDING0:.+]] = util.list.get %[[BINDINGS]][%c0] : !util.list<!util.buffer>
%1 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) : memref<5xf32>
// CHECK: %[[BINDING1_RAW:.+]] = util.list.get %[[BINDINGS]][%c1] : !util.list<memref<?xi8>>
// CHECK-NEXT: %[[BINDING1:.+]] = builtin.unrealized_conversion_cast %[[BINDING1_RAW]] : memref<?xi8> to memref<5xi32>
// CHECK: %[[BINDING1:.+]] = util.list.get %[[BINDINGS]][%c1] : !util.list<!util.buffer>
%2 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) : memref<5xi32>
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
#include "iree/compiler/Dialect/Util/Conversion/MemRefToUtil/ConvertMemRefToUtil.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/VM/IR/VMDialect.h"
#include "iree/compiler/Dialect/VMVX/Conversion/HALToVMVX/ConvertHALToVMVX.h"
Expand Down Expand Up @@ -68,11 +69,14 @@ class ConversionPass : public ConversionBase<ConversionPass> {
mlir::arith::ArithmeticDialect>();
conversionTarget.addLegalDialect<mlir::AffineDialect>();
conversionTarget.addLegalDialect<memref::MemRefDialect>();
conversionTarget.addLegalOp<mlir::UnrealizedConversionCastOp>();
conversionTarget.addIllegalOp<mlir::UnrealizedConversionCastOp>();

RewritePatternSet patterns(&getContext());
populateHALToVMVXPatterns(context, patterns, typeConverter);
populateStandardToVMVXPatterns(context, patterns, typeConverter);
populateMemRefToUtilPatterns(context, conversionTarget, typeConverter,
patterns,
IREE::Util::BufferType::get(&getContext()));

// Use the default 64-bit lowering for TOSA's ApplyScale operator:
// This lowering widens integer types to 64-bit an performs the non-fused
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/iree/compiler/Dialect/VMVX/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ void buildVMVXTransformPassPipeline(OpPassManager &passManager) {
// ---------------------------------------------------------------------------

passManager.addNestedPass<mlir::ModuleOp>(createConversionPass());
passManager.nest<mlir::ModuleOp>().addNestedPass<func::FuncOp>(
memref::createFoldSubViewOpsPass());
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createCSEPass());

Expand Down

0 comments on commit b9a39d7

Please sign in to comment.