relinearization_key_test.cc 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. /*
  2. * Copyright 2018 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include "relinearization_key.h"
  16. #include <gmock/gmock.h>
  17. #include <gtest/gtest.h>
  18. #include "constants.h"
  19. #include "montgomery.h"
  20. #include "ntt_parameters.h"
  21. #include "polynomial.h"
  22. #include "prng/integral_prng_types.h"
  23. #include "status_macros.h"
  24. #include "symmetric_encryption.h"
  25. #include "testing/status_matchers.h"
  26. #include "testing/status_testing.h"
  27. #include "testing/testing_prng.h"
  28. namespace {
  29. unsigned int seed = 1;
  30. // Useful typedefs.
  31. using uint_m = rlwe::MontgomeryInt<absl::uint128>;
  32. using Polynomial = rlwe::Polynomial<uint_m>;
  33. using Ciphertext = rlwe::SymmetricRlweCiphertext<uint_m>;
  34. using Key = rlwe::SymmetricRlweKey<uint_m>;
  35. using RelinearizationKey = rlwe::RelinearizationKey<uint_m>;
  36. using ErrorParams = rlwe::ErrorParams<uint_m>;
  37. // Set constants.
  38. const ssize_t kLogPlaintextModulus = 1;
  39. const ssize_t kPlaintextModulus = (1 << kLogPlaintextModulus) + 1;
  40. const ssize_t kDefaultVariance = 4;
  41. const ssize_t kCoeffs = rlwe::kNewhopeDegreeBound;
  42. const ssize_t kLogCoeffs = rlwe::kNewhopeLogDegreeBound;
  43. const ssize_t kSmallLogDecompositionModulus = 2;
  44. const ssize_t kLargeLogDecompositionModulus = 20;
  45. using ::rlwe::testing::StatusIs;
  46. using ::testing::HasSubstr;
  47. // Test fixture.
  48. class RelinearizationKeyTest : public ::testing::Test {
  49. protected:
  50. void SetUp() override {
  51. ASSERT_OK_AND_ASSIGN(params14_,
  52. uint_m::Params::Create(rlwe::kNewhopeModulus));
  53. ASSERT_OK_AND_ASSIGN(params80_, uint_m::Params::Create(rlwe::kModulus80));
  54. ASSERT_OK_AND_ASSIGN(auto ntt_params, rlwe::InitializeNttParameters<uint_m>(
  55. kLogCoeffs, params14_.get()));
  56. ASSERT_OK_AND_ASSIGN(
  57. auto ntt_params80,
  58. rlwe::InitializeNttParameters<uint_m>(kLogCoeffs, params80_.get()));
  59. ntt_params_ = absl::make_unique<const rlwe::NttParameters<uint_m>>(
  60. std::move(ntt_params));
  61. ntt_params80_ = absl::make_unique<const rlwe::NttParameters<uint_m>>(
  62. std::move(ntt_params80));
  63. ASSERT_OK_AND_ASSIGN(auto error_params,
  64. rlwe::ErrorParams<uint_m>::Create(
  65. kLogPlaintextModulus, kDefaultVariance,
  66. params14_.get(), ntt_params_.get()));
  67. error_params_ = absl::make_unique<const ErrorParams>(error_params);
  68. ASSERT_OK_AND_ASSIGN(auto error_params80,
  69. rlwe::ErrorParams<uint_m>::Create(
  70. kLogPlaintextModulus, kDefaultVariance,
  71. params80_.get(), ntt_params80_.get()));
  72. error_params80_ = absl::make_unique<const ErrorParams>(error_params80);
  73. }
  74. // Convert a vector of integers to a vector of montgomery integers.
  75. rlwe::StatusOr<std::vector<uint_m>> ConvertToMontgomery(
  76. const std::vector<uint_m::Int>& coeffs, const uint_m::Params* params) {
  77. std::vector<uint_m> output(coeffs.size(), uint_m::ImportZero(params));
  78. for (unsigned int i = 0; i < output.size(); i++) {
  79. RLWE_ASSIGN_OR_RETURN(output[i], uint_m::ImportInt(coeffs[i], params));
  80. }
  81. return output;
  82. }
  83. // Sample a random key.
  84. rlwe::StatusOr<Key> SampleKey(rlwe::Uint64 variance = kDefaultVariance,
  85. rlwe::Uint64 log_t = kLogPlaintextModulus) {
  86. RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
  87. rlwe::SingleThreadPrng::GenerateSeed());
  88. RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
  89. return Key::Sample(kLogCoeffs, variance, log_t, params14_.get(),
  90. ntt_params_.get(), prng.get());
  91. }
  92. // Sample a random plaintext.
  93. std::vector<uint_m::Int> SamplePlaintext(uint_m::Int t = kPlaintextModulus,
  94. rlwe::Uint64 coeffs = kCoeffs) {
  95. std::vector<uint_m::Int> plaintext(kCoeffs);
  96. for (unsigned int i = 0; i < kCoeffs; i++) {
  97. plaintext[i] = rand_r(&seed) % t;
  98. }
  99. return plaintext;
  100. }
  101. // Encrypt a plaintext.
  102. rlwe::StatusOr<Ciphertext> Encrypt(
  103. const Key& key, const std::vector<uint_m::Int>& plaintext,
  104. const uint_m::Params* params,
  105. const rlwe::NttParameters<uint_m>* ntt_params,
  106. const ErrorParams* error_params) {
  107. RLWE_ASSIGN_OR_RETURN(auto m_plaintext,
  108. ConvertToMontgomery(plaintext, params));
  109. auto plaintext_ntt =
  110. Polynomial::ConvertToNtt(m_plaintext, ntt_params, params);
  111. RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
  112. rlwe::SingleThreadPrng::GenerateSeed());
  113. RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
  114. return rlwe::Encrypt<uint_m>(key, plaintext_ntt, error_params, prng.get());
  115. }
  116. std::unique_ptr<const uint_m::Params> params14_;
  117. std::unique_ptr<const uint_m::Params> params80_;
  118. std::unique_ptr<const rlwe::NttParameters<uint_m>> ntt_params_;
  119. std::unique_ptr<const rlwe::NttParameters<uint_m>> ntt_params80_;
  120. std::unique_ptr<const ErrorParams> error_params_;
  121. std::unique_ptr<const ErrorParams> error_params80_;
  122. };
  123. TEST_F(RelinearizationKeyTest, RelinearizationKeyReducesSizeOfCiphertext) {
  124. ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  125. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  126. rlwe::SingleThreadPrng::GenerateSeed());
  127. ASSERT_OK_AND_ASSIGN(
  128. auto relinearization_key,
  129. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  130. kSmallLogDecompositionModulus));
  131. // Create the initial plaintexts.
  132. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  133. ASSERT_OK_AND_ASSIGN(auto mp1,
  134. ConvertToMontgomery(plaintext1, params14_.get()));
  135. Polynomial plaintext1_ntt =
  136. Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params14_.get());
  137. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  138. ASSERT_OK_AND_ASSIGN(auto mp2,
  139. ConvertToMontgomery(plaintext2, params14_.get()));
  140. Polynomial plaintext2_ntt =
  141. Polynomial::ConvertToNtt(mp2, ntt_params_.get(), params14_.get());
  142. // Encrypt, multiply, apply the relinearization key and decrypt.
  143. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  144. Encrypt(key, plaintext1, params14_.get(),
  145. ntt_params_.get(), error_params_.get()));
  146. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  147. Encrypt(key, plaintext2, params14_.get(),
  148. ntt_params_.get(), error_params_.get()));
  149. ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  150. ASSERT_OK_AND_ASSIGN(auto relinearized_product,
  151. relinearization_key.ApplyTo(product));
  152. EXPECT_EQ(product.Len(), 3);
  153. EXPECT_EQ(relinearized_product.Len(), 2);
  154. }
  155. TEST_F(RelinearizationKeyTest, RelinearizeKey3PartsDecrypts) {
  156. ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
  157. rlwe::SingleThreadPrng::GenerateSeed());
  158. ASSERT_OK_AND_ASSIGN(auto key_prng,
  159. rlwe::SingleThreadPrng::Create(key_prng_seed));
  160. ASSERT_OK_AND_ASSIGN(
  161. auto key,
  162. Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
  163. params80_.get(), ntt_params80_.get(), key_prng.get()));
  164. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  165. rlwe::SingleThreadPrng::GenerateSeed());
  166. ASSERT_OK_AND_ASSIGN(
  167. auto relinearization_key,
  168. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  169. kSmallLogDecompositionModulus));
  170. // Create the initial plaintexts.
  171. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  172. ASSERT_OK_AND_ASSIGN(auto mp1,
  173. ConvertToMontgomery(plaintext1, params80_.get()));
  174. Polynomial plaintext1_ntt =
  175. Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());
  176. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  177. ASSERT_OK_AND_ASSIGN(auto mp2,
  178. ConvertToMontgomery(plaintext2, params80_.get()));
  179. Polynomial plaintext2_ntt =
  180. Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());
  181. // Encrypt, multiply, apply the relinearization key and decrypt.
  182. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  183. Encrypt(key, plaintext1, params80_.get(),
  184. ntt_params80_.get(), error_params80_.get()));
  185. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  186. Encrypt(key, plaintext2, params80_.get(),
  187. ntt_params80_.get(), error_params80_.get()));
  188. ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  189. ASSERT_OK_AND_ASSIGN(auto relinearized_product,
  190. relinearization_key.ApplyTo(product));
  191. ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
  192. rlwe::Decrypt<uint_m>(key, relinearized_product));
  193. // Create the polynomial we expect.
  194. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  195. std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
  196. plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
  197. params80_->modulus, kPlaintextModulus, params80_.get());
  198. EXPECT_EQ(decrypted, expected);
  199. }
  200. TEST_F(RelinearizationKeyTest, RelinearizeKey4PartsDecrypts) {
  201. ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
  202. rlwe::SingleThreadPrng::GenerateSeed());
  203. ASSERT_OK_AND_ASSIGN(auto key_prng,
  204. rlwe::SingleThreadPrng::Create(key_prng_seed));
  205. ASSERT_OK_AND_ASSIGN(
  206. auto key,
  207. Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
  208. params80_.get(), ntt_params80_.get(), key_prng.get()));
  209. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  210. rlwe::SingleThreadPrng::GenerateSeed());
  211. ASSERT_OK_AND_ASSIGN(
  212. auto relinearization_key,
  213. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/4,
  214. kLargeLogDecompositionModulus));
  215. // Create the initial plaintexts.
  216. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  217. ASSERT_OK_AND_ASSIGN(auto mp1,
  218. ConvertToMontgomery(plaintext1, params80_.get()));
  219. Polynomial plaintext1_ntt =
  220. Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());
  221. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  222. ASSERT_OK_AND_ASSIGN(auto mp2,
  223. ConvertToMontgomery(plaintext2, params80_.get()));
  224. Polynomial plaintext2_ntt =
  225. Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());
  226. std::vector<uint_m::Int> plaintext3 = SamplePlaintext(kPlaintextModulus);
  227. ASSERT_OK_AND_ASSIGN(auto mp3,
  228. ConvertToMontgomery(plaintext3, params80_.get()));
  229. Polynomial plaintext3_ntt =
  230. Polynomial::ConvertToNtt(mp3, ntt_params80_.get(), params80_.get());
  231. // Relinearize a 4 component ciphertext produced from three multiplications.
  232. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  233. Encrypt(key, plaintext1, params80_.get(),
  234. ntt_params80_.get(), error_params80_.get()));
  235. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  236. Encrypt(key, plaintext2, params80_.get(),
  237. ntt_params80_.get(), error_params80_.get()));
  238. ASSERT_OK_AND_ASSIGN(auto ciphertext3,
  239. Encrypt(key, plaintext3, params80_.get(),
  240. ntt_params80_.get(), error_params80_.get()));
  241. ASSERT_OK_AND_ASSIGN(auto intermediate, ciphertext1* ciphertext2);
  242. ASSERT_OK_AND_ASSIGN(auto product, intermediate* ciphertext3);
  243. ASSERT_OK_AND_ASSIGN(auto relinearized_product,
  244. relinearization_key.ApplyTo(product));
  245. ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
  246. rlwe::Decrypt<uint_m>(key, relinearized_product));
  247. // Create the polynomial we expect.
  248. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  249. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext3_ntt, params80_.get()));
  250. std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
  251. plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
  252. params80_->modulus, kPlaintextModulus, params80_.get());
  253. EXPECT_EQ(decrypted, expected);
  254. }
  255. TEST_F(RelinearizationKeyTest, RelinearizeKeyLargeModulusDecrypts) {
  256. ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
  257. rlwe::SingleThreadPrng::GenerateSeed());
  258. ASSERT_OK_AND_ASSIGN(auto key_prng,
  259. rlwe::SingleThreadPrng::Create(key_prng_seed));
  260. ASSERT_OK_AND_ASSIGN(
  261. auto key,
  262. Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
  263. params80_.get(), ntt_params80_.get(), key_prng.get()));
  264. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  265. rlwe::SingleThreadPrng::GenerateSeed());
  266. ASSERT_OK_AND_ASSIGN(
  267. auto relinearization_key,
  268. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  269. kLargeLogDecompositionModulus));
  270. // Create the initial plaintexts.
  271. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  272. ASSERT_OK_AND_ASSIGN(auto mp1,
  273. ConvertToMontgomery(plaintext1, params80_.get()));
  274. Polynomial plaintext1_ntt =
  275. Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());
  276. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  277. ASSERT_OK_AND_ASSIGN(auto mp2,
  278. ConvertToMontgomery(plaintext2, params80_.get()));
  279. Polynomial plaintext2_ntt =
  280. Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());
  281. // Multiply, apply the relinearization key, multiply, relinearize and decrypt.
  282. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  283. Encrypt(key, plaintext1, params80_.get(),
  284. ntt_params80_.get(), error_params80_.get()));
  285. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  286. Encrypt(key, plaintext2, params80_.get(),
  287. ntt_params80_.get(), error_params80_.get()));
  288. ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  289. ASSERT_OK_AND_ASSIGN(auto relinearized_product,
  290. relinearization_key.ApplyTo(product));
  291. ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
  292. rlwe::Decrypt<uint_m>(key, relinearized_product));
  293. // Create the polynomial we expect.
  294. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  295. std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
  296. plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
  297. params80_->modulus, kPlaintextModulus, params80_.get());
  298. EXPECT_EQ(decrypted, expected);
  299. }
  300. TEST_F(RelinearizationKeyTest, RepeatedRelinearizationDecrypts) {
  301. ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
  302. rlwe::SingleThreadPrng::GenerateSeed());
  303. ASSERT_OK_AND_ASSIGN(auto key_prng,
  304. rlwe::SingleThreadPrng::Create(key_prng_seed));
  305. ASSERT_OK_AND_ASSIGN(
  306. auto key,
  307. Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
  308. params80_.get(), ntt_params80_.get(), key_prng.get()));
  309. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  310. rlwe::SingleThreadPrng::GenerateSeed());
  311. ASSERT_OK_AND_ASSIGN(
  312. auto relinearization_key,
  313. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  314. kLargeLogDecompositionModulus));
  315. // Create the initial plaintexts.
  316. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  317. ASSERT_OK_AND_ASSIGN(auto mp1,
  318. ConvertToMontgomery(plaintext1, params80_.get()));
  319. Polynomial plaintext1_ntt =
  320. Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());
  321. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  322. ASSERT_OK_AND_ASSIGN(auto mp2,
  323. ConvertToMontgomery(plaintext2, params80_.get()));
  324. Polynomial plaintext2_ntt =
  325. Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());
  326. std::vector<uint_m::Int> plaintext3 = SamplePlaintext(kPlaintextModulus);
  327. ASSERT_OK_AND_ASSIGN(auto mp3,
  328. ConvertToMontgomery(plaintext3, params80_.get()));
  329. Polynomial plaintext3_ntt =
  330. Polynomial::ConvertToNtt(mp3, ntt_params80_.get(), params80_.get());
  331. // Multiply, apply the relinearization key, multiply, relinearize and decrypt.
  332. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  333. Encrypt(key, plaintext1, params80_.get(),
  334. ntt_params80_.get(), error_params80_.get()));
  335. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  336. Encrypt(key, plaintext2, params80_.get(),
  337. ntt_params80_.get(), error_params80_.get()));
  338. ASSERT_OK_AND_ASSIGN(auto ciphertext3,
  339. Encrypt(key, plaintext3, params80_.get(),
  340. ntt_params80_.get(), error_params80_.get()));
  341. ASSERT_OK_AND_ASSIGN(auto product1, ciphertext1* ciphertext2);
  342. ASSERT_OK_AND_ASSIGN(auto relinearized_product1,
  343. relinearization_key.ApplyTo(product1));
  344. ASSERT_OK_AND_ASSIGN(auto product2, relinearized_product1* ciphertext3);
  345. ASSERT_OK_AND_ASSIGN(auto relinearized_product2,
  346. relinearization_key.ApplyTo(product2));
  347. ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
  348. rlwe::Decrypt<uint_m>(key, relinearized_product2));
  349. // Create the polynomial we expect.
  350. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  351. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext3_ntt, params80_.get()));
  352. std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
  353. plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
  354. params80_->modulus, kPlaintextModulus, params80_.get());
  355. EXPECT_EQ(decrypted, expected);
  356. }
  357. TEST_F(RelinearizationKeyTest, CiphertextWithTooManyComponents) {
  358. ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  359. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  360. rlwe::SingleThreadPrng::GenerateSeed());
  361. // RelinearizationKey has length 2.
  362. ASSERT_OK_AND_ASSIGN(
  363. auto relinearization_key,
  364. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/2,
  365. kSmallLogDecompositionModulus));
  366. // Create the initial plaintexts.
  367. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  368. ASSERT_OK_AND_ASSIGN(auto mp1,
  369. ConvertToMontgomery(plaintext1, params14_.get()));
  370. Polynomial plaintext1_ntt =
  371. Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params14_.get());
  372. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  373. ASSERT_OK_AND_ASSIGN(auto mp2,
  374. ConvertToMontgomery(plaintext2, params14_.get()));
  375. Polynomial plaintext2_ntt =
  376. Polynomial::ConvertToNtt(mp2, ntt_params_.get(), params14_.get());
  377. // Encrypt, multiply, apply the relinearization key.
  378. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  379. Encrypt(key, plaintext1, params14_.get(),
  380. ntt_params_.get(), error_params_.get()));
  381. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  382. Encrypt(key, plaintext2, params14_.get(),
  383. ntt_params_.get(), error_params_.get()));
  384. ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  385. EXPECT_THAT(relinearization_key.ApplyTo(product),
  386. StatusIs(absl::StatusCode::kInvalidArgument,
  387. HasSubstr("RelinearizationKey not large enough")));
  388. }
  389. TEST_F(RelinearizationKeyTest, LogDecompositionModulusOutOfBounds) {
  390. ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  391. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  392. rlwe::SingleThreadPrng::GenerateSeed());
  393. // RelinearizationKey has length 2.
  394. EXPECT_THAT(
  395. rlwe::RelinearizationKey<uint_m>::Create(
  396. key, prng_seed, /*num_parts=*/2,
  397. /*log_decomposition_modulus=*/key.ModulusParams()->log_modulus + 1),
  398. StatusIs(absl::StatusCode::kInvalidArgument,
  399. HasSubstr(absl::StrCat(
  400. "Log decomposition modulus, ",
  401. key.ModulusParams()->log_modulus + 1, ", ",
  402. "must be at most: ", key.ModulusParams()->log_modulus))));
  403. int log_decomposition_modulus = 0;
  404. EXPECT_THAT(rlwe::RelinearizationKey<uint_m>::Create(
  405. key, prng_seed, /*num_parts=*/3, log_decomposition_modulus),
  406. StatusIs(absl::StatusCode::kInvalidArgument,
  407. HasSubstr(absl::StrCat("Log decomposition modulus, ",
  408. log_decomposition_modulus,
  409. ", must be positive."))));
  410. }
  411. TEST_F(RelinearizationKeyTest, NumPartsMustBePositive) {
  412. ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  413. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  414. rlwe::SingleThreadPrng::GenerateSeed());
  415. EXPECT_THAT(
  416. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/-1,
  417. kSmallLogDecompositionModulus),
  418. StatusIs(absl::StatusCode::kInvalidArgument,
  419. HasSubstr("Num parts: -1 must be positive.")));
  420. }
  421. TEST_F(RelinearizationKeyTest, InvalidDeserialize) {
  422. ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
  423. rlwe::SingleThreadPrng::GenerateSeed());
  424. ASSERT_OK_AND_ASSIGN(auto key_prng,
  425. rlwe::SingleThreadPrng::Create(key_prng_seed));
  426. ASSERT_OK_AND_ASSIGN(
  427. auto key,
  428. Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
  429. params80_.get(), ntt_params80_.get(), key_prng.get()));
  430. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  431. rlwe::SingleThreadPrng::GenerateSeed());
  432. ASSERT_OK_AND_ASSIGN(
  433. auto relinearization_key,
  434. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  435. kLargeLogDecompositionModulus));
  436. // Serialize and deserialize.
  437. ASSERT_OK_AND_ASSIGN(rlwe::SerializedRelinearizationKey serialized,
  438. relinearization_key.Serialize());
  439. for (int i = -1; i <= 1; i++) {
  440. serialized.set_num_parts(i);
  441. EXPECT_THAT(RelinearizationKey::Deserialize(serialized, params80_.get(),
  442. ntt_params80_.get()),
  443. StatusIs(absl::StatusCode::kInvalidArgument,
  444. HasSubstr(absl::StrCat(
  445. "The number of parts, ", serialized.num_parts(),
  446. ", must be greater than one."))));
  447. }
  448. ASSERT_GT(serialized.c_size(), 2);
  449. serialized.set_num_parts(serialized.c_size() - 1);
  450. EXPECT_THAT(
  451. RelinearizationKey::Deserialize(serialized, params80_.get(),
  452. ntt_params80_.get()),
  453. StatusIs(absl::StatusCode::kInvalidArgument,
  454. HasSubstr(absl::StrCat(
  455. "The length of serialized, ", serialized.c_size(), ", ",
  456. "must be divisible by the number of parts minus one ",
  457. serialized.num_parts() - 1, "."))));
  458. ASSERT_EQ(serialized.c_size(),
  459. /* log2(kModulus80) / kLargeLogDecompositionModulus = */ 8);
  460. serialized.set_num_parts(serialized.c_size() + 1);
  461. EXPECT_THAT(RelinearizationKey::Deserialize(serialized, params80_.get(),
  462. ntt_params80_.get()),
  463. StatusIs(absl::StatusCode::kInvalidArgument,
  464. HasSubstr(absl::StrCat(
  465. "Number of NTT Polynomials does not match expected ",
  466. "number of matrix entries."))));
  467. }
  468. TEST_F(RelinearizationKeyTest, SerializeKey) {
  469. ASSERT_OK_AND_ASSIGN(std::string key_prng_seed,
  470. rlwe::SingleThreadPrng::GenerateSeed());
  471. ASSERT_OK_AND_ASSIGN(auto key_prng,
  472. rlwe::SingleThreadPrng::Create(key_prng_seed));
  473. ASSERT_OK_AND_ASSIGN(
  474. auto key,
  475. Key::Sample(kLogCoeffs, kDefaultVariance, kLogPlaintextModulus,
  476. params80_.get(), ntt_params80_.get(), key_prng.get()));
  477. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  478. rlwe::SingleThreadPrng::GenerateSeed());
  479. ASSERT_OK_AND_ASSIGN(
  480. auto relinearization_key,
  481. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  482. kLargeLogDecompositionModulus));
  483. // Serialize and deserialize.
  484. ASSERT_OK_AND_ASSIGN(rlwe::SerializedRelinearizationKey serialized,
  485. relinearization_key.Serialize());
  486. ASSERT_OK_AND_ASSIGN(auto deserialized,
  487. RelinearizationKey::Deserialize(
  488. serialized, params80_.get(), ntt_params80_.get()));
  489. // Create the initial plaintexts.
  490. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  491. ASSERT_OK_AND_ASSIGN(auto mp1,
  492. ConvertToMontgomery(plaintext1, params80_.get()));
  493. Polynomial plaintext1_ntt =
  494. Polynomial::ConvertToNtt(mp1, ntt_params80_.get(), params80_.get());
  495. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  496. ASSERT_OK_AND_ASSIGN(auto mp2,
  497. ConvertToMontgomery(plaintext2, params80_.get()));
  498. Polynomial plaintext2_ntt =
  499. Polynomial::ConvertToNtt(mp2, ntt_params80_.get(), params80_.get());
  500. // Encrypt, multiply, apply the relinearization key and decrypt.
  501. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  502. Encrypt(key, plaintext1, params80_.get(),
  503. ntt_params80_.get(), error_params80_.get()));
  504. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  505. Encrypt(key, plaintext2, params80_.get(),
  506. ntt_params80_.get(), error_params80_.get()));
  507. ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  508. ASSERT_OK_AND_ASSIGN(auto relinearized_product,
  509. deserialized.ApplyTo(product));
  510. ASSERT_OK_AND_ASSIGN(std::vector<uint_m::Int> decrypted,
  511. rlwe::Decrypt<uint_m>(key, relinearized_product));
  512. // Create the polynomial we expect.
  513. ASSERT_OK(plaintext1_ntt.MulInPlace(plaintext2_ntt, params80_.get()));
  514. std::vector<uint_m::Int> expected = rlwe::RemoveError<uint_m>(
  515. plaintext1_ntt.InverseNtt(ntt_params80_.get(), params80_.get()),
  516. params80_->modulus, kPlaintextModulus, params80_.get());
  517. EXPECT_EQ(decrypted, expected);
  518. }
  519. TEST_F(RelinearizationKeyTest, RelinearizationKeyIncreasesError) {
  520. ASSERT_OK_AND_ASSIGN(auto key, SampleKey());
  521. ASSERT_OK_AND_ASSIGN(std::string prng_seed,
  522. rlwe::SingleThreadPrng::GenerateSeed());
  523. ASSERT_OK_AND_ASSIGN(
  524. auto relinearization_key,
  525. rlwe::RelinearizationKey<uint_m>::Create(key, prng_seed, /*num_parts=*/3,
  526. kSmallLogDecompositionModulus));
  527. // Create the initial plaintexts.
  528. std::vector<uint_m::Int> plaintext1 = SamplePlaintext(kPlaintextModulus);
  529. ASSERT_OK_AND_ASSIGN(auto mp1,
  530. ConvertToMontgomery(plaintext1, params14_.get()));
  531. Polynomial plaintext1_ntt =
  532. Polynomial::ConvertToNtt(mp1, ntt_params_.get(), params14_.get());
  533. std::vector<uint_m::Int> plaintext2 = SamplePlaintext(kPlaintextModulus);
  534. ASSERT_OK_AND_ASSIGN(auto mp2,
  535. ConvertToMontgomery(plaintext2, params14_.get()));
  536. Polynomial plaintext2_ntt =
  537. Polynomial::ConvertToNtt(mp2, ntt_params_.get(), params14_.get());
  538. // Encrypt, multiply, apply the relinearization key and decrypt.
  539. ASSERT_OK_AND_ASSIGN(auto ciphertext1,
  540. Encrypt(key, plaintext1, params14_.get(),
  541. ntt_params_.get(), error_params_.get()));
  542. ASSERT_OK_AND_ASSIGN(auto ciphertext2,
  543. Encrypt(key, plaintext2, params14_.get(),
  544. ntt_params_.get(), error_params_.get()));
  545. ASSERT_OK_AND_ASSIGN(auto product, ciphertext1* ciphertext2);
  546. ASSERT_OK_AND_ASSIGN(auto relinearized_product,
  547. relinearization_key.ApplyTo(product));
  548. // Expect that the error grows after relinearization.
  549. EXPECT_GT(relinearized_product.Error(), product.Error());
  550. }
  551. } // namespace