symmetric_encryption.h 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977
  1. /*
  2. * Copyright 2017 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #ifndef RLWE_SYMMETRIC_ENCRYPTION_H_
  16. #define RLWE_SYMMETRIC_ENCRYPTION_H_
  17. #include <algorithm>
  18. #include <cstdint>
  19. #include <vector>
  20. #include "error_params.h"
  21. #include "polynomial.h"
  22. #include "prng/integral_prng_types.h"
  23. #include "prng/prng.h"
  24. #include "sample_error.h"
  25. #include "serialization.pb.h"
  26. #include "status_macros.h"
  27. namespace rlwe {
  28. // This file implements the somewhat homomorphic symmetric-key encryption scheme
  29. // from "Fully Homomorphic Encryption from Ring-LWE and Security for Key
  30. // Dependent Messages" by Zvika Brakerski and Vinod Vaikuntanathan. This
  31. // encryption scheme uses Ring Learning with Errors (RLWE).
  32. // http://www.wisdom.weizmann.ac.il/~zvikab/localpapers/IdealHom.pdf
  33. //
  34. // The scheme has CPA security under the hardness of the
  35. // Ring-Learning with Errors problem (see reference above for details). We do
  36. // not implement protections against timing attacks.
  37. //
  38. // The encryption scheme in this file is not fully homomorphic. It does not
  39. // implement any sort of bootstrapping.
  40. // Represents a ciphertext encrypted using a symmetric-key version of the ring
  41. // learning-with-errors (RLWE) encryption scheme. See the comments that follow
  42. // throughout this file for full details on the particular encryption scheme.
  43. //
  44. // This implementation supports the following homomorphic operations:
  45. // - Homomorphic addition.
  46. // - Scalar multiplication by a polynomial (absorption)
  47. // - Homomorphic multiplication.
  48. //
  49. // This implementation is only "somewhat homomorphic," not fully homomorphic.
  50. // There is no bootstrapping, so a limited number of homomorphic operations can
  51. // be performed before so much error accumulates that decryption is impossible.
  52. //
  53. // Each ciphertext comprises a vector of polynomials <c0, ..., cN>. Initially,
  54. // a ciphertext comprises a pair <c0, c1>. Homomorphic multiplications cause
  55. // the vector to grow longer.
  56. template <typename ModularInt>
  57. class SymmetricRlweCiphertext {
  58. using Int = typename ModularInt::Int;
  59. // BigInt is required in order to multiply two Int and ensure that no overflow
  60. // occurs during the multiplication of two ciphertexts.
  61. using BigInt = typename ModularInt::BigInt;
  62. public:
  63. // Default and copy constructors.
  64. explicit SymmetricRlweCiphertext(const typename ModularInt::Params* params,
  65. const ErrorParams<ModularInt>* error_params)
  66. : modulus_params_(params),
  67. error_params_(error_params),
  68. power_of_s_(1),
  69. error_(0) {}
  70. SymmetricRlweCiphertext(const SymmetricRlweCiphertext& that) = default;
  71. // Create a ciphertext by supplying the vector of components.
  72. explicit SymmetricRlweCiphertext(std::vector<Polynomial<ModularInt>> c,
  73. int power_of_s, double error,
  74. const typename ModularInt::Params* params,
  75. const ErrorParams<ModularInt>* error_params)
  76. : c_(std::move(c)),
  77. modulus_params_(params),
  78. error_params_(error_params),
  79. power_of_s_(power_of_s),
  80. error_(error) {}
  81. // Homomorphic addition: add the polynomials representing the ciphertexts
  82. // component-wise. The example below demonstrates why this procedure works
  83. // properly in the two-component case. The quantities a, s, m, t, and e are
  84. // introduced during encryption and are explained in the SymmetricRlweKey
  85. // class.
  86. //
  87. // (a1 * s + m1 + t * e1, -a1)
  88. // + (a2 * s + m2 + t * e2, -a2)
  89. // ------------------------------
  90. // ((a1 + a2) * s + (m1 + m2) + t * (e1 + e2), -(a1 + a2))
  91. //
  92. // Substitute (a1 + a2) = a3, (e1 + e2) = e3:
  93. //
  94. // (a3 * s + (m1 + m2) + t * e3, -a3)
  95. //
  96. // This result is a valid ciphertext where the value of a has changed, the
  97. // error has increased, and the encoded plaintext contains the sum of the
  98. // plaintexts that were encoded in the original two ciphertexts.
  99. rlwe::StatusOr<SymmetricRlweCiphertext> operator+(
  100. const SymmetricRlweCiphertext& that) const {
  101. SymmetricRlweCiphertext out = *this;
  102. RLWE_RETURN_IF_ERROR(out.AddInPlace(that));
  103. return out;
  104. }
  105. absl::Status AddInPlace(const SymmetricRlweCiphertext& that) {
  106. if (power_of_s_ != that.power_of_s_) {
  107. return absl::InvalidArgumentError(
  108. "Ciphertexts must be encrypted with the same key power.");
  109. }
  110. if (c_.size() < that.c_.size()) {
  111. Polynomial<ModularInt> zero(that.c_[0].Len(), modulus_params_);
  112. c_.resize(that.c_.size(), zero);
  113. }
  114. for (int i = 0; i < that.c_.size(); i++) {
  115. RLWE_RETURN_IF_ERROR(c_[i].AddInPlace(that.c_[i], modulus_params_));
  116. }
  117. error_ += that.error_;
  118. return absl::OkStatus();
  119. }
  120. // Homomorphic subtraction: subtract the polynomials representing the
  121. // ciphertexts component-wise. The example below demonstrates why this
  122. // procedure works properly in the two-component case. The quantities a, s, m,
  123. // t, and e are introduced during encryption and are explained in the
  124. // SymmetricRlweKey class.
  125. //
  126. // (a1 * s + m1 + t * e1, -a1)
  127. // - (a2 * s + m2 + t * e2, -a2)
  128. // ------------------------------
  129. // ((a1 - a2) * s + (m1 - m2) + t * (e1 - e2), -(a1 - a2))
  130. //
  131. // Substitute (a1 - a2) = a3, (e1 - e2) = e3:
  132. //
  133. // (a3 * s + (m1 - m2) + t * e3, -a3)
  134. //
  135. // This result is a valid ciphertext where the value of a has changed, the
  136. // error has increased, and the encoded plaintext contains the sum of the
  137. // plaintexts that were encoded in the original two ciphertexts.
  138. rlwe::StatusOr<SymmetricRlweCiphertext> operator-(
  139. const SymmetricRlweCiphertext& that) const {
  140. SymmetricRlweCiphertext out = *this;
  141. RLWE_RETURN_IF_ERROR(out.SubInPlace(that));
  142. return out;
  143. }
  144. absl::Status SubInPlace(const SymmetricRlweCiphertext& that) {
  145. if (power_of_s_ != that.power_of_s_) {
  146. return absl::InvalidArgumentError(
  147. "Ciphertexts must be encrypted with the same key power.");
  148. }
  149. if (c_.size() < that.c_.size()) {
  150. Polynomial<ModularInt> zero(that.c_[0].Len(), modulus_params_);
  151. c_.resize(that.c_.size(), zero);
  152. }
  153. for (int i = 0; i < that.c_.size(); i++) {
  154. RLWE_RETURN_IF_ERROR(c_[i].SubInPlace(that.c_[i], modulus_params_));
  155. }
  156. error_ += that.error_;
  157. return absl::OkStatus();
  158. }
  159. // Homomorphic absorbtion. Multiplies the current ciphertext {m1}_s (plaintext
  160. // m1 encrypted with symmetric key s) by a plaintext m2, resulting in a
  161. // ciphertext {m1 * m2}_s that stores m1 * m2 encrypted with symmetric key s.
  162. //
  163. // DO NOT CONFUSE THIS OPERATION WITH HOMOMORPHIC MULTIPLICATION.
  164. //
  165. // To perform this operation, multiply the each component of the
  166. // ciphertext by the plaintext polynomial. The example below demonstrates why
  167. // this procedure works properly in the two-component case. The quantities a,
  168. // s, m, t, and e are introduced during encryption and are explained in the
  169. // Encrypt() function later in this file.
  170. //
  171. // (a1 * s + m1 + t * e1, -a1) * p
  172. // = (a1 * s * p + m1 * p + t * e1 * p)
  173. //
  174. // Substitute (a1 * p) = a2 and (e1 * p) = e2:
  175. //
  176. // (a2 * s + m1 * p + t * e2)
  177. //
  178. // This result is a valid ciphertext where the value of a has changed, the
  179. // error has increased, and the encoded plaintext contains the product of
  180. // m1 and p.
  181. //
  182. // A few more details about the multiplication that takes place:
  183. //
  184. // The value stored in the resulting ciphertext is (m1 * m2) (mod 2^N + 1)
  185. // (mod t), where N is the number of coefficients in s (or m1 or m2, since
  186. // the all have the same number of coefficients). In other words, the
  187. // result is the remainder of (m1 * m2) mod the polynomial (2^N + 1) with
  188. // each of the coefficients the ntaken mod t. Any coefficient between 0 and
  189. // modulus / 2 is treated as a positive number for the purposes of the final
  190. // (mod t); any coefficient between modulus/2 and modulus is treated as
  191. // a negative number for the purposes of the final (mod t).
  192. rlwe::StatusOr<SymmetricRlweCiphertext> operator*(
  193. const Polynomial<ModularInt>& that) const {
  194. SymmetricRlweCiphertext out = *this;
  195. RLWE_RETURN_IF_ERROR(out.AbsorbInPlace(that));
  196. return out;
  197. }
  198. absl::Status AbsorbInPlace(const Polynomial<ModularInt>& that) {
  199. for (auto& component : this->c_) {
  200. RLWE_RETURN_IF_ERROR(component.MulInPlace(that, modulus_params_));
  201. }
  202. error_ *= error_params_->B_plaintext();
  203. return absl::OkStatus();
  204. }
  205. // Homomorphically absorb a plaintext scalar. This function is exactly like
  206. // homomorphic absorb above, except the plaintext is a constant.
  207. rlwe::StatusOr<SymmetricRlweCiphertext> operator*(
  208. const ModularInt& that) const {
  209. SymmetricRlweCiphertext out = *this;
  210. RLWE_RETURN_IF_ERROR(out.AbsorbInPlace(that));
  211. return out;
  212. }
  213. absl::Status AbsorbInPlace(const ModularInt& that) {
  214. for (auto& component : this->c_) {
  215. RLWE_RETURN_IF_ERROR(component.MulInPlace(that, modulus_params_));
  216. }
  217. error_ *= static_cast<double>(that.ExportInt(modulus_params_));
  218. return absl::OkStatus();
  219. }
  220. // Homomorphic multiply. Given two ciphertexts {m1}_s, {m2}_s containing
  221. // messages m1 and m2 encrypted with the same secret key s, return the
  222. // ciphertext {m1 * m2}_s containing the product of the messages.
  223. //
  224. // To perform this operation, treat the two ciphertext vectors as polynomials
  225. // and perform a polynomial multiplication:
  226. //
  227. // <c0, c1> * <c0', c1'> = <c0 * c0, c0 * c1 + c1 * c0, c1 * c1>
  228. //
  229. // If the two ciphertext vectors are of length m and n, the resulting
  230. // ciphertext is of length m + n - 1.
  231. //
  232. // The details of the multiplication that takes place between m1 and m2 are
  233. // the same as in the homomorphic absorb operation above (the other overload
  234. // of the * operator).
  235. rlwe::StatusOr<SymmetricRlweCiphertext> operator*(
  236. const SymmetricRlweCiphertext& that) {
  237. if (power_of_s_ != that.power_of_s_) {
  238. return absl::InvalidArgumentError(
  239. "Ciphertexts must be encrypted with the same key power.");
  240. }
  241. if (c_.size() <= 0 || that.c_.size() <= 0) {
  242. return absl::InvalidArgumentError(
  243. "Cannot multiply using an empty ciphertext.");
  244. }
  245. if (c_[0].Len() <= 0 || that.c_[0].Len() <= 0) {
  246. return absl::InvalidArgumentError(
  247. "Cannot multiply using an empty polynomial in the ciphertext.");
  248. }
  249. Polynomial<ModularInt> temp(c_[0].Len(), modulus_params_);
  250. std::vector<Polynomial<ModularInt>> result(c_.size() + that.c_.size() - 1,
  251. temp);
  252. for (int i = 0; i < c_.size(); i++) {
  253. for (int j = 0; j < that.c_.size(); j++) {
  254. RLWE_ASSIGN_OR_RETURN(temp, c_[i].Mul(that.c_[j], modulus_params_));
  255. RLWE_RETURN_IF_ERROR(result[i + j].AddInPlace(temp, modulus_params_));
  256. }
  257. }
  258. return SymmetricRlweCiphertext(std::move(result), power_of_s_,
  259. error_ * that.error_, modulus_params_,
  260. error_params_);
  261. }
  262. // Convert this ciphertext from (mod p) to (mod q).
  263. // Assumes that ModularInt::Int and ModularIntQ::Int are the same type.
  264. //
  265. // The current modulus (mod t) must be equal to modulus q (mod t).
  266. // This will always be true. For NTT to work properly, any modulus must be
  267. // of the form 2N + 1, where N is a power of 2. Likewise, the implementation
  268. // requires that t is a power of 2. This means that, for any modulus q and
  269. // modulus t allowed by the RLWE implementation, q % t == 1.
  270. template <typename ModularIntQ>
  271. rlwe::StatusOr<SymmetricRlweCiphertext<ModularIntQ>> SwitchModulus(
  272. const NttParameters<ModularInt>* ntt_params_p,
  273. const typename ModularIntQ::Params* modulus_params_q,
  274. const NttParameters<ModularIntQ>* ntt_params_q,
  275. const ErrorParams<ModularIntQ>* error_params_q, const Int& t) {
  276. Int p = modulus_params_->modulus;
  277. Int q = modulus_params_q->modulus;
  278. // Configuration error.
  279. if (p % t != q % t) {
  280. return absl::InvalidArgumentError("p % t != q % t");
  281. }
  282. SymmetricRlweCiphertext<ModularIntQ> output(modulus_params_q,
  283. error_params_q);
  284. output.power_of_s_ = power_of_s_;
  285. // Overestimate the ratio of the two moduli.
  286. double modulus_ratio = static_cast<double>(modulus_params_q->log_modulus) /
  287. modulus_params_->log_modulus;
  288. output.error_ = modulus_ratio * error_ + error_params_q->B_scale();
  289. output.c_.reserve(c_.size());
  290. for (const Polynomial<ModularInt>& c : c_) {
  291. // Extract each component of the ciphertext from NTT form.
  292. std::vector<ModularInt> coeffs_p =
  293. c.InverseNtt(ntt_params_p, modulus_params_);
  294. std::vector<ModularIntQ> coeffs_q;
  295. coeffs_q.reserve(coeffs_p.size());
  296. // Convert each coefficient of the polynomial from (mod p) to (mod q)
  297. for (const ModularInt& coeff_p : coeffs_p) {
  298. Int int_p = coeff_p.ExportInt(modulus_params_);
  299. // Scale the integer.
  300. Int int_q = static_cast<Int>(ModularInt::DivAndTruncate(
  301. static_cast<BigInt>(int_p) * static_cast<BigInt>(q),
  302. static_cast<BigInt>(p)));
  303. // Ensure that int_p = int_q mod t by changing int_q as little as
  304. // possible.
  305. Int int_p_mod_t = int_p % t;
  306. Int int_q_mod_t = int_q % t;
  307. Int adjustment_up = modulus_params_->Zero();
  308. Int adjustment_down = modulus_params_->Zero();
  309. // Determine whether to adjust int_q up or down to make sure int_q =
  310. // int_p (mod t).
  311. adjustment_up = int_p_mod_t - int_q_mod_t;
  312. adjustment_down = t + int_q_mod_t - int_p_mod_t;
  313. if (int_p_mod_t < int_q_mod_t) {
  314. adjustment_up = adjustment_up + t;
  315. adjustment_down = adjustment_down - t;
  316. }
  317. RLWE_ASSIGN_OR_RETURN(auto m_int_q,
  318. ModularIntQ::ImportInt(int_q, modulus_params_q));
  319. if (adjustment_up > adjustment_down) {
  320. RLWE_ASSIGN_OR_RETURN(
  321. auto m_adjustment_up,
  322. ModularIntQ::ImportInt(adjustment_up, modulus_params_q));
  323. // Adjust up.
  324. coeffs_q.push_back(
  325. std::move(m_adjustment_up.AddInPlace(m_int_q, modulus_params_q)));
  326. } else {
  327. RLWE_ASSIGN_OR_RETURN(
  328. auto m_adjustment_down,
  329. ModularIntQ::ImportInt(q - adjustment_down, modulus_params_q));
  330. // Adjust down.
  331. coeffs_q.push_back(std::move(
  332. m_adjustment_down.AddInPlace(m_int_q, modulus_params_q)));
  333. }
  334. }
  335. // Convert back to NTT.
  336. output.c_.push_back(Polynomial<ModularIntQ>::ConvertToNtt(
  337. std::move(coeffs_q), ntt_params_q, modulus_params_q));
  338. }
  339. return output;
  340. }
  341. // Given a ciphertext c encrypting a plaintext p(x) under secret key s(x),
  342. // returns a ciphertext c' encrypting p(x^power) under the secret key
  343. // s(x^power).
  344. // Power must be an odd non-negative integer less than 2 * num_coeffs.
  345. // This method uses NTT conversions to apply the substitution in the
  346. // coefficient domain, and should be avoided if performance is an issue.
  347. // Substitutions of the form 2^j + 1 are used to obliviously expand a query
  348. // ciphertext into a query vector.
  349. rlwe::StatusOr<SymmetricRlweCiphertext> Substitute(
  350. int substitution_power,
  351. const NttParameters<ModularInt>* ntt_params) const {
  352. SymmetricRlweCiphertext output(modulus_params_, error_params_);
  353. output.c_.reserve(c_.size());
  354. for (const Polynomial<ModularInt>& c : c_) {
  355. RLWE_ASSIGN_OR_RETURN(
  356. auto elt,
  357. c.Substitute(substitution_power, ntt_params, modulus_params_));
  358. output.c_.push_back(std::move(elt));
  359. }
  360. output.power_of_s_ = (power_of_s_ * substitution_power) % (2 * c_[0].Len());
  361. output.error_ = error_;
  362. return output;
  363. }
  364. rlwe::StatusOr<SerializedSymmetricRlweCiphertext> Serialize() const {
  365. SerializedSymmetricRlweCiphertext output;
  366. output.set_power_of_s(power_of_s_);
  367. output.set_error(error_);
  368. for (const Polynomial<ModularInt>& c : c_) {
  369. RLWE_ASSIGN_OR_RETURN(*output.add_c(), c.Serialize(modulus_params_));
  370. }
  371. return output;
  372. }
  373. static rlwe::StatusOr<SymmetricRlweCiphertext> Deserialize(
  374. const SerializedSymmetricRlweCiphertext& serialized,
  375. const typename ModularInt::Params* modulus_params,
  376. const ErrorParams<ModularInt>* error_params) {
  377. SymmetricRlweCiphertext output(modulus_params, error_params);
  378. output.power_of_s_ = serialized.power_of_s();
  379. output.error_ = serialized.error();
  380. if (serialized.c_size() <= 0) {
  381. return absl::InvalidArgumentError("Ciphertext cannot be empty.");
  382. } else if (serialized.c_size() > kMaxNumCoeffs) {
  383. return absl::InvalidArgumentError(
  384. absl::StrCat("Number of coefficients, ", serialized.c_size(),
  385. ", cannot be more than ", kMaxNumCoeffs, "."));
  386. }
  387. for (int i = 0; i < serialized.c_size(); i++) {
  388. RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
  389. serialized.c(i), modulus_params));
  390. output.c_.push_back(std::move(elt));
  391. }
  392. return output;
  393. }
  394. // Accessors.
  395. unsigned int Len() const { return c_.size(); }
  396. rlwe::StatusOr<Polynomial<ModularInt>> Component(int index) const {
  397. if (0 > index || index >= c_.size()) {
  398. return absl::InvalidArgumentError("Index out of range.");
  399. }
  400. return c_[index];
  401. }
  402. const typename ModularInt::Params* ModulusParams() const {
  403. return modulus_params_;
  404. }
  405. const rlwe::ErrorParams<ModularInt>* ErrorParams() const {
  406. return error_params_;
  407. }
  408. int PowerOfS() const { return power_of_s_; }
  409. double Error() const { return error_; }
  410. void SetError(double error) { error_ = error; }
  411. private:
  412. // The ciphertext.
  413. std::vector<Polynomial<ModularInt>> c_;
  414. // ModularInt parameters.
  415. const typename ModularInt::Params* modulus_params_;
  416. // Error parameters.
  417. const rlwe::ErrorParams<ModularInt>* error_params_;
  418. // The power a in s(x^a) that the ciphertext can be decrypted with.
  419. int power_of_s_;
  420. // A heuristic on the error of the ciphertext.
  421. double error_;
  422. // Make this class a friend of any version of this class, no matter the
  423. // template.
  424. template <typename Q>
  425. friend class SymmetricRlweCiphertext;
  426. };
  427. // Holds a key that can be used to encrypt messages using the RLWE-based
  428. // encryption scheme.
  429. template <typename ModularInt>
  430. class SymmetricRlweKey {
  431. using Int = typename ModularInt::Int;
  432. public:
  433. // Allow copy, copy-assign, move and move-assign.
  434. SymmetricRlweKey(const SymmetricRlweKey&) = default;
  435. SymmetricRlweKey& operator=(const SymmetricRlweKey&) = default;
  436. SymmetricRlweKey(SymmetricRlweKey&&) = default;
  437. SymmetricRlweKey& operator=(SymmetricRlweKey&&) = default;
  438. ~SymmetricRlweKey() = default;
  439. // Static factory that samples a key from the error distribution. The
  440. // polynomial representing the key must have a number of coefficients that is
  441. // a power of two, which is enforced by the first argument.
  442. //
  443. // Does not take ownership of rand, modulus_params or ntt_params.
  444. static rlwe::StatusOr<SymmetricRlweKey> Sample(
  445. unsigned int log_num_coeffs, uint64_t variance, uint64_t log_t,
  446. const typename ModularInt::Params* modulus_params,
  447. const NttParameters<ModularInt>* ntt_params, SecurePrng* prng) {
  448. RLWE_ASSIGN_OR_RETURN(
  449. auto error, SampleFromErrorDistribution<ModularInt>(
  450. 1 << log_num_coeffs, variance, prng, modulus_params));
  451. Polynomial<ModularInt> key = Polynomial<ModularInt>::ConvertToNtt(
  452. std::move(error), ntt_params, modulus_params);
  453. RLWE_ASSIGN_OR_RETURN(
  454. auto t_mod, ModularInt::ImportInt((modulus_params->One() << log_t) +
  455. modulus_params->One(),
  456. modulus_params));
  457. return SymmetricRlweKey(std::move(key), variance, log_t, std::move(t_mod),
  458. modulus_params, modulus_params, ntt_params);
  459. }
  460. rlwe::StatusOr<SerializedNttPolynomial> Serialize() const {
  461. return key_.Serialize(modulus_params_);
  462. }
  463. // Deserialize using modulus params as also the plaintext modulus params. Use
  464. // this when deserializing a non-modulus switched key.
  465. static rlwe::StatusOr<SymmetricRlweKey> Deserialize(
  466. Uint64 variance, Uint64 log_t,
  467. const SerializedNttPolynomial& serialized_key,
  468. const typename ModularInt::Params* modulus_params,
  469. const NttParameters<ModularInt>* ntt_params) {
  470. return Deserialize(variance, log_t, serialized_key, modulus_params,
  471. modulus_params, ntt_params);
  472. }
  473. static rlwe::StatusOr<SymmetricRlweKey> Deserialize(
  474. Uint64 variance, Uint64 log_t,
  475. const SerializedNttPolynomial& serialized_key,
  476. const typename ModularInt::Params* modulus_params,
  477. const typename ModularInt::Params* plaintext_modulus_params,
  478. const NttParameters<ModularInt>* ntt_params) {
  479. // Check that log_t is no larger than the log_modulus - 1.
  480. if (log_t > modulus_params->log_modulus - 1) {
  481. return absl::InvalidArgumentError(absl::StrCat(
  482. "The value of log_t, ", log_t, ", must be smaller than ",
  483. "log_modulus - 1, ", modulus_params->log_modulus - 1, "."));
  484. }
  485. RLWE_ASSIGN_OR_RETURN(
  486. Polynomial<ModularInt> key,
  487. Polynomial<ModularInt>::Deserialize(serialized_key, modulus_params));
  488. RLWE_ASSIGN_OR_RETURN(
  489. auto t_mod,
  490. ModularInt::ImportInt((plaintext_modulus_params->One() << log_t) +
  491. plaintext_modulus_params->One(),
  492. plaintext_modulus_params));
  493. return SymmetricRlweKey(std::move(key), variance, log_t, std::move(t_mod),
  494. modulus_params, plaintext_modulus_params,
  495. ntt_params);
  496. }
  497. // Generate a copy of this key in modulus q.
  498. //
  499. // The current modulus (mod t) must be equal to modulus q (mod t). This
  500. // property is implicitly enforced by the design of the code as described
  501. // by the corresponding comment on SymmetricRlweKey::SwitchModulus. This
  502. // property is also dynamically enforced.
  503. //
  504. // The algorithms for modulus-switching ciphertexts and keys are similar but
  505. // slightly different. In particular, RLWE keys are guaranteed to have small
  506. // coefficients, and thus modulus switching can be made very simple. Hence
  507. // we have 2 separate implementations of SwitchModulus for keys and
  508. // ciphertexts.
  509. template <typename ModularIntQ>
  510. rlwe::StatusOr<SymmetricRlweKey<ModularIntQ>> SwitchModulus(
  511. const typename ModularIntQ::Params* modulus_params_q,
  512. const NttParameters<ModularIntQ>* ntt_params_q) const {
  513. // Configuration failure.
  514. Int t = (modulus_params_q->One() << log_t_) + modulus_params_q->One();
  515. if (modulus_params_->modulus % t != modulus_params_q->modulus % t) {
  516. return absl::InvalidArgumentError("p % t != q % t");
  517. }
  518. typename ModularIntQ::Int p_mod_q =
  519. modulus_params_->modulus % modulus_params_q->modulus;
  520. std::vector<ModularInt> coeffs_p =
  521. key_.InverseNtt(ntt_params_, modulus_params_);
  522. std::vector<ModularIntQ> coeffs_q;
  523. // Convert each coefficient of the polynomial from (mod p) to (mod q)
  524. for (const ModularInt& coeff_p : coeffs_p) {
  525. // Ensure that negative numbers (mod p) are translated into negative
  526. // numbers (mod q).
  527. Int int_p = coeff_p.ExportInt(modulus_params_);
  528. if (int_p > modulus_params_->modulus >> 1) {
  529. int_p = int_p - p_mod_q;
  530. }
  531. RLWE_ASSIGN_OR_RETURN(auto m_int_p,
  532. ModularIntQ::ImportInt(int_p, modulus_params_q));
  533. coeffs_q.push_back(std::move(m_int_p));
  534. }
  535. // Convert back to NTT.
  536. auto key_q = Polynomial<ModularIntQ>::ConvertToNtt(
  537. std::move(coeffs_q), ntt_params_q, modulus_params_q);
  538. RLWE_ASSIGN_OR_RETURN(
  539. auto t_mod, ModularInt::ImportInt((modulus_params_q->One() << log_t_) +
  540. modulus_params_q->One(),
  541. modulus_params_q));
  542. return SymmetricRlweKey<ModularIntQ>(std::move(key_q), variance_, log_t_,
  543. std::move(t_mod), modulus_params_q,
  544. modulus_params_q, ntt_params_q);
  545. }
  546. // Given s(x), returns a secret key s(x^a).
  547. // This performs an Inverse NTT on the key, substitutes the key in polynomial
  548. // representation, and then performs an NTT again.
  549. rlwe::StatusOr<SymmetricRlweKey> Substitute(const int power) const {
  550. RLWE_ASSIGN_OR_RETURN(
  551. auto t_mod, ModularInt::ImportInt((modulus_params_->One() << log_t_) +
  552. modulus_params_->One(),
  553. modulus_params_));
  554. RLWE_ASSIGN_OR_RETURN(auto sub,
  555. key_.Substitute(power, ntt_params_, modulus_params_));
  556. return SymmetricRlweKey(std::move(sub), variance_, log_t_, std::move(t_mod),
  557. modulus_params_, plaintext_modulus_params_,
  558. ntt_params_);
  559. }
  560. // Accessors.
  561. unsigned int Len() const { return key_.Len(); }
  562. const NttParameters<ModularInt>* NttParams() const { return ntt_params_; }
  563. const typename ModularInt::Params* ModulusParams() const {
  564. return modulus_params_;
  565. }
  566. const unsigned int BitsPerCoeff() const { return log_t_; }
  567. const Uint64 Variance() const { return variance_; }
  568. const unsigned int LogT() const { return log_t_; }
  569. const ModularInt& PlaintextModulus() const { return t_mod_; }
  570. const typename ModularInt::Params* PlaintextModulusParams() const {
  571. return plaintext_modulus_params_;
  572. }
  573. const Polynomial<ModularInt>& Key() const { return key_; }
  574. // Add two homomorphic encryption keys.
  575. rlwe::StatusOr<SymmetricRlweKey<ModularInt>> Add(
  576. const SymmetricRlweKey<ModularInt>& other_key) {
  577. if (variance_ != other_key.variance_) {
  578. return absl::InvalidArgumentError(absl::StrCat(
  579. "The variance of the other key, ", other_key.variance_,
  580. ", is different than the variance of this key, ", variance_, "."));
  581. }
  582. if (log_t_ != other_key.log_t_) {
  583. return absl::InvalidArgumentError(absl::StrCat(
  584. "The log_t of the other key, ", other_key.log_t_,
  585. ", is different than the log_t of this key, ", log_t_, "."));
  586. }
  587. if (t_mod_ != other_key.t_mod_) {
  588. return absl::InvalidArgumentError(
  589. absl::StrCat("The plaintext space of the other key is different than "
  590. "the plaintext space of this key."));
  591. }
  592. RLWE_ASSIGN_OR_RETURN(auto key, key_.Add(other_key.key_, modulus_params_));
  593. return SymmetricRlweKey<ModularInt>(std::move(key), variance_, log_t_,
  594. t_mod_, modulus_params_,
  595. plaintext_modulus_params_, ntt_params_);
  596. }
  597. // Substract two homomorphic encryption keys.
  598. rlwe::StatusOr<SymmetricRlweKey<ModularInt>> Sub(
  599. const SymmetricRlweKey<ModularInt>& other_key) {
  600. if (variance_ != other_key.variance_) {
  601. return absl::InvalidArgumentError(absl::StrCat(
  602. "The variance of the other key, ", other_key.variance_,
  603. ", is different than the variance of this key, ", variance_, "."));
  604. }
  605. if (log_t_ != other_key.log_t_) {
  606. return absl::InvalidArgumentError(absl::StrCat(
  607. "The log_t of the other key, ", other_key.log_t_,
  608. ", is different than the log_t of this key, ", log_t_, "."));
  609. }
  610. if (t_mod_ != other_key.t_mod_) {
  611. return absl::InvalidArgumentError(
  612. absl::StrCat("The plaintext space of the other key is different than "
  613. "the plaintext space of this key."));
  614. }
  615. RLWE_ASSIGN_OR_RETURN(auto key, key_.Sub(other_key.key_, modulus_params_));
  616. return SymmetricRlweKey<ModularInt>(std::move(key), variance_, log_t_,
  617. t_mod_, modulus_params_,
  618. plaintext_modulus_params_, ntt_params_);
  619. }
  620. // Static function to create a null key (with value 0).
  621. static rlwe::StatusOr<SymmetricRlweKey> NullKey(
  622. unsigned int log_num_coeffs, Uint64 variance, Uint64 log_t,
  623. const typename ModularInt::Params* modulus_params,
  624. const NttParameters<ModularInt>* ntt_params) {
  625. Polynomial<ModularInt> zero(1 << log_num_coeffs, modulus_params);
  626. RLWE_ASSIGN_OR_RETURN(
  627. auto t_mod, ModularInt::ImportInt((modulus_params->One() << log_t) +
  628. modulus_params->One(),
  629. modulus_params));
  630. return SymmetricRlweKey(std::move(zero), variance, log_t, std::move(t_mod),
  631. modulus_params, modulus_params, ntt_params);
  632. }
  633. private:
  634. // The contents of the key itself.
  635. Polynomial<ModularInt> key_;
  636. // The variance of the binomial distribution from which the key and error are
  637. // drawn.
  638. Uint64 variance_;
  639. // The maximum size of any one coefficient of the polynomial representing a
  640. // plaintext message.
  641. unsigned int log_t_;
  642. ModularInt t_mod_;
  643. // NTT parameters.
  644. const NttParameters<ModularInt>* ntt_params_;
  645. // ModularInt parameters.
  646. const typename ModularInt::Params* modulus_params_;
  647. const typename ModularInt::Params* plaintext_modulus_params_;
  648. // A constructor. Does not take ownership of params.
  649. SymmetricRlweKey(Polynomial<ModularInt> key, Uint64 variance,
  650. unsigned int log_t, ModularInt t_mod,
  651. const typename ModularInt::Params* modulus_params,
  652. const typename ModularInt::Params* plaintext_modulus_params,
  653. const NttParameters<ModularInt>* ntt_params)
  654. : key_(std::move(key)),
  655. variance_(variance),
  656. log_t_(log_t),
  657. t_mod_(std::move(t_mod)),
  658. ntt_params_(ntt_params),
  659. modulus_params_(modulus_params),
  660. plaintext_modulus_params_(plaintext_modulus_params) {}
  661. // Make this class a friend of any version of this class, no matter the
  662. // template.
  663. template <typename Q>
  664. friend class SymmetricRlweKey;
  665. };
  666. // Encrypts the plaintext using ring learning-with-errors (RLWE) encryption.
  667. // (b/79577340): The parameter t is specified by log_t right, but is equal to
  668. // (1 << log_t) + 1 so that t is odd. This is to allow multiplicative inverses
  669. // of powers of 2, which are used to compress and obliviously expand a query
  670. // ciphertext.
  671. //
  672. // The scheme works as follows:
  673. // KeyGen(n, modulus q, error distr):
  674. // Sample a degree (n-1) polynomial whose coefficients are drawn from the
  675. // error distribution (mod q). This is our secret key. Call it s.
  676. //
  677. // Encrypt(secret key s, plaintext m, modulus q, modulus t, error distr):
  678. // 1) Sample a degree (n-1) polynomial whose coefficients are drawn
  679. // uniformly from any integer (mod q). Call this polynomial a.
  680. // 2) Sample a degree (n-1) polynomial whose coefficients are drawn from
  681. // the error distribution (mod q). Call this polynomial e.
  682. // 3) Our secret key s and plaintext m are both degree (n-1) polynomials.
  683. // For decryption to work, each coefficient of m must be < t.
  684. // Compute (a * s + t * e + m) (mod x^n + 1). Call this polynomial b.
  685. // 4) The ciphertext is the pair (b, -a). We refer to the pair of
  686. // polynomials representing a ciphertext as (c0, c1) =
  687. // (a * s + m + e * t, -a).
  688. //
  689. // Decrypt(secret key s, ciphertext (b, -a), modulus t):
  690. // // Decryption when the ciphertext has two components.
  691. // Compute and return (b - as) (mod t). Doing out the algebra:
  692. // b - as (mod t)
  693. // = as + te + m - as (mod t)
  694. // = te + m (mod t)
  695. // = m
  696. // Quoting the paper, "the condition for correct decryption is that the
  697. // L_infinity norm of the polynomial [te + m] is smaller than q/2." In
  698. // other words, the largest of the values te + m (recall that e is
  699. // sampled from a distribution) cannot exceed q/2.
  700. //
  701. // When the ciphertext has more than two components <c0, c1, ..., cN>,
  702. // it can be decrypted by taking the dot product with the vector
  703. // <s^0, s^1, ..., s^N> containing powers of the secret key:
  704. // te + m = <c0, 1, ..., cN> dot <s^0, s^1, ..., s^N>
  705. // = c0 * s^0 + c1 * s^1 + ... + cN * s^N
  706. //
  707. // Note that the Encrypt() function takes the original plaintext as
  708. // an Polynomial<ModularInt>, while the corresponding Decrypt() method
  709. // returns a std::vector<typename ModularInt::Int>. The two values will be the
  710. // same once the original plaintext is converted out of NTT and Montgomery form.
  711. // - The Encrypt() function takes an NTT polynomial so that, if the same
  712. // plaintext is to be encrypted repeatedly, the NTT conversion only needs
  713. // to be performed once by the caller.
  714. // - The Decrypt() function returns a vector of integers because the final
  715. // (mod t) step requires taking the polynomial (te + m) out of NTT and
  716. // Montgomery form.
  717. // It would be straightforward to write a wrapper of Encrypt() that takes
  718. // a vector of integers as input, thereby making the plaintext types of the
  719. // Encrypt() and Decrypt() functions symmetric.
  720. namespace internal {
  721. // This functions allows injecting a specific polynomial "a" as the randomness
  722. // of the encryption (that is the negation of the c1 component of the
  723. // ciphertext) and returns only the resulting c1 component of the ciphertext.
  724. // This function is intended for internal use only.
  725. template <typename ModularInt>
  726. rlwe::StatusOr<Polynomial<ModularInt>> Encrypt(
  727. const SymmetricRlweKey<ModularInt>& key,
  728. const Polynomial<ModularInt>& plaintext, const Polynomial<ModularInt>& a,
  729. SecurePrng* prng) {
  730. // Sample the error term from the error distribution.
  731. unsigned int num_coeffs = key.Len();
  732. RLWE_ASSIGN_OR_RETURN(
  733. std::vector<ModularInt> e_coeffs,
  734. SampleFromErrorDistribution<ModularInt>(num_coeffs, key.Variance(), prng,
  735. key.ModulusParams()));
  736. // Create and return c0.
  737. auto e = Polynomial<ModularInt>::ConvertToNtt(
  738. std::move(e_coeffs), key.NttParams(), key.ModulusParams());
  739. RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> temp,
  740. a.Mul(key.Key(), key.ModulusParams()));
  741. RLWE_RETURN_IF_ERROR(
  742. e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
  743. RLWE_RETURN_IF_ERROR(temp.AddInPlace(e, key.ModulusParams()));
  744. RLWE_RETURN_IF_ERROR(temp.AddInPlace(plaintext, key.ModulusParams()));
  745. return temp;
  746. }
  747. } // namespace internal
  748. // Encrypts the supplied plaintext using the given key. Randomness is drawn from
  749. // the key's underlying ModulusParams.
  750. template <typename ModularInt>
  751. rlwe::StatusOr<SymmetricRlweCiphertext<ModularInt>> Encrypt(
  752. const SymmetricRlweKey<ModularInt>& key,
  753. const Polynomial<ModularInt>& plaintext,
  754. const ErrorParams<ModularInt>* error_params, SecurePrng* prng) {
  755. // Sample a from the uniform distribution.
  756. RLWE_ASSIGN_OR_RETURN(auto a, SamplePolynomialFromPrng<ModularInt>(
  757. key.Len(), prng, key.ModulusParams()));
  758. // Create c0.
  759. RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> c0,
  760. internal::Encrypt(key, plaintext, a, prng));
  761. // Compute c1 = -a and return the ciphertext.
  762. return SymmetricRlweCiphertext<ModularInt>(
  763. std::vector<Polynomial<ModularInt>>{
  764. std::move(c0), std::move(a.NegateInPlace(key.ModulusParams()))},
  765. 1, error_params->B_encryption(), key.ModulusParams(), error_params);
  766. }
  767. // Takes as input the result of decrypting a RLWE plaintext that still contains
  768. // the error. Concretely, it contains m + e * t (mod q). This function
  769. // eliminates the error and returns the message. For reasons described below,
  770. // this operation is more complicated than a simple (mod t).
  771. //
  772. // The error is drawn from a binomial distribution centered at zero and
  773. // multiplied by t, meaning error values are either positive or negative
  774. // multiples of t. Since each coefficient of the plaintext is smaller than
  775. // t, some coefficients of the quantity m + e * t (which is all that's
  776. // left in the vector error_and_message) could be negative. We are using
  777. // modular arithmetic, so negative values become large positive values.
  778. //
  779. // Unfortunately, these negative values caues the naive error elimination
  780. // strategy to fail. In theory we could take (m + e * t) mod t to
  781. // eliminate the error portion and extract the message. However, consider
  782. // a case where the error is negative. Suppose that t=2, m=1, and e=-1
  783. // with a modulus q=7:
  784. //
  785. // m + e * t (mod q) =
  786. // 1 + -1 * 2 (mod 7) =
  787. // -1 (mod 7) =
  788. // 6 (mod 7)
  789. //
  790. // When we take 6 (mod t) = 6 (mod 2), we get 0, which is not the original
  791. // bit of m. To avoid this problem, we treat negative values as negative
  792. // values, not as their equivalents mod q.
  793. //
  794. // We consider (m + e * t) to be negative whenever it is between q/2
  795. // and q. Recall that, if |m + e * t| is greater than q/2, decryption
  796. // fails.
  797. //
  798. // When the quantity (m + e * t) (mod q) represents a negative number
  799. // mod q, we can re-create its non-modular negative form by computing
  800. // ((m + e * t) - q). We can then take this value mod t to extract the
  801. // correct answer.
  802. //
  803. // 1. (m + e * t (mod q)) = // in the range [q/2, q)
  804. // 2. (m + e * t - q) = // in the range [-q/2, 0)
  805. // 3. m (mod t) + e * t (mod t) - q (mod t) = // taken (mod t)
  806. // 4. m - (q (mod t))
  807. //
  808. // If we subtract q at step 2, we return negative numbers to their
  809. // original form. Since we are going to perform a (mod t) operation
  810. // anyway, we can subtract q (mod t) at step 2 to get the same result.
  811. // Subtracting q (mod t) instead ensures that the quantity at step 2
  812. // does not become negative, which is convenient because we are using
  813. // an unsigned integer type.
  814. //
  815. // Concluding the example from before with the fix:
  816. //
  817. // m + e * t (mod q) - q (mod t) =
  818. // 1 + -1 * 2 (mod 7) - 7 (mod 2) =
  819. // -1 (mod 7) - 7 (mod 2) = 6 - 1 = 5
  820. //
  821. // 5 (mod t) = 1, which is the original message.
  822. template <typename ModularInt>
  823. std::vector<typename ModularInt::Int> RemoveError(
  824. const std::vector<ModularInt>& error_and_message,
  825. const typename ModularInt::Int& q, const typename ModularInt::Int& t,
  826. const typename ModularInt::Params* modulus_params_q) {
  827. using Int = typename ModularInt::Int;
  828. Int q_mod_t = q % t;
  829. Int zero = modulus_params_q->Zero();
  830. std::vector<Int> plaintext(error_and_message.size(), zero);
  831. for (int i = 0; i < error_and_message.size(); i++) {
  832. plaintext[i] = error_and_message[i].ExportInt(modulus_params_q);
  833. if (plaintext[i] > (q >> 1)) {
  834. plaintext[i] = plaintext[i] - q_mod_t;
  835. }
  836. plaintext[i] = plaintext[i] % t;
  837. }
  838. return plaintext;
  839. }
  840. template <typename ModularInt>
  841. rlwe::StatusOr<std::vector<typename ModularInt::Int>> Decrypt(
  842. const SymmetricRlweKey<ModularInt>& key,
  843. const SymmetricRlweCiphertext<ModularInt>& ciphertext) {
  844. // Extract the error and message. To do so, take the dot product of the
  845. // ciphertext vector <c0, c1, ..., cN> and the vector of the powers of
  846. // the key <s^0, s^1, ..., s^N>.
  847. // Accumulator variables.
  848. Polynomial<ModularInt> error_and_message_ntt(key.Len(), key.ModulusParams());
  849. Polynomial<ModularInt> key_powers = key.Key();
  850. unsigned int ciphertext_len = ciphertext.Len();
  851. for (int i = 0; i < ciphertext_len; i++) {
  852. // Extract component i.
  853. RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> ci, ciphertext.Component(i));
  854. // Lazily increase the exponent of the key.
  855. if (i > 1) {
  856. RLWE_RETURN_IF_ERROR(
  857. key_powers.MulInPlace(key.Key(), key.ModulusParams()));
  858. }
  859. // Beyond c0, multiply the exponentiated key in.
  860. if (i > 0) {
  861. RLWE_RETURN_IF_ERROR(
  862. ci.MulInPlace(key_powers, ciphertext.ModulusParams()));
  863. }
  864. RLWE_RETURN_IF_ERROR(
  865. error_and_message_ntt.AddInPlace(ci, key.ModulusParams()));
  866. }
  867. // Invert the NTT process.
  868. std::vector<ModularInt> error_and_message =
  869. error_and_message_ntt.InverseNtt(key.NttParams(), key.ModulusParams());
  870. // Extract the message.
  871. return RemoveError<ModularInt>(
  872. error_and_message, key.ModulusParams()->modulus,
  873. key.PlaintextModulus().ExportInt(key.PlaintextModulusParams()),
  874. key.ModulusParams());
  875. }
  876. } // namespace rlwe
  877. #endif // RLWE_SYMMETRIC_ENCRYPTION_H_