Skip to content

Commit

Permalink
Benchmarking Montgomery Multiplication (intel#94)
Browse files Browse the repository at this point in the history
* Adding Benchs for different EltWise Mont Mult

* Fixes for Debug

* Removing comment

* Cleanup

* Replacing inv_mod for neg_inv_mod

* Using MultiplyMod
  • Loading branch information
joserochh authored Nov 22, 2021
1 parent a6bb385 commit 2571f20
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 47 deletions.
52 changes: 50 additions & 2 deletions benchmark/bench-eltwise-mult-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "eltwise/eltwise-mult-mod-avx512.hpp"
#include "eltwise/eltwise-mult-mod-internal.hpp"
#include "eltwise/eltwise-reduce-mod-avx512.hpp"
#include "hexl/eltwise/eltwise-mult-mod.hpp"
#include "hexl/logging/logging.hpp"
#include "hexl/number-theory/number-theory.hpp"
Expand Down Expand Up @@ -145,8 +146,10 @@ static void BM_EltwiseMultModAVX512IFMAInt(
size_t input_mod_factor = state.range(1);
size_t modulus = 100;

auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, modulus);
auto input2 = GenerateInsecureUniformRandomValues(input_size, 0, modulus);
auto input1 = GenerateInsecureUniformRandomValues(input_size, 0,
input_mod_factor * modulus);
auto input2 = GenerateInsecureUniformRandomValues(input_size, 0,
input_mod_factor * modulus);
AlignedVector64<uint64_t> output(input_size, 3);

for (auto _ : state) {
Expand Down Expand Up @@ -174,5 +177,50 @@ BENCHMARK(BM_EltwiseMultModAVX512IFMAInt)

//=================================================================

#ifdef HEXL_HAS_AVX512IFMA

// state[0] is the degree
// state[1] is the input_mod_factor
static void BM_EltwiseMultModMontAVX512IFMAIntEConv(
benchmark::State& state) { // NOLINT

size_t input_size = state.range(0);
size_t input_mod_factor = state.range(1);
uint64_t modulus = (1ULL << 50) + 7; // 1125899906842631
auto op1 = GenerateInsecureUniformRandomValues(input_size, 0,
input_mod_factor * modulus);
auto op2 = GenerateInsecureUniformRandomValues(input_size, 0,
input_mod_factor * modulus);
AlignedVector64<uint64_t> output(input_size, 3);

int r = 51; // R = 2251799813685248
// mod(2251799813685248*2251799813685248;1125899906842631)
uint64_t R_reduced = ReduceMod<2>(1ULL << r, modulus);
const uint64_t R_square_mod_q = MultiplyMod(R_reduced, R_reduced, modulus);
uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus);

for (auto _ : state) {
if (input_mod_factor != 1) {
EltwiseReduceModAVX512(op1.data(), op1.data(), input_size, modulus,
input_mod_factor, 1);
EltwiseReduceModAVX512(op2.data(), op2.data(), input_size, modulus,
input_mod_factor, 1);
}
EltwiseMontgomeryFormInAVX512<52, 51>(output.data(), op1.data(),
R_square_mod_q, input_size, modulus,
neg_inv_mod);
EltwiseMontReduceModAVX512<52, 51>(output.data(), output.data(), op2.data(),
input_size, modulus, neg_inv_mod);
}
}

BENCHMARK(BM_EltwiseMultModMontAVX512IFMAIntEConv)
->Unit(benchmark::kMicrosecond)
->ArgsProduct({{1024, 4096, 16384}, {1, 2, 4}});

#endif

//=================================================================

} // namespace hexl
} // namespace intel
25 changes: 12 additions & 13 deletions benchmark/bench-eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ BENCHMARK(BM_EltwiseReduceModMontAVX512BitShift52LT)
->Args({4096})
->Args({16384});

static void BM_EltwiseReduceModMontFormAVX512BitShift52LT(
static void BM_EltwiseReduceModMontFormInAVX512BitShift52LT(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
uint64_t modulus = 67280421310725ULL;
Expand All @@ -266,18 +266,18 @@ static void BM_EltwiseReduceModMontFormAVX512BitShift52LT(
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseMontgomeryFormAVX512<52, 46>(output.data(), input_a.data(), R2_mod_q,
input_size, modulus, inv_mod);
EltwiseMontgomeryFormInAVX512<52, 46>(
output.data(), input_a.data(), R2_mod_q, input_size, modulus, inv_mod);
}
}

BENCHMARK(BM_EltwiseReduceModMontFormAVX512BitShift52LT)
BENCHMARK(BM_EltwiseReduceModMontFormInAVX512BitShift52LT)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
->Args({16384});

static void BM_EltwiseReduceModMontFormAVX512BitShift64LT(
static void BM_EltwiseReduceModMontFormInAVX512BitShift64LT(
benchmark::State& state) { // NOLINT
size_t input_size = state.range(0);
uint64_t modulus = 67280421310725ULL;
Expand All @@ -292,12 +292,12 @@ static void BM_EltwiseReduceModMontFormAVX512BitShift64LT(
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseMontgomeryFormAVX512<64, 46>(output.data(), input_a.data(), R2_mod_q,
input_size, modulus, inv_mod);
EltwiseMontgomeryFormInAVX512<64, 46>(
output.data(), input_a.data(), R2_mod_q, input_size, modulus, inv_mod);
}
}

BENCHMARK(BM_EltwiseReduceModMontFormAVX512BitShift64LT)
BENCHMARK(BM_EltwiseReduceModMontFormInAVX512BitShift64LT)
->Unit(benchmark::kMicrosecond)
->Args({1024})
->Args({4096})
Expand All @@ -309,7 +309,6 @@ static void BM_EltwiseReduceModInOutMontFormAVX512BitShift52LT(
uint64_t modulus = 67280421310725ULL;

auto input_a = GenerateInsecureUniformRandomValues(input_size, 0, modulus);
AlignedVector64<uint64_t> input_b(input_size, 42006526039321);

int r = 46; // R^2 mod N = 42006526039321
const uint64_t R2_mod_q = 42006526039321;
Expand All @@ -318,10 +317,10 @@ static void BM_EltwiseReduceModInOutMontFormAVX512BitShift52LT(
AlignedVector64<uint64_t> output(input_size, 0);

for (auto _ : state) {
EltwiseMontgomeryFormAVX512<52, 46>(output.data(), input_a.data(), R2_mod_q,
input_size, modulus, inv_mod);
EltwiseMontgomeryFormAVX512<52, 46>(output.data(), output.data(), 1ULL,
input_size, modulus, inv_mod);
EltwiseMontgomeryFormInAVX512<52, 46>(
output.data(), input_a.data(), R2_mod_q, input_size, modulus, inv_mod);
EltwiseMontgomeryFormOutAVX512<52, 46>(output.data(), output.data(),
input_size, modulus, inv_mod);
}
}

Expand Down
97 changes: 85 additions & 12 deletions hexl/eltwise/eltwise-reduce-mod-avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,15 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand,
/// @param[in] a input vector. T = ab in the range [0, Rq − 1].
/// @param[in] b input vector.
/// @param[in] modulus such that gcd(R, modulus) = 1.
/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R,
/// @param[in] n number of elements in input vector.
/// @param[out] result unsigned long int vector in the range [0, q − 1] such
/// that S ≡ TR^−1 mod q
template <int BitShift, int r>
void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
const uint64_t* b, uint64_t n, uint64_t modulus,
uint64_t inv_mod) {
uint64_t neg_inv_mod) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
HEXL_CHECK(b != nullptr, "Require operand b != nullptr");
HEXL_CHECK(n != 0, "Require n != 0");
Expand All @@ -169,6 +170,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
uint64_t mod_R_mask = R - 1;
uint64_t prod_rs;
if (BitShift == 64) {
HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow");
prod_rs = (1ULL << 63) - 1;
} else {
prod_rs = (1ULL << (52 - r));
Expand All @@ -183,7 +185,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
uint64_t T_lo;
MultiplyUInt64(a[i], b[i], &T_hi, &T_lo);
result[i] = MontgomeryReduce<BitShift>(T_hi, T_lo, modulus, r, mod_R_mask,
inv_mod);
neg_inv_mod);
}
a += n_mod_8;
b += n_mod_8;
Expand All @@ -195,7 +197,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
const __m512i* v_b = reinterpret_cast<const __m512i*>(b);
__m512i* v_result = reinterpret_cast<__m512i*>(result);
__m512i v_modulus = _mm512_set1_epi64(modulus);
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
__m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod);
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);

for (size_t i = 0; i < n_tmp; i += 8) {
Expand All @@ -204,6 +206,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
__m512i v_T_hi = _mm512_hexl_mulhi_epi<BitShift>(v_a_op, v_b_op);
__m512i v_T_lo = _mm512_hexl_mullo_epi<BitShift>(v_a_op, v_b_op);

// Convert to 63 bits to save intermediate carry
if (BitShift == 64) {
v_T_hi = _mm512_slli_epi64(v_T_hi, 1);
__m512i tmp = _mm512_srli_epi64(v_T_lo, 63);
Expand All @@ -212,7 +215,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
}

__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs);
HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
"v_op exceeds bound " << modulus);
_mm512_storeu_si512(v_result, v_c);
Expand All @@ -230,14 +233,15 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
/// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1].
/// @param[in] R2_mod_q R^2 mod q.
/// @param[in] modulus such that gcd(R, modulus) = 1.
/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R,
/// @param[in] n number of elements in input vector.
/// @param[out] result unsigned long int vector in the range [0, q − 1] such
/// that S ≡ TR^−1 mod q
template <int BitShift, int r>
void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
uint64_t R2_mod_q, uint64_t n,
uint64_t modulus, uint64_t inv_mod) {
void EltwiseMontgomeryFormInAVX512(uint64_t* result, const uint64_t* a,
uint64_t R2_mod_q, uint64_t n,
uint64_t modulus, uint64_t neg_inv_mod) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
HEXL_CHECK(n != 0, "Require n != 0");
HEXL_CHECK(modulus > 1, "Require modulus > 1");
Expand All @@ -251,6 +255,7 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
uint64_t mod_R_mask = R - 1;
uint64_t prod_rs;
if (BitShift == 64) {
HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow");
prod_rs = (1ULL << 63) - 1;
} else {
prod_rs = (1ULL << (52 - r));
Expand All @@ -265,7 +270,7 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
uint64_t T_lo;
MultiplyUInt64(a[i], R2_mod_q, &T_hi, &T_lo);
result[i] = MontgomeryReduce<BitShift>(T_hi, T_lo, modulus, r, mod_R_mask,
inv_mod);
neg_inv_mod);
}
a += n_mod_8;
result += n_mod_8;
Expand All @@ -276,14 +281,15 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
__m512i* v_result = reinterpret_cast<__m512i*>(result);
__m512i v_b = _mm512_set1_epi64(R2_mod_q);
__m512i v_modulus = _mm512_set1_epi64(modulus);
__m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
__m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod);
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);

for (size_t i = 0; i < n_tmp; i += 8) {
__m512i v_a_op = _mm512_loadu_si512(v_a);
__m512i v_T_hi = _mm512_hexl_mulhi_epi<BitShift>(v_a_op, v_b);
__m512i v_T_lo = _mm512_hexl_mullo_epi<BitShift>(v_a_op, v_b);

// Convert to 63 bits to save intermediate carry
if (BitShift == 64) {
v_T_hi = _mm512_slli_epi64(v_T_hi, 1);
__m512i tmp = _mm512_srli_epi64(v_T_lo, 63);
Expand All @@ -292,7 +298,74 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
}

__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs);
HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
"v_op exceeds bound " << modulus);
_mm512_storeu_si512(v_result, v_c);
++v_a;
++v_result;
}
}

/// @brief Convert out of the Montgomery Form computed via the REDC algorithm,
/// also known as Montgomery reduction.
/// @tparam BitShift denotes the operational length, in bits, of the operands
/// and result values.
/// @tparam r defines the value of R, being R = 2^r. R > modulus.
/// @param[in] a input vector in Montgomery Form.
/// @param[in] modulus such that gcd(R, modulus) = 1.
/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R,
/// @param[in] n number of elements in input vector.
/// @param[out] result unsigned long int vector in the range [0, q − 1] such
/// that S ≡ TR^−1 mod q
template <int BitShift, int r>
void EltwiseMontgomeryFormOutAVX512(uint64_t* result, const uint64_t* a,
uint64_t n, uint64_t modulus,
uint64_t neg_inv_mod) {
HEXL_CHECK(result != nullptr, "Require result != nullptr");
HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
HEXL_CHECK(n != 0, "Require n != 0");
HEXL_CHECK(modulus > 1, "Require modulus > 1");

uint64_t R = (1ULL << r);
HEXL_CHECK(std::__gcd(static_cast<int64_t>(modulus), static_cast<int64_t>(R)),
1);
HEXL_CHECK(R > modulus, "Needs R bigger than q.");

// mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones
uint64_t mod_R_mask = R - 1;
uint64_t prod_rs;
if (BitShift == 64) {
HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow");
prod_rs = (1ULL << 63) - 1;
} else {
prod_rs = (1ULL << (52 - r));
}
uint64_t n_tmp = n;

// Deals with n not divisible by 8
uint64_t n_mod_8 = n_tmp % 8;
if (n_mod_8 != 0) {
for (size_t i = 0; i < n_mod_8; ++i) {
result[i] = MontgomeryReduce<BitShift>(0, a[i], modulus, r, mod_R_mask,
neg_inv_mod);
}
a += n_mod_8;
result += n_mod_8;
n_tmp -= n_mod_8;
}

const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
__m512i* v_result = reinterpret_cast<__m512i*>(result);
__m512i v_modulus = _mm512_set1_epi64(modulus);
__m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod);
__m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
__m512i v_T_hi = _mm512_set1_epi64(0);

for (size_t i = 0; i < n_tmp; i += 8) {
__m512i v_T_lo = _mm512_loadu_si512(v_a);
__m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs);
HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
"v_op exceeds bound " << modulus);
_mm512_storeu_si512(v_result, v_c);
Expand Down
3 changes: 1 addition & 2 deletions hexl/util/avx512-util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,7 @@ inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo,
if (BitShift == 52) {
// Operation:
// m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
__m512i m = ClearTopBits64<r>(T_lo);
m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
__m512i m = _mm512_hexl_mullo_epi<BitShift>(T_lo, v_inv_mod);
m = ClearTopBits64<r>(m);

// Operation: t ← (T + mN) / R = (T + m*q) >> r
Expand Down
Loading

0 comments on commit 2571f20

Please sign in to comment.