From 8a976dd6de18277613f9c64f877bb3975348ba2f Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Mon, 1 Nov 2021 07:20:41 -0700 Subject: [PATCH] Fboemer/reference inv ntt (#89) * Add reference radix-2 Inv NTT --- hexl/ntt/ntt-internal.hpp | 15 ++++++++++++++- hexl/ntt/ntt-radix-2.cpp | 37 +++++++++++++++++++++++++++++++++++++ test/test-ntt.cpp | 19 +++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/hexl/ntt/ntt-internal.hpp b/hexl/ntt/ntt-internal.hpp index b1449ce1..1f0e7fbc 100644 --- a/hexl/ntt/ntt-internal.hpp +++ b/hexl/ntt/ntt-internal.hpp @@ -57,7 +57,8 @@ void ForwardTransformToBitReverseRadix4( const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); -/// @brief Reference NTT which is written for clarity rather than performance +/// @brief Reference forward NTT which is written for clarity rather than +/// performance /// @param[in, out] operand Input data. Overwritten with NTT output /// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a /// power of two. @@ -68,6 +69,18 @@ void ReferenceForwardTransformToBitReverse( uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers); +/// @brief Reference inverse NTT which is written for clarity rather than +/// performance +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +void ReferenceInverseTransformFromBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers); + /// @brief Radix-2 native C++ NTT implementation of the inverse NTT /// @param[out] result Output data. Overwritten with NTT output /// @param[in] operand Input data. diff --git a/hexl/ntt/ntt-radix-2.cpp b/hexl/ntt/ntt-radix-2.cpp index 1549740b..2e0786bb 100644 --- a/hexl/ntt/ntt-radix-2.cpp +++ b/hexl/ntt/ntt-radix-2.cpp @@ -290,6 +290,43 @@ void ReferenceForwardTransformToBitReverse( } } +void ReferenceInverseTransformFromBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + size_t t = 1; + size_t root_index = 1; + for (size_t m = (n >> 1); m >= 1; m >>= 1) { + size_t j1 = 0; + for (size_t i = 0; i < m; i++, root_index++) { + const uint64_t W = inv_root_of_unity_powers[root_index]; + uint64_t* X_r = operand + j1; + uint64_t* Y_r = X_r + t; + for (size_t j = 0; j < t; j++) { + uint64_t X_op = *X_r; + uint64_t Y_op = *Y_r; + // Butterfly X' = (X + Y) mod q, Y' = W(X-Y) mod q + *X_r = AddUIntMod(X_op, Y_op, modulus); + *Y_r = MultiplyMod(W, SubUIntMod(X_op, Y_op, modulus), modulus); + X_r++; + Y_r++; + } + j1 += (t << 1); + } + t <<= 1; + } + + // Final multiplication by N^{-1} + const uint64_t inv_n = InverseMod(n, modulus); + for (size_t i = 0; i < n; ++i) { + operand[i] = MultiplyMod(operand[i], inv_n, modulus); + } +} + void InverseTransformFromBitReverseRadix2( uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* inv_root_of_unity_powers, diff --git a/test/test-ntt.cpp b/test/test-ntt.cpp index 70543691..12b28541 100644 --- a/test/test-ntt.cpp +++ b/test/test-ntt.cpp @@ -255,6 +255,9 @@ TEST_P(DegreeModulusInputOutput, API) { ReferenceForwardTransformToBitReverse(input.data(), N, modulus, ntt.GetRootOfUnityPowers().data()); AssertEqual(input, exp_output); + ReferenceInverseTransformFromBitReverse(input.data(), N, modulus, + ntt.GetInvRootOfUnityPowers().data()); + AssertEqual(input, input_copy); // Test round-trip input = input_copy; @@ -448,6 +451,22 @@ TEST_P(NttNativeTest, InverseRadix4Random) { AssertEqual(input, input_radix4); } +TEST_P(NttNativeTest, InverseRadix2Random) { + auto input = GenerateInsecureUniformRandomValues(m_N, 1, 2); + auto input_reference = input; + + InverseTransformFromBitReverseRadix2( + input.data(), input.data(), m_N, m_modulus, + m_ntt.GetInvRootOfUnityPowers().data(), + m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1); + + ReferenceInverseTransformFromBitReverse( + input_reference.data(), m_N, m_modulus, + m_ntt.GetInvRootOfUnityPowers().data()); + + AssertEqual(input, input_reference); +} + INSTANTIATE_TEST_SUITE_P( NTT, NttNativeTest, ::testing::Combine(