ntt_parameters.h 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /*
  2. * Copyright 2017 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #ifndef RLWE_NTT_PARAMETERS_H_
  16. #define RLWE_NTT_PARAMETERS_H_
  17. #include <algorithm>
  18. #include <cstdlib>
  19. #include <vector>
  20. #include "absl/memory/memory.h"
  21. #include "absl/strings/str_cat.h"
  22. #include "constants.h"
  23. #include "status_macros.h"
  24. #include "statusor.h"
  25. #include "third_party/shell-encryption/base/shell_encryption_export.h"
  26. namespace rlwe {
  27. namespace internal {
  28. // Fill row with every power in {0, 1, ..., n-1} (mod modulus) of base .
  29. template <typename ModularInt>
  30. void FillWithEveryPower(const ModularInt& base, unsigned int n,
  31. std::vector<ModularInt>* row,
  32. const typename ModularInt::Params* params) {
  33. for (int i = 0; i < n; i++) {
  34. (*row)[i].AddInPlace(base.ModExp(i, params), params);
  35. }
  36. }
  37. template <typename ModularInt>
  38. rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
  39. unsigned int log_n, const typename ModularInt::Params* params) {
  40. typename ModularInt::Int n = params->One() << log_n;
  41. typename ModularInt::Int half_n = n >> 1;
  42. // When the modulus is prime, the value k is a power such that any number
  43. // raised to it will be a n-th root of unity. (It will not necessarily be a
  44. // *primitive* root of unity, however).
  45. typename ModularInt::Int k = (params->modulus - params->One()) / n;
  46. // Test each number t to see whether t^k is a primitive n-th root
  47. // of unity - that t^{nk} is a root of unity but t^{(n/2)k} is not.
  48. ModularInt one = ModularInt::ImportOne(params);
  49. for (typename ModularInt::Int t = params->Two(); t < params->modulus;
  50. t = t + params->One()) {
  51. // Produce a candidate root of unity.
  52. RLWE_ASSIGN_OR_RETURN(auto mt, ModularInt::ImportInt(t, params));
  53. ModularInt candidate = mt.ModExp(k, params);
  54. // Check whether candidate^half_n = 1. If not, it is a primitive root of
  55. // unity.
  56. if (candidate.ModExp(half_n, params) != one) {
  57. return candidate;
  58. }
  59. }
  60. // Failure state. The above loop should always return successfully assuming
  61. // the parameters were set properly.
  62. return absl::UnknownError("Loop in PrimitiveNthRootOfUnity terminated.");
  63. }
  64. // Let psi be a primitive 2n-th root of unity, i.e., a 2n-th root of unity such
  65. // that psi^n = -1. When performing the NTT transformation, the powers of psi in
  66. // bitreversed order are needed. The vector produced by this helper function
  67. // contains the powers of psi (psi^0, psi^1, psi^2, ..., psi^(n-1)).
  68. //
  69. // Each item of the vector is in modular integer representation.
  70. template <typename ModularInt>
  71. rlwe::StatusOr<std::vector<ModularInt>> NttPsis(
  72. unsigned int log_n, const typename ModularInt::Params* params) {
  73. // Obtain psi, a primitive 2n-th root of unity (hence log_n + 1).
  74. RLWE_ASSIGN_OR_RETURN(
  75. ModularInt psi,
  76. internal::PrimitiveNthRootOfUnity<ModularInt>(log_n + 1, params));
  77. unsigned int n = 1 << log_n;
  78. ModularInt zero = ModularInt::ImportZero(params);
  79. // Create a vector with the powers of psi.
  80. std::vector<ModularInt> row(n, zero);
  81. internal::FillWithEveryPower<ModularInt>(psi, n, &row, params);
  82. return row;
  83. }
  84. // Creates a vector containing the indices necessary to perform the NTT bit
  85. // reversal operation. Index i of the returned vector contains an integer with
  86. // the rightmost log_n bits of i reversed.
  87. SHELL_ENCRYPTION_EXPORT std::vector<unsigned int> BitrevArray(unsigned int log_n);
  88. // Helper function: Perform the bit-reversal operation in-place on coeffs_.
  89. template <typename ModularInt>
  90. static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
  91. std::vector<ModularInt>* item_to_reverse) {
  92. using std::swap;
  93. for (int i = 0; i < item_to_reverse->size(); i++) {
  94. // Only swap in one direction - don't accidentally swap twice.
  95. unsigned int r = bitrevs[i];
  96. if (static_cast<unsigned int>(i) < r) {
  97. swap((*item_to_reverse)[i], (*item_to_reverse)[r]);
  98. }
  99. }
  100. }
  101. } // namespace internal
  102. // The precomputed roots of unity used during the forward NTT are the
  103. // bitreversed powers of the primitive 2n-th root of unity.
  104. template <typename ModularInt>
  105. rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
  106. unsigned int log_n, const typename ModularInt::Params* params) {
  107. // Retrieve the table for the forward transformation.
  108. RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> psis,
  109. internal::NttPsis<ModularInt>(log_n, params));
  110. // Bitreverse the vector.
  111. internal::BitrevHelper(internal::BitrevArray(log_n), &psis);
  112. return psis;
  113. }
  114. // The precomputed roots of unity used during the inverse NTT are the inverses
  115. // of the bitreversed powers of the primitive 2n-th root of unity plus 1.
  116. template <typename ModularInt>
  117. rlwe::StatusOr<std::vector<ModularInt>> NttPsisInvBitrev(
  118. unsigned int log_n, const typename ModularInt::Params* params) {
  119. // Retrieve the table for the forward transformation.
  120. RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> row,
  121. internal::NttPsis<ModularInt>(log_n, params));
  122. // Reverse the items at indices 1 through (n - 1). Multiplying index i
  123. // of the reversed row by index i of the original row will yield psi^n = -1.
  124. // (The exception is psi^0 = 1, which is already its own inverse.)
  125. std::reverse(row.begin() + 1, row.end());
  126. // Get the inverse of psi
  127. ModularInt psi_inv = row[1].Negate(params);
  128. ModularInt negative_psi_inv = row[1];
  129. // Bitreverse the vector.
  130. internal::BitrevHelper(internal::BitrevArray(log_n), &row);
  131. // Finally, multiply each of the items at indices 1 to (n-1) by -1. Multiply
  132. // every entry by psi_inv.
  133. row[0].MulInPlace(psi_inv, params);
  134. for (int i = 1; i < row.size(); i++) {
  135. row[i].MulInPlace(negative_psi_inv, params);
  136. }
  137. return row;
  138. }
  139. // A struct that stores a package of NTT Parameters
  140. template <typename ModularInt>
  141. struct NttParameters {
  142. NttParameters() = default;
  143. // Disallow copy and copy-assign, allow move and move-assign.
  144. NttParameters(const NttParameters<ModularInt>&) = delete;
  145. NttParameters& operator=(const NttParameters<ModularInt>&) = delete;
  146. NttParameters(NttParameters<ModularInt>&&) = default;
  147. NttParameters& operator=(NttParameters<ModularInt>&&) = default;
  148. ~NttParameters() = default;
  149. int number_coeffs;
  150. absl::optional<ModularInt> n_inv_ptr;
  151. std::vector<ModularInt> psis_bitrev;
  152. std::vector<ModularInt> psis_inv_bitrev;
  153. std::vector<unsigned int> bitrevs;
  154. };
  155. // A convenient function that sets up all NTT parameters at once.
  156. // Does not take ownership of params.
  157. template <typename ModularInt>
  158. rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
  159. int log_n, const typename ModularInt::Params* params) {
  160. // Abort if log_n is non-positive.
  161. if (log_n <= 0) {
  162. return absl::InvalidArgumentError("log_n must be positive");
  163. } else if (log_n > kMaxLogNumCoeffs) {
  164. return absl::InvalidArgumentError(absl::StrCat(
  165. "log_n, ", log_n, ", must be less than ", kMaxLogNumCoeffs, "."));
  166. }
  167. if (!ModularInt::Params::DoesLogNFit(log_n)) {
  168. return absl::InvalidArgumentError(
  169. absl::StrCat("log_n, ", log_n,
  170. ", does not fit into underlying ModularInt::Int type."));
  171. }
  172. NttParameters<ModularInt> output;
  173. output.number_coeffs = 1 << log_n;
  174. typename ModularInt::Int two_times_n = params->One() << (log_n + 1);
  175. if (params->modulus % two_times_n != params->One()){
  176. return absl::InvalidArgumentError(
  177. absl::StrCat("modulus is not 1 mod 2n for logn, ", log_n));
  178. }
  179. // Compute the inverse of n.
  180. typename ModularInt::Int n = params->One() << log_n;
  181. RLWE_ASSIGN_OR_RETURN(auto mn, ModularInt::ImportInt(n, params));
  182. output.n_inv_ptr = mn.MultiplicativeInverse(params);
  183. RLWE_ASSIGN_OR_RETURN(output.psis_bitrev,
  184. NttPsisBitrev<ModularInt>(log_n, params));
  185. RLWE_ASSIGN_OR_RETURN(output.psis_inv_bitrev,
  186. NttPsisInvBitrev<ModularInt>(log_n, params));
  187. output.bitrevs = internal::BitrevArray(log_n);
  188. return output;
  189. }
  190. } // namespace rlwe
  191. #endif // RLWE_NTT_PARAMETERS_H_