polynomial.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. /*
  2. * Copyright 2017 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #ifndef RLWE_POLYNOMIAL_H_
  16. #define RLWE_POLYNOMIAL_H_
  17. #include <cmath>
  18. #include <vector>
  19. #include "absl/strings/str_cat.h"
  20. #include "constants.h"
  21. #include "ntt_parameters.h"
  22. #include "prng/prng.h"
  23. #include "serialization.pb.h"
  24. #include "status_macros.h"
  25. #include "statusor.h"
  26. namespace rlwe {
  27. // A polynomial in NTT form. The length of the polynomial must be a power of 2.
  28. template <typename ModularInt>
  29. class Polynomial {
  30. using ModularIntParams = typename ModularInt::Params;
  31. public:
  32. // Default constructor.
  33. Polynomial() = default;
  34. // Copy constructor.
  35. Polynomial(const Polynomial& p) = default;
  36. Polynomial& operator=(const Polynomial& that) = default;
  37. // Basic constructor.
  38. explicit Polynomial(std::vector<ModularInt> poly_coeffs)
  39. : log_len_(log2(poly_coeffs.size())), coeffs_(std::move(poly_coeffs)) {}
  40. // Create an empty polynomial of the specified length. The length must be
  41. // a power of 2.
  42. explicit Polynomial(int len, const ModularIntParams* params)
  43. : Polynomial(
  44. std::vector<ModularInt>(len, ModularInt::ImportZero(params))) {}
  45. // This is an implementation of the FFT from [Sei18, Sec. 2].
  46. // [Sei18] https://eprint.iacr.org/2018/039
  47. // All polynomial arithmetic performed is modulo (x^n+1) for n a power of two,
  48. // with the coefficients operated on modulo a prime modulus.
  49. //
  50. // Let psi be a primitive 2n-th root of the unity, i.e., psi is a 2n-th root
  51. // of unity such that psi^n = -1. Hence it holds that
  52. // x^n+1 = x^n-psi^n = (x^n/2-psi^n/2)*(x^n/2+psi^n/2)
  53. //
  54. //
  55. // If f = f_0 + f_1*x + ... + f_{n-1}*x^(n-1) is the polynomial to transform,
  56. // the i-th coefficient of the polynomial mod x^n/2-psi^n/2 can thus be
  57. // computed as
  58. // f'_i = f_i + psi^(n/2)*f_(n/2+i),
  59. // and the i-th coefficient of the polynomial mod x^n/2+psi^n/2 can thus be
  60. // computed as
  61. // f''_i = f_i - psi^(n/2)*f_(n/2+i)
  62. // This operation is called the Cooley-Tukey butterfly and is done
  63. // iteratively during the NTT.
  64. //
  65. // The FFT can thus be performed in-place and after the k-th level, it
  66. // produces the vector of polynomials with pairs of coefficients
  67. // f mod (x^(n/2^(k+1))-psi^brv[2^k+1]), f mod (x^(n/2^(k+1))+psi^brv[2^k+1])
  68. // where brv maps a log(n)-bit number to its bitreversal.
  69. static Polynomial ConvertToNtt(std::vector<ModularInt> poly_coeffs,
  70. const NttParameters<ModularInt>* ntt_params,
  71. const ModularIntParams* modular_params) {
  72. // Check to ensure that the coefficient vector is of the correct length.
  73. int len = poly_coeffs.size();
  74. if (len <= 0 || (len & (len - 1)) != 0) {
  75. // An error value.
  76. return Polynomial();
  77. }
  78. Polynomial output(std::move(poly_coeffs));
  79. output.IterativeCooleyTukey(ntt_params->psis_bitrev, modular_params);
  80. return output;
  81. }
  82. // Deprecated ConvertToNtt function taking NttParameters by constant reference
  83. ABSL_DEPRECATED("Use ConvertToNtt function with NttParameters pointer above.")
  84. static Polynomial ConvertToNtt(std::vector<ModularInt> poly_coeffs,
  85. const NttParameters<ModularInt>& ntt_params,
  86. const ModularIntParams* modular_params) {
  87. return ConvertToNtt(std::move(poly_coeffs), &ntt_params, modular_params);
  88. }
  89. // The inverse NTT transform is computed similarly by iteratively inverting
  90. // the NTT representation. For instance, using the same notation as above,
  91. // f'_i + f''_i = 2f_i and psi^(-n/2)*(f'_i-f''_i) = 2c_(n/2+i).
  92. //
  93. // In particular, the butterfly operation differs from the Cooley-Tukey
  94. // butterfly used during the forward transform in that addition and
  95. // substraction come before multiplying with a power of the root of unity.
  96. // This butterfly operation is called the Gentleman-Sande butterfly.
  97. //
  98. // At the end of the computation, a normalization step by the inverse of
  99. // n=2^log(n) (the factor 2 obtained at each level of the butterfly) is
  100. // required.
  101. std::vector<ModularInt> InverseNtt(
  102. const NttParameters<ModularInt>* ntt_params,
  103. const ModularIntParams* modular_params) const {
  104. Polynomial copy(*this);
  105. copy.IterativeGentlemanSande(ntt_params->psis_inv_bitrev, modular_params);
  106. // Normalize the result by multiplying by the inverse of n.
  107. for (auto& coeff : copy.coeffs_) {
  108. coeff.MulInPlace(ntt_params->n_inv_ptr.value(), modular_params);
  109. }
  110. return copy.coeffs_;
  111. }
  112. // Deprecated InverseNtt function taking NttParameters by constant reference
  113. ABSL_DEPRECATED("Use InverseNtt function with NttParameters pointer above.")
  114. std::vector<ModularInt> InverseNtt(
  115. const NttParameters<ModularInt>& ntt_params,
  116. const ModularIntParams* modular_params) const {
  117. return InverseNtt(&ntt_params, modular_params);
  118. }
  119. // Specifies whether the Polynomial is valid.
  120. bool IsValid() const { return !coeffs_.empty(); }
  121. // Scalar multiply.
  122. rlwe::StatusOr<Polynomial> Mul(const ModularInt& scalar,
  123. const ModularIntParams* modular_params) const {
  124. Polynomial output = *this;
  125. RLWE_RETURN_IF_ERROR(output.MulInPlace(scalar, modular_params));
  126. return output;
  127. }
  128. // Scalar multiply in place.
  129. absl::Status MulInPlace(const ModularInt& scalar,
  130. const ModularIntParams* modular_params) {
  131. return ModularInt::BatchMulInPlace(&coeffs_, scalar, modular_params);
  132. }
  133. // Coordinate-wise multiplication.
  134. rlwe::StatusOr<Polynomial> Mul(const Polynomial& that,
  135. const ModularIntParams* modular_params) const {
  136. Polynomial output = *this;
  137. RLWE_RETURN_IF_ERROR(output.MulInPlace(that, modular_params));
  138. return output;
  139. }
  140. // Coordinate-wise multiplication in place.
  141. absl::Status MulInPlace(const Polynomial& that,
  142. const ModularIntParams* modular_params) {
  143. // If this operation is invalid, return an invalid error.
  144. if (Len() != that.Len()) {
  145. return absl::InvalidArgumentError(
  146. "The polynomials do not have the same length.");
  147. }
  148. return ModularInt::BatchMulInPlace(&coeffs_, that.coeffs_, modular_params);
  149. }
  150. // Negation.
  151. Polynomial Negate(const ModularIntParams* modular_params) const {
  152. Polynomial output = *this;
  153. output.NegateInPlace(modular_params);
  154. return output;
  155. }
  156. // Negation in place.
  157. Polynomial& NegateInPlace(const ModularIntParams* modular_params) {
  158. for (auto& coeff : coeffs_) {
  159. coeff.NegateInPlace(modular_params);
  160. }
  161. return *this;
  162. }
  163. // Coordinate-wise addition.
  164. rlwe::StatusOr<Polynomial> Add(const Polynomial& that,
  165. const ModularIntParams* modular_params) const {
  166. Polynomial output = *this;
  167. RLWE_RETURN_IF_ERROR(output.AddInPlace(that, modular_params));
  168. return output;
  169. }
  170. // Coordinate-wise substraction.
  171. rlwe::StatusOr<Polynomial> Sub(const Polynomial& that,
  172. const ModularIntParams* modular_params) const {
  173. Polynomial output = *this;
  174. RLWE_RETURN_IF_ERROR(output.SubInPlace(that, modular_params));
  175. return output;
  176. }
  177. // Coordinate-wise addition in place.
  178. absl::Status AddInPlace(const Polynomial& that,
  179. const ModularIntParams* modular_params) {
  180. // If this operation is invalid, return an invalid error.
  181. if (Len() != that.Len()) {
  182. return absl::InvalidArgumentError(
  183. "The polynomials do not have the same length.");
  184. }
  185. return ModularInt::BatchAddInPlace(&coeffs_, that.coeffs_, modular_params);
  186. }
  187. // Coordinate-wise substraction in place.
  188. absl::Status SubInPlace(const Polynomial& that,
  189. const ModularIntParams* modular_params) {
  190. // If this operation is invalid, return an invalid error.
  191. if (Len() != that.Len()) {
  192. return absl::InvalidArgumentError(
  193. "The polynomials do not have the same length.");
  194. }
  195. return ModularInt::BatchSubInPlace(&coeffs_, that.coeffs_, modular_params);
  196. }
  197. // Substitute: Given an Polynomial representing p(x), returns an
  198. // Polynomial representing p(x^power). Power must be an odd non-negative
  199. // integer less than 2 * Len().
  200. rlwe::StatusOr<Polynomial> Substitute(
  201. const int power, const NttParameters<ModularInt>* ntt_params,
  202. const ModularIntParams* modulus_params) const {
  203. // The NTT representation consists in the evaluations of the polynomial at
  204. // roots psi^brv[n/2], psi^brv[n/2+1], ..., psi^brv[n/2+n/2-1],
  205. // psi^(n/2+brv[n/2+1]), ..., psi^(n/2+brv[n/2+n/2-1]).
  206. // Let f(x) be the original polynomial, and out(x) be the polynomial after
  207. // the substitution. Note that (psi^i)^power = psi^{(i * power) % (2 * n).
  208. if (0 > power || (power % 2) == 0 || power >= 2 * Len()) {
  209. return absl::InvalidArgumentError(
  210. absl::StrCat("Substitution power must be a non-negative odd "
  211. "integer less than 2*n."));
  212. }
  213. Polynomial out = *this;
  214. // Get the index of the psi^power evaluation
  215. int psi_power_index = (power - 1) / 2;
  216. // Update the coefficients one by one: remember that they are stored in
  217. // bitreversed order.
  218. for (int i = 0; i < Len(); i++) {
  219. out.coeffs_[ntt_params->bitrevs[i]] =
  220. coeffs_[ntt_params->bitrevs[psi_power_index]];
  221. // Each time the index increases by 1, the psi_power_index increases by
  222. // power mod the length.
  223. psi_power_index = (psi_power_index + power) % Len();
  224. }
  225. return out;
  226. }
  227. // Deprecated Substitute function taking NttParameters by constant reference
  228. ABSL_DEPRECATED("Use Substitute function with NttParameters pointer above.")
  229. rlwe::StatusOr<Polynomial> Substitute(
  230. const int power, const NttParameters<ModularInt>& ntt_params,
  231. const ModularIntParams* modulus_params) const {
  232. return Substitute(power, &ntt_params, modulus_params);
  233. }
  234. // Boolean comparison.
  235. bool operator==(const Polynomial& that) const {
  236. if (Len() != that.Len()) {
  237. return false;
  238. }
  239. for (int i = 0; i < Len(); i++) {
  240. if (coeffs_[i] != that.coeffs_[i]) {
  241. return false;
  242. }
  243. }
  244. return true;
  245. }
  246. bool operator!=(const Polynomial& that) const { return !(*this == that); }
  247. int Len() const { return coeffs_.size(); }
  248. // Accessor for coefficients.
  249. std::vector<ModularInt> Coeffs() const { return coeffs_; }
  250. rlwe::StatusOr<SerializedNttPolynomial> Serialize(
  251. const ModularIntParams* modular_params) const {
  252. SerializedNttPolynomial output;
  253. RLWE_ASSIGN_OR_RETURN(*(output.mutable_coeffs()),
  254. ModularInt::SerializeVector(coeffs_, modular_params));
  255. output.set_num_coeffs(coeffs_.size());
  256. return output;
  257. }
  258. static rlwe::StatusOr<Polynomial> Deserialize(
  259. const SerializedNttPolynomial& serialized,
  260. const ModularIntParams* modular_params) {
  261. if (serialized.num_coeffs() <= 0) {
  262. return absl::InvalidArgumentError(
  263. "Number of serialized coefficients must be positive.");
  264. } else if (serialized.num_coeffs() > kMaxNumCoeffs) {
  265. return absl::InvalidArgumentError(absl::StrCat(
  266. "Number of serialized coefficients, ", serialized.num_coeffs(),
  267. ", must be less than ", kMaxNumCoeffs, "."));
  268. }
  269. Polynomial output(serialized.num_coeffs(), modular_params);
  270. RLWE_ASSIGN_OR_RETURN(
  271. output.coeffs_,
  272. ModularInt::DeserializeVector(serialized.num_coeffs(),
  273. serialized.coeffs(), modular_params));
  274. return output;
  275. }
  276. private:
  277. // Instance variables.
  278. size_t log_len_;
  279. std::vector<ModularInt> coeffs_;
  280. // Helper function: Perform iterations of the Cooley-Tukey butterfly.
  281. void IterativeCooleyTukey(const std::vector<ModularInt>& psis_bitrev,
  282. const ModularIntParams* modular_params) {
  283. int index_psi = 1;
  284. for (int i = log_len_ - 1; i >= 0; i--) {
  285. const unsigned int half_m = 1 << i;
  286. const unsigned int m = half_m << 1;
  287. for (int k = 0; k < Len(); k += m) {
  288. const ModularInt psi = psis_bitrev[index_psi];
  289. for (int j = 0; j < half_m; j++) {
  290. // The Cooley-Tukey butterfly operation.
  291. const ModularInt t = psi.Mul(coeffs_[k + j + half_m], modular_params);
  292. ModularInt u = coeffs_[k + j];
  293. coeffs_[k + j].AddInPlace(t, modular_params);
  294. coeffs_[k + j + half_m] = std::move(u.SubInPlace(t, modular_params));
  295. }
  296. index_psi++;
  297. }
  298. }
  299. }
  300. // Helper function: Perform iterations of the Gentleman-Sande butterfly.
  301. void IterativeGentlemanSande(const std::vector<ModularInt>& psis_inv_bitrev,
  302. const ModularIntParams* modular_params) {
  303. int index_psi_inv = 0;
  304. for (int i = 0; i < log_len_; i++) {
  305. const unsigned int half_m = 1 << i;
  306. const unsigned int m = half_m << 1;
  307. for (int k = 0; k < Len(); k += m) {
  308. const ModularInt psi_inv = psis_inv_bitrev[index_psi_inv];
  309. for (int j = 0; j < half_m; j++) {
  310. // The Gentleman-Sande butterfly operation.
  311. const ModularInt t = coeffs_[k + j + half_m];
  312. ModularInt u = coeffs_[k + j];
  313. coeffs_[k + j].AddInPlace(t, modular_params);
  314. coeffs_[k + j + half_m] =
  315. std::move(u.SubInPlace(t, modular_params)
  316. .MulInPlace(psi_inv, modular_params));
  317. }
  318. index_psi_inv++;
  319. }
  320. }
  321. }
  322. };
  323. template <typename ModularInt, typename Prng = rlwe::SecurePrng>
  324. rlwe::StatusOr<Polynomial<ModularInt>> SamplePolynomialFromPrng(
  325. int num_coeffs, Prng* prng,
  326. const typename ModularInt::Params* modulus_params) {
  327. // Sample a from the uniform distribution. Since a is uniformly distributed,
  328. // it can be generated directly in NTT form since the NTT transformation is
  329. // an automorphism.
  330. if (num_coeffs < 1) {
  331. return absl::InvalidArgumentError(
  332. "SamplePolynomialFromPrng: number of coefficients must be a "
  333. "non-negative integer.");
  334. }
  335. std::vector<ModularInt> a_ntt_coeffs(num_coeffs,
  336. ModularInt::ImportZero(modulus_params));
  337. for (int i = 0; i < num_coeffs; i++) {
  338. RLWE_ASSIGN_OR_RETURN(a_ntt_coeffs[i],
  339. ModularInt::ImportRandom(prng, modulus_params));
  340. }
  341. return Polynomial<ModularInt>(a_ntt_coeffs);
  342. }
  343. } // namespace rlwe
  344. #endif // RLWE_POLYNOMIAL_H_