123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442 |
- /*
- * Copyright 2018 Google LLC.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "relinearization_key.h"
- #include "absl/numeric/int128.h"
- #include "bits_util.h"
- #include "montgomery.h"
- #include "prng/integral_prng_types.h"
- #include "status_macros.h"
- #include "statusor.h"
- #include "symmetric_encryption_with_prng.h"
- #include "third_party/shell-encryption/base/shell_encryption_export.h"
- #include "third_party/shell-encryption/base/shell_encryption_export_template.h"
- namespace rlwe {
- namespace {
- // Method to compute the number of digits needed to represent integers mod
- // q in base T. Upcasts the modulus to absl::uint128 to handle all Uint*
- // types.
- inline int ComputeDimension(Uint64 log_decomposition_modulus,
- absl::uint128 modulus) {
- Uint64 modulus_bits = static_cast<Uint64>(internal::BitLength(modulus));
- return (modulus_bits + (log_decomposition_modulus - 1)) /
- log_decomposition_modulus;
- }
- // Returns a random vector r orthogonal to (1,s). The second component is chosen
- // using randomness-of-encryption sampled using the specified PRNG. The first
- // component is then chosen so that r is perpendicular to (1,s).
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> SampleOrthogonalFromPrng(
- const SymmetricRlweKey<ModularInt>& key, SecurePrng* prng) {
- // Sample a random polynomial r using a PRNG.
- RLWE_ASSIGN_OR_RETURN(auto r, SamplePolynomialFromPrng<ModularInt>(
- key.Len(), prng, key.ModulusParams()));
- // Top entries of the matrix R will be -s*r, thus R is orthogonal to
- // (1,s).
- RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> r_top,
- r.Mul(key.Key(), key.ModulusParams()));
- r_top.NegateInPlace(key.ModulusParams());
- std::vector<Polynomial<ModularInt>> res = {std::move(r_top), std::move(r)};
- return res;
- }
- // The i-th component of the result is (T^i key_power).
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> PowersOfT(
- const Polynomial<ModularInt>& key_power,
- const SymmetricRlweKey<ModularInt>& key,
- const ModularInt& decomposition_modulus, int dimension) {
- std::vector<Polynomial<ModularInt>> result;
- result.reserve(dimension);
- Polynomial<ModularInt> key_to_i = key_power;
- for (int i = 0; i < dimension; i++) {
- // Increase the power of T in T^i s in place.
- if (i != 0) {
- RLWE_RETURN_IF_ERROR(
- key_to_i.MulInPlace(decomposition_modulus, key.ModulusParams()));
- }
- result.push_back(key_to_i);
- }
- return result;
- }
- // The i-th component of the result contains a vector of i-th digits of the
- // coefficients in base T (the decomposition modulus).
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<std::vector<ModularInt>>> BitDecompose(
- const std::vector<ModularInt>& coefficients,
- const typename ModularInt::Params* modulus_params,
- const Uint64 log_decomposition_modulus, int dimension) {
- std::vector<typename ModularInt::Int> ciphertext_coeffs(coefficients.size(),
- 0);
- std::transform(
- coefficients.begin(), coefficients.end(), ciphertext_coeffs.begin(),
- [modulus_params](ModularInt x) { return x.ExportInt(modulus_params); });
- std::vector<std::vector<ModularInt>> result(dimension);
- for (int i = 0; i < dimension; i++) {
- result[i].reserve(ciphertext_coeffs.size());
- for (int j = 0; j < ciphertext_coeffs.size(); ++j) {
- RLWE_ASSIGN_OR_RETURN(
- auto coefficient_part,
- ModularInt::ImportInt(
- (ciphertext_coeffs[j] % (1L << log_decomposition_modulus)),
- modulus_params));
- result[i].push_back(std::move(coefficient_part));
- ciphertext_coeffs[j] = ciphertext_coeffs[j] >> log_decomposition_modulus;
- }
- }
- return result;
- }
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
- std::vector<std::vector<ModularInt>> decomposed_coefficients,
- const std::vector<std::vector<Polynomial<ModularInt>>>& matrix,
- const typename ModularInt::Params* modulus_params,
- const NttParameters<ModularInt>* ntt_params) {
- Polynomial<ModularInt> temp(matrix[0][0].Len(), modulus_params);
- Polynomial<ModularInt> ntt_part(matrix[0][0].Len(), modulus_params);
- std::vector<Polynomial<ModularInt>> result(2, temp);
- for (int i = 0; i < matrix[0].size(); i++) {
- ntt_part = Polynomial<ModularInt>::ConvertToNtt(
- std::move(decomposed_coefficients[i]), ntt_params, modulus_params);
- RLWE_ASSIGN_OR_RETURN(temp, ntt_part.Mul(matrix[0][i], modulus_params));
- RLWE_RETURN_IF_ERROR(result[0].AddInPlace(temp, modulus_params));
- RLWE_RETURN_IF_ERROR(ntt_part.MulInPlace(matrix[1][i], modulus_params))
- RLWE_RETURN_IF_ERROR(result[1].AddInPlace(ntt_part, modulus_params));
- }
- return result;
- }
- } // namespace
- template <typename ModularInt>
- rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
- RelinearizationKey<ModularInt>::RelinearizationKeyPart::Create(
- const Polynomial<ModularInt>& key_power,
- const SymmetricRlweKey<ModularInt>& key,
- const Uint64 log_decomposition_modulus,
- const ModularInt& decomposition_modulus, int dimension, SecurePrng* prng,
- SecurePrng* prng_encryption) {
- std::vector<std::vector<Polynomial<ModularInt>>> matrix(2);
- for (auto& row : matrix) {
- row.reserve(dimension);
- }
- // Compute a vector of (T^i key_power).
- RLWE_ASSIGN_OR_RETURN(
- auto powers_of_t,
- PowersOfT(key_power, key, decomposition_modulus, dimension));
- // For key_power = s^j, the ith iteration of this loop computes the column of
- // the KeyPart corresponding to (T^i s^j).
- for (int i = 0; i < dimension; ++i) {
- // Sample r component orthogonal to (1,s).
- RLWE_ASSIGN_OR_RETURN(auto r, SampleOrthogonalFromPrng(key, prng));
- // Sample error.
- RLWE_ASSIGN_OR_RETURN(auto error,
- SampleFromErrorDistribution<ModularInt>(
- key_power.Len(), key.Variance(), prng_encryption,
- key.ModulusParams()));
- // Convert the error coefficients into an error polynomial.
- auto e = Polynomial<ModularInt>::ConvertToNtt(
- std::move(error), key.NttParams(), key.ModulusParams());
- // Set the column of the Relinearization matrix.
- RLWE_RETURN_IF_ERROR(
- e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
- RLWE_RETURN_IF_ERROR(e.AddInPlace(r[0], key.ModulusParams()));
- RLWE_RETURN_IF_ERROR(e.AddInPlace(powers_of_t[i], key.ModulusParams()));
- matrix[0].push_back(std::move(e));
- matrix[1].push_back(std::move(r[1]));
- }
- return RelinearizationKeyPart(std::move(matrix), log_decomposition_modulus);
- }
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<Polynomial<ModularInt>>>
- RelinearizationKey<ModularInt>::RelinearizationKeyPart::ApplyPartTo(
- const Polynomial<ModularInt>& ciphertext_part,
- const typename ModularInt::Params* modulus_params,
- const NttParameters<ModularInt>* ntt_params) const {
- // Convert ciphertext out of NTT form.
- std::vector<ModularInt> ciphertext_coefficients =
- ciphertext_part.InverseNtt(ntt_params, modulus_params);
- // Bit-decompose the vector of coefficients in the ciphertext.
- RLWE_ASSIGN_OR_RETURN(
- std::vector<std::vector<ModularInt>> decomposed_coefficients,
- BitDecompose<ModularInt>(ciphertext_coefficients, modulus_params,
- log_decomposition_modulus_, matrix_[0].size()));
- // Matrix multiply with the bit-decomposed coefficients.
- return MatrixMultiply<ModularInt>(std::move(decomposed_coefficients), matrix_,
- modulus_params, ntt_params);
- }
- template <typename ModularInt>
- rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
- RelinearizationKey<ModularInt>::RelinearizationKeyPart::Deserialize(
- const std::vector<SerializedNttPolynomial>& polynomials,
- Uint64 log_decomposition_modulus, SecurePrng* prng,
- const ModularIntParams* modulus_params,
- const NttParameters<ModularInt>* ntt_params) {
- // The polynomials input is a flattened representation of a 2 x dimension
- // matrix where the first half corresponds to the first row of matrix and the
- // second half corresponds to the second row of matrix. This matrix makes up
- // the RelinearizationKeyPart.
- int dimension = polynomials.size();
- auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
- matrix[0].reserve(dimension);
- matrix[1].reserve(dimension);
- for (int i = 0; i < dimension; i++) {
- RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
- polynomials[i], modulus_params));
- matrix[0].push_back(std::move(elt));
- RLWE_ASSIGN_OR_RETURN(auto sample,
- SamplePolynomialFromPrng<ModularInt>(
- matrix[0][i].Len(), prng, modulus_params));
- matrix[1].push_back(std::move(sample));
- }
- return RelinearizationKeyPart(std::move(matrix), log_decomposition_modulus);
- }
- template <typename ModularInt>
- RelinearizationKey<ModularInt>::RelinearizationKey(
- const SymmetricRlweKey<ModularInt>& key, absl::string_view prng_seed,
- ssize_t num_parts, Uint64 log_decomposition_modulus,
- Uint64 substitution_power, ModularInt decomposition_modulus,
- std::vector<RelinearizationKeyPart> relinearization_key)
- : dimension_(ComputeDimension(log_decomposition_modulus,
- key.ModulusParams()->modulus)),
- num_parts_(num_parts),
- log_decomposition_modulus_(log_decomposition_modulus),
- decomposition_modulus_(decomposition_modulus),
- substitution_power_(substitution_power),
- modulus_params_(key.ModulusParams()),
- ntt_params_(key.NttParams()),
- relinearization_key_(std::move(relinearization_key)),
- prng_seed_(prng_seed) {}
- template <typename ModularInt>
- rlwe::StatusOr<RelinearizationKey<ModularInt>>
- RelinearizationKey<ModularInt>::Create(const SymmetricRlweKey<ModularInt>& key,
- absl::string_view prng_seed,
- ssize_t num_parts,
- Uint64 log_decomposition_modulus,
- Uint64 substitution_power) {
- if (num_parts <= 0) {
- return absl::InvalidArgumentError(
- absl::StrCat("Num parts: ", num_parts, " must be positive."));
- }
- if (log_decomposition_modulus <= 0) {
- return absl::InvalidArgumentError(
- absl::StrCat("Log decomposition modulus, ", log_decomposition_modulus,
- ", must be positive."));
- } else if (log_decomposition_modulus > key.ModulusParams()->log_modulus) {
- return absl::InvalidArgumentError(absl::StrCat(
- "Log decomposition modulus, ", log_decomposition_modulus,
- ", must be at most: ", key.ModulusParams()->log_modulus, "."));
- }
- RLWE_ASSIGN_OR_RETURN(auto decomposition_modulus,
- ModularInt::ImportInt(key.ModulusParams()->One()
- << log_decomposition_modulus,
- key.ModulusParams()));
- // Initialize the first part of the secret key, s.
- RLWE_ASSIGN_OR_RETURN(auto key_base, key.Substitute(substitution_power));
- auto key_power = key_base.Key();
- RLWE_ASSIGN_OR_RETURN(auto prng, SingleThreadPrng::Create(prng_seed));
- RLWE_ASSIGN_OR_RETURN(auto prng_encryption_seed,
- SingleThreadPrng::GenerateSeed());
- RLWE_ASSIGN_OR_RETURN(auto prng_encryption,
- SingleThreadPrng::Create(prng_encryption_seed));
- auto dimension =
- ComputeDimension(log_decomposition_modulus, key.ModulusParams()->modulus);
- std::vector<RelinearizationKeyPart> relinearization_key;
- relinearization_key.reserve(num_parts);
- // Create RealinearizationKeyPart for each of the secret key parts: s, ...,
- // s^k.
- for (int i = 1; i < num_parts; i++) {
- if (i != 1) {
- // Increment the power of s.
- RLWE_RETURN_IF_ERROR(
- key_power.MulInPlace(key_base.Key(), key.ModulusParams()));
- }
- RLWE_ASSIGN_OR_RETURN(
- auto key_part,
- RelinearizationKeyPart::Create(
- key_power, key, log_decomposition_modulus, decomposition_modulus,
- dimension, prng.get(), prng_encryption.get()));
- relinearization_key.push_back(std::move(key_part));
- }
- return RelinearizationKey<ModularInt>(
- key, prng_seed, num_parts, log_decomposition_modulus, substitution_power,
- decomposition_modulus, std::move(relinearization_key));
- }
- template <typename ModularInt>
- rlwe::StatusOr<SymmetricRlweCiphertext<ModularInt>>
- RelinearizationKey<ModularInt>::ApplyTo(
- const SymmetricRlweCiphertext<ModularInt>& ciphertext) const {
- // Ensure that the length of the ciphertext is less than or equal to the
- // length of the relinearization key.
- if (ciphertext.Len() > num_parts_) {
- return absl::InvalidArgumentError(
- "RelinearizationKey not large enough for ciphertext.");
- }
- // Initialize the result ciphertext of length 2.
- RLWE_ASSIGN_OR_RETURN(auto comp, ciphertext.Component(0));
- std::vector<Polynomial<ModularInt>> result(
- 2, Polynomial<ModularInt>(comp.Len(), modulus_params_));
- // Apply each RelinearizationKeyPart to the part of the ciphertext it
- // corresponds to. The first component of the ciphertext corresponds to the
- // "1" part of the secret key, and is added without any
- // RelinearizationKeyPart.
- result[0] = std::move(comp);
- for (int i = 0; i < relinearization_key_.size(); i++) {
- // Add RelinearizationKeyPart_i c_i to the result vector.
- RLWE_ASSIGN_OR_RETURN(auto temp_comp, ciphertext.Component(i + 1));
- RLWE_ASSIGN_OR_RETURN(auto result_part,
- relinearization_key_[i].ApplyPartTo(
- temp_comp, modulus_params_, ntt_params_));
- RLWE_RETURN_IF_ERROR(result[0].AddInPlace(result_part[0], modulus_params_));
- RLWE_RETURN_IF_ERROR(result[1].AddInPlace(result_part[1], modulus_params_));
- }
- return SymmetricRlweCiphertext<ModularInt>(
- std::move(result), 1,
- ciphertext.Error() +
- ciphertext.ErrorParams()->B_relinearize(log_decomposition_modulus_),
- modulus_params_, ciphertext.ErrorParams());
- }
- template <typename ModularInt>
- rlwe::StatusOr<SerializedRelinearizationKey>
- RelinearizationKey<ModularInt>::Serialize() const {
- SerializedRelinearizationKey output;
- output.set_log_decomposition_modulus(log_decomposition_modulus_);
- output.set_num_parts(num_parts_);
- output.set_prng_seed(prng_seed_);
- output.set_power_of_s(substitution_power_);
- for (const RelinearizationKeyPart& matrix : relinearization_key_) {
- // Only serialize the first row of each matrix.
- for (const Polynomial<ModularInt>& c : matrix.Matrix()) {
- RLWE_ASSIGN_OR_RETURN(*output.add_c(), c.Serialize(modulus_params_));
- }
- }
- return output;
- }
- template <typename ModularInt>
- rlwe::StatusOr<RelinearizationKey<ModularInt>>
- RelinearizationKey<ModularInt>::Deserialize(
- const SerializedRelinearizationKey& serialized,
- const typename ModularInt::Params* modulus_params,
- const NttParameters<ModularInt>* ntt_params) {
- // Verifies that the number of polynomials in serialized is expected.
- // A RelinearizationKey can decrypt ciphertexts with num_parts number of
- // components corresponding to decryption under (1, s, ..., s^k) or (1,
- // s(x^power)) but only contains parts corresponding to the non-"1"
- // components.
- if (serialized.num_parts() <= 1) {
- return absl::InvalidArgumentError(
- absl::StrCat("The number of parts, ", serialized.num_parts(),
- ", must be greater than one."));
- } else if (serialized.c_size() % (serialized.num_parts() - 1) != 0) {
- return absl::InvalidArgumentError(
- absl::StrCat("The length of serialized, ", serialized.c_size(), ", ",
- "must be divisible by the number of parts minus one ",
- serialized.num_parts() - 1, "."));
- }
- // Return an error when log decomposition modulus is non-positive.
- if (serialized.log_decomposition_modulus() <= 0) {
- return absl::InvalidArgumentError(absl::StrCat(
- "Log decomposition modulus, ", serialized.log_decomposition_modulus(),
- ", must be positive."));
- } else if (serialized.log_decomposition_modulus() >
- modulus_params->log_modulus) {
- return absl::InvalidArgumentError(absl::StrCat(
- "Log decomposition modulus, ", serialized.log_decomposition_modulus(),
- ", must be at most: ", modulus_params->log_modulus, "."));
- }
- int polynomials_per_matrix =
- serialized.c_size() / (serialized.num_parts() - 1);
- int dimension = polynomials_per_matrix;
- if (dimension != ComputeDimension(serialized.log_decomposition_modulus(),
- modulus_params->modulus)) {
- return absl::InvalidArgumentError(
- absl::StrCat("Number of NTT Polynomials does not match expected ",
- "number of matrix entries."));
- }
- RLWE_ASSIGN_OR_RETURN(
- auto decomposition_modulus,
- ModularInt::ImportInt(static_cast<typename ModularInt::Int>(1)
- << serialized.log_decomposition_modulus(),
- modulus_params));
- RelinearizationKey output(serialized.log_decomposition_modulus(),
- decomposition_modulus, modulus_params, ntt_params);
- output.dimension_ = dimension;
- output.num_parts_ = serialized.num_parts();
- output.prng_seed_ = serialized.prng_seed();
- output.substitution_power_ = serialized.power_of_s();
- // Create prng based on seed.
- RLWE_ASSIGN_OR_RETURN(auto prng, SingleThreadPrng::Create(output.prng_seed_));
- // Takes each polynomials_per_matrix chunk of serialized.c()'s and places them
- // into a KeyPart.
- output.relinearization_key_.reserve(serialized.num_parts() - 1);
- for (int i = 0; i < (serialized.num_parts() - 1); i++) {
- auto start = serialized.c().begin() + i * polynomials_per_matrix;
- auto end = start + polynomials_per_matrix;
- std::vector<SerializedNttPolynomial> chunk(start, end);
- RLWE_ASSIGN_OR_RETURN(auto deserialized,
- RelinearizationKeyPart::Deserialize(
- chunk, serialized.log_decomposition_modulus(),
- prng.get(), modulus_params, ntt_params));
- output.relinearization_key_.push_back(std::move(deserialized));
- }
- return output;
- }
- // Instantiations of RelinearizationKey with specific MontgomeryInt classes.
- // If any new types are added, montgomery.h should be updated accordingly (such
- // as ensuring BigInt is correctly specialized, etc.).
- template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint16>>;
- template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint32>>;
- template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint64>>;
- template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<absl::uint128>>;
- } // namespace rlwe
|