Skip to content

Commit

Permalink
Convert func->util as part of input conversion. (iree-org#16411)
Browse files Browse the repository at this point in the history
We now use `util.func` in place of `func.func` in all host code in the
compiler. flow/stream/hal executables continue to use `func.func` as
before for compatibility with upstream code and the benefits of the util
ops are fewer. Most code is still written against the
function/callable/call op interfaces so that we support initializers and
other future function types we may add. All tests have been updated to
use `util.func` for consistency even if the pass does still work with
`func.func`.

There's a few TODOs around better supporting tied function operands in
IPO and other passes but we aren't currently ever producing functions
with tied operands so they are hacked to bail in cases where they are
(IPO doesn't act on functions/calls with tied operands, etc).
  • Loading branch information
benvanik authored Feb 15, 2024
1 parent 1ee6007 commit 045bca1
Show file tree
Hide file tree
Showing 320 changed files with 6,186 additions and 5,947 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// Check that the auto input conversion pipeline uses this plugin.

// CHECK-LABEL: func.func @simple_add_stablehlo
// CHECK-LABEL: util.func public @simple_add_stablehlo
// CHECK: arith.addi
func.func @simple_add_stablehlo(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// Check that the auto input conversion pipeline uses this plugin.

// CHECK-LABEL: func.func @simple_add_tosa
// CHECK-LABEL: util.func public @simple_add_tosa
// CHECK: arith.addi
func.func @simple_add_tosa(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = tosa.add %arg0, %arg1 : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// Check that the auto input conversion pipeline uses this plugin.

// CHECK-LABEL: func.func @simple_add_torch
// CHECK-LABEL: util.func public @simple_add_torch
// CHECK: arith.addf
func.func @simple_add_torch(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> {
%int1 = torch.constant.int 1
Expand All @@ -12,7 +12,7 @@ func.func @simple_add_torch(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtenso

// -----

// CHECK-LABEL: func.func @simple_add_onnx
// CHECK-LABEL: util.func public @simple_add_onnx
// CHECK: arith.addi
func.func @simple_add_onnx(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} {
%0 = torch.operator "onnx.Add"(%arg0, %arg1) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ iree_cc_library(
DEPS
LLVMSupport
MLIRAffineUtils
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
MLIRPass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include "iree/compiler/Bindings/Native/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
Expand Down Expand Up @@ -80,7 +81,7 @@ verifyResultDimsFunc(FunctionType functionType, int requiredResultDims,
// Converts a func.func with the iree.abi.streamable attribute into a flow.func
// and fixes all func.call ops to be flow.call across the module.
static std::optional<StreamableFunc>
convertStreamableFunc(mlir::ModuleOp moduleOp, func::FuncOp funcOp,
convertStreamableFunc(mlir::ModuleOp moduleOp, IREE::Util::FuncOp funcOp,
SymbolTable &symbolTable) {
OpBuilder moduleBuilder(funcOp);
auto functionType = funcOp.getFunctionType();
Expand Down Expand Up @@ -137,8 +138,18 @@ convertStreamableFunc(mlir::ModuleOp moduleOp, func::FuncOp funcOp,
}
}

bool anyTiedOperands = false;
streamableFunc.tiedOperands.resize(functionType.getNumResults(),
IREE::Util::TiedOpInterface::kUntiedIndex);
if (auto tiedOperandsAttr = funcOp.getTiedOperandsAttr()) {
for (auto [resultIndex, tiedAttr] : llvm::enumerate(
funcOp.getTiedOperandsAttr().getAsRange<IntegerAttr>())) {
if (tiedAttr.getInt() != IREE::Util::TiedOpInterface::kUntiedIndex) {
streamableFunc.tiedOperands[resultIndex] = tiedAttr.getInt();
anyTiedOperands = true;
}
}
}
SmallVector<DictionaryAttr> funcResAttrs;
for (auto [i, resultType] : llvm::enumerate(functionType.getResults())) {
// Tensor results need to have their dynamic dimensions specified.
Expand All @@ -157,8 +168,8 @@ convertStreamableFunc(mlir::ModuleOp moduleOp, func::FuncOp funcOp,
if (auto oldResAttrs = funcOp.getResultAttrDict(i)) {
// First check if the result is tied to an argument.
// We can use this to source the initial set of dynamic dimensions.
if (auto tiedAttr = oldResAttrs.getAs<IntegerAttr>("iree.abi.tied")) {
streamableFunc.tiedOperands[i] = tiedAttr.getInt();
int64_t tiedIndex = streamableFunc.tiedOperands[i];
if (tiedIndex != IREE::Util::TiedOpInterface::kUntiedIndex) {
if (!streamableFunc.resultDimsFunc &&
shapedType == functionType.getInput(i)) {
// Tied types match and we can infer the shape from that. This may
Expand Down Expand Up @@ -195,8 +206,7 @@ convertStreamableFunc(mlir::ModuleOp moduleOp, func::FuncOp funcOp,

// Pass-through all other attrs we don't care about.
for (auto resAttr : oldResAttrs) {
if (resAttr.getName() == "iree.abi.tied" ||
resAttr.getName() == "iree.abi.dims") {
if (resAttr.getName() == "iree.abi.dims") {
continue;
}
newResAttrs.push_back(resAttr);
Expand All @@ -221,10 +231,13 @@ convertStreamableFunc(mlir::ModuleOp moduleOp, func::FuncOp funcOp,
}

// Create the new streamable flow.func op at the same place as the original.
auto tiedOperandsAttr =
anyTiedOperands
? moduleBuilder.getIndexArrayAttr(streamableFunc.tiedOperands)
: ArrayAttr{};
streamableFunc.funcOp = moduleBuilder.create<IREE::Flow::FuncOp>(
funcOp.getLoc(), funcOp.getName(), functionType,
moduleBuilder.getIndexArrayAttr(streamableFunc.tiedOperands), funcAttrs,
funcArgAttrs, funcResAttrs);
funcOp.getLoc(), funcOp.getName(), functionType, tiedOperandsAttr,
funcAttrs, funcArgAttrs, funcResAttrs);

// Swap out the symbol in the symbol table.
symbolTable.erase(funcOp);
Expand All @@ -234,7 +247,7 @@ convertStreamableFunc(mlir::ModuleOp moduleOp, func::FuncOp funcOp,
}

static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc,
func::CallOp callOp) {
IREE::Util::CallOp callOp) {
OpBuilder builder(callOp);

// Capture all argument dynamic dimensions.
Expand All @@ -253,9 +266,10 @@ static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc,
// It should return the required number of dynamic dimensions.
SmallVector<Type> resultDimTypes(streamableFunc.requiredResultDims,
builder.getIndexType());
auto calculateCallOp = builder.create<func::CallOp>(
callOp.getLoc(), streamableFunc.resultDimsFunc, resultDimTypes,
callOp.getOperands());
auto calculateCallOp = builder.create<IREE::Util::CallOp>(
callOp.getLoc(), resultDimTypes,
streamableFunc.resultDimsFunc.getLeafReference().getValue(),
callOp.getOperands(), ArrayAttr{});
llvm::append_range(resultDims, calculateCallOp.getResults());
} else {
// Get the shape dimensions from existing call arguments or tied operands.
Expand Down Expand Up @@ -301,7 +315,7 @@ static LogicalResult convertStreamableCall(StreamableFunc &streamableFunc,
static LogicalResult
convertStreamableCalls(mlir::ModuleOp moduleOp,
DenseMap<StringRef, StreamableFunc> &streamableFuncs) {
auto walkResult = moduleOp.walk([&](func::CallOp callOp) {
auto walkResult = moduleOp.walk([&](IREE::Util::CallOp callOp) {
auto it = streamableFuncs.find(callOp.getCallee());
if (it != streamableFuncs.end()) {
if (failed(convertStreamableCall(it->second, callOp))) {
Expand All @@ -320,8 +334,8 @@ class ConvertStreamableOpsPass
ConvertStreamableOpsPass(const ConvertStreamableOpsPass &pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, mlir::tensor::TensorDialect,
IREE::Flow::FlowDialect>();
registry.insert<mlir::tensor::TensorDialect, IREE::Flow::FlowDialect,
IREE::Util::UtilDialect>();
}

StringRef getArgument() const override {
Expand All @@ -337,8 +351,8 @@ class ConvertStreamableOpsPass
auto moduleOp = getOperation();

// Gather functions that need wrapping.
SmallVector<func::FuncOp> originalFuncOps;
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
SmallVector<IREE::Util::FuncOp> originalFuncOps;
for (auto funcOp : moduleOp.getOps<IREE::Util::FuncOp>()) {
// Ignore functions already marked as having their ABI goo handled.
if (funcOp->hasAttr("iree.abi.streamable")) {
if (!funcOp.isExternal()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Utils/PassUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassOptions.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"

namespace mlir::iree_compiler::IREE::ABI {

using FunctionLikeNest =
MultiOpNest<func::FuncOp, IREE::Util::InitializerOp, IREE::Util::FuncOp>;
MultiOpNest<IREE::Util::InitializerOp, IREE::Util::FuncOp>;

void buildTransformPassPipeline(OpPassManager &passManager,
const InvocationOptions &invocationOptions) {
Expand Down
Loading

0 comments on commit 045bca1

Please sign in to comment.