123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977 |
- /*
- * Copyright 2017 Google LLC.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef RLWE_SYMMETRIC_ENCRYPTION_H_
- #define RLWE_SYMMETRIC_ENCRYPTION_H_
- #include <algorithm>
- #include <cstdint>
- #include <vector>
- #include "error_params.h"
- #include "polynomial.h"
- #include "prng/integral_prng_types.h"
- #include "prng/prng.h"
- #include "sample_error.h"
- #include "serialization.pb.h"
- #include "status_macros.h"
- namespace rlwe {
- // This file implements the somewhat homomorphic symmetric-key encryption scheme
- // from "Fully Homomorphic Encryption from Ring-LWE and Security for Key
- // Dependent Messages" by Zvika Brakerski and Vinod Vaikuntanathan. This
- // encryption scheme uses Ring Learning with Errors (RLWE).
- // http://www.wisdom.weizmann.ac.il/~zvikab/localpapers/IdealHom.pdf
- //
- // The scheme has CPA security under the hardness of the
- // Ring-Learning with Errors problem (see reference above for details). We do
- // not implement protections against timing attacks.
- //
- // The encryption scheme in this file is not fully homomorphic. It does not
- // implement any sort of bootstrapping.
- // Represents a ciphertext encrypted using a symmetric-key version of the ring
- // learning-with-errors (RLWE) encryption scheme. See the comments that follow
- // throughout this file for full details on the particular encryption scheme.
- //
- // This implementation supports the following homomorphic operations:
- // - Homomorphic addition.
- // - Scalar multiplication by a polynomial (absorption)
- // - Homomorphic multiplication.
- //
- // This implementation is only "somewhat homomorphic," not fully homomorphic.
- // There is no bootstrapping, so a limited number of homomorphic operations can
- // be performed before so much error accumulates that decryption is impossible.
- //
- // Each ciphertext comprises a vector of polynomials <c0, ..., cN>. Initially,
- // a ciphertext comprises a pair <c0, c1>. Homomorphic multiplications cause
- // the vector to grow longer.
- template <typename ModularInt>
- class SymmetricRlweCiphertext {
- using Int = typename ModularInt::Int;
- // BigInt is required in order to multiply two Int and ensure that no overflow
- // occurs during the multiplication of two ciphertexts.
- using BigInt = typename ModularInt::BigInt;
- public:
- // Default and copy constructors.
- explicit SymmetricRlweCiphertext(const typename ModularInt::Params* params,
- const ErrorParams<ModularInt>* error_params)
- : modulus_params_(params),
- error_params_(error_params),
- power_of_s_(1),
- error_(0) {}
- SymmetricRlweCiphertext(const SymmetricRlweCiphertext& that) = default;
- // Create a ciphertext by supplying the vector of components.
- explicit SymmetricRlweCiphertext(std::vector<Polynomial<ModularInt>> c,
- int power_of_s, double error,
- const typename ModularInt::Params* params,
- const ErrorParams<ModularInt>* error_params)
- : c_(std::move(c)),
- modulus_params_(params),
- error_params_(error_params),
- power_of_s_(power_of_s),
- error_(error) {}
- // Homomorphic addition: add the polynomials representing the ciphertexts
- // component-wise. The example below demonstrates why this procedure works
- // properly in the two-component case. The quantities a, s, m, t, and e are
- // introduced during encryption and are explained in the SymmetricRlweKey
- // class.
- //
- // (a1 * s + m1 + t * e1, -a1)
- // + (a2 * s + m2 + t * e2, -a2)
- // ------------------------------
- // ((a1 + a2) * s + (m1 + m2) + t * (e1 + e2), -(a1 + a2))
- //
- // Substitute (a1 + a2) = a3, (e1 + e2) = e3:
- //
- // (a3 * s + (m1 + m2) + t * e3, -a3)
- //
- // This result is a valid ciphertext where the value of a has changed, the
- // error has increased, and the encoded plaintext contains the sum of the
- // plaintexts that were encoded in the original two ciphertexts.
- rlwe::StatusOr<SymmetricRlweCiphertext> operator+(
- const SymmetricRlweCiphertext& that) const {
- SymmetricRlweCiphertext out = *this;
- RLWE_RETURN_IF_ERROR(out.AddInPlace(that));
- return out;
- }
- absl::Status AddInPlace(const SymmetricRlweCiphertext& that) {
- if (power_of_s_ != that.power_of_s_) {
- return absl::InvalidArgumentError(
- "Ciphertexts must be encrypted with the same key power.");
- }
- if (c_.size() < that.c_.size()) {
- Polynomial<ModularInt> zero(that.c_[0].Len(), modulus_params_);
- c_.resize(that.c_.size(), zero);
- }
- for (int i = 0; i < that.c_.size(); i++) {
- RLWE_RETURN_IF_ERROR(c_[i].AddInPlace(that.c_[i], modulus_params_));
- }
- error_ += that.error_;
- return absl::OkStatus();
- }
- // Homomorphic subtraction: subtract the polynomials representing the
- // ciphertexts component-wise. The example below demonstrates why this
- // procedure works properly in the two-component case. The quantities a, s, m,
- // t, and e are introduced during encryption and are explained in the
- // SymmetricRlweKey class.
- //
- // (a1 * s + m1 + t * e1, -a1)
- // - (a2 * s + m2 + t * e2, -a2)
- // ------------------------------
- // ((a1 - a2) * s + (m1 - m2) + t * (e1 - e2), -(a1 - a2))
- //
- // Substitute (a1 - a2) = a3, (e1 - e2) = e3:
- //
- // (a3 * s + (m1 - m2) + t * e3, -a3)
- //
- // This result is a valid ciphertext where the value of a has changed, the
- // error has increased, and the encoded plaintext contains the sum of the
- // plaintexts that were encoded in the original two ciphertexts.
- rlwe::StatusOr<SymmetricRlweCiphertext> operator-(
- const SymmetricRlweCiphertext& that) const {
- SymmetricRlweCiphertext out = *this;
- RLWE_RETURN_IF_ERROR(out.SubInPlace(that));
- return out;
- }
- absl::Status SubInPlace(const SymmetricRlweCiphertext& that) {
- if (power_of_s_ != that.power_of_s_) {
- return absl::InvalidArgumentError(
- "Ciphertexts must be encrypted with the same key power.");
- }
- if (c_.size() < that.c_.size()) {
- Polynomial<ModularInt> zero(that.c_[0].Len(), modulus_params_);
- c_.resize(that.c_.size(), zero);
- }
- for (int i = 0; i < that.c_.size(); i++) {
- RLWE_RETURN_IF_ERROR(c_[i].SubInPlace(that.c_[i], modulus_params_));
- }
- error_ += that.error_;
- return absl::OkStatus();
- }
- // Homomorphic absorbtion. Multiplies the current ciphertext {m1}_s (plaintext
- // m1 encrypted with symmetric key s) by a plaintext m2, resulting in a
- // ciphertext {m1 * m2}_s that stores m1 * m2 encrypted with symmetric key s.
- //
- // DO NOT CONFUSE THIS OPERATION WITH HOMOMORPHIC MULTIPLICATION.
- //
- // To perform this operation, multiply the each component of the
- // ciphertext by the plaintext polynomial. The example below demonstrates why
- // this procedure works properly in the two-component case. The quantities a,
- // s, m, t, and e are introduced during encryption and are explained in the
- // Encrypt() function later in this file.
- //
- // (a1 * s + m1 + t * e1, -a1) * p
- // = (a1 * s * p + m1 * p + t * e1 * p)
- //
- // Substitute (a1 * p) = a2 and (e1 * p) = e2:
- //
- // (a2 * s + m1 * p + t * e2)
- //
- // This result is a valid ciphertext where the value of a has changed, the
- // error has increased, and the encoded plaintext contains the product of
- // m1 and p.
- //
- // A few more details about the multiplication that takes place:
- //
- // The value stored in the resulting ciphertext is (m1 * m2) (mod 2^N + 1)
- // (mod t), where N is the number of coefficients in s (or m1 or m2, since
- // the all have the same number of coefficients). In other words, the
- // result is the remainder of (m1 * m2) mod the polynomial (2^N + 1) with
- // each of the coefficients the ntaken mod t. Any coefficient between 0 and
- // modulus / 2 is treated as a positive number for the purposes of the final
- // (mod t); any coefficient between modulus/2 and modulus is treated as
- // a negative number for the purposes of the final (mod t).
- rlwe::StatusOr<SymmetricRlweCiphertext> operator*(
- const Polynomial<ModularInt>& that) const {
- SymmetricRlweCiphertext out = *this;
- RLWE_RETURN_IF_ERROR(out.AbsorbInPlace(that));
- return out;
- }
- absl::Status AbsorbInPlace(const Polynomial<ModularInt>& that) {
- for (auto& component : this->c_) {
- RLWE_RETURN_IF_ERROR(component.MulInPlace(that, modulus_params_));
- }
- error_ *= error_params_->B_plaintext();
- return absl::OkStatus();
- }
- // Homomorphically absorb a plaintext scalar. This function is exactly like
- // homomorphic absorb above, except the plaintext is a constant.
- rlwe::StatusOr<SymmetricRlweCiphertext> operator*(
- const ModularInt& that) const {
- SymmetricRlweCiphertext out = *this;
- RLWE_RETURN_IF_ERROR(out.AbsorbInPlace(that));
- return out;
- }
- absl::Status AbsorbInPlace(const ModularInt& that) {
- for (auto& component : this->c_) {
- RLWE_RETURN_IF_ERROR(component.MulInPlace(that, modulus_params_));
- }
- error_ *= static_cast<double>(that.ExportInt(modulus_params_));
- return absl::OkStatus();
- }
- // Homomorphic multiply. Given two ciphertexts {m1}_s, {m2}_s containing
- // messages m1 and m2 encrypted with the same secret key s, return the
- // ciphertext {m1 * m2}_s containing the product of the messages.
- //
- // To perform this operation, treat the two ciphertext vectors as polynomials
- // and perform a polynomial multiplication:
- //
- // <c0, c1> * <c0', c1'> = <c0 * c0, c0 * c1 + c1 * c0, c1 * c1>
- //
- // If the two ciphertext vectors are of length m and n, the resulting
- // ciphertext is of length m + n - 1.
- //
- // The details of the multiplication that takes place between m1 and m2 are
- // the same as in the homomorphic absorb operation above (the other overload
- // of the * operator).
- rlwe::StatusOr<SymmetricRlweCiphertext> operator*(
- const SymmetricRlweCiphertext& that) {
- if (power_of_s_ != that.power_of_s_) {
- return absl::InvalidArgumentError(
- "Ciphertexts must be encrypted with the same key power.");
- }
- if (c_.size() <= 0 || that.c_.size() <= 0) {
- return absl::InvalidArgumentError(
- "Cannot multiply using an empty ciphertext.");
- }
- if (c_[0].Len() <= 0 || that.c_[0].Len() <= 0) {
- return absl::InvalidArgumentError(
- "Cannot multiply using an empty polynomial in the ciphertext.");
- }
- Polynomial<ModularInt> temp(c_[0].Len(), modulus_params_);
- std::vector<Polynomial<ModularInt>> result(c_.size() + that.c_.size() - 1,
- temp);
- for (int i = 0; i < c_.size(); i++) {
- for (int j = 0; j < that.c_.size(); j++) {
- RLWE_ASSIGN_OR_RETURN(temp, c_[i].Mul(that.c_[j], modulus_params_));
- RLWE_RETURN_IF_ERROR(result[i + j].AddInPlace(temp, modulus_params_));
- }
- }
- return SymmetricRlweCiphertext(std::move(result), power_of_s_,
- error_ * that.error_, modulus_params_,
- error_params_);
- }
- // Convert this ciphertext from (mod p) to (mod q).
- // Assumes that ModularInt::Int and ModularIntQ::Int are the same type.
- //
- // The current modulus (mod t) must be equal to modulus q (mod t).
- // This will always be true. For NTT to work properly, any modulus must be
- // of the form 2N + 1, where N is a power of 2. Likewise, the implementation
- // requires that t is a power of 2. This means that, for any modulus q and
- // modulus t allowed by the RLWE implementation, q % t == 1.
- template <typename ModularIntQ>
- rlwe::StatusOr<SymmetricRlweCiphertext<ModularIntQ>> SwitchModulus(
- const NttParameters<ModularInt>* ntt_params_p,
- const typename ModularIntQ::Params* modulus_params_q,
- const NttParameters<ModularIntQ>* ntt_params_q,
- const ErrorParams<ModularIntQ>* error_params_q, const Int& t) {
- Int p = modulus_params_->modulus;
- Int q = modulus_params_q->modulus;
- // Configuration error.
- if (p % t != q % t) {
- return absl::InvalidArgumentError("p % t != q % t");
- }
- SymmetricRlweCiphertext<ModularIntQ> output(modulus_params_q,
- error_params_q);
- output.power_of_s_ = power_of_s_;
- // Overestimate the ratio of the two moduli.
- double modulus_ratio = static_cast<double>(modulus_params_q->log_modulus) /
- modulus_params_->log_modulus;
- output.error_ = modulus_ratio * error_ + error_params_q->B_scale();
- output.c_.reserve(c_.size());
- for (const Polynomial<ModularInt>& c : c_) {
- // Extract each component of the ciphertext from NTT form.
- std::vector<ModularInt> coeffs_p =
- c.InverseNtt(ntt_params_p, modulus_params_);
- std::vector<ModularIntQ> coeffs_q;
- coeffs_q.reserve(coeffs_p.size());
- // Convert each coefficient of the polynomial from (mod p) to (mod q)
- for (const ModularInt& coeff_p : coeffs_p) {
- Int int_p = coeff_p.ExportInt(modulus_params_);
- // Scale the integer.
- Int int_q = static_cast<Int>(ModularInt::DivAndTruncate(
- static_cast<BigInt>(int_p) * static_cast<BigInt>(q),
- static_cast<BigInt>(p)));
- // Ensure that int_p = int_q mod t by changing int_q as little as
- // possible.
- Int int_p_mod_t = int_p % t;
- Int int_q_mod_t = int_q % t;
- Int adjustment_up = modulus_params_->Zero();
- Int adjustment_down = modulus_params_->Zero();
- // Determine whether to adjust int_q up or down to make sure int_q =
- // int_p (mod t).
- adjustment_up = int_p_mod_t - int_q_mod_t;
- adjustment_down = t + int_q_mod_t - int_p_mod_t;
- if (int_p_mod_t < int_q_mod_t) {
- adjustment_up = adjustment_up + t;
- adjustment_down = adjustment_down - t;
- }
- RLWE_ASSIGN_OR_RETURN(auto m_int_q,
- ModularIntQ::ImportInt(int_q, modulus_params_q));
- if (adjustment_up > adjustment_down) {
- RLWE_ASSIGN_OR_RETURN(
- auto m_adjustment_up,
- ModularIntQ::ImportInt(adjustment_up, modulus_params_q));
- // Adjust up.
- coeffs_q.push_back(
- std::move(m_adjustment_up.AddInPlace(m_int_q, modulus_params_q)));
- } else {
- RLWE_ASSIGN_OR_RETURN(
- auto m_adjustment_down,
- ModularIntQ::ImportInt(q - adjustment_down, modulus_params_q));
- // Adjust down.
- coeffs_q.push_back(std::move(
- m_adjustment_down.AddInPlace(m_int_q, modulus_params_q)));
- }
- }
- // Convert back to NTT.
- output.c_.push_back(Polynomial<ModularIntQ>::ConvertToNtt(
- std::move(coeffs_q), ntt_params_q, modulus_params_q));
- }
- return output;
- }
- // Given a ciphertext c encrypting a plaintext p(x) under secret key s(x),
- // returns a ciphertext c' encrypting p(x^power) under the secret key
- // s(x^power).
- // Power must be an odd non-negative integer less than 2 * num_coeffs.
- // This method uses NTT conversions to apply the substitution in the
- // coefficient domain, and should be avoided if performance is an issue.
- // Substitutions of the form 2^j + 1 are used to obliviously expand a query
- // ciphertext into a query vector.
- rlwe::StatusOr<SymmetricRlweCiphertext> Substitute(
- int substitution_power,
- const NttParameters<ModularInt>* ntt_params) const {
- SymmetricRlweCiphertext output(modulus_params_, error_params_);
- output.c_.reserve(c_.size());
- for (const Polynomial<ModularInt>& c : c_) {
- RLWE_ASSIGN_OR_RETURN(
- auto elt,
- c.Substitute(substitution_power, ntt_params, modulus_params_));
- output.c_.push_back(std::move(elt));
- }
- output.power_of_s_ = (power_of_s_ * substitution_power) % (2 * c_[0].Len());
- output.error_ = error_;
- return output;
- }
- rlwe::StatusOr<SerializedSymmetricRlweCiphertext> Serialize() const {
- SerializedSymmetricRlweCiphertext output;
- output.set_power_of_s(power_of_s_);
- output.set_error(error_);
- for (const Polynomial<ModularInt>& c : c_) {
- RLWE_ASSIGN_OR_RETURN(*output.add_c(), c.Serialize(modulus_params_));
- }
- return output;
- }
- static rlwe::StatusOr<SymmetricRlweCiphertext> Deserialize(
- const SerializedSymmetricRlweCiphertext& serialized,
- const typename ModularInt::Params* modulus_params,
- const ErrorParams<ModularInt>* error_params) {
- SymmetricRlweCiphertext output(modulus_params, error_params);
- output.power_of_s_ = serialized.power_of_s();
- output.error_ = serialized.error();
- if (serialized.c_size() <= 0) {
- return absl::InvalidArgumentError("Ciphertext cannot be empty.");
- } else if (serialized.c_size() > kMaxNumCoeffs) {
- return absl::InvalidArgumentError(
- absl::StrCat("Number of coefficients, ", serialized.c_size(),
- ", cannot be more than ", kMaxNumCoeffs, "."));
- }
- for (int i = 0; i < serialized.c_size(); i++) {
- RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
- serialized.c(i), modulus_params));
- output.c_.push_back(std::move(elt));
- }
- return output;
- }
- // Accessors.
- unsigned int Len() const { return c_.size(); }
- rlwe::StatusOr<Polynomial<ModularInt>> Component(int index) const {
- if (0 > index || index >= c_.size()) {
- return absl::InvalidArgumentError("Index out of range.");
- }
- return c_[index];
- }
- const typename ModularInt::Params* ModulusParams() const {
- return modulus_params_;
- }
- const rlwe::ErrorParams<ModularInt>* ErrorParams() const {
- return error_params_;
- }
- int PowerOfS() const { return power_of_s_; }
- double Error() const { return error_; }
- void SetError(double error) { error_ = error; }
- private:
- // The ciphertext.
- std::vector<Polynomial<ModularInt>> c_;
- // ModularInt parameters.
- const typename ModularInt::Params* modulus_params_;
- // Error parameters.
- const rlwe::ErrorParams<ModularInt>* error_params_;
- // The power a in s(x^a) that the ciphertext can be decrypted with.
- int power_of_s_;
- // A heuristic on the error of the ciphertext.
- double error_;
- // Make this class a friend of any version of this class, no matter the
- // template.
- template <typename Q>
- friend class SymmetricRlweCiphertext;
- };
- // Holds a key that can be used to encrypt messages using the RLWE-based
- // encryption scheme.
- template <typename ModularInt>
- class SymmetricRlweKey {
- using Int = typename ModularInt::Int;
- public:
- // Allow copy, copy-assign, move and move-assign.
- SymmetricRlweKey(const SymmetricRlweKey&) = default;
- SymmetricRlweKey& operator=(const SymmetricRlweKey&) = default;
- SymmetricRlweKey(SymmetricRlweKey&&) = default;
- SymmetricRlweKey& operator=(SymmetricRlweKey&&) = default;
- ~SymmetricRlweKey() = default;
- // Static factory that samples a key from the error distribution. The
- // polynomial representing the key must have a number of coefficients that is
- // a power of two, which is enforced by the first argument.
- //
- // Does not take ownership of rand, modulus_params or ntt_params.
- static rlwe::StatusOr<SymmetricRlweKey> Sample(
- unsigned int log_num_coeffs, uint64_t variance, uint64_t log_t,
- const typename ModularInt::Params* modulus_params,
- const NttParameters<ModularInt>* ntt_params, SecurePrng* prng) {
- RLWE_ASSIGN_OR_RETURN(
- auto error, SampleFromErrorDistribution<ModularInt>(
- 1 << log_num_coeffs, variance, prng, modulus_params));
- Polynomial<ModularInt> key = Polynomial<ModularInt>::ConvertToNtt(
- std::move(error), ntt_params, modulus_params);
- RLWE_ASSIGN_OR_RETURN(
- auto t_mod, ModularInt::ImportInt((modulus_params->One() << log_t) +
- modulus_params->One(),
- modulus_params));
- return SymmetricRlweKey(std::move(key), variance, log_t, std::move(t_mod),
- modulus_params, modulus_params, ntt_params);
- }
- rlwe::StatusOr<SerializedNttPolynomial> Serialize() const {
- return key_.Serialize(modulus_params_);
- }
- // Deserialize using modulus params as also the plaintext modulus params. Use
- // this when deserializing a non-modulus switched key.
- static rlwe::StatusOr<SymmetricRlweKey> Deserialize(
- Uint64 variance, Uint64 log_t,
- const SerializedNttPolynomial& serialized_key,
- const typename ModularInt::Params* modulus_params,
- const NttParameters<ModularInt>* ntt_params) {
- return Deserialize(variance, log_t, serialized_key, modulus_params,
- modulus_params, ntt_params);
- }
- static rlwe::StatusOr<SymmetricRlweKey> Deserialize(
- Uint64 variance, Uint64 log_t,
- const SerializedNttPolynomial& serialized_key,
- const typename ModularInt::Params* modulus_params,
- const typename ModularInt::Params* plaintext_modulus_params,
- const NttParameters<ModularInt>* ntt_params) {
- // Check that log_t is no larger than the log_modulus - 1.
- if (log_t > modulus_params->log_modulus - 1) {
- return absl::InvalidArgumentError(absl::StrCat(
- "The value of log_t, ", log_t, ", must be smaller than ",
- "log_modulus - 1, ", modulus_params->log_modulus - 1, "."));
- }
- RLWE_ASSIGN_OR_RETURN(
- Polynomial<ModularInt> key,
- Polynomial<ModularInt>::Deserialize(serialized_key, modulus_params));
- RLWE_ASSIGN_OR_RETURN(
- auto t_mod,
- ModularInt::ImportInt((plaintext_modulus_params->One() << log_t) +
- plaintext_modulus_params->One(),
- plaintext_modulus_params));
- return SymmetricRlweKey(std::move(key), variance, log_t, std::move(t_mod),
- modulus_params, plaintext_modulus_params,
- ntt_params);
- }
- // Generate a copy of this key in modulus q.
- //
- // The current modulus (mod t) must be equal to modulus q (mod t). This
- // property is implicitly enforced by the design of the code as described
- // by the corresponding comment on SymmetricRlweKey::SwitchModulus. This
- // property is also dynamically enforced.
- //
- // The algorithms for modulus-switching ciphertexts and keys are similar but
- // slightly different. In particular, RLWE keys are guaranteed to have small
- // coefficients, and thus modulus switching can be made very simple. Hence
- // we have 2 separate implementations of SwitchModulus for keys and
- // ciphertexts.
- template <typename ModularIntQ>
- rlwe::StatusOr<SymmetricRlweKey<ModularIntQ>> SwitchModulus(
- const typename ModularIntQ::Params* modulus_params_q,
- const NttParameters<ModularIntQ>* ntt_params_q) const {
- // Configuration failure.
- Int t = (modulus_params_q->One() << log_t_) + modulus_params_q->One();
- if (modulus_params_->modulus % t != modulus_params_q->modulus % t) {
- return absl::InvalidArgumentError("p % t != q % t");
- }
- typename ModularIntQ::Int p_mod_q =
- modulus_params_->modulus % modulus_params_q->modulus;
- std::vector<ModularInt> coeffs_p =
- key_.InverseNtt(ntt_params_, modulus_params_);
- std::vector<ModularIntQ> coeffs_q;
- // Convert each coefficient of the polynomial from (mod p) to (mod q)
- for (const ModularInt& coeff_p : coeffs_p) {
- // Ensure that negative numbers (mod p) are translated into negative
- // numbers (mod q).
- Int int_p = coeff_p.ExportInt(modulus_params_);
- if (int_p > modulus_params_->modulus >> 1) {
- int_p = int_p - p_mod_q;
- }
- RLWE_ASSIGN_OR_RETURN(auto m_int_p,
- ModularIntQ::ImportInt(int_p, modulus_params_q));
- coeffs_q.push_back(std::move(m_int_p));
- }
- // Convert back to NTT.
- auto key_q = Polynomial<ModularIntQ>::ConvertToNtt(
- std::move(coeffs_q), ntt_params_q, modulus_params_q);
- RLWE_ASSIGN_OR_RETURN(
- auto t_mod, ModularInt::ImportInt((modulus_params_q->One() << log_t_) +
- modulus_params_q->One(),
- modulus_params_q));
- return SymmetricRlweKey<ModularIntQ>(std::move(key_q), variance_, log_t_,
- std::move(t_mod), modulus_params_q,
- modulus_params_q, ntt_params_q);
- }
- // Given s(x), returns a secret key s(x^a).
- // This performs an Inverse NTT on the key, substitutes the key in polynomial
- // representation, and then performs an NTT again.
- rlwe::StatusOr<SymmetricRlweKey> Substitute(const int power) const {
- RLWE_ASSIGN_OR_RETURN(
- auto t_mod, ModularInt::ImportInt((modulus_params_->One() << log_t_) +
- modulus_params_->One(),
- modulus_params_));
- RLWE_ASSIGN_OR_RETURN(auto sub,
- key_.Substitute(power, ntt_params_, modulus_params_));
- return SymmetricRlweKey(std::move(sub), variance_, log_t_, std::move(t_mod),
- modulus_params_, plaintext_modulus_params_,
- ntt_params_);
- }
- // Accessors.
- unsigned int Len() const { return key_.Len(); }
- const NttParameters<ModularInt>* NttParams() const { return ntt_params_; }
- const typename ModularInt::Params* ModulusParams() const {
- return modulus_params_;
- }
- const unsigned int BitsPerCoeff() const { return log_t_; }
- const Uint64 Variance() const { return variance_; }
- const unsigned int LogT() const { return log_t_; }
- const ModularInt& PlaintextModulus() const { return t_mod_; }
- const typename ModularInt::Params* PlaintextModulusParams() const {
- return plaintext_modulus_params_;
- }
- const Polynomial<ModularInt>& Key() const { return key_; }
- // Add two homomorphic encryption keys.
- rlwe::StatusOr<SymmetricRlweKey<ModularInt>> Add(
- const SymmetricRlweKey<ModularInt>& other_key) {
- if (variance_ != other_key.variance_) {
- return absl::InvalidArgumentError(absl::StrCat(
- "The variance of the other key, ", other_key.variance_,
- ", is different than the variance of this key, ", variance_, "."));
- }
- if (log_t_ != other_key.log_t_) {
- return absl::InvalidArgumentError(absl::StrCat(
- "The log_t of the other key, ", other_key.log_t_,
- ", is different than the log_t of this key, ", log_t_, "."));
- }
- if (t_mod_ != other_key.t_mod_) {
- return absl::InvalidArgumentError(
- absl::StrCat("The plaintext space of the other key is different than "
- "the plaintext space of this key."));
- }
- RLWE_ASSIGN_OR_RETURN(auto key, key_.Add(other_key.key_, modulus_params_));
- return SymmetricRlweKey<ModularInt>(std::move(key), variance_, log_t_,
- t_mod_, modulus_params_,
- plaintext_modulus_params_, ntt_params_);
- }
- // Substract two homomorphic encryption keys.
- rlwe::StatusOr<SymmetricRlweKey<ModularInt>> Sub(
- const SymmetricRlweKey<ModularInt>& other_key) {
- if (variance_ != other_key.variance_) {
- return absl::InvalidArgumentError(absl::StrCat(
- "The variance of the other key, ", other_key.variance_,
- ", is different than the variance of this key, ", variance_, "."));
- }
- if (log_t_ != other_key.log_t_) {
- return absl::InvalidArgumentError(absl::StrCat(
- "The log_t of the other key, ", other_key.log_t_,
- ", is different than the log_t of this key, ", log_t_, "."));
- }
- if (t_mod_ != other_key.t_mod_) {
- return absl::InvalidArgumentError(
- absl::StrCat("The plaintext space of the other key is different than "
- "the plaintext space of this key."));
- }
- RLWE_ASSIGN_OR_RETURN(auto key, key_.Sub(other_key.key_, modulus_params_));
- return SymmetricRlweKey<ModularInt>(std::move(key), variance_, log_t_,
- t_mod_, modulus_params_,
- plaintext_modulus_params_, ntt_params_);
- }
- // Static function to create a null key (with value 0).
- static rlwe::StatusOr<SymmetricRlweKey> NullKey(
- unsigned int log_num_coeffs, Uint64 variance, Uint64 log_t,
- const typename ModularInt::Params* modulus_params,
- const NttParameters<ModularInt>* ntt_params) {
- Polynomial<ModularInt> zero(1 << log_num_coeffs, modulus_params);
- RLWE_ASSIGN_OR_RETURN(
- auto t_mod, ModularInt::ImportInt((modulus_params->One() << log_t) +
- modulus_params->One(),
- modulus_params));
- return SymmetricRlweKey(std::move(zero), variance, log_t, std::move(t_mod),
- modulus_params, modulus_params, ntt_params);
- }
- private:
- // The contents of the key itself.
- Polynomial<ModularInt> key_;
- // The variance of the binomial distribution from which the key and error are
- // drawn.
- Uint64 variance_;
- // The maximum size of any one coefficient of the polynomial representing a
- // plaintext message.
- unsigned int log_t_;
- ModularInt t_mod_;
- // NTT parameters.
- const NttParameters<ModularInt>* ntt_params_;
- // ModularInt parameters.
- const typename ModularInt::Params* modulus_params_;
- const typename ModularInt::Params* plaintext_modulus_params_;
- // A constructor. Does not take ownership of params.
- SymmetricRlweKey(Polynomial<ModularInt> key, Uint64 variance,
- unsigned int log_t, ModularInt t_mod,
- const typename ModularInt::Params* modulus_params,
- const typename ModularInt::Params* plaintext_modulus_params,
- const NttParameters<ModularInt>* ntt_params)
- : key_(std::move(key)),
- variance_(variance),
- log_t_(log_t),
- t_mod_(std::move(t_mod)),
- ntt_params_(ntt_params),
- modulus_params_(modulus_params),
- plaintext_modulus_params_(plaintext_modulus_params) {}
- // Make this class a friend of any version of this class, no matter the
- // template.
- template <typename Q>
- friend class SymmetricRlweKey;
- };
- // Encrypts the plaintext using ring learning-with-errors (RLWE) encryption.
- // (b/79577340): The parameter t is specified by log_t right, but is equal to
- // (1 << log_t) + 1 so that t is odd. This is to allow multiplicative inverses
- // of powers of 2, which are used to compress and obliviously expand a query
- // ciphertext.
- //
- // The scheme works as follows:
- // KeyGen(n, modulus q, error distr):
- // Sample a degree (n-1) polynomial whose coefficients are drawn from the
- // error distribution (mod q). This is our secret key. Call it s.
- //
- // Encrypt(secret key s, plaintext m, modulus q, modulus t, error distr):
- // 1) Sample a degree (n-1) polynomial whose coefficients are drawn
- // uniformly from any integer (mod q). Call this polynomial a.
- // 2) Sample a degree (n-1) polynomial whose coefficients are drawn from
- // the error distribution (mod q). Call this polynomial e.
- // 3) Our secret key s and plaintext m are both degree (n-1) polynomials.
- // For decryption to work, each coefficient of m must be < t.
- // Compute (a * s + t * e + m) (mod x^n + 1). Call this polynomial b.
- // 4) The ciphertext is the pair (b, -a). We refer to the pair of
- // polynomials representing a ciphertext as (c0, c1) =
- // (a * s + m + e * t, -a).
- //
- // Decrypt(secret key s, ciphertext (b, -a), modulus t):
- // // Decryption when the ciphertext has two components.
- // Compute and return (b - as) (mod t). Doing out the algebra:
- // b - as (mod t)
- // = as + te + m - as (mod t)
- // = te + m (mod t)
- // = m
- // Quoting the paper, "the condition for correct decryption is that the
- // L_infinity norm of the polynomial [te + m] is smaller than q/2." In
- // other words, the largest of the values te + m (recall that e is
- // sampled from a distribution) cannot exceed q/2.
- //
- // When the ciphertext has more than two components <c0, c1, ..., cN>,
- // it can be decrypted by taking the dot product with the vector
- // <s^0, s^1, ..., s^N> containing powers of the secret key:
- // te + m = <c0, 1, ..., cN> dot <s^0, s^1, ..., s^N>
- // = c0 * s^0 + c1 * s^1 + ... + cN * s^N
- //
- // Note that the Encrypt() function takes the original plaintext as
- // an Polynomial<ModularInt>, while the corresponding Decrypt() method
- // returns a std::vector<typename ModularInt::Int>. The two values will be the
- // same once the original plaintext is converted out of NTT and Montgomery form.
- // - The Encrypt() function takes an NTT polynomial so that, if the same
- // plaintext is to be encrypted repeatedly, the NTT conversion only needs
- // to be performed once by the caller.
- // - The Decrypt() function returns a vector of integers because the final
- // (mod t) step requires taking the polynomial (te + m) out of NTT and
- // Montgomery form.
- // It would be straightforward to write a wrapper of Encrypt() that takes
- // a vector of integers as input, thereby making the plaintext types of the
- // Encrypt() and Decrypt() functions symmetric.
- namespace internal {
- // This functions allows injecting a specific polynomial "a" as the randomness
- // of the encryption (that is the negation of the c1 component of the
- // ciphertext) and returns only the resulting c1 component of the ciphertext.
- // This function is intended for internal use only.
- template <typename ModularInt>
- rlwe::StatusOr<Polynomial<ModularInt>> Encrypt(
- const SymmetricRlweKey<ModularInt>& key,
- const Polynomial<ModularInt>& plaintext, const Polynomial<ModularInt>& a,
- SecurePrng* prng) {
- // Sample the error term from the error distribution.
- unsigned int num_coeffs = key.Len();
- RLWE_ASSIGN_OR_RETURN(
- std::vector<ModularInt> e_coeffs,
- SampleFromErrorDistribution<ModularInt>(num_coeffs, key.Variance(), prng,
- key.ModulusParams()));
- // Create and return c0.
- auto e = Polynomial<ModularInt>::ConvertToNtt(
- std::move(e_coeffs), key.NttParams(), key.ModulusParams());
- RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> temp,
- a.Mul(key.Key(), key.ModulusParams()));
- RLWE_RETURN_IF_ERROR(
- e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
- RLWE_RETURN_IF_ERROR(temp.AddInPlace(e, key.ModulusParams()));
- RLWE_RETURN_IF_ERROR(temp.AddInPlace(plaintext, key.ModulusParams()));
- return temp;
- }
- } // namespace internal
- // Encrypts the supplied plaintext using the given key. Randomness is drawn from
- // the key's underlying ModulusParams.
- template <typename ModularInt>
- rlwe::StatusOr<SymmetricRlweCiphertext<ModularInt>> Encrypt(
- const SymmetricRlweKey<ModularInt>& key,
- const Polynomial<ModularInt>& plaintext,
- const ErrorParams<ModularInt>* error_params, SecurePrng* prng) {
- // Sample a from the uniform distribution.
- RLWE_ASSIGN_OR_RETURN(auto a, SamplePolynomialFromPrng<ModularInt>(
- key.Len(), prng, key.ModulusParams()));
- // Create c0.
- RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> c0,
- internal::Encrypt(key, plaintext, a, prng));
- // Compute c1 = -a and return the ciphertext.
- return SymmetricRlweCiphertext<ModularInt>(
- std::vector<Polynomial<ModularInt>>{
- std::move(c0), std::move(a.NegateInPlace(key.ModulusParams()))},
- 1, error_params->B_encryption(), key.ModulusParams(), error_params);
- }
- // Takes as input the result of decrypting a RLWE plaintext that still contains
- // the error. Concretely, it contains m + e * t (mod q). This function
- // eliminates the error and returns the message. For reasons described below,
- // this operation is more complicated than a simple (mod t).
- //
- // The error is drawn from a binomial distribution centered at zero and
- // multiplied by t, meaning error values are either positive or negative
- // multiples of t. Since each coefficient of the plaintext is smaller than
- // t, some coefficients of the quantity m + e * t (which is all that's
- // left in the vector error_and_message) could be negative. We are using
- // modular arithmetic, so negative values become large positive values.
- //
- // Unfortunately, these negative values caues the naive error elimination
- // strategy to fail. In theory we could take (m + e * t) mod t to
- // eliminate the error portion and extract the message. However, consider
- // a case where the error is negative. Suppose that t=2, m=1, and e=-1
- // with a modulus q=7:
- //
- // m + e * t (mod q) =
- // 1 + -1 * 2 (mod 7) =
- // -1 (mod 7) =
- // 6 (mod 7)
- //
- // When we take 6 (mod t) = 6 (mod 2), we get 0, which is not the original
- // bit of m. To avoid this problem, we treat negative values as negative
- // values, not as their equivalents mod q.
- //
- // We consider (m + e * t) to be negative whenever it is between q/2
- // and q. Recall that, if |m + e * t| is greater than q/2, decryption
- // fails.
- //
- // When the quantity (m + e * t) (mod q) represents a negative number
- // mod q, we can re-create its non-modular negative form by computing
- // ((m + e * t) - q). We can then take this value mod t to extract the
- // correct answer.
- //
- // 1. (m + e * t (mod q)) = // in the range [q/2, q)
- // 2. (m + e * t - q) = // in the range [-q/2, 0)
- // 3. m (mod t) + e * t (mod t) - q (mod t) = // taken (mod t)
- // 4. m - (q (mod t))
- //
- // If we subtract q at step 2, we return negative numbers to their
- // original form. Since we are going to perform a (mod t) operation
- // anyway, we can subtract q (mod t) at step 2 to get the same result.
- // Subtracting q (mod t) instead ensures that the quantity at step 2
- // does not become negative, which is convenient because we are using
- // an unsigned integer type.
- //
- // Concluding the example from before with the fix:
- //
- // m + e * t (mod q) - q (mod t) =
- // 1 + -1 * 2 (mod 7) - 7 (mod 2) =
- // -1 (mod 7) - 7 (mod 2) = 6 - 1 = 5
- //
- // 5 (mod t) = 1, which is the original message.
- template <typename ModularInt>
- std::vector<typename ModularInt::Int> RemoveError(
- const std::vector<ModularInt>& error_and_message,
- const typename ModularInt::Int& q, const typename ModularInt::Int& t,
- const typename ModularInt::Params* modulus_params_q) {
- using Int = typename ModularInt::Int;
- Int q_mod_t = q % t;
- Int zero = modulus_params_q->Zero();
- std::vector<Int> plaintext(error_and_message.size(), zero);
- for (int i = 0; i < error_and_message.size(); i++) {
- plaintext[i] = error_and_message[i].ExportInt(modulus_params_q);
- if (plaintext[i] > (q >> 1)) {
- plaintext[i] = plaintext[i] - q_mod_t;
- }
- plaintext[i] = plaintext[i] % t;
- }
- return plaintext;
- }
- template <typename ModularInt>
- rlwe::StatusOr<std::vector<typename ModularInt::Int>> Decrypt(
- const SymmetricRlweKey<ModularInt>& key,
- const SymmetricRlweCiphertext<ModularInt>& ciphertext) {
- // Extract the error and message. To do so, take the dot product of the
- // ciphertext vector <c0, c1, ..., cN> and the vector of the powers of
- // the key <s^0, s^1, ..., s^N>.
- // Accumulator variables.
- Polynomial<ModularInt> error_and_message_ntt(key.Len(), key.ModulusParams());
- Polynomial<ModularInt> key_powers = key.Key();
- unsigned int ciphertext_len = ciphertext.Len();
- for (int i = 0; i < ciphertext_len; i++) {
- // Extract component i.
- RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> ci, ciphertext.Component(i));
- // Lazily increase the exponent of the key.
- if (i > 1) {
- RLWE_RETURN_IF_ERROR(
- key_powers.MulInPlace(key.Key(), key.ModulusParams()));
- }
- // Beyond c0, multiply the exponentiated key in.
- if (i > 0) {
- RLWE_RETURN_IF_ERROR(
- ci.MulInPlace(key_powers, ciphertext.ModulusParams()));
- }
- RLWE_RETURN_IF_ERROR(
- error_and_message_ntt.AddInPlace(ci, key.ModulusParams()));
- }
- // Invert the NTT process.
- std::vector<ModularInt> error_and_message =
- error_and_message_ntt.InverseNtt(key.NttParams(), key.ModulusParams());
- // Extract the message.
- return RemoveError<ModularInt>(
- error_and_message, key.ModulusParams()->modulus,
- key.PlaintextModulus().ExportInt(key.PlaintextModulusParams()),
- key.ModulusParams());
- }
- } // namespace rlwe
- #endif // RLWE_SYMMETRIC_ENCRYPTION_H_
|