Skip to content

Commit

Permalink
Decouple Transform dialect usage in IREE from iree-dialects. (iree-or…
Browse files Browse the repository at this point in the history
…g#9745)

This revision creates a transform dialect interpreter pass in IREE with
the proper dialect registrations to allow end-to-end examples from both
iree-run-mlir and iree-opt.

In the future, when the layering is right, only a single interpreter will
be needed for both codegen and non-codegen rewrites, which will allow
retiring the specialized interpreter that is used for dispatch region
creation with the transform dialect.

For now, the iree-dialects interpreter remain as a way to separate concerns
between patterns and transform ops that are IREE-specific from one that
will be upstreamed in the fullness of time.

The GPU-specific transforms are relaxed to allow targeting either hal.executable or hal.executable.variant
which lets them apply with either an iree-run-mlir or iree-opt flow.
  • Loading branch information
nicolasvasilache authored Jul 12, 2022
1 parent a9c91cd commit f6b6335
Show file tree
Hide file tree
Showing 12 changed files with 323 additions and 21 deletions.
63 changes: 63 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,68 @@ iree_gentbl_cc_library(
],
)

# TODO: If the layering causes concerns then the transform dialect interpreter
# should be one level above everything: it is a mechanism by which
# transformations are applied to any IR and needs to register all the dialects
# that may be produced.
# In particular, a single IREE-side transform interpreter is enough to perform
# all kind of transformations and not just codegen.
# This is an opportunity to retire the specific interpreter that is used for
# creating dispatch regions with the transform dialect, but only once the
# layering is correct.
iree_compiler_cc_library(
name = "TransformDialectInterpreterPass",
srcs = [
"TransformDialectInterpreterPass.cpp",
],
deps = [
# Dialects
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransformOps",
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:AsyncDialect",
"@llvm-project//mlir:ArithmeticDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:BufferizationTransforms",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:PDLDialect",
"@llvm-project//mlir:PDLInterpDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:VectorDialect",
# IR
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
# Interfaces
# Transforms (needed mostly for the BufferizableOpInterfaceImpl)
"@llvm-project//mlir:ArithmeticTransforms",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:SCFTransforms",
"@llvm-project//mlir:TensorTransforms",
"@llvm-project//mlir:VectorTransforms",
# Other Stuff
"//compiler/src/iree/compiler/Codegen:PassHeaders",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Support",
# TransformExtensions
"//compiler/src/iree/compiler/Codegen/Common/TransformExtensions:CommonExtensions",
"//compiler/src/iree/compiler/Dialect/Flow/TransformExtensions:FlowExtensions",
"//compiler/src/iree/compiler/Codegen/LLVMCPU/TransformExtensions:LLVMCPUExtensions",
"//compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions:LLVMGPUExtensions",
"@llvm-project//mlir:LinalgTransformOps",
],
)

iree_compiler_cc_library(
name = "Common",
srcs = [
Expand Down Expand Up @@ -66,6 +128,7 @@ iree_compiler_cc_library(
"Transforms.h",
],
deps = [
":TransformDialectInterpreterPass",
"//compiler/src/iree/compiler/Codegen:PassHeaders",
"//compiler/src/iree/compiler/Codegen/Common:FoldTensorExtractOpIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect:IREECodegenDialect",
Expand Down
48 changes: 48 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,53 @@ iree_tablegen_library(
--gen-rewriters FoldTensorExtractOp.cpp.inc
)

iree_cc_library(
NAME
TransformDialectInterpreterPass
SRCS
"TransformDialectInterpreterPass.cpp"
DEPS
IREELinalgExtDialect
IREELinalgExtTransformOps
IREELinalgTransformDialect
LLVMSupport
MLIRAffineDialect
MLIRAffineUtils
MLIRArithmeticDialect
MLIRArithmeticTransforms
MLIRAsyncDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRFuncDialect
MLIRGPUOps
MLIRIR
MLIRLLVMDialect
MLIRLinalgDialect
MLIRLinalgTransformOps
MLIRLinalgTransforms
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRParser
MLIRPass
MLIRRewrite
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils
MLIRSupport
MLIRTensorDialect
MLIRTensorTransforms
MLIRTransformDialect
MLIRVectorDialect
MLIRVectorTransforms
iree::compiler::Codegen::Common::TransformExtensions::CommonExtensions
iree::compiler::Codegen::LLVMCPU::TransformExtensions::LLVMCPUExtensions
iree::compiler::Codegen::LLVMGPU::TransformExtensions::LLVMGPUExtensions
iree::compiler::Codegen::PassHeaders
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::Flow::TransformExtensions::FlowExtensions
PUBLIC
)

iree_cc_library(
NAME
Common
Expand Down Expand Up @@ -52,6 +99,7 @@ iree_cc_library(
"VectorizeMMT4d.cpp"
"WorkGroupSwizzle.cpp"
DEPS
::TransformDialectInterpreterPass
IREELinalgExtDialect
IREELinalgExtPasses
IREELinalgExtTransforms
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h"
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h"
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h"
#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h"
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/FileUtilities.h"

#define DEBUG_TYPE "iree-transform-dialect-interpreter"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")

using namespace mlir;

namespace {

/// Pass declaration.
/// Interpreter pass that applies transform dialect ops for codegen.
/// This needs to be its own pass because the registration mechanism and ops
/// available are different than for other interpreters.
class TransformDialectInterpreterPass
: public iree_compiler::TransformDialectInterpreterBase<
TransformDialectInterpreterPass> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
// TODO: this is only necessary to make registry subset happy when running
// the lowering to LLVM. The lowering should be changed to stop using the
// nested pass manager and this will go away.

// clang-format off
registry.insert<mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect,
mlir::iree_compiler::IREE::Flow::FlowDialect,
arith::ArithmeticDialect,
AffineDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
gpu::GPUDialect,
linalg::LinalgDialect,
linalg::transform::LinalgTransformDialect,
LLVM::LLVMDialect,
pdl::PDLDialect,
pdl_interp::PDLInterpDialect,
scf::SCFDialect,
tensor::TensorDialect,
transform::TransformDialect,
vector::VectorDialect
// clang-format on
>();

// TODO: these should be registered by the extension instead, but there is
// no support for it in core currently.
arith::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
tensor::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);

registry.addExtensions<
mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension,
transform_ext::StructuredTransformOpsExtension>();
iree_compiler::registerTransformDialectCommonExtension(registry);
iree_compiler::registerTransformDialectFlowExtension(registry);
iree_compiler::registerTransformDialectLLVMCPUExtension(registry);
iree_compiler::registerTransformDialectLLVMGPUExtension(registry);
linalg::registerTransformDialectExtension(registry);
}

TransformDialectInterpreterPass(StringRef transformFileName = StringRef()) {
this->transformFileName = transformFileName.str();
}
TransformDialectInterpreterPass(const TransformDialectInterpreterPass &pass) {
this->transformFileName = pass.transformFileName;
// TODO: if we really don't like shared_ptr, we could also clone the
// transformModule here.
sharedTransformModule = pass.sharedTransformModule;
}

LogicalResult initialize(MLIRContext *context) override {
OwningOpRef<ModuleOp> module;
if (failed(transform::parseTransformModuleFromFile(
context, transformFileName, module)))
return failure();
sharedTransformModule =
std::make_shared<OwningOpRef<ModuleOp>>(std::move(module));
return success();
}

void runOnOperation() override {
Operation *target = getOperation();
bool parsedTransform = (sharedTransformModule && *sharedTransformModule);
assert(parsedTransform || (target->getNumRegions() == 1 &&
target->getRegion(0).getBlocks().size() == 1) &&
"Cannot extract transform from op");
Region &transformRegion = parsedTransform
? (*sharedTransformModule)->getRegion()
: target->getRegion(0);
if (failed(transform::applyTransformsInRegion(transformRegion, target))) {
target->emitOpError() << "transform dialect interpreter failed";
return signalPassFailure();
}
}

private:
// The parsed transform module to be used for transformations.
// TODO: Figure a better way to build a transform module and transport it in
// the proper places in the IR as it is transformed by IREE so that it is
// available with better ownership semantics.
// Note: we wrap the OwningOpRef to get the desired destruction mechanism.
// Note: shared_ptr is not great but we know the sharedTransformModule is
// readonly.
// Alternatives comprise:
// 1. no shared_ptr but copying the module with every pass clone that the
// OpPassManager decides to perform.
// 2. lifting ownership of the parsed transform module higher up in the
// IREE stack. This may be only shift the problem as we have passes
// building pass managers in IREE.
// 3. build better support to embed the transformation module in the
// input IR and transport it to the place of use in IREE. This is deemed
// too intrusive atm.
// 4. (future) config/resources mechanism that is being proposed in core?
std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule;
};
} // namespace

namespace mlir {
namespace iree_compiler {
/// Create a Transform dialect interpreter pass.
std::unique_ptr<Pass> createTransformDialectInterpreterPass(
llvm::StringRef transformFileName) {
return std::make_unique<TransformDialectInterpreterPass>(transformFileName);
}
} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s -transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s

// CHECK-LABEL: @select_cmp_eq_select
// CHECK: return %arg1
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,9 @@ void addCPUDefaultPassPipeline(OpPassManager &passManager) {

void addTransformDialectInterpreterPasses(OpPassManager &passManager) {
// Give control to the transform dialect.
passManager.addPass(createTransformDialectInterpreterPass(
clCPUCodegenTransformDialectFileName));
passManager.addPass(
mlir::iree_compiler::createTransformDialectInterpreterPass(
clCPUCodegenTransformDialectFileName));

// Dropping the schedule is only needed if we want to embed the transform in
// the module: we should drop the schedule once applied.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s -transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s

#device_target_cpu = #hal.device.target<"cpu", {executable_targets = [#hal.executable.target<"llvm", "embedded-elf-x86_64", {cpu_features = "", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", native_vector_size = 16 : index, target_triple = "x86_64-unknown-unknown-eabi-elf"}>]}>
#executable_layout = #hal.executable.layout<push_constants = 0, sets = [#hal.descriptor_set.layout<0, bindings = [#hal.descriptor_set.binding<0, storage_buffer>, #hal.descriptor_set.binding<1, storage_buffer>, #hal.descriptor_set.binding<2, storage_buffer>]>]>
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ extern llvm::cl::opt<std::string> clGPUCodegenTransformDialectFileName;

void addGPUTransformDialectInterpreterPasses(OpPassManager &passManager) {
// Give control to the transform dialect.
passManager.addPass(createTransformDialectInterpreterPass(
clGPUCodegenTransformDialectFileName));
passManager.addPass(
mlir::iree_compiler::createTransformDialectInterpreterPass(
clGPUCodegenTransformDialectFileName));

// Dropping the schedule is only needed if we want to embed the transform in
// the module: we should drop the schedule once applied.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ DiagnosedSilenceableFailure
transform_dialect::ForeachThreadToGpuAndTranslationInfo::applyToOne(
func::FuncOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableVariantOp>(state.getTopLevel())) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
"requires HAL::ExecutableVariantOp toplevel");
"requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel to "
"attach the workgroup size information to a nested ExecutableExportOp");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}

Expand Down Expand Up @@ -353,11 +354,12 @@ DiagnosedSilenceableFailure
transform_dialect::VectorWarpExecuteOnLane0Op::applyToOne(
scf::IfOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
if (!isa<HAL::ExecutableVariantOp>(state.getTopLevel())) {
if (!isa<HAL::ExecutableOp, HAL::ExecutableVariantOp>(state.getTopLevel())) {
state.getTopLevel()->emitOpError(
"requires HAL::ExecutableVariantOp toplevel so that IR is properly "
"isolated. This is required so we can safely inspect the "
"HAL::ExecutableExportOp under multi-threaded pass assumptions.");
"requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel so "
"that "
"IR is properly isolated. This is required so we can safely inspect "
"the HAL::ExecutableExportOp under multi-threaded pass assumptions.");
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s -transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s
// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s

func.func @pad_matmul_static_dispatch_0(%arg0: tensor<250x500xf32>, %arg1: tensor<500x1020xf32>) -> tensor<250x1020xf32> {
%c0 = arith.constant 0 : index
Expand Down
Loading

0 comments on commit f6b6335

Please sign in to comment.