diff --git a/benchmark/bench-eltwise-mult-mod.cpp b/benchmark/bench-eltwise-mult-mod.cpp index a88ff000..c15bee2b 100644 --- a/benchmark/bench-eltwise-mult-mod.cpp +++ b/benchmark/bench-eltwise-mult-mod.cpp @@ -7,6 +7,7 @@ #include "eltwise/eltwise-mult-mod-avx512.hpp" #include "eltwise/eltwise-mult-mod-internal.hpp" +#include "eltwise/eltwise-reduce-mod-avx512.hpp" #include "hexl/eltwise/eltwise-mult-mod.hpp" #include "hexl/logging/logging.hpp" #include "hexl/number-theory/number-theory.hpp" @@ -145,8 +146,10 @@ static void BM_EltwiseMultModAVX512IFMAInt( size_t input_mod_factor = state.range(1); size_t modulus = 100; - auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, modulus); - auto input2 = GenerateInsecureUniformRandomValues(input_size, 0, modulus); + auto input1 = GenerateInsecureUniformRandomValues(input_size, 0, + input_mod_factor * modulus); + auto input2 = GenerateInsecureUniformRandomValues(input_size, 0, + input_mod_factor * modulus); AlignedVector64 output(input_size, 3); for (auto _ : state) { @@ -174,5 +177,50 @@ BENCHMARK(BM_EltwiseMultModAVX512IFMAInt) //================================================================= +#ifdef HEXL_HAS_AVX512IFMA + +// state[0] is the degree +// state[1] is the input_mod_factor +static void BM_EltwiseMultModMontAVX512IFMAIntEConv( + benchmark::State& state) { // NOLINT + + size_t input_size = state.range(0); + size_t input_mod_factor = state.range(1); + uint64_t modulus = (1ULL << 50) + 7; // 1125899906842631 + auto op1 = GenerateInsecureUniformRandomValues(input_size, 0, + input_mod_factor * modulus); + auto op2 = GenerateInsecureUniformRandomValues(input_size, 0, + input_mod_factor * modulus); + AlignedVector64 output(input_size, 3); + + int r = 51; // R = 2251799813685248 + // mod(2251799813685248*2251799813685248;1125899906842631) + uint64_t R_reduced = ReduceMod<2>(1ULL << r, modulus); + const uint64_t R_square_mod_q = MultiplyMod(R_reduced, R_reduced, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); + + for (auto _ : state) { + if (input_mod_factor != 1) { + EltwiseReduceModAVX512(op1.data(), op1.data(), input_size, modulus, + input_mod_factor, 1); + EltwiseReduceModAVX512(op2.data(), op2.data(), input_size, modulus, + input_mod_factor, 1); + } + EltwiseMontgomeryFormInAVX512<52, 51>(output.data(), op1.data(), + R_square_mod_q, input_size, modulus, + neg_inv_mod); + EltwiseMontReduceModAVX512<52, 51>(output.data(), output.data(), op2.data(), + input_size, modulus, neg_inv_mod); + } +} + +BENCHMARK(BM_EltwiseMultModMontAVX512IFMAIntEConv) + ->Unit(benchmark::kMicrosecond) + ->ArgsProduct({{1024, 4096, 16384}, {1, 2, 4}}); + +#endif + +//================================================================= + } // namespace hexl } // namespace intel diff --git a/benchmark/bench-eltwise-reduce-mod.cpp b/benchmark/bench-eltwise-reduce-mod.cpp index c97570e6..8de6234b 100644 --- a/benchmark/bench-eltwise-reduce-mod.cpp +++ b/benchmark/bench-eltwise-reduce-mod.cpp @@ -251,7 +251,7 @@ BENCHMARK(BM_EltwiseReduceModMontAVX512BitShift52LT) ->Args({4096}) ->Args({16384}); -static void BM_EltwiseReduceModMontFormAVX512BitShift52LT( +static void BM_EltwiseReduceModMontFormInAVX512BitShift52LT( benchmark::State& state) { // NOLINT size_t input_size = state.range(0); uint64_t modulus = 67280421310725ULL; @@ -266,18 +266,18 @@ static void BM_EltwiseReduceModMontFormAVX512BitShift52LT( AlignedVector64 output(input_size, 0); for (auto _ : state) { - EltwiseMontgomeryFormAVX512<52, 46>(output.data(), input_a.data(), R2_mod_q, - input_size, modulus, inv_mod); + EltwiseMontgomeryFormInAVX512<52, 46>( + output.data(), input_a.data(), R2_mod_q, input_size, modulus, inv_mod); } } -BENCHMARK(BM_EltwiseReduceModMontFormAVX512BitShift52LT) +BENCHMARK(BM_EltwiseReduceModMontFormInAVX512BitShift52LT) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) ->Args({16384}); -static void BM_EltwiseReduceModMontFormAVX512BitShift64LT( +static void BM_EltwiseReduceModMontFormInAVX512BitShift64LT( benchmark::State& state) { // NOLINT size_t input_size = state.range(0); uint64_t modulus = 67280421310725ULL; @@ -292,12 +292,12 @@ static void BM_EltwiseReduceModMontFormAVX512BitShift64LT( AlignedVector64 output(input_size, 0); for (auto _ : state) { - EltwiseMontgomeryFormAVX512<64, 46>(output.data(), input_a.data(), R2_mod_q, - input_size, modulus, inv_mod); + EltwiseMontgomeryFormInAVX512<64, 46>( + output.data(), input_a.data(), R2_mod_q, input_size, modulus, inv_mod); } } -BENCHMARK(BM_EltwiseReduceModMontFormAVX512BitShift64LT) +BENCHMARK(BM_EltwiseReduceModMontFormInAVX512BitShift64LT) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) @@ -309,7 +309,6 @@ static void BM_EltwiseReduceModInOutMontFormAVX512BitShift52LT( uint64_t modulus = 67280421310725ULL; auto input_a = GenerateInsecureUniformRandomValues(input_size, 0, modulus); - AlignedVector64 input_b(input_size, 42006526039321); int r = 46; // R^2 mod N = 42006526039321 const uint64_t R2_mod_q = 42006526039321; @@ -318,10 +317,10 @@ static void BM_EltwiseReduceModInOutMontFormAVX512BitShift52LT( AlignedVector64 output(input_size, 0); for (auto _ : state) { - EltwiseMontgomeryFormAVX512<52, 46>(output.data(), input_a.data(), R2_mod_q, - input_size, modulus, inv_mod); - EltwiseMontgomeryFormAVX512<52, 46>(output.data(), output.data(), 1ULL, - input_size, modulus, inv_mod); + EltwiseMontgomeryFormInAVX512<52, 46>( + output.data(), input_a.data(), R2_mod_q, input_size, modulus, inv_mod); + EltwiseMontgomeryFormOutAVX512<52, 46>(output.data(), output.data(), + input_size, modulus, inv_mod); } } diff --git a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp index 66361e6e..a2f2e0d4 100644 --- a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp @@ -147,14 +147,15 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, /// @param[in] a input vector. T = ab in the range [0, Rq − 1]. /// @param[in] b input vector. /// @param[in] modulus such that gcd(R, modulus) = 1. -/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R, +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, /// @param[in] n number of elements in input vector. /// @param[out] result unsigned long int vector in the range [0, q − 1] such /// that S ≡ TR^−1 mod q template void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, const uint64_t* b, uint64_t n, uint64_t modulus, - uint64_t inv_mod) { + uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); HEXL_CHECK(b != nullptr, "Require operand b != nullptr"); HEXL_CHECK(n != 0, "Require n != 0"); @@ -169,6 +170,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, uint64_t mod_R_mask = R - 1; uint64_t prod_rs; if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); prod_rs = (1ULL << 63) - 1; } else { prod_rs = (1ULL << (52 - r)); @@ -183,7 +185,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, uint64_t T_lo; MultiplyUInt64(a[i], b[i], &T_hi, &T_lo); result[i] = MontgomeryReduce(T_hi, T_lo, modulus, r, mod_R_mask, - inv_mod); + neg_inv_mod); } a += n_mod_8; b += n_mod_8; @@ -195,7 +197,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, const __m512i* v_b = reinterpret_cast(b); __m512i* v_result = reinterpret_cast<__m512i*>(result); __m512i v_modulus = _mm512_set1_epi64(modulus); - __m512i v_inv_mod = _mm512_set1_epi64(inv_mod); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); for (size_t i = 0; i < n_tmp; i += 8) { @@ -204,6 +206,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, __m512i v_T_hi = _mm512_hexl_mulhi_epi(v_a_op, v_b_op); __m512i v_T_lo = _mm512_hexl_mullo_epi(v_a_op, v_b_op); + // Convert to 63 bits to save intermediate carry if (BitShift == 64) { v_T_hi = _mm512_slli_epi64(v_T_hi, 1); __m512i tmp = _mm512_srli_epi64(v_T_lo, 63); @@ -212,7 +215,7 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, } __m512i v_c = _mm512_hexl_montgomery_reduce( - v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs); + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, "v_op exceeds bound " << modulus); _mm512_storeu_si512(v_result, v_c); @@ -230,14 +233,15 @@ void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, /// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1]. /// @param[in] R2_mod_q R^2 mod q. /// @param[in] modulus such that gcd(R, modulus) = 1. -/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R, +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, /// @param[in] n number of elements in input vector. /// @param[out] result unsigned long int vector in the range [0, q − 1] such /// that S ≡ TR^−1 mod q template -void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a, - uint64_t R2_mod_q, uint64_t n, - uint64_t modulus, uint64_t inv_mod) { +void EltwiseMontgomeryFormInAVX512(uint64_t* result, const uint64_t* a, + uint64_t R2_mod_q, uint64_t n, + uint64_t modulus, uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); HEXL_CHECK(n != 0, "Require n != 0"); HEXL_CHECK(modulus > 1, "Require modulus > 1"); @@ -251,6 +255,7 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a, uint64_t mod_R_mask = R - 1; uint64_t prod_rs; if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); prod_rs = (1ULL << 63) - 1; } else { prod_rs = (1ULL << (52 - r)); @@ -265,7 +270,7 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a, uint64_t T_lo; MultiplyUInt64(a[i], R2_mod_q, &T_hi, &T_lo); result[i] = MontgomeryReduce(T_hi, T_lo, modulus, r, mod_R_mask, - inv_mod); + neg_inv_mod); } a += n_mod_8; result += n_mod_8; @@ -276,7 +281,7 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a, __m512i* v_result = reinterpret_cast<__m512i*>(result); __m512i v_b = _mm512_set1_epi64(R2_mod_q); __m512i v_modulus = _mm512_set1_epi64(modulus); - __m512i v_inv_mod = _mm512_set1_epi64(inv_mod); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); for (size_t i = 0; i < n_tmp; i += 8) { @@ -284,6 +289,7 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a, __m512i v_T_hi = _mm512_hexl_mulhi_epi(v_a_op, v_b); __m512i v_T_lo = _mm512_hexl_mullo_epi(v_a_op, v_b); + // Convert to 63 bits to save intermediate carry if (BitShift == 64) { v_T_hi = _mm512_slli_epi64(v_T_hi, 1); __m512i tmp = _mm512_srli_epi64(v_T_lo, 63); @@ -292,7 +298,74 @@ void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a, } __m512i v_c = _mm512_hexl_montgomery_reduce( - v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs); + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_result; + } +} + +/// @brief Convert out of the Montgomery Form computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector in Montgomery Form. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontgomeryFormOutAVX512(uint64_t* result, const uint64_t* a, + uint64_t n, uint64_t modulus, + uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::__gcd(static_cast(modulus), static_cast(R)), + 1); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + result[i] = MontgomeryReduce(0, a[i], modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + __m512i v_T_hi = _mm512_set1_epi64(0); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_T_lo = _mm512_loadu_si512(v_a); + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, "v_op exceeds bound " << modulus); _mm512_storeu_si512(v_result, v_c); diff --git a/hexl/util/avx512-util.hpp b/hexl/util/avx512-util.hpp index 3f3c54ad..cac8d83c 100644 --- a/hexl/util/avx512-util.hpp +++ b/hexl/util/avx512-util.hpp @@ -395,8 +395,7 @@ inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo, if (BitShift == 52) { // Operation: // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask - __m512i m = ClearTopBits64(T_lo); - m = _mm512_hexl_mullo_epi(m, v_inv_mod); + __m512i m = _mm512_hexl_mullo_epi(T_lo, v_inv_mod); m = ClearTopBits64(m); // Operation: t ← (T + mN) / R = (T + m*q) >> r diff --git a/test/test-avx512-util.cpp b/test/test-avx512-util.cpp index c437981c..02fd1ed4 100644 --- a/test/test-avx512-util.cpp +++ b/test/test-avx512-util.cpp @@ -380,20 +380,20 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce52) { uint64_t modulus = 5; int r = 3; uint64_t prod_rs = (1ULL << (52 - r)); - uint64_t inv_mod = HenselLemma2adicRoot(r, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones __m512i v_modulus = _mm512_set1_epi64(modulus); - __m512i v_inv_mod = _mm512_set1_epi64(inv_mod); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); __m512i _c = _mm512_hexl_montgomery_reduce<52, 3>(T_hi, T_lo, v_modulus, - v_inv_mod, v_prod_rs); + v_neg_inv_mod, v_prod_rs); AssertEqual(_c, expected_out); // Out of Montgomery form - _c = _mm512_hexl_montgomery_reduce<52, 3>(T_hi, _c, v_modulus, v_inv_mod, - v_prod_rs); + _c = _mm512_hexl_montgomery_reduce<52, 3>(T_hi, _c, v_modulus, + v_neg_inv_mod, v_prod_rs); AssertEqual(_c, expected_c_out); } @@ -418,12 +418,12 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce52) { __m512i T_hi = _mm512_set_epi64(559639348720ULL, 0, 0, 0, 0, 0, 0, 0); __m512i T_lo = _mm512_set_epi64(1832906312477596ULL, 0, 0, 0, 0, 0, 0, 0); __m512i v_modulus = _mm512_set1_epi64(67280421310725); - __m512i v_inv_mod = _mm512_set1_epi64(62463730494515); + __m512i v_neg_inv_mod = _mm512_set1_epi64(62463730494515); __m512i v_prod_rs = _mm512_set1_epi64(64); // 52 bits __m512i c = _mm512_hexl_montgomery_reduce<52, 46>(T_hi, T_lo, v_modulus, - v_inv_mod, v_prod_rs); + v_neg_inv_mod, v_prod_rs); AssertEqual(c, expected_out); } @@ -431,17 +431,17 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce52) { { int r = 51; uint64_t modulus = 2251799813684809; - uint64_t inv_mod = HenselLemma2adicRoot(r, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); uint64_t prod_rs = (1ULL << (52 - r)); __m512i expected_out = _mm512_set_epi64(1832909426971103, 0, 0, 0, 0, 0, 0, 0); __m512i T_hi = _mm512_set_epi64(5446ULL, 0, 0, 0, 0, 0, 0, 0); __m512i T_lo = _mm512_set_epi64(3006504763740625ULL, 0, 0, 0, 0, 0, 0, 0); __m512i v_modulus = _mm512_set1_epi64(modulus); - __m512i v_inv_mod = _mm512_set1_epi64(inv_mod); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); __m512i c = _mm512_hexl_montgomery_reduce<52, 51>(T_hi, T_lo, v_modulus, - v_inv_mod, v_prod_rs); + v_neg_inv_mod, v_prod_rs); AssertEqual(c, expected_out); } } @@ -459,7 +459,7 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) { __m512i T_hi = _mm512_set_epi64(559639348720ULL, 0, 0, 0, 0, 0, 0, 0); __m512i T_lo = _mm512_set_epi64(1832906312477596ULL, 0, 0, 0, 0, 0, 0, 0); __m512i v_modulus = _mm512_set1_epi64(67280421310725); - __m512i v_inv_mod = _mm512_set1_epi64(62463730494515); + __m512i v_neg_inv_mod = _mm512_set1_epi64(62463730494515); // 64 bits uint64_t prod_rs = (1ULL << 63) - 1; @@ -469,7 +469,7 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) { T_lo = _mm512_set_epi64(6847304339915631516, 0, 0, 0, 0, 0, 0, 0); __m512i c = _mm512_hexl_montgomery_reduce<64, 46>(T_hi, T_lo, v_modulus, - v_inv_mod, v_prod_rs); + v_neg_inv_mod, v_prod_rs); AssertEqual(c, expected_out); } @@ -477,7 +477,7 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) { { int r = 61; uint64_t modulus = 2305843009213693487; - uint64_t inv_mod = HenselLemma2adicRoot(r, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); uint64_t prod_rs = (1ULL << 63) - 1; __m512i expected_out = _mm512_set_epi64(59185395909485265, 0, 0, 0, 0, 0, 0, 0); @@ -485,10 +485,10 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) { __m512i T_lo = _mm512_set_epi64(9074465024201096609ULL, 0, 0, 0, 0, 0, 0, 0); __m512i v_modulus = _mm512_set1_epi64(modulus); - __m512i v_inv_mod = _mm512_set1_epi64(inv_mod); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); __m512i c = _mm512_hexl_montgomery_reduce<64, 61>(T_hi, T_lo, v_modulus, - v_inv_mod, v_prod_rs); + v_neg_inv_mod, v_prod_rs); AssertEqual(c, expected_out); } @@ -496,17 +496,17 @@ TEST(AVX512, _mm512_hexl_montgomery_reduce64) { { int r = 62; uint64_t modulus = 4611686018427387631; - uint64_t inv_mod = HenselLemma2adicRoot(r, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); uint64_t prod_rs = (1ULL << 63) - 1; __m512i expected_out = _mm512_set_epi64(34747555017826833, 0, 0, 0, 0, 0, 0, 0); __m512i T_hi = _mm512_set_epi64(1ULL, 0, 0, 0, 0, 0, 0, 0); __m512i T_lo = _mm512_set_epi64(262710483011949601ULL, 0, 0, 0, 0, 0, 0, 0); __m512i v_modulus = _mm512_set1_epi64(modulus); - __m512i v_inv_mod = _mm512_set1_epi64(inv_mod); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); __m512i c = _mm512_hexl_montgomery_reduce<64, 62>(T_hi, T_lo, v_modulus, - v_inv_mod, v_prod_rs); + v_neg_inv_mod, v_prod_rs); AssertEqual(c, expected_out); } } diff --git a/test/test-eltwise-mult-mod-avx512.cpp b/test/test-eltwise-mult-mod-avx512.cpp index df9b6455..9898c1f0 100644 --- a/test/test-eltwise-mult-mod-avx512.cpp +++ b/test/test-eltwise-mult-mod-avx512.cpp @@ -7,6 +7,7 @@ #include "eltwise/eltwise-mult-mod-avx512.hpp" #include "eltwise/eltwise-mult-mod-internal.hpp" +#include "eltwise/eltwise-reduce-mod-avx512.hpp" #include "hexl/eltwise/eltwise-mult-mod.hpp" #include "hexl/logging/logging.hpp" #include "hexl/number-theory/number-theory.hpp" @@ -205,6 +206,59 @@ TEST(EltwiseMultMod, avx512dqint_big) { } } } + +// Checks Montgomery and AVX512DQInt eltwise mult implementations match +TEST(EltwiseMultModMont_EConv, avx512dqint_big) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + size_t length = 1024; + std::vector rs1(length, 0); + std::vector rs2(length, 0); + + uint64_t modulus = (1ULL << 60) + 7; // 1152921504606846983 + auto op1 = GenerateInsecureUniformRandomValues(length, 0, modulus); + auto op2 = GenerateInsecureUniformRandomValues(length, 0, modulus); + + int r = 61; // R = 2305843009213693952 + // mod(2305843009213693952*2305843009213693952;1152921504606846983) + uint64_t R_reduced = ReduceMod<2>(1ULL << r, modulus); + const uint64_t R_square_mod_q = MultiplyMod(R_reduced, R_reduced, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); + + EltwiseMultModAVX512DQInt<1>(rs1.data(), op1.data(), op2.data(), op1.size(), + modulus); + EltwiseMontgomeryFormInAVX512<64, 61>(op1.data(), op1.data(), R_square_mod_q, + op1.size(), modulus, neg_inv_mod); + EltwiseMontReduceModAVX512<64, 61>(rs2.data(), op1.data(), op2.data(), + rs2.size(), modulus, neg_inv_mod); + ASSERT_EQ(rs2, rs1); +} + +// Checks Montgomery and AVX512DQInt eltwise mult implementations match +TEST(EltwiseMultModMont_NoConv, avx512dqint_big) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + size_t length = 1024; + std::vector rs1(length, 0); + std::vector rs2(length, 0); + + uint64_t modulus = 2305843009213693951; + auto op1 = GenerateInsecureUniformRandomValues(length, 0, modulus); + auto op2 = GenerateInsecureUniformRandomValues(length, 0, modulus); + + int r = 61; // R = 2305843009213693952 + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); + + EltwiseMultModAVX512DQInt<1>(rs1.data(), op1.data(), op2.data(), op1.size(), + modulus); + EltwiseMontReduceModAVX512<64, 61>(rs2.data(), op1.data(), op2.data(), + rs2.size(), modulus, neg_inv_mod); + ASSERT_EQ(rs2, rs1); +} #endif #ifdef HEXL_HAS_AVX512IFMA @@ -275,6 +329,36 @@ TEST(EltwiseMultMod, avx512ifma_big) { } } } + +// Checks Montgomery and AVX512ifmaInt eltwise mult implementations match +TEST(EltwiseMultModMont, avx512ifmaint_big) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + size_t length = 1024; + std::vector rs1(length, 0); + std::vector rs2(length, 0); + std::vector rs3(length, 0); + + uint64_t modulus = (1ULL << 49) + 7; // 562949953421319 + auto op1 = GenerateInsecureUniformRandomValues(length, 0, modulus); + auto op2 = GenerateInsecureUniformRandomValues(length, 0, modulus); + + int r = 50; // R = 1125899906842624 + // mod(1125899906842624*1125899906842624;562949953421319) + uint64_t R_reduced = ReduceMod<2>(1ULL << r, modulus); + const uint64_t R_square_mod_q = MultiplyMod(R_reduced, R_reduced, modulus); + uint64_t neg_inv_mod = HenselLemma2adicRoot(r, modulus); + + EltwiseMultModAVX512IFMAInt<1>(rs1.data(), op1.data(), op2.data(), op1.size(), + modulus); + EltwiseMontgomeryFormInAVX512<52, 50>(op1.data(), op1.data(), R_square_mod_q, + op1.size(), modulus, neg_inv_mod); + EltwiseMontReduceModAVX512<52, 50>(rs2.data(), op1.data(), op2.data(), + rs2.size(), modulus, neg_inv_mod); + ASSERT_EQ(rs2, rs1); +} + #endif } // namespace hexl diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index e8bfff53..6c8d069c 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -35,6 +35,35 @@ TEST(EltwiseReduceMod, avx512_64_mod_1) { CheckEqual(result, exp_out); } +TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + uint64_t modulus = 67280421310725ULL; + std::vector input_a{0, + 67280421310000, + 25040294381203, + 340231313, + 769231483400, + 90032324, + 120042353, + 1530}; + std::vector output{0, 0, 0, 0, 0, 0, 0, 0}; + + int r = 46; // R^2 mod N = 42006526039321 + uint64_t R_reduced = ReduceMod<2>(1ULL << r, modulus); + const uint64_t R_square_mod_q = MultiplyMod(R_reduced, R_reduced, modulus); + uint64_t inv_mod = HenselLemma2adicRoot(r, modulus); + + EltwiseMontgomeryFormInAVX512<64, 46>(output.data(), input_a.data(), + R_square_mod_q, input_a.size(), modulus, + inv_mod); + EltwiseMontgomeryFormOutAVX512<64, 46>(output.data(), output.data(), + input_a.size(), modulus, inv_mod); + CheckEqual(input_a, output); +} + #ifdef HEXL_HAS_AVX512IFMA TEST(EltwiseReduceMod, avx512_52_mod_1) { if (!has_avx512dq) { @@ -74,6 +103,36 @@ TEST(EltwiseReduceMod, avx512Big_mod_1) { input_mod_factor, output_mod_factor); CheckEqual(result, exp_out); } + +TEST(EltwiseReduceModMontInOut, avx512_52_mod_1) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + + uint64_t modulus = 67280421310725ULL; + std::vector input_a{0, + 67280421310000, + 25040294381203, + 340231313, + 769231483400, + 90032324, + 120042353, + 1530}; + std::vector output{0, 0, 0, 0, 0, 0, 0, 0}; + + int r = 46; // R^2 mod N = 42006526039321 + uint64_t R_reduced = ReduceMod<2>(1ULL << r, modulus); + const uint64_t R_square_mod_q = MultiplyMod(R_reduced, R_reduced, modulus); + uint64_t inv_mod = HenselLemma2adicRoot(r, modulus); + + EltwiseMontgomeryFormInAVX512<52, 46>(output.data(), input_a.data(), + R_square_mod_q, input_a.size(), modulus, + inv_mod); + EltwiseMontgomeryFormOutAVX512<52, 46>(output.data(), output.data(), + input_a.size(), modulus, inv_mod); + CheckEqual(input_a, output); +} + #endif TEST(EltwiseReduceMod, avx512_2_1) {