Skip to content

Commit

Permalink
[Transform] Add a new transform op that applies patterns (iree-org#9676)
Browse files Browse the repository at this point in the history
This revision introduces a `transform.iree.apply_patterns` operation that operates on
an isolated from above op and applies a set of patterns while listening and updating
transform dialect handles.

The list of patterns is specified via attributes that are additive and roughly play the role of
populate functions.

iree-opt is extended so that it can run transform dialect directly without files, which
results in general simplifications.
  • Loading branch information
nicolasvasilache authored Jul 1, 2022
1 parent e789f95 commit 2e03b3c
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "TransformDialectCommonExtensions.h"

#include "iree-dialects/Dialect/LinalgTransform/StructuredTransformOpsExt.h"
#include "iree-dialects/Transforms/ListenerGreedyPatternRewriteDriver.h"
#include "iree/compiler/Codegen/Interfaces/BufferizationInterfaces.h"
#include "iree/compiler/Codegen/Passes.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -27,6 +29,37 @@ void mlir::iree_compiler::registerTransformDialectCommonExtension(
registry.addExtensions<transform_dialect::TransformDialectCommonExtensions>();
}

//===---------------------------------------------------------------------===//
// ApplyPatternsOp
//===---------------------------------------------------------------------===//

static void addAllRegisteredCanonicalizationPatterns(
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
for (Dialect *dialect : ctx->getLoadedDialects())
dialect->getCanonicalizationPatterns(patterns);
for (RegisteredOperationName op : ctx->getRegisteredOperations())
op.getCanonicalizationPatterns(patterns, ctx);
}

FailureOr<Operation *> transform_dialect::ApplyPatternsOp::applyToOne(
Operation *target, transform::TransformState &state) {
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>())
return emitOpError() << "applies only to isolated-from-above targets "
"because it needs to apply patterns greedily";
MLIRContext *ctx = target->getContext();
RewritePatternSet patterns(ctx);
if (getCanonicalization()) addAllRegisteredCanonicalizationPatterns(patterns);

TrackingListener listener(state);
GreedyRewriteConfig config;
LogicalResult result = applyPatternsAndFoldGreedily(
target, std::move(patterns), config, &listener);
LogicalResult listenerResult = listener.checkErrorState();
if (failed(result) || failed(listenerResult)) return failure();
return target;
}

//===---------------------------------------------------------------------===//
// Default allocation functions for CPU backend
// TODO: register the bufferization behavior in a target-specific way.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_TRANSFORMDIALECTCOMMONEXTENSIONS_H_
#define IREE_COMPILER_CODEGEN_COMMON_TRANSFORMDIALECTEXTENSIONS_TRANSFORMDIALECTCOMMONEXTENSIONS_H_

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"

#define GET_OP_CLASSES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@ include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def ApplyPatternsOp : Op<Transform_Dialect, "iree.apply_patterns",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformEachOpTrait,
TransformOpInterface]> {
let description = [{
Greedily applies patterns as specified by its attributes.

Must be applied to an op with trait IsolatedFromAbove since the
GreedyPatternRewriter asserts those.

Returns the IsolatedFromAbove op whose content it has modified for better
chaining APIs.
}];

let arguments = (ins PDL_Operation:$target,
UnitAttr:$canonicalization);
let results = (outs PDL_Operation:$result);

let assemblyFormat = "$target attr-dict";
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";

let extraClassDeclaration = [{
::mlir::FailureOr<::mlir::Operation *> applyToOne(
::mlir::Operation *target, ::mlir::transform::TransformState &state);
}];
}

def IREEBufferizeOp : Op<Transform_Dialect, "iree.bufferize",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ iree_lit_test_suite(
"swizzle_workgroup.mlir",
"test_partitionable_loops_interface.mlir",
"tile_and_distribute_to_workgroups.mlir",
"transform_dialect_apply_pattern_op.mlir",
"transpose_canonicalization.mlir",
"type_propagation.mlir",
"vectorize_linalg_conv.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_lit_test_suite(
"swizzle_workgroup.mlir"
"test_partitionable_loops_interface.mlir"
"tile_and_distribute_to_workgroups.mlir"
"transform_dialect_apply_pattern_op.mlir"
"transpose_canonicalization.mlir"
"type_propagation.mlir"
"vectorize_linalg_conv.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: iree-opt %s -transform-dialect-interpreter -transform-dialect-drop-schedule | FileCheck %s

// CHECK-LABEL: @select_cmp_eq_select
// CHECK: return %arg1
func.func @select_cmp_eq_select(%arg0: i64, %arg1: i64) -> i64 {
%0 = arith.cmpi eq, %arg0, %arg1 : i64
%1 = arith.select %0, %arg0, %arg1 : i64
return %1 : i64
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @pdl_fun_target : benefit(1) {
%args = operands
%results = types
%0 = operation "func.func"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
// TODO: we don't want this, but it is the required terminator for pdl.pattern
rewrite %0 with "transform.dialect"
}

transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_fun_target in %arg1
transform.iree.apply_patterns %0 { canonicalization }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H
#define IREE_DIALECTS_DIALECT_LINALG_TRANSFORM_STRUCTUREDTRANSFORMOPSEXT_H

#include "iree-dialects/Transforms/Listener.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
Expand All @@ -19,6 +20,50 @@ class LinalgOp;
namespace scf {
class ForOp;
} // namespace scf

class TrackingListener : public RewriteListener,
public transform::TransformState::Extension {
public:
explicit TrackingListener(transform::TransformState &state)
: transform::TransformState::Extension(state) {}

~TrackingListener() override {
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(errorStateChecked && "must check listener error state");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

void notifyOperationReplaced(Operation *op, ValueRange newValues) override;

void notifyOperationRemoved(Operation *op) override;

LogicalResult checkErrorState() const {
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
errorStateChecked = true;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return failure(hadErrors);
}

private:
InFlightDiagnostic emitError(Operation *op, const llvm::Twine &message = {}) {
mayFail(failure());
return op->emitError(message);
}

void mayFail(LogicalResult result) {
hadErrors |= result.failed();
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
errorStateChecked = false;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

bool hadErrors = false;

#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
mutable bool errorStateChecked = false;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};

} // namespace mlir

#define GET_OP_CLASSES
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,84 +390,41 @@ static Operation *findSingleDefiningOp(Operation *replacedOp,
.Default([](Operation *) -> Operation * { return nullptr; });
}

namespace detail {
class TrackingListener : public RewriteListener,
public transform::TransformState::Extension {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TrackingListener);

explicit TrackingListener(transform::TransformState &state)
: transform::TransformState::Extension(state) {}

~TrackingListener() override {
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(errorStateChecked && "must check listener error state");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}

void notifyOperationReplaced(Operation *op, ValueRange newValues) override {
// Bail out if in error state.
if (hadErrors)
return;

// Exit early if the op is not tracked.
Value handle = getTransformState().getHandleForPayloadOp(op);
if (!handle)
return;

Operation *replacement = findSingleDefiningOp(op, newValues);
if (!replacement) {
emitError(op) << "could not find replacement for tracked op";
return;
}

LLVM_DEBUG(DBGS() << "replacing tracked " << *op << " with " << *replacement
<< " for " << handle << "\n");
mayFail(replacePayloadOp(op, replacement));
}

void notifyOperationRemoved(Operation *op) override {
// Bail out if in error state.
if (hadErrors)
return;

// Exit early if the op is not tracked.
Value handle = getTransformState().getHandleForPayloadOp(op);
if (!handle)
return;
void mlir::TrackingListener::notifyOperationReplaced(Operation *op,
ValueRange newValues) {
// Bail out if in error state.
if (hadErrors)
return;

LLVM_DEBUG(DBGS() << "removing tracked " << *op << " for " << handle
<< "\n");
mayFail(replacePayloadOp(op, nullptr));
}
// Exit early if the op is not tracked.
Value handle = getTransformState().getHandleForPayloadOp(op);
if (!handle)
return;

LogicalResult checkErrorState() const {
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
errorStateChecked = true;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return failure(hadErrors);
Operation *replacement = findSingleDefiningOp(op, newValues);
if (!replacement) {
emitError(op) << "could not find replacement for tracked op";
return;
}

private:
InFlightDiagnostic emitError(Operation *op, const llvm::Twine &message = {}) {
mayFail(failure());
return op->emitError(message);
}
LLVM_DEBUG(DBGS() << "replacing tracked " << *op << " with " << *replacement
<< " for " << handle << "\n");
mayFail(replacePayloadOp(op, replacement));
}

void mayFail(LogicalResult result) {
hadErrors |= result.failed();
#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
errorStateChecked = false;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
void mlir::TrackingListener::notifyOperationRemoved(Operation *op) {
// Bail out if in error state.
if (hadErrors)
return;

bool hadErrors = false;
// Exit early if the op is not tracked.
Value handle = getTransformState().getHandleForPayloadOp(op);
if (!handle)
return;

#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
mutable bool errorStateChecked = false;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
} // namespace detail
LLVM_DEBUG(DBGS() << "removing tracked " << *op << " for " << handle << "\n");
mayFail(replacePayloadOp(op, nullptr));
}

//===----------------------------------------------------------------------===//
// CanonicalizedSequenceOp
Expand Down Expand Up @@ -533,9 +490,9 @@ DiagnosedSilenceableFailure transform_ext::CanonicalizedSequenceOp::apply(

transform::TransformState::RegionScope regionScope =
state.make_region_scope(getBodyRegion());
auto &listener = state.addExtension<::detail::TrackingListener>();
auto &listener = state.addExtension<::mlir::TrackingListener>();
auto detachListener = llvm::make_scope_exit(
[&] { state.removeExtension<::detail::TrackingListener>(); });
[&] { state.removeExtension<::mlir::TrackingListener>(); });
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();

Expand Down
10 changes: 10 additions & 0 deletions tools/iree-opt-main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ int main(int argc, char **argv) {
mlir::iree_compiler::registerAllPasses();
mlir::iree_compiler::registerHALTargetBackends();

// Also register the transform interpreter pass so that iree-opt can run
// transform dialect IR without resorting to a separate file.
// Resorting to separate files is a convenience for iree-compile to be able to
// use the transform dialect without requiring special plumbing.
// Still the preferred mode of execution should be to transport the relevant
// piece of transform IR in the right location, for each piece of code we
// want to transform for.
mlir::linalg::transform::registerTransformDialectInterpreterPass();
mlir::linalg::transform::registerDropSchedulePass();

if (failed(MlirOptMain(argc, argv, "IREE modular optimizer driver\n",
registry,
/*preloadDialectsInContext=*/false))) {
Expand Down

0 comments on commit 2e03b3c

Please sign in to comment.