forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Decouple Transform dialect usage in IREE from iree-dialects. (iree-or…
…g#9745) This revision creates a transform dialect interpreter pass in IREE with the proper dialect registrations to allow end-to-end examples from both iree-run-mlir and iree-opt. In the future, when the layering is right, only a single interpreter will be needed for both codegen and non-codegen rewrites, which will allow retiring the specialized interpreter that is used for dispatch region creation with the transform dialect. For now, the iree-dialects interpreter remain as a way to separate concerns between patterns and transform ops that are IREE-specific from one that will be upstreamed in the fullness of time. The GPU-specific transforms are relaxed to allow targeting either hal.executable or hal.executable.variant which lets them apply with either an iree-run-mlir or iree-opt flow.
- Loading branch information
1 parent
a9c91cd
commit f6b6335
Showing
12 changed files
with
323 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
173 changes: 173 additions & 0 deletions
173
compiler/src/iree/compiler/Codegen/Common/TransformDialectInterpreterPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// Copyright 2021 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" | ||
#include "iree-dialects/Dialect/LinalgExt/TransformOps/LinalgExtTransformOps.h" | ||
#include "iree-dialects/Dialect/LinalgTransform/LinalgTransformOps.h" | ||
#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h" | ||
#include "iree-dialects/Dialect/LinalgTransform/TransformInterpreterUtils.h" | ||
#include "iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.h" | ||
#include "iree/compiler/Codegen/LLVMCPU/TransformExtensions/LLVMCPUExtensions.h" | ||
#include "iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h" | ||
#include "iree/compiler/Codegen/PassDetail.h" | ||
#include "iree/compiler/Codegen/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h" | ||
#include "iree/compiler/Dialect/Flow/TransformExtensions/FlowExtensions.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/ScopeExit.h" | ||
#include "llvm/Support/SourceMgr.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | ||
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" | ||
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "mlir/Dialect/PDL/IR/PDL.h" | ||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "mlir/Dialect/Transform/IR/TransformDialect.h" | ||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "mlir/Parser/Parser.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassRegistry.h" | ||
#include "mlir/Support/FileUtilities.h" | ||
|
||
#define DEBUG_TYPE "iree-transform-dialect-interpreter" | ||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
/// Pass declaration. | ||
/// Interpreter pass that applies transform dialect ops for codegen. | ||
/// This needs to be its own pass because the registration mechanism and ops | ||
/// available are different than for other interpreters. | ||
class TransformDialectInterpreterPass | ||
: public iree_compiler::TransformDialectInterpreterBase< | ||
TransformDialectInterpreterPass> { | ||
public: | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
// TODO: this is only necessary to make registry subset happy when running | ||
// the lowering to LLVM. The lowering should be changed to stop using the | ||
// nested pass manager and this will go away. | ||
|
||
// clang-format off | ||
registry.insert<mlir::iree_compiler::IREE::LinalgExt::IREELinalgExtDialect, | ||
mlir::iree_compiler::IREE::Flow::FlowDialect, | ||
arith::ArithmeticDialect, | ||
AffineDialect, | ||
bufferization::BufferizationDialect, | ||
func::FuncDialect, | ||
gpu::GPUDialect, | ||
linalg::LinalgDialect, | ||
linalg::transform::LinalgTransformDialect, | ||
LLVM::LLVMDialect, | ||
pdl::PDLDialect, | ||
pdl_interp::PDLInterpDialect, | ||
scf::SCFDialect, | ||
tensor::TensorDialect, | ||
transform::TransformDialect, | ||
vector::VectorDialect | ||
// clang-format on | ||
>(); | ||
|
||
// TODO: these should be registered by the extension instead, but there is | ||
// no support for it in core currently. | ||
arith::registerBufferizableOpInterfaceExternalModels(registry); | ||
linalg::registerBufferizableOpInterfaceExternalModels(registry); | ||
scf::registerBufferizableOpInterfaceExternalModels(registry); | ||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( | ||
registry); | ||
tensor::registerBufferizableOpInterfaceExternalModels(registry); | ||
vector::registerBufferizableOpInterfaceExternalModels(registry); | ||
|
||
registry.addExtensions< | ||
mlir::iree_compiler::IREE::LinalgExt::LinalgExtTransformOpsExtension, | ||
transform_ext::StructuredTransformOpsExtension>(); | ||
iree_compiler::registerTransformDialectCommonExtension(registry); | ||
iree_compiler::registerTransformDialectFlowExtension(registry); | ||
iree_compiler::registerTransformDialectLLVMCPUExtension(registry); | ||
iree_compiler::registerTransformDialectLLVMGPUExtension(registry); | ||
linalg::registerTransformDialectExtension(registry); | ||
} | ||
|
||
TransformDialectInterpreterPass(StringRef transformFileName = StringRef()) { | ||
this->transformFileName = transformFileName.str(); | ||
} | ||
TransformDialectInterpreterPass(const TransformDialectInterpreterPass &pass) { | ||
this->transformFileName = pass.transformFileName; | ||
// TODO: if we really don't like shared_ptr, we could also clone the | ||
// transformModule here. | ||
sharedTransformModule = pass.sharedTransformModule; | ||
} | ||
|
||
LogicalResult initialize(MLIRContext *context) override { | ||
OwningOpRef<ModuleOp> module; | ||
if (failed(transform::parseTransformModuleFromFile( | ||
context, transformFileName, module))) | ||
return failure(); | ||
sharedTransformModule = | ||
std::make_shared<OwningOpRef<ModuleOp>>(std::move(module)); | ||
return success(); | ||
} | ||
|
||
void runOnOperation() override { | ||
Operation *target = getOperation(); | ||
bool parsedTransform = (sharedTransformModule && *sharedTransformModule); | ||
assert(parsedTransform || (target->getNumRegions() == 1 && | ||
target->getRegion(0).getBlocks().size() == 1) && | ||
"Cannot extract transform from op"); | ||
Region &transformRegion = parsedTransform | ||
? (*sharedTransformModule)->getRegion() | ||
: target->getRegion(0); | ||
if (failed(transform::applyTransformsInRegion(transformRegion, target))) { | ||
target->emitOpError() << "transform dialect interpreter failed"; | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
private: | ||
// The parsed transform module to be used for transformations. | ||
// TODO: Figure a better way to build a transform module and transport it in | ||
// the proper places in the IR as it is transformed by IREE so that it is | ||
// available with better ownership semantics. | ||
// Note: we wrap the OwningOpRef to get the desired destruction mechanism. | ||
// Note: shared_ptr is not great but we know the sharedTransformModule is | ||
// readonly. | ||
// Alternatives comprise: | ||
// 1. no shared_ptr but copying the module with every pass clone that the | ||
// OpPassManager decides to perform. | ||
// 2. lifting ownership of the parsed transform module higher up in the | ||
// IREE stack. This may be only shift the problem as we have passes | ||
// building pass managers in IREE. | ||
// 3. build better support to embed the transformation module in the | ||
// input IR and transport it to the place of use in IREE. This is deemed | ||
// too intrusive atm. | ||
// 4. (future) config/resources mechanism that is being proposed in core? | ||
std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule; | ||
}; | ||
} // namespace | ||
|
||
namespace mlir { | ||
namespace iree_compiler { | ||
/// Create a Transform dialect interpreter pass. | ||
std::unique_ptr<Pass> createTransformDialectInterpreterPass( | ||
llvm::StringRef transformFileName) { | ||
return std::make_unique<TransformDialectInterpreterPass>(transformFileName); | ||
} | ||
} // namespace iree_compiler | ||
} // namespace mlir |
2 changes: 1 addition & 1 deletion
2
compiler/src/iree/compiler/Codegen/Common/test/transform_dialect_apply_pattern_op.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
compiler/src/iree/compiler/Codegen/LLVMCPU/test/transform_dialect_bufferize.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
compiler/src/iree/compiler/Codegen/LLVMGPU/test/transform_dialect_bufferize.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.