123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- /*
- * Copyright 2017 Google LLC.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef RLWE_POLYNOMIAL_H_
- #define RLWE_POLYNOMIAL_H_
- #include <cmath>
- #include <vector>
- #include "absl/strings/str_cat.h"
- #include "constants.h"
- #include "ntt_parameters.h"
- #include "prng/prng.h"
- #include "serialization.pb.h"
- #include "status_macros.h"
- #include "statusor.h"
- namespace rlwe {
- // A polynomial in NTT form. The length of the polynomial must be a power of 2.
- template <typename ModularInt>
- class Polynomial {
- using ModularIntParams = typename ModularInt::Params;
- public:
- // Default constructor.
- Polynomial() = default;
- // Copy constructor.
- Polynomial(const Polynomial& p) = default;
- Polynomial& operator=(const Polynomial& that) = default;
- // Basic constructor.
- explicit Polynomial(std::vector<ModularInt> poly_coeffs)
- : log_len_(log2(poly_coeffs.size())), coeffs_(std::move(poly_coeffs)) {}
- // Create an empty polynomial of the specified length. The length must be
- // a power of 2.
- explicit Polynomial(int len, const ModularIntParams* params)
- : Polynomial(
- std::vector<ModularInt>(len, ModularInt::ImportZero(params))) {}
- // This is an implementation of the FFT from [Sei18, Sec. 2].
- // [Sei18] https://eprint.iacr.org/2018/039
- // All polynomial arithmetic performed is modulo (x^n+1) for n a power of two,
- // with the coefficients operated on modulo a prime modulus.
- //
- // Let psi be a primitive 2n-th root of the unity, i.e., psi is a 2n-th root
- // of unity such that psi^n = -1. Hence it holds that
- // x^n+1 = x^n-psi^n = (x^n/2-psi^n/2)*(x^n/2+psi^n/2)
- //
- //
- // If f = f_0 + f_1*x + ... + f_{n-1}*x^(n-1) is the polynomial to transform,
- // the i-th coefficient of the polynomial mod x^n/2-psi^n/2 can thus be
- // computed as
- // f'_i = f_i + psi^(n/2)*f_(n/2+i),
- // and the i-th coefficient of the polynomial mod x^n/2+psi^n/2 can thus be
- // computed as
- // f''_i = f_i - psi^(n/2)*f_(n/2+i)
- // This operation is called the Cooley-Tukey butterfly and is done
- // iteratively during the NTT.
- //
- // The FFT can thus be performed in-place and after the k-th level, it
- // produces the vector of polynomials with pairs of coefficients
- // f mod (x^(n/2^(k+1))-psi^brv[2^k+1]), f mod (x^(n/2^(k+1))+psi^brv[2^k+1])
- // where brv maps a log(n)-bit number to its bitreversal.
- static Polynomial ConvertToNtt(std::vector<ModularInt> poly_coeffs,
- const NttParameters<ModularInt>* ntt_params,
- const ModularIntParams* modular_params) {
- // Check to ensure that the coefficient vector is of the correct length.
- int len = poly_coeffs.size();
- if (len <= 0 || (len & (len - 1)) != 0) {
- // An error value.
- return Polynomial();
- }
- Polynomial output(std::move(poly_coeffs));
- output.IterativeCooleyTukey(ntt_params->psis_bitrev, modular_params);
- return output;
- }
- // Deprecated ConvertToNtt function taking NttParameters by constant reference
- ABSL_DEPRECATED("Use ConvertToNtt function with NttParameters pointer above.")
- static Polynomial ConvertToNtt(std::vector<ModularInt> poly_coeffs,
- const NttParameters<ModularInt>& ntt_params,
- const ModularIntParams* modular_params) {
- return ConvertToNtt(std::move(poly_coeffs), &ntt_params, modular_params);
- }
- // The inverse NTT transform is computed similarly by iteratively inverting
- // the NTT representation. For instance, using the same notation as above,
- // f'_i + f''_i = 2f_i and psi^(-n/2)*(f'_i-f''_i) = 2c_(n/2+i).
- //
- // In particular, the butterfly operation differs from the Cooley-Tukey
- // butterfly used during the forward transform in that addition and
- // substraction come before multiplying with a power of the root of unity.
- // This butterfly operation is called the Gentleman-Sande butterfly.
- //
- // At the end of the computation, a normalization step by the inverse of
- // n=2^log(n) (the factor 2 obtained at each level of the butterfly) is
- // required.
- std::vector<ModularInt> InverseNtt(
- const NttParameters<ModularInt>* ntt_params,
- const ModularIntParams* modular_params) const {
- Polynomial copy(*this);
- copy.IterativeGentlemanSande(ntt_params->psis_inv_bitrev, modular_params);
- // Normalize the result by multiplying by the inverse of n.
- for (auto& coeff : copy.coeffs_) {
- coeff.MulInPlace(ntt_params->n_inv_ptr.value(), modular_params);
- }
- return copy.coeffs_;
- }
- // Deprecated InverseNtt function taking NttParameters by constant reference
- ABSL_DEPRECATED("Use InverseNtt function with NttParameters pointer above.")
- std::vector<ModularInt> InverseNtt(
- const NttParameters<ModularInt>& ntt_params,
- const ModularIntParams* modular_params) const {
- return InverseNtt(&ntt_params, modular_params);
- }
- // Specifies whether the Polynomial is valid.
- bool IsValid() const { return !coeffs_.empty(); }
- // Scalar multiply.
- rlwe::StatusOr<Polynomial> Mul(const ModularInt& scalar,
- const ModularIntParams* modular_params) const {
- Polynomial output = *this;
- RLWE_RETURN_IF_ERROR(output.MulInPlace(scalar, modular_params));
- return output;
- }
- // Scalar multiply in place.
- absl::Status MulInPlace(const ModularInt& scalar,
- const ModularIntParams* modular_params) {
- return ModularInt::BatchMulInPlace(&coeffs_, scalar, modular_params);
- }
- // Coordinate-wise multiplication.
- rlwe::StatusOr<Polynomial> Mul(const Polynomial& that,
- const ModularIntParams* modular_params) const {
- Polynomial output = *this;
- RLWE_RETURN_IF_ERROR(output.MulInPlace(that, modular_params));
- return output;
- }
- // Coordinate-wise multiplication in place.
- absl::Status MulInPlace(const Polynomial& that,
- const ModularIntParams* modular_params) {
- // If this operation is invalid, return an invalid error.
- if (Len() != that.Len()) {
- return absl::InvalidArgumentError(
- "The polynomials do not have the same length.");
- }
- return ModularInt::BatchMulInPlace(&coeffs_, that.coeffs_, modular_params);
- }
- // Negation.
- Polynomial Negate(const ModularIntParams* modular_params) const {
- Polynomial output = *this;
- output.NegateInPlace(modular_params);
- return output;
- }
- // Negation in place.
- Polynomial& NegateInPlace(const ModularIntParams* modular_params) {
- for (auto& coeff : coeffs_) {
- coeff.NegateInPlace(modular_params);
- }
- return *this;
- }
- // Coordinate-wise addition.
- rlwe::StatusOr<Polynomial> Add(const Polynomial& that,
- const ModularIntParams* modular_params) const {
- Polynomial output = *this;
- RLWE_RETURN_IF_ERROR(output.AddInPlace(that, modular_params));
- return output;
- }
- // Coordinate-wise substraction.
- rlwe::StatusOr<Polynomial> Sub(const Polynomial& that,
- const ModularIntParams* modular_params) const {
- Polynomial output = *this;
- RLWE_RETURN_IF_ERROR(output.SubInPlace(that, modular_params));
- return output;
- }
- // Coordinate-wise addition in place.
- absl::Status AddInPlace(const Polynomial& that,
- const ModularIntParams* modular_params) {
- // If this operation is invalid, return an invalid error.
- if (Len() != that.Len()) {
- return absl::InvalidArgumentError(
- "The polynomials do not have the same length.");
- }
- return ModularInt::BatchAddInPlace(&coeffs_, that.coeffs_, modular_params);
- }
- // Coordinate-wise substraction in place.
- absl::Status SubInPlace(const Polynomial& that,
- const ModularIntParams* modular_params) {
- // If this operation is invalid, return an invalid error.
- if (Len() != that.Len()) {
- return absl::InvalidArgumentError(
- "The polynomials do not have the same length.");
- }
- return ModularInt::BatchSubInPlace(&coeffs_, that.coeffs_, modular_params);
- }
- // Substitute: Given an Polynomial representing p(x), returns an
- // Polynomial representing p(x^power). Power must be an odd non-negative
- // integer less than 2 * Len().
- rlwe::StatusOr<Polynomial> Substitute(
- const int power, const NttParameters<ModularInt>* ntt_params,
- const ModularIntParams* modulus_params) const {
- // The NTT representation consists in the evaluations of the polynomial at
- // roots psi^brv[n/2], psi^brv[n/2+1], ..., psi^brv[n/2+n/2-1],
- // psi^(n/2+brv[n/2+1]), ..., psi^(n/2+brv[n/2+n/2-1]).
- // Let f(x) be the original polynomial, and out(x) be the polynomial after
- // the substitution. Note that (psi^i)^power = psi^{(i * power) % (2 * n).
- if (0 > power || (power % 2) == 0 || power >= 2 * Len()) {
- return absl::InvalidArgumentError(
- absl::StrCat("Substitution power must be a non-negative odd "
- "integer less than 2*n."));
- }
- Polynomial out = *this;
- // Get the index of the psi^power evaluation
- int psi_power_index = (power - 1) / 2;
- // Update the coefficients one by one: remember that they are stored in
- // bitreversed order.
- for (int i = 0; i < Len(); i++) {
- out.coeffs_[ntt_params->bitrevs[i]] =
- coeffs_[ntt_params->bitrevs[psi_power_index]];
- // Each time the index increases by 1, the psi_power_index increases by
- // power mod the length.
- psi_power_index = (psi_power_index + power) % Len();
- }
- return out;
- }
- // Deprecated Substitute function taking NttParameters by constant reference
- ABSL_DEPRECATED("Use Substitute function with NttParameters pointer above.")
- rlwe::StatusOr<Polynomial> Substitute(
- const int power, const NttParameters<ModularInt>& ntt_params,
- const ModularIntParams* modulus_params) const {
- return Substitute(power, &ntt_params, modulus_params);
- }
- // Boolean comparison.
- bool operator==(const Polynomial& that) const {
- if (Len() != that.Len()) {
- return false;
- }
- for (int i = 0; i < Len(); i++) {
- if (coeffs_[i] != that.coeffs_[i]) {
- return false;
- }
- }
- return true;
- }
- bool operator!=(const Polynomial& that) const { return !(*this == that); }
- int Len() const { return coeffs_.size(); }
- // Accessor for coefficients.
- std::vector<ModularInt> Coeffs() const { return coeffs_; }
- rlwe::StatusOr<SerializedNttPolynomial> Serialize(
- const ModularIntParams* modular_params) const {
- SerializedNttPolynomial output;
- RLWE_ASSIGN_OR_RETURN(*(output.mutable_coeffs()),
- ModularInt::SerializeVector(coeffs_, modular_params));
- output.set_num_coeffs(coeffs_.size());
- return output;
- }
- static rlwe::StatusOr<Polynomial> Deserialize(
- const SerializedNttPolynomial& serialized,
- const ModularIntParams* modular_params) {
- if (serialized.num_coeffs() <= 0) {
- return absl::InvalidArgumentError(
- "Number of serialized coefficients must be positive.");
- } else if (serialized.num_coeffs() > kMaxNumCoeffs) {
- return absl::InvalidArgumentError(absl::StrCat(
- "Number of serialized coefficients, ", serialized.num_coeffs(),
- ", must be less than ", kMaxNumCoeffs, "."));
- }
- Polynomial output(serialized.num_coeffs(), modular_params);
- RLWE_ASSIGN_OR_RETURN(
- output.coeffs_,
- ModularInt::DeserializeVector(serialized.num_coeffs(),
- serialized.coeffs(), modular_params));
- return output;
- }
- private:
- // Instance variables.
- size_t log_len_;
- std::vector<ModularInt> coeffs_;
- // Helper function: Perform iterations of the Cooley-Tukey butterfly.
- void IterativeCooleyTukey(const std::vector<ModularInt>& psis_bitrev,
- const ModularIntParams* modular_params) {
- int index_psi = 1;
- for (int i = log_len_ - 1; i >= 0; i--) {
- const unsigned int half_m = 1 << i;
- const unsigned int m = half_m << 1;
- for (int k = 0; k < Len(); k += m) {
- const ModularInt psi = psis_bitrev[index_psi];
- for (int j = 0; j < half_m; j++) {
- // The Cooley-Tukey butterfly operation.
- const ModularInt t = psi.Mul(coeffs_[k + j + half_m], modular_params);
- ModularInt u = coeffs_[k + j];
- coeffs_[k + j].AddInPlace(t, modular_params);
- coeffs_[k + j + half_m] = std::move(u.SubInPlace(t, modular_params));
- }
- index_psi++;
- }
- }
- }
- // Helper function: Perform iterations of the Gentleman-Sande butterfly.
- void IterativeGentlemanSande(const std::vector<ModularInt>& psis_inv_bitrev,
- const ModularIntParams* modular_params) {
- int index_psi_inv = 0;
- for (int i = 0; i < log_len_; i++) {
- const unsigned int half_m = 1 << i;
- const unsigned int m = half_m << 1;
- for (int k = 0; k < Len(); k += m) {
- const ModularInt psi_inv = psis_inv_bitrev[index_psi_inv];
- for (int j = 0; j < half_m; j++) {
- // The Gentleman-Sande butterfly operation.
- const ModularInt t = coeffs_[k + j + half_m];
- ModularInt u = coeffs_[k + j];
- coeffs_[k + j].AddInPlace(t, modular_params);
- coeffs_[k + j + half_m] =
- std::move(u.SubInPlace(t, modular_params)
- .MulInPlace(psi_inv, modular_params));
- }
- index_psi_inv++;
- }
- }
- }
- };
- template <typename ModularInt, typename Prng = rlwe::SecurePrng>
- rlwe::StatusOr<Polynomial<ModularInt>> SamplePolynomialFromPrng(
- int num_coeffs, Prng* prng,
- const typename ModularInt::Params* modulus_params) {
- // Sample a from the uniform distribution. Since a is uniformly distributed,
- // it can be generated directly in NTT form since the NTT transformation is
- // an automorphism.
- if (num_coeffs < 1) {
- return absl::InvalidArgumentError(
- "SamplePolynomialFromPrng: number of coefficients must be a "
- "non-negative integer.");
- }
- std::vector<ModularInt> a_ntt_coeffs(num_coeffs,
- ModularInt::ImportZero(modulus_params));
- for (int i = 0; i < num_coeffs; i++) {
- RLWE_ASSIGN_OR_RETURN(a_ntt_coeffs[i],
- ModularInt::ImportRandom(prng, modulus_params));
- }
- return Polynomial<ModularInt>(a_ntt_coeffs);
- }
- } // namespace rlwe
- #endif // RLWE_POLYNOMIAL_H_
|