ntt_parameters_test.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. /*
  2. * Copyright 2017 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include "ntt_parameters.h"
  16. #include <cstdint>
  17. #include <vector>
  18. #include <gmock/gmock.h>
  19. #include <gtest/gtest.h>
  20. #include "absl/numeric/int128.h"
  21. #include "constants.h"
  22. #include "montgomery.h"
  23. #include "status_macros.h"
  24. #include "testing/parameters.h"
  25. #include "testing/status_matchers.h"
  26. #include "testing/status_testing.h"
  27. namespace {
  28. using ::rlwe::testing::StatusIs;
  29. using ::testing::HasSubstr;
  30. template <typename ModularInt>
  31. class NttParametersTest : public testing::Test {};
  32. TYPED_TEST_SUITE(NttParametersTest, rlwe::testing::ModularIntTypes);
  33. TYPED_TEST(NttParametersTest, LogNumCoeffsTooLarge) {
  34. for (const auto& params :
  35. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  36. // Do not create a context, since it creates NttParameters already. Instead,
  37. // create the modulus parameters manually.
  38. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  39. TypeParam::Params::Create(params.modulus));
  40. int log_n = rlwe::kMaxLogNumCoeffs + 1;
  41. EXPECT_THAT(
  42. rlwe::InitializeNttParameters<TypeParam>(log_n, modulus_params.get()),
  43. StatusIs(
  44. ::absl::StatusCode::kInvalidArgument,
  45. HasSubstr(absl::StrCat("log_n, ", log_n, ", must be less than ",
  46. rlwe::kMaxLogNumCoeffs, "."))));
  47. log_n = (sizeof(typename TypeParam::Int) * 8) - 1;
  48. if (log_n <= rlwe::kMaxLogNumCoeffs) {
  49. EXPECT_THAT(
  50. rlwe::InitializeNttParameters<TypeParam>(log_n, modulus_params.get()),
  51. StatusIs(
  52. ::absl::StatusCode::kInvalidArgument,
  53. HasSubstr(absl::StrCat(
  54. "log_n, ", log_n,
  55. ", does not fit into underlying ModularInt::Int type."))));
  56. }
  57. }
  58. }
  59. TYPED_TEST(NttParametersTest, PrimitiveNthRootOfUnity) {
  60. unsigned int log_ns[] = {2u, 4u, 6u, 8u, 11u};
  61. unsigned int len = 5;
  62. for (const auto& params :
  63. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  64. // Do not create a context, since it creates NttParameters already. Instead,
  65. // create the modulus parameters manually.
  66. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  67. TypeParam::Params::Create(params.modulus));
  68. for (unsigned int i = 0; i < len; i++) {
  69. ASSERT_OK_AND_ASSIGN(TypeParam w,
  70. rlwe::internal::PrimitiveNthRootOfUnity<TypeParam>(
  71. log_ns[i], modulus_params.get()));
  72. unsigned int n = 1 << log_ns[i];
  73. // Ensure it is really a n-th root of unity.
  74. auto res = w.ModExp(n, modulus_params.get());
  75. auto one = TypeParam::ImportOne(modulus_params.get());
  76. EXPECT_EQ(res, one) << "Not an n-th root of unity.";
  77. // Ensure it is really a primitive n-th root of unity.
  78. auto res2 = w.ModExp(n / 2, modulus_params.get());
  79. EXPECT_NE(res2, one) << "Not a primitive n-th root of unity.";
  80. }
  81. }
  82. }
  83. TYPED_TEST(NttParametersTest, NttPsis) {
  84. for (const auto& params :
  85. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  86. // Do not create a context, since it creates NttParameters already. Instead,
  87. // create the modulus parameters manually.
  88. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  89. TypeParam::Params::Create(params.modulus));
  90. const size_t n = 1 << params.log_n;
  91. // The values of psi should be the powers of the primitive 2n-th root of
  92. // unity.
  93. // Obtain the psis.
  94. ASSERT_OK_AND_ASSIGN(
  95. std::vector<TypeParam> psis,
  96. rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get()));
  97. // Verify that that the 0th entry is 1.
  98. TypeParam one = TypeParam::ImportOne(modulus_params.get());
  99. EXPECT_EQ(one, psis[0]);
  100. // Verify that the 1th entry is a primitive 2n-th root of unity.
  101. auto r1 = psis[1].ModExp(2 * n, modulus_params.get());
  102. auto r2 = psis[1].ModExp(n, modulus_params.get());
  103. EXPECT_EQ(one, r1);
  104. EXPECT_NE(one, r2);
  105. // Verify that each subsequent entry is the appropriate power of the 1th
  106. // entry.
  107. for (unsigned int i = 2; i < n; i++) {
  108. auto ri = psis[1].ModExp(i, modulus_params.get());
  109. EXPECT_EQ(psis[i], ri);
  110. }
  111. }
  112. }
  113. TYPED_TEST(NttParametersTest, NttPsisBitrev) {
  114. for (const auto& params :
  115. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  116. // Do not create a context, since it creates NttParameters already. Instead,
  117. // create the modulus parameters manually.
  118. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  119. TypeParam::Params::Create(params.modulus));
  120. const size_t n = 1 << params.log_n;
  121. // The values of psi should be bitreversed.
  122. // Target vector: obtain the psis in bitreversed order.
  123. ASSERT_OK_AND_ASSIGN(
  124. std::vector<TypeParam> psis_bitrev,
  125. rlwe::NttPsisBitrev<TypeParam>(params.log_n, modulus_params.get()));
  126. // Obtain the psis.
  127. ASSERT_OK_AND_ASSIGN(
  128. std::vector<TypeParam> psis,
  129. rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get()));
  130. // Obtain the mapping for bitreversed order
  131. std::vector<unsigned int> bit_rev =
  132. rlwe::internal::BitrevArray(params.log_n);
  133. for (unsigned int i = 0; i < n; i++) {
  134. EXPECT_EQ(psis_bitrev[i], psis[bit_rev[i]]);
  135. }
  136. }
  137. }
  138. TYPED_TEST(NttParametersTest, NttPsisInvBitrev) {
  139. for (const auto& params :
  140. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  141. // Do not create a context, since it creates NttParameters already. Instead,
  142. // create the modulus parameters manually.
  143. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  144. TypeParam::Params::Create(params.modulus));
  145. const size_t n = 1 << params.log_n;
  146. // The values of the vectors should be psi^(-(brv[k]+1) for all k.
  147. // Target vector: obtain the psi inv in bit reversed order.
  148. ASSERT_OK_AND_ASSIGN(
  149. std::vector<TypeParam> psis_inv_bitrev,
  150. rlwe::NttPsisInvBitrev<TypeParam>(params.log_n, modulus_params.get()));
  151. // Obtain the psis.
  152. ASSERT_OK_AND_ASSIGN(
  153. std::vector<TypeParam> psis,
  154. rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get()));
  155. // Obtain the mapping for bitreversed order
  156. std::vector<unsigned int> bit_rev =
  157. rlwe::internal::BitrevArray(params.log_n);
  158. for (unsigned int i = 0; i < n; i++) {
  159. EXPECT_EQ(modulus_params->One(),
  160. psis_inv_bitrev[i]
  161. .Mul(psis[1], modulus_params.get())
  162. .Mul(psis[bit_rev[i]], modulus_params.get())
  163. .ExportInt(modulus_params.get()));
  164. }
  165. }
  166. }
  167. TEST(NttParametersRegularTest, Bitrev) {
  168. for (unsigned int log_N = 2; log_N < 11; log_N++) {
  169. unsigned int N = 1 << log_N;
  170. std::vector<unsigned int> bit_rev = rlwe::internal::BitrevArray(log_N);
  171. // Visit each entry of the array.
  172. for (unsigned int i = 0; i < N; i++) {
  173. for (unsigned int j = 0; j < log_N; j++) {
  174. // Ensure bit j of i is equal to bit (log_N - j) of bit_rev[i].
  175. rlwe::Uint64 mask1 = 1 << j;
  176. rlwe::Uint64 mask2 = 1 << (log_N - j - 1);
  177. EXPECT_EQ((i & mask1) == 0, (bit_rev[i] & mask2) == 0);
  178. }
  179. }
  180. }
  181. }
  182. TYPED_TEST(NttParametersTest, IncorrectNTTParams) {
  183. for (const auto& params :
  184. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  185. // Do not create a context, since it creates NttParameters already. Instead,
  186. // create the modulus parameters manually.
  187. // modulus + 2, will no longer be 1 mod 2*n
  188. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  189. TypeParam::Params::Create(params.modulus + 2));
  190. EXPECT_THAT(
  191. rlwe::InitializeNttParameters<TypeParam>(params.log_n,
  192. modulus_params.get()),
  193. StatusIs(::absl::StatusCode::kInvalidArgument,
  194. HasSubstr(absl::StrCat("modulus is not 1 mod 2n for logn, ",
  195. params.log_n))));
  196. }
  197. }
  198. // Test all the NTT Parameter fields.
  199. TYPED_TEST(NttParametersTest, Initialize) {
  200. for (const auto& params :
  201. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  202. // Do not create a context, since it creates NttParameters already. Instead,
  203. // create the modulus parameters manually.
  204. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  205. TypeParam::Params::Create(params.modulus));
  206. const size_t n = 1 << params.log_n;
  207. ASSERT_OK_AND_ASSIGN(rlwe::NttParameters<TypeParam> ntt_params,
  208. rlwe::InitializeNttParameters<TypeParam>(
  209. params.log_n, modulus_params.get()));
  210. TypeParam one = TypeParam::ImportOne(modulus_params.get());
  211. // Obtain the mapping for bitreversed order
  212. std::vector<unsigned int> bit_rev =
  213. rlwe::internal::BitrevArray(params.log_n);
  214. // Test first entry of psis in bitreversed order is one.
  215. EXPECT_EQ(one, ntt_params.psis_bitrev[0]);
  216. // Test n/2-th (brv[1]-th) entry of psis in bitreversed order is a primitive
  217. // 2n-th root of unity.
  218. auto psi = ntt_params.psis_bitrev[bit_rev[1]];
  219. auto r1 = psi.ModExp(2 * n, modulus_params.get());
  220. auto r2 = psi.ModExp(n, modulus_params.get());
  221. EXPECT_EQ(one, r1);
  222. EXPECT_NE(one, r2);
  223. // The values of psis should be the powers of the primitive 2n-th root of
  224. // unity in bitreversed order.
  225. for (unsigned int i = 0; i < n; i++) {
  226. auto bi = psi.ModExp(i, modulus_params.get());
  227. EXPECT_EQ(ntt_params.psis_bitrev[bit_rev[i]], bi);
  228. }
  229. // Test psis_inv_bitrev contains the inverses of the powers of psi in
  230. // bitreversed order, each multiplied by the inverse of psi.
  231. for (unsigned int i = 0; i < n; i++) {
  232. EXPECT_EQ(modulus_params->One(),
  233. ntt_params.psis_bitrev[i]
  234. .Mul(psi, modulus_params.get())
  235. .Mul(ntt_params.psis_inv_bitrev[i], modulus_params.get())
  236. .ExportInt(modulus_params.get()));
  237. }
  238. }
  239. }
  240. } // namespace