relinearization_key.cc 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. /*
  2. * Copyright 2018 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include "relinearization_key.h"
  16. #include "absl/numeric/int128.h"
  17. #include "bits_util.h"
  18. #include "montgomery.h"
  19. #include "prng/integral_prng_types.h"
  20. #include "status_macros.h"
  21. #include "statusor.h"
  22. #include "symmetric_encryption_with_prng.h"
  23. #include "third_party/shell-encryption/base/shell_encryption_export.h"
  24. #include "third_party/shell-encryption/base/shell_encryption_export_template.h"
  25. namespace rlwe {
  26. namespace {
  27. // Method to compute the number of digits needed to represent integers mod
  28. // q in base T. Upcasts the modulus to absl::uint128 to handle all Uint*
  29. // types.
  30. inline int ComputeDimension(Uint64 log_decomposition_modulus,
  31. absl::uint128 modulus) {
  32. Uint64 modulus_bits = static_cast<Uint64>(internal::BitLength(modulus));
  33. return (modulus_bits + (log_decomposition_modulus - 1)) /
  34. log_decomposition_modulus;
  35. }
  36. // Returns a random vector r orthogonal to (1,s). The second component is chosen
  37. // using randomness-of-encryption sampled using the specified PRNG. The first
  38. // component is then chosen so that r is perpendicular to (1,s).
  39. template <typename ModularInt>
  40. rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> SampleOrthogonalFromPrng(
  41. const SymmetricRlweKey<ModularInt>& key, SecurePrng* prng) {
  42. // Sample a random polynomial r using a PRNG.
  43. RLWE_ASSIGN_OR_RETURN(auto r, SamplePolynomialFromPrng<ModularInt>(
  44. key.Len(), prng, key.ModulusParams()));
  45. // Top entries of the matrix R will be -s*r, thus R is orthogonal to
  46. // (1,s).
  47. RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> r_top,
  48. r.Mul(key.Key(), key.ModulusParams()));
  49. r_top.NegateInPlace(key.ModulusParams());
  50. std::vector<Polynomial<ModularInt>> res = {std::move(r_top), std::move(r)};
  51. return res;
  52. }
  53. // The i-th component of the result is (T^i key_power).
  54. template <typename ModularInt>
  55. rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> PowersOfT(
  56. const Polynomial<ModularInt>& key_power,
  57. const SymmetricRlweKey<ModularInt>& key,
  58. const ModularInt& decomposition_modulus, int dimension) {
  59. std::vector<Polynomial<ModularInt>> result;
  60. result.reserve(dimension);
  61. Polynomial<ModularInt> key_to_i = key_power;
  62. for (int i = 0; i < dimension; i++) {
  63. // Increase the power of T in T^i s in place.
  64. if (i != 0) {
  65. RLWE_RETURN_IF_ERROR(
  66. key_to_i.MulInPlace(decomposition_modulus, key.ModulusParams()));
  67. }
  68. result.push_back(key_to_i);
  69. }
  70. return result;
  71. }
  72. // The i-th component of the result contains a vector of i-th digits of the
  73. // coefficients in base T (the decomposition modulus).
  74. template <typename ModularInt>
  75. rlwe::StatusOr<std::vector<std::vector<ModularInt>>> BitDecompose(
  76. const std::vector<ModularInt>& coefficients,
  77. const typename ModularInt::Params* modulus_params,
  78. const Uint64 log_decomposition_modulus, int dimension) {
  79. std::vector<typename ModularInt::Int> ciphertext_coeffs(coefficients.size(),
  80. 0);
  81. std::transform(
  82. coefficients.begin(), coefficients.end(), ciphertext_coeffs.begin(),
  83. [modulus_params](ModularInt x) { return x.ExportInt(modulus_params); });
  84. std::vector<std::vector<ModularInt>> result(dimension);
  85. for (int i = 0; i < dimension; i++) {
  86. result[i].reserve(ciphertext_coeffs.size());
  87. for (int j = 0; j < ciphertext_coeffs.size(); ++j) {
  88. RLWE_ASSIGN_OR_RETURN(
  89. auto coefficient_part,
  90. ModularInt::ImportInt(
  91. (ciphertext_coeffs[j] % (1L << log_decomposition_modulus)),
  92. modulus_params));
  93. result[i].push_back(std::move(coefficient_part));
  94. ciphertext_coeffs[j] = ciphertext_coeffs[j] >> log_decomposition_modulus;
  95. }
  96. }
  97. return result;
  98. }
  99. template <typename ModularInt>
  100. rlwe::StatusOr<std::vector<Polynomial<ModularInt>>> MatrixMultiply(
  101. std::vector<std::vector<ModularInt>> decomposed_coefficients,
  102. const std::vector<std::vector<Polynomial<ModularInt>>>& matrix,
  103. const typename ModularInt::Params* modulus_params,
  104. const NttParameters<ModularInt>* ntt_params) {
  105. Polynomial<ModularInt> temp(matrix[0][0].Len(), modulus_params);
  106. Polynomial<ModularInt> ntt_part(matrix[0][0].Len(), modulus_params);
  107. std::vector<Polynomial<ModularInt>> result(2, temp);
  108. for (int i = 0; i < matrix[0].size(); i++) {
  109. ntt_part = Polynomial<ModularInt>::ConvertToNtt(
  110. std::move(decomposed_coefficients[i]), ntt_params, modulus_params);
  111. RLWE_ASSIGN_OR_RETURN(temp, ntt_part.Mul(matrix[0][i], modulus_params));
  112. RLWE_RETURN_IF_ERROR(result[0].AddInPlace(temp, modulus_params));
  113. RLWE_RETURN_IF_ERROR(ntt_part.MulInPlace(matrix[1][i], modulus_params))
  114. RLWE_RETURN_IF_ERROR(result[1].AddInPlace(ntt_part, modulus_params));
  115. }
  116. return result;
  117. }
  118. } // namespace
  119. template <typename ModularInt>
  120. rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
  121. RelinearizationKey<ModularInt>::RelinearizationKeyPart::Create(
  122. const Polynomial<ModularInt>& key_power,
  123. const SymmetricRlweKey<ModularInt>& key,
  124. const Uint64 log_decomposition_modulus,
  125. const ModularInt& decomposition_modulus, int dimension, SecurePrng* prng,
  126. SecurePrng* prng_encryption) {
  127. std::vector<std::vector<Polynomial<ModularInt>>> matrix(2);
  128. for (auto& row : matrix) {
  129. row.reserve(dimension);
  130. }
  131. // Compute a vector of (T^i key_power).
  132. RLWE_ASSIGN_OR_RETURN(
  133. auto powers_of_t,
  134. PowersOfT(key_power, key, decomposition_modulus, dimension));
  135. // For key_power = s^j, the ith iteration of this loop computes the column of
  136. // the KeyPart corresponding to (T^i s^j).
  137. for (int i = 0; i < dimension; ++i) {
  138. // Sample r component orthogonal to (1,s).
  139. RLWE_ASSIGN_OR_RETURN(auto r, SampleOrthogonalFromPrng(key, prng));
  140. // Sample error.
  141. RLWE_ASSIGN_OR_RETURN(auto error,
  142. SampleFromErrorDistribution<ModularInt>(
  143. key_power.Len(), key.Variance(), prng_encryption,
  144. key.ModulusParams()));
  145. // Convert the error coefficients into an error polynomial.
  146. auto e = Polynomial<ModularInt>::ConvertToNtt(
  147. std::move(error), key.NttParams(), key.ModulusParams());
  148. // Set the column of the Relinearization matrix.
  149. RLWE_RETURN_IF_ERROR(
  150. e.MulInPlace(key.PlaintextModulus(), key.ModulusParams()));
  151. RLWE_RETURN_IF_ERROR(e.AddInPlace(r[0], key.ModulusParams()));
  152. RLWE_RETURN_IF_ERROR(e.AddInPlace(powers_of_t[i], key.ModulusParams()));
  153. matrix[0].push_back(std::move(e));
  154. matrix[1].push_back(std::move(r[1]));
  155. }
  156. return RelinearizationKeyPart(std::move(matrix), log_decomposition_modulus);
  157. }
  158. template <typename ModularInt>
  159. rlwe::StatusOr<std::vector<Polynomial<ModularInt>>>
  160. RelinearizationKey<ModularInt>::RelinearizationKeyPart::ApplyPartTo(
  161. const Polynomial<ModularInt>& ciphertext_part,
  162. const typename ModularInt::Params* modulus_params,
  163. const NttParameters<ModularInt>* ntt_params) const {
  164. // Convert ciphertext out of NTT form.
  165. std::vector<ModularInt> ciphertext_coefficients =
  166. ciphertext_part.InverseNtt(ntt_params, modulus_params);
  167. // Bit-decompose the vector of coefficients in the ciphertext.
  168. RLWE_ASSIGN_OR_RETURN(
  169. std::vector<std::vector<ModularInt>> decomposed_coefficients,
  170. BitDecompose<ModularInt>(ciphertext_coefficients, modulus_params,
  171. log_decomposition_modulus_, matrix_[0].size()));
  172. // Matrix multiply with the bit-decomposed coefficients.
  173. return MatrixMultiply<ModularInt>(std::move(decomposed_coefficients), matrix_,
  174. modulus_params, ntt_params);
  175. }
  176. template <typename ModularInt>
  177. rlwe::StatusOr<typename RelinearizationKey<ModularInt>::RelinearizationKeyPart>
  178. RelinearizationKey<ModularInt>::RelinearizationKeyPart::Deserialize(
  179. const std::vector<SerializedNttPolynomial>& polynomials,
  180. Uint64 log_decomposition_modulus, SecurePrng* prng,
  181. const ModularIntParams* modulus_params,
  182. const NttParameters<ModularInt>* ntt_params) {
  183. // The polynomials input is a flattened representation of a 2 x dimension
  184. // matrix where the first half corresponds to the first row of matrix and the
  185. // second half corresponds to the second row of matrix. This matrix makes up
  186. // the RelinearizationKeyPart.
  187. int dimension = polynomials.size();
  188. auto matrix = std::vector<std::vector<Polynomial<ModularInt>>>(2);
  189. matrix[0].reserve(dimension);
  190. matrix[1].reserve(dimension);
  191. for (int i = 0; i < dimension; i++) {
  192. RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize(
  193. polynomials[i], modulus_params));
  194. matrix[0].push_back(std::move(elt));
  195. RLWE_ASSIGN_OR_RETURN(auto sample,
  196. SamplePolynomialFromPrng<ModularInt>(
  197. matrix[0][i].Len(), prng, modulus_params));
  198. matrix[1].push_back(std::move(sample));
  199. }
  200. return RelinearizationKeyPart(std::move(matrix), log_decomposition_modulus);
  201. }
  202. template <typename ModularInt>
  203. RelinearizationKey<ModularInt>::RelinearizationKey(
  204. const SymmetricRlweKey<ModularInt>& key, absl::string_view prng_seed,
  205. ssize_t num_parts, Uint64 log_decomposition_modulus,
  206. Uint64 substitution_power, ModularInt decomposition_modulus,
  207. std::vector<RelinearizationKeyPart> relinearization_key)
  208. : dimension_(ComputeDimension(log_decomposition_modulus,
  209. key.ModulusParams()->modulus)),
  210. num_parts_(num_parts),
  211. log_decomposition_modulus_(log_decomposition_modulus),
  212. decomposition_modulus_(decomposition_modulus),
  213. substitution_power_(substitution_power),
  214. modulus_params_(key.ModulusParams()),
  215. ntt_params_(key.NttParams()),
  216. relinearization_key_(std::move(relinearization_key)),
  217. prng_seed_(prng_seed) {}
  218. template <typename ModularInt>
  219. rlwe::StatusOr<RelinearizationKey<ModularInt>>
  220. RelinearizationKey<ModularInt>::Create(const SymmetricRlweKey<ModularInt>& key,
  221. absl::string_view prng_seed,
  222. ssize_t num_parts,
  223. Uint64 log_decomposition_modulus,
  224. Uint64 substitution_power) {
  225. if (num_parts <= 0) {
  226. return absl::InvalidArgumentError(
  227. absl::StrCat("Num parts: ", num_parts, " must be positive."));
  228. }
  229. if (log_decomposition_modulus <= 0) {
  230. return absl::InvalidArgumentError(
  231. absl::StrCat("Log decomposition modulus, ", log_decomposition_modulus,
  232. ", must be positive."));
  233. } else if (log_decomposition_modulus > key.ModulusParams()->log_modulus) {
  234. return absl::InvalidArgumentError(absl::StrCat(
  235. "Log decomposition modulus, ", log_decomposition_modulus,
  236. ", must be at most: ", key.ModulusParams()->log_modulus, "."));
  237. }
  238. RLWE_ASSIGN_OR_RETURN(auto decomposition_modulus,
  239. ModularInt::ImportInt(key.ModulusParams()->One()
  240. << log_decomposition_modulus,
  241. key.ModulusParams()));
  242. // Initialize the first part of the secret key, s.
  243. RLWE_ASSIGN_OR_RETURN(auto key_base, key.Substitute(substitution_power));
  244. auto key_power = key_base.Key();
  245. RLWE_ASSIGN_OR_RETURN(auto prng, SingleThreadPrng::Create(prng_seed));
  246. RLWE_ASSIGN_OR_RETURN(auto prng_encryption_seed,
  247. SingleThreadPrng::GenerateSeed());
  248. RLWE_ASSIGN_OR_RETURN(auto prng_encryption,
  249. SingleThreadPrng::Create(prng_encryption_seed));
  250. auto dimension =
  251. ComputeDimension(log_decomposition_modulus, key.ModulusParams()->modulus);
  252. std::vector<RelinearizationKeyPart> relinearization_key;
  253. relinearization_key.reserve(num_parts);
  254. // Create RealinearizationKeyPart for each of the secret key parts: s, ...,
  255. // s^k.
  256. for (int i = 1; i < num_parts; i++) {
  257. if (i != 1) {
  258. // Increment the power of s.
  259. RLWE_RETURN_IF_ERROR(
  260. key_power.MulInPlace(key_base.Key(), key.ModulusParams()));
  261. }
  262. RLWE_ASSIGN_OR_RETURN(
  263. auto key_part,
  264. RelinearizationKeyPart::Create(
  265. key_power, key, log_decomposition_modulus, decomposition_modulus,
  266. dimension, prng.get(), prng_encryption.get()));
  267. relinearization_key.push_back(std::move(key_part));
  268. }
  269. return RelinearizationKey<ModularInt>(
  270. key, prng_seed, num_parts, log_decomposition_modulus, substitution_power,
  271. decomposition_modulus, std::move(relinearization_key));
  272. }
  273. template <typename ModularInt>
  274. rlwe::StatusOr<SymmetricRlweCiphertext<ModularInt>>
  275. RelinearizationKey<ModularInt>::ApplyTo(
  276. const SymmetricRlweCiphertext<ModularInt>& ciphertext) const {
  277. // Ensure that the length of the ciphertext is less than or equal to the
  278. // length of the relinearization key.
  279. if (ciphertext.Len() > num_parts_) {
  280. return absl::InvalidArgumentError(
  281. "RelinearizationKey not large enough for ciphertext.");
  282. }
  283. // Initialize the result ciphertext of length 2.
  284. RLWE_ASSIGN_OR_RETURN(auto comp, ciphertext.Component(0));
  285. std::vector<Polynomial<ModularInt>> result(
  286. 2, Polynomial<ModularInt>(comp.Len(), modulus_params_));
  287. // Apply each RelinearizationKeyPart to the part of the ciphertext it
  288. // corresponds to. The first component of the ciphertext corresponds to the
  289. // "1" part of the secret key, and is added without any
  290. // RelinearizationKeyPart.
  291. result[0] = std::move(comp);
  292. for (int i = 0; i < relinearization_key_.size(); i++) {
  293. // Add RelinearizationKeyPart_i c_i to the result vector.
  294. RLWE_ASSIGN_OR_RETURN(auto temp_comp, ciphertext.Component(i + 1));
  295. RLWE_ASSIGN_OR_RETURN(auto result_part,
  296. relinearization_key_[i].ApplyPartTo(
  297. temp_comp, modulus_params_, ntt_params_));
  298. RLWE_RETURN_IF_ERROR(result[0].AddInPlace(result_part[0], modulus_params_));
  299. RLWE_RETURN_IF_ERROR(result[1].AddInPlace(result_part[1], modulus_params_));
  300. }
  301. return SymmetricRlweCiphertext<ModularInt>(
  302. std::move(result), 1,
  303. ciphertext.Error() +
  304. ciphertext.ErrorParams()->B_relinearize(log_decomposition_modulus_),
  305. modulus_params_, ciphertext.ErrorParams());
  306. }
  307. template <typename ModularInt>
  308. rlwe::StatusOr<SerializedRelinearizationKey>
  309. RelinearizationKey<ModularInt>::Serialize() const {
  310. SerializedRelinearizationKey output;
  311. output.set_log_decomposition_modulus(log_decomposition_modulus_);
  312. output.set_num_parts(num_parts_);
  313. output.set_prng_seed(prng_seed_);
  314. output.set_power_of_s(substitution_power_);
  315. for (const RelinearizationKeyPart& matrix : relinearization_key_) {
  316. // Only serialize the first row of each matrix.
  317. for (const Polynomial<ModularInt>& c : matrix.Matrix()) {
  318. RLWE_ASSIGN_OR_RETURN(*output.add_c(), c.Serialize(modulus_params_));
  319. }
  320. }
  321. return output;
  322. }
  323. template <typename ModularInt>
  324. rlwe::StatusOr<RelinearizationKey<ModularInt>>
  325. RelinearizationKey<ModularInt>::Deserialize(
  326. const SerializedRelinearizationKey& serialized,
  327. const typename ModularInt::Params* modulus_params,
  328. const NttParameters<ModularInt>* ntt_params) {
  329. // Verifies that the number of polynomials in serialized is expected.
  330. // A RelinearizationKey can decrypt ciphertexts with num_parts number of
  331. // components corresponding to decryption under (1, s, ..., s^k) or (1,
  332. // s(x^power)) but only contains parts corresponding to the non-"1"
  333. // components.
  334. if (serialized.num_parts() <= 1) {
  335. return absl::InvalidArgumentError(
  336. absl::StrCat("The number of parts, ", serialized.num_parts(),
  337. ", must be greater than one."));
  338. } else if (serialized.c_size() % (serialized.num_parts() - 1) != 0) {
  339. return absl::InvalidArgumentError(
  340. absl::StrCat("The length of serialized, ", serialized.c_size(), ", ",
  341. "must be divisible by the number of parts minus one ",
  342. serialized.num_parts() - 1, "."));
  343. }
  344. // Return an error when log decomposition modulus is non-positive.
  345. if (serialized.log_decomposition_modulus() <= 0) {
  346. return absl::InvalidArgumentError(absl::StrCat(
  347. "Log decomposition modulus, ", serialized.log_decomposition_modulus(),
  348. ", must be positive."));
  349. } else if (serialized.log_decomposition_modulus() >
  350. modulus_params->log_modulus) {
  351. return absl::InvalidArgumentError(absl::StrCat(
  352. "Log decomposition modulus, ", serialized.log_decomposition_modulus(),
  353. ", must be at most: ", modulus_params->log_modulus, "."));
  354. }
  355. int polynomials_per_matrix =
  356. serialized.c_size() / (serialized.num_parts() - 1);
  357. int dimension = polynomials_per_matrix;
  358. if (dimension != ComputeDimension(serialized.log_decomposition_modulus(),
  359. modulus_params->modulus)) {
  360. return absl::InvalidArgumentError(
  361. absl::StrCat("Number of NTT Polynomials does not match expected ",
  362. "number of matrix entries."));
  363. }
  364. RLWE_ASSIGN_OR_RETURN(
  365. auto decomposition_modulus,
  366. ModularInt::ImportInt(static_cast<typename ModularInt::Int>(1)
  367. << serialized.log_decomposition_modulus(),
  368. modulus_params));
  369. RelinearizationKey output(serialized.log_decomposition_modulus(),
  370. decomposition_modulus, modulus_params, ntt_params);
  371. output.dimension_ = dimension;
  372. output.num_parts_ = serialized.num_parts();
  373. output.prng_seed_ = serialized.prng_seed();
  374. output.substitution_power_ = serialized.power_of_s();
  375. // Create prng based on seed.
  376. RLWE_ASSIGN_OR_RETURN(auto prng, SingleThreadPrng::Create(output.prng_seed_));
  377. // Takes each polynomials_per_matrix chunk of serialized.c()'s and places them
  378. // into a KeyPart.
  379. output.relinearization_key_.reserve(serialized.num_parts() - 1);
  380. for (int i = 0; i < (serialized.num_parts() - 1); i++) {
  381. auto start = serialized.c().begin() + i * polynomials_per_matrix;
  382. auto end = start + polynomials_per_matrix;
  383. std::vector<SerializedNttPolynomial> chunk(start, end);
  384. RLWE_ASSIGN_OR_RETURN(auto deserialized,
  385. RelinearizationKeyPart::Deserialize(
  386. chunk, serialized.log_decomposition_modulus(),
  387. prng.get(), modulus_params, ntt_params));
  388. output.relinearization_key_.push_back(std::move(deserialized));
  389. }
  390. return output;
  391. }
  392. // Instantiations of RelinearizationKey with specific MontgomeryInt classes.
  393. // If any new types are added, montgomery.h should be updated accordingly (such
  394. // as ensuring BigInt is correctly specialized, etc.).
  395. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint16>>;
  396. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint32>>;
  397. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<Uint64>>;
  398. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) RelinearizationKey<MontgomeryInt<absl::uint128>>;
  399. } // namespace rlwe