Skip to content

Commit

Permalink
Fboemer/reference inv ntt (intel#89)
Browse files Browse the repository at this point in the history
* Add reference radix-2 Inv NTT
  • Loading branch information
fboemer authored Nov 1, 2021
1 parent eecc3bf commit 8a976dd
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
15 changes: 14 additions & 1 deletion hexl/ntt/ntt-internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ void ForwardTransformToBitReverseRadix4(
const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1,
uint64_t output_mod_factor = 1);

/// @brief Reference NTT which is written for clarity rather than performance
/// @brief Reference forward NTT which is written for clarity rather than
/// performance
/// @param[in, out] operand Input data. Overwritten with NTT output
/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a
/// power of two.
Expand All @@ -68,6 +69,18 @@ void ReferenceForwardTransformToBitReverse(
uint64_t* operand, uint64_t n, uint64_t modulus,
const uint64_t* root_of_unity_powers);

/// @brief Reference inverse NTT which is written for clarity rather than
/// performance
/// @param[in, out] operand Input data. Overwritten with NTT output
/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a
/// power of two.
/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n
/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in
/// F_q. In bit-reversed order.
void ReferenceInverseTransformFromBitReverse(
uint64_t* operand, uint64_t n, uint64_t modulus,
const uint64_t* inv_root_of_unity_powers);

/// @brief Radix-2 native C++ NTT implementation of the inverse NTT
/// @param[out] result Output data. Overwritten with NTT output
/// @param[in] operand Input data.
Expand Down
37 changes: 37 additions & 0 deletions hexl/ntt/ntt-radix-2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,43 @@ void ReferenceForwardTransformToBitReverse(
}
}

void ReferenceInverseTransformFromBitReverse(
uint64_t* operand, uint64_t n, uint64_t modulus,
const uint64_t* inv_root_of_unity_powers) {
HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
HEXL_CHECK(inv_root_of_unity_powers != nullptr,
"inv_root_of_unity_powers == nullptr");
HEXL_CHECK(operand != nullptr, "operand == nullptr");

size_t t = 1;
size_t root_index = 1;
for (size_t m = (n >> 1); m >= 1; m >>= 1) {
size_t j1 = 0;
for (size_t i = 0; i < m; i++, root_index++) {
const uint64_t W = inv_root_of_unity_powers[root_index];
uint64_t* X_r = operand + j1;
uint64_t* Y_r = X_r + t;
for (size_t j = 0; j < t; j++) {
uint64_t X_op = *X_r;
uint64_t Y_op = *Y_r;
// Butterfly X' = (X + Y) mod q, Y' = W(X-Y) mod q
*X_r = AddUIntMod(X_op, Y_op, modulus);
*Y_r = MultiplyMod(W, SubUIntMod(X_op, Y_op, modulus), modulus);
X_r++;
Y_r++;
}
j1 += (t << 1);
}
t <<= 1;
}

// Final multiplication by N^{-1}
const uint64_t inv_n = InverseMod(n, modulus);
for (size_t i = 0; i < n; ++i) {
operand[i] = MultiplyMod(operand[i], inv_n, modulus);
}
}

void InverseTransformFromBitReverseRadix2(
uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus,
const uint64_t* inv_root_of_unity_powers,
Expand Down
19 changes: 19 additions & 0 deletions test/test-ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ TEST_P(DegreeModulusInputOutput, API) {
ReferenceForwardTransformToBitReverse(input.data(), N, modulus,
ntt.GetRootOfUnityPowers().data());
AssertEqual(input, exp_output);
ReferenceInverseTransformFromBitReverse(input.data(), N, modulus,
ntt.GetInvRootOfUnityPowers().data());
AssertEqual(input, input_copy);

// Test round-trip
input = input_copy;
Expand Down Expand Up @@ -448,6 +451,22 @@ TEST_P(NttNativeTest, InverseRadix4Random) {
AssertEqual(input, input_radix4);
}

TEST_P(NttNativeTest, InverseRadix2Random) {
auto input = GenerateInsecureUniformRandomValues(m_N, 1, 2);
auto input_reference = input;

InverseTransformFromBitReverseRadix2(
input.data(), input.data(), m_N, m_modulus,
m_ntt.GetInvRootOfUnityPowers().data(),
m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1);

ReferenceInverseTransformFromBitReverse(
input_reference.data(), m_N, m_modulus,
m_ntt.GetInvRootOfUnityPowers().data());

AssertEqual(input, input_reference);
}

INSTANTIATE_TEST_SUITE_P(
NTT, NttNativeTest,
::testing::Combine(
Expand Down

0 comments on commit 8a976dd

Please sign in to comment.