Skip to content

Commit

Permalink
[NFC] Deprecate filter based decomposition patterns. (iree-org#12812)
Browse files Browse the repository at this point in the history
The filter is not needed in decomposition becaues the patterns will only be applied once.
  • Loading branch information
hanhanW authored Mar 28, 2023
1 parent 8e4ebcf commit 5c0c4ea
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 66 deletions.
17 changes: 4 additions & 13 deletions compiler/src/iree/compiler/Codegen/Common/GPUVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,12 @@ struct GPUVectorizationPass
MLIRContext *context = &getContext();

// Pre-process convolution ops.
RewritePatternSet decompositionPattern(funcOp.getContext());
IREE::LinalgExt::LinalgTransformationFilter f(
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getVectorizeMarker()));
f.setMatchByDefault();
decompositionPattern
.add<IREE::LinalgExt::DownscaleSizeOneWindowed2DConvolution<
linalg::Conv2DNhwcHwcfOp, linalg::Conv1DNwcWcfOp>,
IREE::LinalgExt::DownscaleSizeOneWindowed2DConvolution<
linalg::Conv2DNchwFchwOp, linalg::Conv1DNcwFcwOp>,
IREE::LinalgExt::DownscaleDepthwiseConv2DNhwcHwcOp>(
funcOp.getContext(), f);
RewritePatternSet decompositionPattern(context);
linalg::populateDecomposeConvolutionPatterns(decompositionPattern);
if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(decompositionPattern))))
std::move(decompositionPattern)))) {
return signalPassFailure();
}

RewritePatternSet vectorizationPatterns(context);
populateVectorizationPatterns(vectorizationPatterns, maxVectorSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,59 +337,6 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
: LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
};

/// Wraps upstream Linalg pattern in a filter check + update.
template <typename Conv2DOp, typename Conv1DOp>
struct DownscaleSizeOneWindowed2DConvolution final
: public OpRewritePattern<Conv2DOp> {
DownscaleSizeOneWindowed2DConvolution(MLIRContext *context,
LinalgTransformationFilter f)
: OpRewritePattern<Conv2DOp>(context, /*benefit=*/1),
filter(std::move(f)) {}

LogicalResult matchAndRewrite(Conv2DOp convOp,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, convOp)))
return failure();
linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp> p(
convOp.getContext());
auto maybeConv1DOp = p.returningMatchAndRewrite(convOp, rewriter);
if (failed(maybeConv1DOp))
return failure();
filter.replaceLinalgTransformationFilter(rewriter, *maybeConv1DOp);
return success();
}

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

/// Wraps upstream Linalg pattern in a filter check + update.
struct DownscaleDepthwiseConv2DNhwcHwcOp final
: public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
LinalgTransformationFilter f)
: OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>(context,
/*benefit=*/1),
filter(std::move(f)) {}

LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, convOp)))
return failure();
linalg::DownscaleDepthwiseConv2DNhwcHwcOp p(convOp.getContext());
auto maybeConv1DOp = p.returningMatchAndRewrite(convOp, rewriter);
if (failed(maybeConv1DOp))
return failure();
filter.replaceLinalgTransformationFilter(rewriter, *maybeConv1DOp);
return success();
}

private:
/// LinalgTransformMarker handles special attribute manipulations.
LinalgTransformationFilter filter;
};

FailureOr<linalg::TileLoopNest> tileConsumerAndFuseProducers(
OpBuilder &b, linalg::LinalgOp consumerOp, ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange,
Expand Down

0 comments on commit 5c0c4ea

Please sign in to comment.