-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This pass ensures that operands of addition and subtraction have the same scale by artificially multiplying the one with smaller scale by 1 (the scale of the product is the product of scales). Signed-off-by: Andrzej Turko <[email protected]>
- Loading branch information
1 parent
049dbcf
commit a166c88
Showing
7 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#ifndef HECO_PASSES_EVAMATCHSCALE_MATCHSCALE_H_ | ||
#define HECO_PASSES_EVAMATCHSCALE_MATCHSCALE_H_ | ||
|
||
#include "mlir/Pass/Pass.h" | ||
|
||
struct MatchScalePass : public mlir::PassWrapper<MatchScalePass, mlir::OperationPass<>> | ||
{ | ||
void getDependentDialects(mlir::DialectRegistry ®istry) const override; | ||
|
||
void runOnOperation() override; | ||
|
||
MatchScalePass() = default; | ||
MatchScalePass(const MatchScalePass &){}; // Necessary to make Options work | ||
|
||
mlir::StringRef getArgument() const final | ||
{ | ||
return "evamatchscale"; | ||
} | ||
|
||
Option<int> source_scale{*this, "source_scale", llvm::cl::desc("Binary exponent of the scale" | ||
"of the fixed point representation value in the source nodes"), | ||
llvm::cl::init(30)}; | ||
|
||
Option<int> waterline{*this, "waterline", llvm::cl::desc("Binary exponent of the treshold scale for rescaling"), | ||
llvm::cl::init(60)}; | ||
|
||
Option<int> scale_drop{*this, "scale_drop", llvm::cl::desc("Binary exponent of the scale drop after rescaling"), | ||
llvm::cl::init(60)}; | ||
}; | ||
|
||
#endif // HECO_PASSES_EVAMATCHSCALE_MATCHSCALE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
set(LLVM_TARGET_DEFINITIONS MatchScale.td) | ||
mlir_tablegen(MatchScale.cpp.inc -gen-rewriters) | ||
add_public_tablegen_target(evamatchscale) | ||
|
||
set(CMAKE_INCLUDE_CURRENT_DIR ON) | ||
|
||
add_heco_conversion_library(HECOMatchScale | ||
MatchScale.cpp | ||
|
||
#ADDITIONAL_HEADER_DIRS | ||
#Passes | ||
|
||
DEPENDS | ||
evamatchscale | ||
|
||
LINK_COMPONENTS | ||
Core | ||
|
||
LINK_LIBS PUBLIC | ||
HECOEVADialect | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#include "heco/Passes/evamatchscale/MatchScale.h" | ||
#include <iostream> | ||
#include "heco/IR/EVA/EVADialect.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
using namespace mlir; | ||
using namespace heco; | ||
using namespace eva; | ||
|
||
void MatchScalePass::getDependentDialects(mlir::DialectRegistry ®istry) const | ||
{ | ||
registry.insert<eva::EVADialect>(); | ||
} | ||
|
||
template <typename OpTy> | ||
class MatchScalePattern final : public RewritePattern | ||
{ | ||
public: | ||
MatchScalePattern(PatternBenefit benefit, mlir::MLIRContext *context, eva::ScaleAnalysis *_analysis) : | ||
RewritePattern(OpTy::getOperationName(), benefit, context) { | ||
analysis = _analysis; | ||
} | ||
|
||
LogicalResult matchAndRewrite( | ||
Operation *op, PatternRewriter &rewriter) const override | ||
{ | ||
std::string op_name = op->getName().getStringRef().str(); | ||
|
||
int left_scale = analysis->getValueScaleWithoutImpliedRescales(op->getOperand(0)); | ||
int right_scale = analysis->getValueScaleWithoutImpliedRescales(op->getOperand(1)); | ||
|
||
if (left_scale == -1 or right_scale == -1) { | ||
return failure(); | ||
} | ||
|
||
if (left_scale == right_scale) { | ||
return failure(); | ||
} | ||
|
||
rewriter.setInsertionPoint(op); | ||
|
||
int diff = std::abs(left_scale - right_scale); | ||
// we assume the lengths to be identical | ||
int vec_len = op->getOperand(0).getType().cast<CipherType>().getSize(); | ||
|
||
if (left_scale < right_scale) { | ||
|
||
Operation *constant = rewriter.create<eva::ConstOp>(op->getLoc(), op->getOperand(0).getType(), | ||
rewriter.getI32ArrayAttr(std::vector <int> (vec_len, 1)), diff, -1); | ||
|
||
Operation *mult = rewriter.create<eva::MultiplyOp>(op->getLoc(), op->getOperand(0).getType(), ValueRange({op->getOperand(0), constant->getResult(0)})); | ||
rewriter.replaceOpWithNewOp<OpTy>(op, TypeRange(op->getResultTypes()), ValueRange({mult->getResult(0), op->getOperand(1)})); | ||
|
||
} else { | ||
|
||
Operation *constant = rewriter.create<eva::ConstOp>(op->getLoc(), op->getOperand(0).getType(), | ||
rewriter.getI32ArrayAttr(std::vector <int> (vec_len, 1)), diff, -1); | ||
|
||
Operation *mult = rewriter.create<eva::MultiplyOp>(op->getLoc(), op->getOperand(1).getType(), ValueRange({op->getOperand(1), constant->getResult(0)})); | ||
rewriter.replaceOpWithNewOp<OpTy>(op, TypeRange(op->getResultTypes()), ValueRange({op->getOperand(0), mult->getResult(0)})); | ||
} | ||
|
||
return success(); | ||
}; | ||
|
||
private: | ||
eva::ScaleAnalysis *analysis; | ||
}; | ||
|
||
void MatchScalePass::runOnOperation() | ||
{ | ||
mlir::RewritePatternSet patterns(&getContext()); | ||
|
||
// Configure the ScaleAnalysis | ||
eva::ScaleAnalysis::argument_scale = source_scale; | ||
eva::ScaleAnalysis::waterline = waterline; | ||
eva::ScaleAnalysis::scale_drop = scale_drop; | ||
|
||
ScaleAnalysis analysis = getAnalysis<ScaleAnalysis>(); | ||
|
||
patterns.add<MatchScalePattern<eva::AddOp>>(PatternBenefit(10), patterns.getContext(), &analysis); | ||
patterns.add<MatchScalePattern<eva::SubOp>>(PatternBenefit(10), patterns.getContext(), &analysis); | ||
|
||
GreedyRewriteConfig config; | ||
// force topological order of processing operations | ||
config.useTopDownTraversal = true; | ||
config.maxIterations = 1; | ||
config.strictMode = GreedyRewriteStrictness::ExistingOps; | ||
|
||
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config))) | ||
signalPassFailure(); | ||
|
||
getOperation()->print(llvm::outs()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#ifndef HECO_evamatchscale_MATCHSCALE_TD | ||
#define HECO_evamatchscale_MATCHSCALE_TD | ||
|
||
include "mlir/IR/PatternBase.td" | ||
|
||
#endif // HECO_evamatchscale_MATCHSCALE_TD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters