Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composite Scaling CKKS Bootstrapping (#910 phase 3) #931

Open
wants to merge 16 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Recalculate sizeP and update EstimateLogP interface.
  • Loading branch information
fdiasmor committed Dec 19, 2024
commit 6369eaad34009c235da2c20fb41b27917ebc2ab2
3 changes: 3 additions & 0 deletions src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,15 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
* @param extraModulusSize bit size for extra modulus in FLEXIBLEAUTOEXT (CKKS and BGV only)
* @param numPrimes number of moduli witout extraModulus
* @param auxBits size of auxiliar moduli used for hybrid key switching
* @param scalTech scaling technique
* @param compositeDegree number of moduli in each level (CKKS only)
* @param addOne should an extra bit be added (for CKKS and BGV)
*
* @return log2 of the modulus and number of RNS limbs.
*/
static std::pair<double, uint32_t> EstimateLogP(uint32_t numPartQ, double firstModulusSize, double dcrtBits,
double extraModulusSize, uint32_t numPrimes, uint32_t auxBits,
ScalingTechnique scalTech, uint32_t compositeDegree = 1,
bool addOne = false);

/*
Expand Down
7 changes: 6 additions & 1 deletion src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ BFV implementation. See https://eprint.iacr.org/2021/204 for details.
#include "scheme/bfvrns/bfvrns-parametergeneration.h"
#include "scheme/scheme-utils.h"

#include <vector>
#include <memory>
#include <string>

namespace lbcrypto {

bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParams,
Expand Down Expand Up @@ -125,7 +129,8 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
uint32_t k = static_cast<uint32_t>(std::ceil(std::ceil(logq) / dcrtBits));
// set the number of digits
uint32_t numPartQ = ComputeNumLargeDigits(numDigits, k - 1);
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, dcrtBits, dcrtBits, 0, k, auxBits);
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, dcrtBits, dcrtBits, 0, k, auxBits, scalTech);
logq += std::get<0>(hybridKSInfo);
}
return static_cast<double>(
Expand Down
19 changes: 12 additions & 7 deletions src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ BGV implementation. See https://eprint.iacr.org/2021/204 for details.
#include "scheme/bgvrns/bgvrns-cryptoparameters.h"
#include "scheme/bgvrns/bgvrns-parametergeneration.h"

#include <vector>
#include <memory>
#include <string>
#include <utility>

namespace lbcrypto {

uint32_t ParameterGenerationBGVRNS::computeRingDimension(
Expand Down Expand Up @@ -151,7 +156,7 @@ uint64_t ParameterGenerationBGVRNS::getCyclicOrder(const uint32_t ringDimension,
if (pow2ptm < cyclOrder)
pow2ptm = cyclOrder;

lcmCyclOrderPtm = (uint64_t)pow2ptm * plaintextModulus;
lcmCyclOrderPtm = static_cast<uint64_t>(pow2ptm) * plaintextModulus;
}
else {
lcmCyclOrderPtm = cyclOrder;
Expand Down Expand Up @@ -451,8 +456,8 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters

uint32_t auxTowers = 0;
if (ksTech == HYBRID) {
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, dcrtBits, extraModSize, numPrimes, auxBits, true);
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, dcrtBits, extraModSize, numPrimes,
auxBits, scalTech, 1, true);
qBound += std::get<0>(hybridKSInfo);
auxTowers = std::get<1>(hybridKSInfo);
}
Expand Down Expand Up @@ -484,7 +489,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
numPartQ, std::log2(moduliQ[0].ConvertToDouble()),
(moduliQ.size() > 1) ? std::log2(moduliQ[1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? std::log2(moduliQ[moduliQ.size() - 1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits, false);
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits, scalTech, 1, false);
newQBound += std::get<0>(hybridKSInfo);
}
} while (qBound < newQBound);
Expand All @@ -511,7 +516,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
if (pow2ptm < cyclOrder)
pow2ptm = cyclOrder;

modulusOrder = (uint64_t)pow2ptm * plaintextModulus;
modulusOrder = static_cast<uint64_t>(pow2ptm) * plaintextModulus;

// Get the largest prime with size less or equal to firstModSize bits.
moduliQ[0] = LastPrime<NativeInteger>(firstModSize, modulusOrder);
Expand Down Expand Up @@ -592,10 +597,10 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
else {
// set batchsize to the actual batchsize i.e. n/d where d is the
// order of ptm mod CyclOrder
a = (uint64_t)ptm % cyclOrder;
a = static_cast<uint64_t>(ptm) % cyclOrder;
b = 1;
while (a != 1) {
a = ((uint64_t)(a * ptm)) % cyclOrder;
a = static_cast<uint64_t>(a * ptm) % cyclOrder;
b++;
}

Expand Down
9 changes: 3 additions & 6 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
cryptoParamsCKKSRNS->ConfigureCompositeDegree(firstModSize);
uint32_t compositeDegree = cryptoParamsCKKSRNS->GetCompositeDegree();
uint32_t registerWordSize = cryptoParamsCKKSRNS->GetRegisterWordSize();
compositeDegree *= static_cast<uint32_t>(1); // @fdiasmor: Avoid unused variable compilation error.
registerWordSize *= static_cast<uint32_t>(1); // @fdiasmor: Avoid unused variable compilation error.
// Bookeeping unique prime moduli
// std::unordered_set<uint64_t> moduliQRecord;

if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) {
if (compositeDegree > 2 && (firstModSize <= 68 || scalingModSize <= 67)) {
Expand Down Expand Up @@ -120,7 +116,7 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
// Estimate ciphertext modulus Q*P bound (in case of HYBRID P*Q)
if (ksTech == HYBRID) {
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, scalingModSize, extraModSize,
numPrimes, auxBits, true);
numPrimes, auxBits, scalTech, compositeDegree, true);
if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) {
uint32_t tmpFactor = (compositeDegree == 2) ? 2 : 4;
qBound += ceil(ceil(static_cast<double>(qBound) / numPartQ) / (tmpFactor * auxBits)) * tmpFactor * auxBits;
Expand Down Expand Up @@ -171,7 +167,8 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete

if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) {
numPrimes *= compositeDegree;
std::cout << __FUNCTION__ << "::" << __LINE__ << " numPrimes: " << numPrimes << std::endl;
std::cout << __FUNCTION__ << "::" << __LINE__ << " numPrimes: " << numPrimes << " qBound: " << qBound
<< std::endl;
}

uint32_t vecSize = (extraModSize == 0) ? numPrimes : numPrimes + 1;
Expand Down
41 changes: 38 additions & 3 deletions src/pke/lib/schemerns/rns-cryptoparameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
#include "cryptocontext.h"
#include "schemerns/rns-cryptoparameters.h"

#include <vector>
#include <memory>
#include <utility>

namespace lbcrypto {

void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, ScalingTechnique scalTech,
Expand Down Expand Up @@ -128,7 +132,21 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling
maxBits = bits;
}
// Select number of primes in auxiliary CRT basis
uint32_t sizeP = static_cast<uint32_t>(std::ceil(static_cast<double>(maxBits) / auxBits));
uint32_t sizeP = static_cast<uint32_t>(std::ceil(static_cast<double>(maxBits) / auxBits));
if (GetScalingTechnique() == COMPOSITESCALINGAUTO || GetScalingTechnique() == COMPOSITESCALINGMANUAL) {
usint compositeDegree = GetCompositeDegree();
switch (compositeDegree) {
case 0: // not allowed
case 1: // not composite
break;
case 2: // composite degree == 2
sizeP += (sizeP % 2);
break;
default: // composite degree > 2
sizeP += (sizeP % 4);
break;
}
}
uint64_t primeStep = FindAuxPrimeStep();

// Choose special primes in auxiliary basis and compute their roots
Expand Down Expand Up @@ -387,7 +405,9 @@ uint64_t CryptoParametersRNS::FindAuxPrimeStep() const {

std::pair<double, uint32_t> CryptoParametersRNS::EstimateLogP(uint32_t numPartQ, double firstModulusSize,
double dcrtBits, double extraModulusSize,
uint32_t numPrimes, uint32_t auxBits, bool addOne) {
uint32_t numPrimes, uint32_t auxBits,
ScalingTechnique scalTech, uint32_t compositeDegree,
bool addOne) {
// numPartQ can not be zero as there is a division by numPartQ
if (numPartQ == 0)
OPENFHE_THROW("numPartQ is zero");
Expand Down Expand Up @@ -421,7 +441,8 @@ std::pair<double, uint32_t> CryptoParametersRNS::EstimateLogP(uint32_t numPartQ,
size_t endTower = ((j + 1) * numPerPartQ - 1 < sizeQ) ? (j + 1) * numPerPartQ - 1 : sizeQ - 1;

// sum qi elements qi[startTower] + ... + qi[endTower] inclusive. the end element should be qi.begin()+(endTower+1)
uint32_t bits = static_cast<uint32_t>(std::accumulate(qi.begin() + startTower, qi.begin() + (endTower + 1), 0.0));
uint32_t bits =
static_cast<uint32_t>(std::accumulate(qi.begin() + startTower, qi.begin() + (endTower + 1), 0.0));
if (bits > maxBits)
maxBits = bits;
}
Expand All @@ -434,6 +455,20 @@ std::pair<double, uint32_t> CryptoParametersRNS::EstimateLogP(uint32_t numPartQ,
// Select number of primes in auxiliary CRT basis
auto sizeP = static_cast<uint32_t>(std::ceil(static_cast<double>(maxBits) / auxBits));

if (scalTech == COMPOSITESCALINGAUTO || scalTech == COMPOSITESCALINGMANUAL) {
switch (compositeDegree) {
case 0: // not allowed
case 1: // not composite
break;
case 2: // composite degree == 2
sizeP += (sizeP % 2);
break;
default: // composite degree > 2
sizeP += (sizeP % 4);
break;
}
}

return std::make_pair(sizeP * auxBits, sizeP);
}

Expand Down