123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- /*
- * Copyright 2017 Google LLC.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef RLWE_NTT_PARAMETERS_H_
- #define RLWE_NTT_PARAMETERS_H_
- #include <algorithm>
- #include <cstdlib>
- #include <vector>
- #include "absl/memory/memory.h"
- #include "absl/strings/str_cat.h"
- #include "constants.h"
- #include "status_macros.h"
- #include "statusor.h"
- #include "third_party/shell-encryption/base/shell_encryption_export.h"
- namespace rlwe {
- namespace internal {
- // Fill row with every power in {0, 1, ..., n-1} (mod modulus) of base .
- template <typename ModularInt>
- void FillWithEveryPower(const ModularInt& base, unsigned int n,
- std::vector<ModularInt>* row,
- const typename ModularInt::Params* params) {
- for (int i = 0; i < n; i++) {
- (*row)[i].AddInPlace(base.ModExp(i, params), params);
- }
- }
- template <typename ModularInt>
- rlwe::StatusOr<ModularInt> PrimitiveNthRootOfUnity(
- unsigned int log_n, const typename ModularInt::Params* params) {
- typename ModularInt::Int n = params->One() << log_n;
- typename ModularInt::Int half_n = n >> 1;
- // When the modulus is prime, the value k is a power such that any number
- // raised to it will be a n-th root of unity. (It will not necessarily be a
- // *primitive* root of unity, however).
- typename ModularInt::Int k = (params->modulus - params->One()) / n;
- // Test each number t to see whether t^k is a primitive n-th root
- // of unity - that t^{nk} is a root of unity but t^{(n/2)k} is not.
- ModularInt one = ModularInt::ImportOne(params);
- for (typename ModularInt::Int t = params->Two(); t < params->modulus;
- t = t + params->One()) {
- // Produce a candidate root of unity.
- RLWE_ASSIGN_OR_RETURN(auto mt, ModularInt::ImportInt(t, params));
- ModularInt candidate = mt.ModExp(k, params);
- // Check whether candidate^half_n = 1. If not, it is a primitive root of
- // unity.
- if (candidate.ModExp(half_n, params) != one) {
- return candidate;
- }
- }
- // Failure state. The above loop should always return successfully assuming
- // the parameters were set properly.
- return absl::UnknownError("Loop in PrimitiveNthRootOfUnity terminated.");
- }
- // Let psi be a primitive 2n-th root of unity, i.e., a 2n-th root of unity such
- // that psi^n = -1. When performing the NTT transformation, the powers of psi in
- // bitreversed order are needed. The vector produced by this helper function
- // contains the powers of psi (psi^0, psi^1, psi^2, ..., psi^(n-1)).
- //
- // Each item of the vector is in modular integer representation.
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<ModularInt>> NttPsis(
- unsigned int log_n, const typename ModularInt::Params* params) {
- // Obtain psi, a primitive 2n-th root of unity (hence log_n + 1).
- RLWE_ASSIGN_OR_RETURN(
- ModularInt psi,
- internal::PrimitiveNthRootOfUnity<ModularInt>(log_n + 1, params));
- unsigned int n = 1 << log_n;
- ModularInt zero = ModularInt::ImportZero(params);
- // Create a vector with the powers of psi.
- std::vector<ModularInt> row(n, zero);
- internal::FillWithEveryPower<ModularInt>(psi, n, &row, params);
- return row;
- }
- // Creates a vector containing the indices necessary to perform the NTT bit
- // reversal operation. Index i of the returned vector contains an integer with
- // the rightmost log_n bits of i reversed.
- SHELL_ENCRYPTION_EXPORT std::vector<unsigned int> BitrevArray(unsigned int log_n);
- // Helper function: Perform the bit-reversal operation in-place on coeffs_.
- template <typename ModularInt>
- static void BitrevHelper(const std::vector<unsigned int>& bitrevs,
- std::vector<ModularInt>* item_to_reverse) {
- using std::swap;
- for (int i = 0; i < item_to_reverse->size(); i++) {
- // Only swap in one direction - don't accidentally swap twice.
- unsigned int r = bitrevs[i];
- if (static_cast<unsigned int>(i) < r) {
- swap((*item_to_reverse)[i], (*item_to_reverse)[r]);
- }
- }
- }
- } // namespace internal
- // The precomputed roots of unity used during the forward NTT are the
- // bitreversed powers of the primitive 2n-th root of unity.
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<ModularInt>> NttPsisBitrev(
- unsigned int log_n, const typename ModularInt::Params* params) {
- // Retrieve the table for the forward transformation.
- RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> psis,
- internal::NttPsis<ModularInt>(log_n, params));
- // Bitreverse the vector.
- internal::BitrevHelper(internal::BitrevArray(log_n), &psis);
- return psis;
- }
- // The precomputed roots of unity used during the inverse NTT are the inverses
- // of the bitreversed powers of the primitive 2n-th root of unity plus 1.
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<ModularInt>> NttPsisInvBitrev(
- unsigned int log_n, const typename ModularInt::Params* params) {
- // Retrieve the table for the forward transformation.
- RLWE_ASSIGN_OR_RETURN(std::vector<ModularInt> row,
- internal::NttPsis<ModularInt>(log_n, params));
- // Reverse the items at indices 1 through (n - 1). Multiplying index i
- // of the reversed row by index i of the original row will yield psi^n = -1.
- // (The exception is psi^0 = 1, which is already its own inverse.)
- std::reverse(row.begin() + 1, row.end());
- // Get the inverse of psi
- ModularInt psi_inv = row[1].Negate(params);
- ModularInt negative_psi_inv = row[1];
- // Bitreverse the vector.
- internal::BitrevHelper(internal::BitrevArray(log_n), &row);
- // Finally, multiply each of the items at indices 1 to (n-1) by -1. Multiply
- // every entry by psi_inv.
- row[0].MulInPlace(psi_inv, params);
- for (int i = 1; i < row.size(); i++) {
- row[i].MulInPlace(negative_psi_inv, params);
- }
- return row;
- }
- // A struct that stores a package of NTT Parameters
- template <typename ModularInt>
- struct NttParameters {
- NttParameters() = default;
- // Disallow copy and copy-assign, allow move and move-assign.
- NttParameters(const NttParameters<ModularInt>&) = delete;
- NttParameters& operator=(const NttParameters<ModularInt>&) = delete;
- NttParameters(NttParameters<ModularInt>&&) = default;
- NttParameters& operator=(NttParameters<ModularInt>&&) = default;
- ~NttParameters() = default;
- int number_coeffs;
- absl::optional<ModularInt> n_inv_ptr;
- std::vector<ModularInt> psis_bitrev;
- std::vector<ModularInt> psis_inv_bitrev;
- std::vector<unsigned int> bitrevs;
- };
- // A convenient function that sets up all NTT parameters at once.
- // Does not take ownership of params.
- template <typename ModularInt>
- rlwe::StatusOr<NttParameters<ModularInt>> InitializeNttParameters(
- int log_n, const typename ModularInt::Params* params) {
- // Abort if log_n is non-positive.
- if (log_n <= 0) {
- return absl::InvalidArgumentError("log_n must be positive");
- } else if (log_n > kMaxLogNumCoeffs) {
- return absl::InvalidArgumentError(absl::StrCat(
- "log_n, ", log_n, ", must be less than ", kMaxLogNumCoeffs, "."));
- }
- if (!ModularInt::Params::DoesLogNFit(log_n)) {
- return absl::InvalidArgumentError(
- absl::StrCat("log_n, ", log_n,
- ", does not fit into underlying ModularInt::Int type."));
- }
- NttParameters<ModularInt> output;
- output.number_coeffs = 1 << log_n;
- typename ModularInt::Int two_times_n = params->One() << (log_n + 1);
- if (params->modulus % two_times_n != params->One()){
- return absl::InvalidArgumentError(
- absl::StrCat("modulus is not 1 mod 2n for logn, ", log_n));
- }
- // Compute the inverse of n.
- typename ModularInt::Int n = params->One() << log_n;
- RLWE_ASSIGN_OR_RETURN(auto mn, ModularInt::ImportInt(n, params));
- output.n_inv_ptr = mn.MultiplicativeInverse(params);
- RLWE_ASSIGN_OR_RETURN(output.psis_bitrev,
- NttPsisBitrev<ModularInt>(log_n, params));
- RLWE_ASSIGN_OR_RETURN(output.psis_inv_bitrev,
- NttPsisInvBitrev<ModularInt>(log_n, params));
- output.bitrevs = internal::BitrevArray(log_n);
- return output;
- }
- } // namespace rlwe
- #endif // RLWE_NTT_PARAMETERS_H_
|