123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277 |
- /*
- * Copyright 2017 Google LLC.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "ntt_parameters.h"
- #include <cstdint>
- #include <vector>
- #include <gmock/gmock.h>
- #include <gtest/gtest.h>
- #include "absl/numeric/int128.h"
- #include "constants.h"
- #include "montgomery.h"
- #include "status_macros.h"
- #include "testing/parameters.h"
- #include "testing/status_matchers.h"
- #include "testing/status_testing.h"
- namespace {
- using ::rlwe::testing::StatusIs;
- using ::testing::HasSubstr;
- template <typename ModularInt>
- class NttParametersTest : public testing::Test {};
- TYPED_TEST_SUITE(NttParametersTest, rlwe::testing::ModularIntTypes);
- TYPED_TEST(NttParametersTest, LogNumCoeffsTooLarge) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus));
- int log_n = rlwe::kMaxLogNumCoeffs + 1;
- EXPECT_THAT(
- rlwe::InitializeNttParameters<TypeParam>(log_n, modulus_params.get()),
- StatusIs(
- ::absl::StatusCode::kInvalidArgument,
- HasSubstr(absl::StrCat("log_n, ", log_n, ", must be less than ",
- rlwe::kMaxLogNumCoeffs, "."))));
- log_n = (sizeof(typename TypeParam::Int) * 8) - 1;
- if (log_n <= rlwe::kMaxLogNumCoeffs) {
- EXPECT_THAT(
- rlwe::InitializeNttParameters<TypeParam>(log_n, modulus_params.get()),
- StatusIs(
- ::absl::StatusCode::kInvalidArgument,
- HasSubstr(absl::StrCat(
- "log_n, ", log_n,
- ", does not fit into underlying ModularInt::Int type."))));
- }
- }
- }
- TYPED_TEST(NttParametersTest, PrimitiveNthRootOfUnity) {
- unsigned int log_ns[] = {2u, 4u, 6u, 8u, 11u};
- unsigned int len = 5;
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus));
- for (unsigned int i = 0; i < len; i++) {
- ASSERT_OK_AND_ASSIGN(TypeParam w,
- rlwe::internal::PrimitiveNthRootOfUnity<TypeParam>(
- log_ns[i], modulus_params.get()));
- unsigned int n = 1 << log_ns[i];
- // Ensure it is really a n-th root of unity.
- auto res = w.ModExp(n, modulus_params.get());
- auto one = TypeParam::ImportOne(modulus_params.get());
- EXPECT_EQ(res, one) << "Not an n-th root of unity.";
- // Ensure it is really a primitive n-th root of unity.
- auto res2 = w.ModExp(n / 2, modulus_params.get());
- EXPECT_NE(res2, one) << "Not a primitive n-th root of unity.";
- }
- }
- }
- TYPED_TEST(NttParametersTest, NttPsis) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus));
- const size_t n = 1 << params.log_n;
- // The values of psi should be the powers of the primitive 2n-th root of
- // unity.
- // Obtain the psis.
- ASSERT_OK_AND_ASSIGN(
- std::vector<TypeParam> psis,
- rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get()));
- // Verify that that the 0th entry is 1.
- TypeParam one = TypeParam::ImportOne(modulus_params.get());
- EXPECT_EQ(one, psis[0]);
- // Verify that the 1th entry is a primitive 2n-th root of unity.
- auto r1 = psis[1].ModExp(2 * n, modulus_params.get());
- auto r2 = psis[1].ModExp(n, modulus_params.get());
- EXPECT_EQ(one, r1);
- EXPECT_NE(one, r2);
- // Verify that each subsequent entry is the appropriate power of the 1th
- // entry.
- for (unsigned int i = 2; i < n; i++) {
- auto ri = psis[1].ModExp(i, modulus_params.get());
- EXPECT_EQ(psis[i], ri);
- }
- }
- }
- TYPED_TEST(NttParametersTest, NttPsisBitrev) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus));
- const size_t n = 1 << params.log_n;
- // The values of psi should be bitreversed.
- // Target vector: obtain the psis in bitreversed order.
- ASSERT_OK_AND_ASSIGN(
- std::vector<TypeParam> psis_bitrev,
- rlwe::NttPsisBitrev<TypeParam>(params.log_n, modulus_params.get()));
- // Obtain the psis.
- ASSERT_OK_AND_ASSIGN(
- std::vector<TypeParam> psis,
- rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get()));
- // Obtain the mapping for bitreversed order
- std::vector<unsigned int> bit_rev =
- rlwe::internal::BitrevArray(params.log_n);
- for (unsigned int i = 0; i < n; i++) {
- EXPECT_EQ(psis_bitrev[i], psis[bit_rev[i]]);
- }
- }
- }
- TYPED_TEST(NttParametersTest, NttPsisInvBitrev) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus));
- const size_t n = 1 << params.log_n;
- // The values of the vectors should be psi^(-(brv[k]+1) for all k.
- // Target vector: obtain the psi inv in bit reversed order.
- ASSERT_OK_AND_ASSIGN(
- std::vector<TypeParam> psis_inv_bitrev,
- rlwe::NttPsisInvBitrev<TypeParam>(params.log_n, modulus_params.get()));
- // Obtain the psis.
- ASSERT_OK_AND_ASSIGN(
- std::vector<TypeParam> psis,
- rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get()));
- // Obtain the mapping for bitreversed order
- std::vector<unsigned int> bit_rev =
- rlwe::internal::BitrevArray(params.log_n);
- for (unsigned int i = 0; i < n; i++) {
- EXPECT_EQ(modulus_params->One(),
- psis_inv_bitrev[i]
- .Mul(psis[1], modulus_params.get())
- .Mul(psis[bit_rev[i]], modulus_params.get())
- .ExportInt(modulus_params.get()));
- }
- }
- }
- TEST(NttParametersRegularTest, Bitrev) {
- for (unsigned int log_N = 2; log_N < 11; log_N++) {
- unsigned int N = 1 << log_N;
- std::vector<unsigned int> bit_rev = rlwe::internal::BitrevArray(log_N);
- // Visit each entry of the array.
- for (unsigned int i = 0; i < N; i++) {
- for (unsigned int j = 0; j < log_N; j++) {
- // Ensure bit j of i is equal to bit (log_N - j) of bit_rev[i].
- rlwe::Uint64 mask1 = 1 << j;
- rlwe::Uint64 mask2 = 1 << (log_N - j - 1);
- EXPECT_EQ((i & mask1) == 0, (bit_rev[i] & mask2) == 0);
- }
- }
- }
- }
- TYPED_TEST(NttParametersTest, IncorrectNTTParams) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- // modulus + 2, will no longer be 1 mod 2*n
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus + 2));
- EXPECT_THAT(
- rlwe::InitializeNttParameters<TypeParam>(params.log_n,
- modulus_params.get()),
- StatusIs(::absl::StatusCode::kInvalidArgument,
- HasSubstr(absl::StrCat("modulus is not 1 mod 2n for logn, ",
- params.log_n))));
- }
- }
- // Test all the NTT Parameter fields.
- TYPED_TEST(NttParametersTest, Initialize) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- // Do not create a context, since it creates NttParameters already. Instead,
- // create the modulus parameters manually.
- ASSERT_OK_AND_ASSIGN(auto modulus_params,
- TypeParam::Params::Create(params.modulus));
- const size_t n = 1 << params.log_n;
- ASSERT_OK_AND_ASSIGN(rlwe::NttParameters<TypeParam> ntt_params,
- rlwe::InitializeNttParameters<TypeParam>(
- params.log_n, modulus_params.get()));
- TypeParam one = TypeParam::ImportOne(modulus_params.get());
- // Obtain the mapping for bitreversed order
- std::vector<unsigned int> bit_rev =
- rlwe::internal::BitrevArray(params.log_n);
- // Test first entry of psis in bitreversed order is one.
- EXPECT_EQ(one, ntt_params.psis_bitrev[0]);
- // Test n/2-th (brv[1]-th) entry of psis in bitreversed order is a primitive
- // 2n-th root of unity.
- auto psi = ntt_params.psis_bitrev[bit_rev[1]];
- auto r1 = psi.ModExp(2 * n, modulus_params.get());
- auto r2 = psi.ModExp(n, modulus_params.get());
- EXPECT_EQ(one, r1);
- EXPECT_NE(one, r2);
- // The values of psis should be the powers of the primitive 2n-th root of
- // unity in bitreversed order.
- for (unsigned int i = 0; i < n; i++) {
- auto bi = psi.ModExp(i, modulus_params.get());
- EXPECT_EQ(ntt_params.psis_bitrev[bit_rev[i]], bi);
- }
- // Test psis_inv_bitrev contains the inverses of the powers of psi in
- // bitreversed order, each multiplied by the inverse of psi.
- for (unsigned int i = 0; i < n; i++) {
- EXPECT_EQ(modulus_params->One(),
- ntt_params.psis_bitrev[i]
- .Mul(psi, modulus_params.get())
- .Mul(ntt_params.psis_inv_bitrev[i], modulus_params.get())
- .ExportInt(modulus_params.get()));
- }
- }
- }
- } // namespace
|