Skip to content

Commit

Permalink
Split VectorizePackUnPackOps into decomposition and vectorization pas…
Browse files Browse the repository at this point in the history
…ses (iree-org#12865)

- Rename VectorizePackUnPackOps.cpp to DecomposePackUnPackOps.cpp
- Each backend has its own vectorization pass, it delagates the
  vectorization to them.
  • Loading branch information
hanhanW authored Apr 4, 2023
1 parent 11ffea4 commit 820222c
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 171 deletions.
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ iree_compiler_cc_library(
"ConcretizePadResultShape.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposePackUnPackOps.cpp",
"EraseHALDescriptorTypeFromMemRef.cpp",
"FixupSubspanWithOffsets.cpp",
"FlattenMemRefSubspanPass.cpp",
Expand All @@ -135,7 +136,6 @@ iree_compiler_cc_library(
"TileAndDistributeToWorkgroupsPass.cpp",
"TileDispatchUsingInterface.cpp",
"TypePropagationPass.cpp",
"VectorizePackUnPackOps.cpp",
"VectorizePad.cpp",
"WorkgroupSpecializationPass.cpp",
],
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ iree_cc_library(
"ConcretizePadResultShape.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposePackUnPackOps.cpp"
"EraseHALDescriptorTypeFromMemRef.cpp"
"FixupSubspanWithOffsets.cpp"
"FlattenMemRefSubspanPass.cpp"
Expand All @@ -111,7 +112,6 @@ iree_cc_library(
"TileAndDistributeToWorkgroupsPass.cpp"
"TileDispatchUsingInterface.cpp"
"TypePropagationPass.cpp"
"VectorizePackUnPackOps.cpp"
"VectorizePad.cpp"
"WorkgroupSpecializationPass.cpp"
DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "llvm/Support/Debug.h"
Expand All @@ -14,9 +13,8 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -38,17 +36,6 @@ struct DecomposePackUnPackOpsPass

void runOnOperation() override;
};

struct VectorizePackUnPackOpsPass
: public VectorizePackUnPackOpsBase<VectorizePackUnPackOpsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, func::FuncDialect,
arith::ArithDialect, scf::SCFDialect, tensor::TensorDialect,
vector::VectorDialect>();
}

void runOnOperation() override;
};
} // namespace

void DecomposePackUnPackOpsPass::runOnOperation() {
Expand Down Expand Up @@ -140,30 +127,10 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
});
}

void VectorizePackUnPackOpsPass::runOnOperation() {
MLIRContext *ctx = &getContext();

// Kick in generic vectorizer.
RewritePatternSet patterns(ctx);
patterns.add<IREE::LinalgExt::LinalgVectorizationPattern>(ctx);
linalg::populatePadOpVectorizationPatterns(patterns);
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
// TODO(hanchung): Capture the failure after the vectorization pattern
// rewrite converges.
(void)(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)));
}

std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposePackUnPackOpsPass() {
return std::make_unique<DecomposePackUnPackOpsPass>();
}

std::unique_ptr<OperationPass<func::FuncOp>>
createVectorizePackUnPackOpsPass() {
return std::make_unique<VectorizePackUnPackOpsPass>();
}

} // namespace iree_compiler
} // namespace mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ static void populateVectorizationPatterns(RewritePatternSet &patterns,

IREE::LinalgExt::LinalgVectorizationOptions vectorizationOptions;
VectorizationPatterns<linalg::FillOp, linalg::GenericOp,
linalg::Conv1DNwcWcfOp,
linalg::Conv1DNcwFcwOp>::insert(patterns,
vectorizationOptions,
f);
linalg::Conv1DNwcWcfOp, linalg::Conv1DNcwFcwOp,
linalg::TransposeOp>::insert(patterns,
vectorizationOptions, f);
patterns.add<linalg::CopyVectorizationPattern>(ctx);
patterns.add<LinalgVectorizationPattern>(
ctx, vectorizationOptions,
f.addOpFilter<linalg::ContractionOpInterface>());
linalg::populatePadOpVectorizationPatterns(patterns);
}

namespace {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"dead_alloc.mlir",
"decompose_affine_ops.mlir",
"decompose_linalg_generic.mlir",
"decompose_pack_unpack_ops.mlir",
"distribute_gpu_shared_memory.mlir",
"eliminate_empty_tensors.mlir",
"erase_hal_descriptor_type.mlir",
Expand Down Expand Up @@ -58,7 +59,6 @@ iree_lit_test_suite(
"transform_ops_invalid.mlir",
"transpose_canonicalization.mlir",
"type_propagation.mlir",
"vectorize_pack_unpack_ops.mlir",
"vectorize_tensor_pad.mlir",
"warp_reduction.mlir",
"workgroup_specialization.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
"dead_alloc.mlir"
"decompose_affine_ops.mlir"
"decompose_linalg_generic.mlir"
"decompose_pack_unpack_ops.mlir"
"distribute_gpu_shared_memory.mlir"
"eliminate_empty_tensors.mlir"
"erase_hal_descriptor_type.mlir"
Expand Down Expand Up @@ -54,7 +55,6 @@ iree_lit_test_suite(
"transform_ops_invalid.mlir"
"transpose_canonicalization.mlir"
"type_propagation.mlir"
"vectorize_pack_unpack_ops.mlir"
"vectorize_tensor_pad.mlir"
"warp_reduction.mlir"
"workgroup_specialization.mlir"
Expand Down
Loading

0 comments on commit 820222c

Please sign in to comment.