error_params_test.cc 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. /*
  2. * Copyright 2018 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #ifndef RLWE_ERROR_PARAMS_TEST_H_
  16. #define RLWE_ERROR_PARAMS_TEST_H_
  17. #include "error_params.h"
  18. #include <gmock/gmock.h>
  19. #include <gtest/gtest.h>
  20. #include "constants.h"
  21. #include "context.h"
  22. #include "montgomery.h"
  23. #include "ntt_parameters.h"
  24. #include "prng/integral_prng_types.h"
  25. #include "status_macros.h"
  26. #include "symmetric_encryption.h"
  27. #include "testing/parameters.h"
  28. #include "testing/status_matchers.h"
  29. #include "testing/status_testing.h"
  30. #include "testing/testing_prng.h"
  31. #include "testing/testing_utils.h"
  32. namespace {
  33. using ::rlwe::testing::StatusIs;
  34. using ::testing::HasSubstr;
  35. // Number of samples used to compute the actual variance.
  36. const rlwe::Uint64 kSamples = 50;
  37. template <typename ModularInt>
  38. class ErrorParamsTest : public testing::Test {
  39. using Int = typename ModularInt::Int;
  40. using Polynomial = rlwe::Polynomial<ModularInt>;
  41. using Ciphertext = rlwe::SymmetricRlweCiphertext<ModularInt>;
  42. using Key = rlwe::SymmetricRlweKey<ModularInt>;
  43. public:
  44. // Computes the l-infinity norm of a vector of Ints.
  45. double ComputeNorm(const std::vector<Int>& coeffs) {
  46. return static_cast<double>(*std::max_element(coeffs.begin(), coeffs.end()));
  47. }
  48. // Sample a random key.
  49. rlwe::StatusOr<Key> SampleKey(const rlwe::RlweContext<ModularInt>* context) {
  50. RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
  51. rlwe::SingleThreadPrng::GenerateSeed());
  52. RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
  53. return Key::Sample(context->GetLogN(), context->GetVariance(),
  54. context->GetLogT(), context->GetModulusParams(),
  55. context->GetNttParams(), prng.get());
  56. }
  57. // Encrypt a plaintext.
  58. rlwe::StatusOr<Ciphertext> Encrypt(
  59. const Key& key, const std::vector<Int>& plaintext,
  60. const rlwe::RlweContext<ModularInt>* context) {
  61. RLWE_ASSIGN_OR_RETURN(auto m_p,
  62. rlwe::testing::ConvertToMontgomery<ModularInt>(
  63. plaintext, context->GetModulusParams()));
  64. auto plaintext_ntt = Polynomial::ConvertToNtt(m_p, context->GetNttParams(),
  65. context->GetModulusParams());
  66. RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
  67. rlwe::SingleThreadPrng::GenerateSeed());
  68. RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
  69. return rlwe::Encrypt<ModularInt>(key, plaintext_ntt,
  70. context->GetErrorParams(), prng.get());
  71. }
  72. // Decrypt without removing the error, returning (m + et).
  73. rlwe::StatusOr<std::vector<Int>> GetErrorAndMessage(
  74. const Key& key, const Ciphertext& ciphertext) {
  75. Polynomial error_and_message_ntt(key.Len(), key.ModulusParams());
  76. Polynomial key_powers = key.Key();
  77. for (int i = 0; i < ciphertext.Len(); i++) {
  78. // Extract component i.
  79. RLWE_ASSIGN_OR_RETURN(Polynomial ci, ciphertext.Component(i));
  80. if (i > 1) {
  81. RLWE_RETURN_IF_ERROR(
  82. key_powers.MulInPlace(key.Key(), key.ModulusParams()));
  83. }
  84. // Beyond c0, multiply the exponentiated key in.
  85. if (i > 0) {
  86. RLWE_RETURN_IF_ERROR(ci.MulInPlace(key_powers, key.ModulusParams()));
  87. }
  88. RLWE_RETURN_IF_ERROR(
  89. error_and_message_ntt.AddInPlace(ci, key.ModulusParams()));
  90. }
  91. auto error_and_message =
  92. error_and_message_ntt.InverseNtt(key.NttParams(), key.ModulusParams());
  93. // Convert the integers mod q to integers.
  94. std::vector<Int> error_and_message_ints(error_and_message.size(), 0);
  95. for (int i = 0; i < error_and_message.size(); i++) {
  96. error_and_message_ints[i] =
  97. error_and_message[i].ExportInt(key.ModulusParams());
  98. if (error_and_message_ints[i] > (key.ModulusParams()->modulus >> 1)) {
  99. error_and_message_ints[i] =
  100. key.ModulusParams()->modulus - error_and_message_ints[i];
  101. }
  102. }
  103. return error_and_message_ints;
  104. }
  105. };
  106. TYPED_TEST_SUITE(ErrorParamsTest, rlwe::testing::ModularIntTypes);
  107. TYPED_TEST(ErrorParamsTest, CreateError) {
  108. for (const auto& params :
  109. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  110. ASSERT_OK_AND_ASSIGN(auto context,
  111. rlwe::RlweContext<TypeParam>::Create(params));
  112. // large value for log_t
  113. const int log_t = context->GetModulusParams()->log_modulus;
  114. EXPECT_THAT(rlwe::ErrorParams<TypeParam>::Create(
  115. log_t, rlwe::testing::kDefaultVariance,
  116. context->GetModulusParams(), context->GetNttParams()),
  117. StatusIs(::absl::StatusCode::kInvalidArgument,
  118. HasSubstr(absl::StrCat(
  119. "The value log_t, ", log_t,
  120. ", must be smaller than log_modulus - 1, ",
  121. log_t - 1, "."))));
  122. }
  123. }
  124. TYPED_TEST(ErrorParamsTest, PlaintextError) {
  125. for (const auto& params :
  126. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  127. ASSERT_OK_AND_ASSIGN(auto context,
  128. rlwe::RlweContext<TypeParam>::Create(params));
  129. // Randomly sample polynomials and expect l-infinity norm is bounded by
  130. // b_plaintext.
  131. for (int i = 0; i < kSamples; i++) {
  132. // Samples a polynomial with kLogT and kDefaultCoeffs.
  133. auto plaintext = rlwe::testing::SamplePlaintext<TypeParam>(
  134. context->GetN(), context->GetT());
  135. // Expect that the norm of the coefficients of the plaintext is less than
  136. // b_plaintext.
  137. double norm = this->ComputeNorm(plaintext);
  138. EXPECT_LT(norm, context->GetErrorParams()->B_plaintext());
  139. }
  140. }
  141. }
  142. TYPED_TEST(ErrorParamsTest, EncryptionError) {
  143. for (const auto& params :
  144. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  145. ASSERT_OK_AND_ASSIGN(auto context,
  146. rlwe::RlweContext<TypeParam>::Create(params));
  147. ASSERT_OK_AND_ASSIGN(auto key, this->SampleKey(context.get()));
  148. // Randomly sample polynomials, decrypt, and compute the size of the result
  149. // before removing error.
  150. for (int i = 0; i < kSamples; i++) {
  151. // Expect that the norm of the coefficients of (m + et) is less than
  152. // b_encryption.
  153. auto plaintext = rlwe::testing::SamplePlaintext<TypeParam>(
  154. context->GetN(), context->GetT());
  155. ASSERT_OK_AND_ASSIGN(auto ciphertext,
  156. this->Encrypt(key, plaintext, context.get()));
  157. ASSERT_OK_AND_ASSIGN(auto error_and_message,
  158. this->GetErrorAndMessage(key, ciphertext));
  159. EXPECT_LT(this->ComputeNorm(error_and_message),
  160. context->GetErrorParams()->B_encryption());
  161. }
  162. }
  163. }
  164. TYPED_TEST(ErrorParamsTest, RelinearizationErrorScalesWithT) {
  165. for (const auto& params :
  166. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  167. ASSERT_OK_AND_ASSIGN(auto context,
  168. rlwe::RlweContext<TypeParam>::Create(params));
  169. // Error scales by (T / logT) when all other constants are fixed.
  170. int small_decomposition_modulus = 1;
  171. int large_decomposition_modulus = 10;
  172. EXPECT_LT(
  173. context->GetErrorParams()->B_relinearize(small_decomposition_modulus),
  174. context->GetErrorParams()->B_relinearize(large_decomposition_modulus));
  175. }
  176. }
  177. } // namespace
  178. #endif // RLWE_ERROR_PARAMS_TEST_H_