Skip to content

Commit

Permalink
EVA: Implement the Matchscale pass
Browse files Browse the repository at this point in the history
This pass ensures that operands of addition and subtraction have the same
scale by artificially multiplying the one with smaller scale by 1
(the scale of the product is the product of scales).

Signed-off-by: Andrzej Turko <[email protected]>
  • Loading branch information
a-turko authored and AlexanderViand-Intel committed Dec 18, 2023
1 parent 049dbcf commit a166c88
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 0 deletions.
31 changes: 31 additions & 0 deletions include/heco/Passes/evamatchscale/MatchScale.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef HECO_PASSES_EVAMATCHSCALE_MATCHSCALE_H_
#define HECO_PASSES_EVAMATCHSCALE_MATCHSCALE_H_

#include "mlir/Pass/Pass.h"

struct MatchScalePass : public mlir::PassWrapper<MatchScalePass, mlir::OperationPass<>>
{
void getDependentDialects(mlir::DialectRegistry &registry) const override;

void runOnOperation() override;

MatchScalePass() = default;
MatchScalePass(const MatchScalePass &){}; // Necessary to make Options work

mlir::StringRef getArgument() const final
{
return "evamatchscale";
}

Option<int> source_scale{*this, "source_scale", llvm::cl::desc("Binary exponent of the scale"
"of the fixed point representation value in the source nodes"),
llvm::cl::init(30)};

Option<int> waterline{*this, "waterline", llvm::cl::desc("Binary exponent of the treshold scale for rescaling"),
llvm::cl::init(60)};

Option<int> scale_drop{*this, "scale_drop", llvm::cl::desc("Binary exponent of the scale drop after rescaling"),
llvm::cl::init(60)};
};

#endif // HECO_PASSES_EVAMATCHSCALE_MATCHSCALE_H_
1 change: 1 addition & 0 deletions src/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(evamatchscale)
add_subdirectory(evalazymodswitch)
add_subdirectory(evametadata)
add_subdirectory(fhe2bfv)
Expand Down
22 changes: 22 additions & 0 deletions src/Passes/evamatchscale/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
set(LLVM_TARGET_DEFINITIONS MatchScale.td)
mlir_tablegen(MatchScale.cpp.inc -gen-rewriters)
add_public_tablegen_target(evamatchscale)

set(CMAKE_INCLUDE_CURRENT_DIR ON)

add_heco_conversion_library(HECOMatchScale
MatchScale.cpp

#ADDITIONAL_HEADER_DIRS
#Passes

DEPENDS
evamatchscale

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
HECOEVADialect
)

96 changes: 96 additions & 0 deletions src/Passes/evamatchscale/MatchScale.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#include "heco/Passes/evamatchscale/MatchScale.h"
#include <iostream>
#include "heco/IR/EVA/EVADialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace heco;
using namespace eva;

void MatchScalePass::getDependentDialects(mlir::DialectRegistry &registry) const
{
registry.insert<eva::EVADialect>();
}

template <typename OpTy>
class MatchScalePattern final : public RewritePattern
{
public:
MatchScalePattern(PatternBenefit benefit, mlir::MLIRContext *context, eva::ScaleAnalysis *_analysis) :
RewritePattern(OpTy::getOperationName(), benefit, context) {
analysis = _analysis;
}

LogicalResult matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const override
{
std::string op_name = op->getName().getStringRef().str();

int left_scale = analysis->getValueScaleWithoutImpliedRescales(op->getOperand(0));
int right_scale = analysis->getValueScaleWithoutImpliedRescales(op->getOperand(1));

if (left_scale == -1 or right_scale == -1) {
return failure();
}

if (left_scale == right_scale) {
return failure();
}

rewriter.setInsertionPoint(op);

int diff = std::abs(left_scale - right_scale);
// we assume the lengths to be identical
int vec_len = op->getOperand(0).getType().cast<CipherType>().getSize();

if (left_scale < right_scale) {

Operation *constant = rewriter.create<eva::ConstOp>(op->getLoc(), op->getOperand(0).getType(),
rewriter.getI32ArrayAttr(std::vector <int> (vec_len, 1)), diff, -1);

Operation *mult = rewriter.create<eva::MultiplyOp>(op->getLoc(), op->getOperand(0).getType(), ValueRange({op->getOperand(0), constant->getResult(0)}));
rewriter.replaceOpWithNewOp<OpTy>(op, TypeRange(op->getResultTypes()), ValueRange({mult->getResult(0), op->getOperand(1)}));

} else {

Operation *constant = rewriter.create<eva::ConstOp>(op->getLoc(), op->getOperand(0).getType(),
rewriter.getI32ArrayAttr(std::vector <int> (vec_len, 1)), diff, -1);

Operation *mult = rewriter.create<eva::MultiplyOp>(op->getLoc(), op->getOperand(1).getType(), ValueRange({op->getOperand(1), constant->getResult(0)}));
rewriter.replaceOpWithNewOp<OpTy>(op, TypeRange(op->getResultTypes()), ValueRange({op->getOperand(0), mult->getResult(0)}));
}

return success();
};

private:
eva::ScaleAnalysis *analysis;
};

void MatchScalePass::runOnOperation()
{
mlir::RewritePatternSet patterns(&getContext());

// Configure the ScaleAnalysis
eva::ScaleAnalysis::argument_scale = source_scale;
eva::ScaleAnalysis::waterline = waterline;
eva::ScaleAnalysis::scale_drop = scale_drop;

ScaleAnalysis analysis = getAnalysis<ScaleAnalysis>();

patterns.add<MatchScalePattern<eva::AddOp>>(PatternBenefit(10), patterns.getContext(), &analysis);
patterns.add<MatchScalePattern<eva::SubOp>>(PatternBenefit(10), patterns.getContext(), &analysis);

GreedyRewriteConfig config;
// force topological order of processing operations
config.useTopDownTraversal = true;
config.maxIterations = 1;
config.strictMode = GreedyRewriteStrictness::ExistingOps;

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), config)))
signalPassFailure();

getOperation()->print(llvm::outs());
}
6 changes: 6 additions & 0 deletions src/Passes/evamatchscale/MatchScale.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#ifndef HECO_evamatchscale_MATCHSCALE_TD
#define HECO_evamatchscale_MATCHSCALE_TD

include "mlir/IR/PatternBase.td"

#endif // HECO_evamatchscale_MATCHSCALE_TD
1 change: 1 addition & 0 deletions src/tools/heco/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ set(LIBS
HECOLowerFHEToBFV
HECOLowerFHEToEmitC
HECOLowerFHEToEVA
HECOMatchScale
HECOMarkMetadata
HECOLazyModswitch
HECOCatchAll
Expand Down
2 changes: 2 additions & 0 deletions src/tools/heco/heco.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "heco/IR/EVA/EVADialect.h"
#include "heco/IR/FHE/FHEDialect.h"
#include "heco/IR/Poly/PolyDialect.h"
#include "heco/Passes/evamatchscale/MatchScale.h"
#include "heco/Passes/evalazymodswitch/LazyModswitch.h"
#include "heco/Passes/evametadata/MarkMetadata.h"
#include "heco/Passes/bfv2emitc/LowerBFVToEmitC.h"
Expand Down Expand Up @@ -163,6 +164,7 @@ int main(int argc, char **argv)
PassRegistration<LowerBFVToLLVMPass>();
PassRegistration<LowerFHEToEmitCPass>();
PassRegistration<LowerFHEToEVAPass>();
PassRegistration<MatchScalePass>();
PassRegistration<LazyModswitchPass>();
PassRegistration<MarkMetadataPass>();

Expand Down

0 comments on commit a166c88

Please sign in to comment.