123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- /*
- * Copyright 2018 Google LLC.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef RLWE_ERROR_PARAMS_TEST_H_
- #define RLWE_ERROR_PARAMS_TEST_H_
- #include "error_params.h"
- #include <gmock/gmock.h>
- #include <gtest/gtest.h>
- #include "constants.h"
- #include "context.h"
- #include "montgomery.h"
- #include "ntt_parameters.h"
- #include "prng/integral_prng_types.h"
- #include "status_macros.h"
- #include "symmetric_encryption.h"
- #include "testing/parameters.h"
- #include "testing/status_matchers.h"
- #include "testing/status_testing.h"
- #include "testing/testing_prng.h"
- #include "testing/testing_utils.h"
- namespace {
- using ::rlwe::testing::StatusIs;
- using ::testing::HasSubstr;
- // Number of samples used to compute the actual variance.
- const rlwe::Uint64 kSamples = 50;
- template <typename ModularInt>
- class ErrorParamsTest : public testing::Test {
- using Int = typename ModularInt::Int;
- using Polynomial = rlwe::Polynomial<ModularInt>;
- using Ciphertext = rlwe::SymmetricRlweCiphertext<ModularInt>;
- using Key = rlwe::SymmetricRlweKey<ModularInt>;
- public:
- // Computes the l-infinity norm of a vector of Ints.
- double ComputeNorm(const std::vector<Int>& coeffs) {
- return static_cast<double>(*std::max_element(coeffs.begin(), coeffs.end()));
- }
- // Sample a random key.
- rlwe::StatusOr<Key> SampleKey(const rlwe::RlweContext<ModularInt>* context) {
- RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
- rlwe::SingleThreadPrng::GenerateSeed());
- RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
- return Key::Sample(context->GetLogN(), context->GetVariance(),
- context->GetLogT(), context->GetModulusParams(),
- context->GetNttParams(), prng.get());
- }
- // Encrypt a plaintext.
- rlwe::StatusOr<Ciphertext> Encrypt(
- const Key& key, const std::vector<Int>& plaintext,
- const rlwe::RlweContext<ModularInt>* context) {
- RLWE_ASSIGN_OR_RETURN(auto m_p,
- rlwe::testing::ConvertToMontgomery<ModularInt>(
- plaintext, context->GetModulusParams()));
- auto plaintext_ntt = Polynomial::ConvertToNtt(m_p, context->GetNttParams(),
- context->GetModulusParams());
- RLWE_ASSIGN_OR_RETURN(std::string prng_seed,
- rlwe::SingleThreadPrng::GenerateSeed());
- RLWE_ASSIGN_OR_RETURN(auto prng, rlwe::SingleThreadPrng::Create(prng_seed));
- return rlwe::Encrypt<ModularInt>(key, plaintext_ntt,
- context->GetErrorParams(), prng.get());
- }
- // Decrypt without removing the error, returning (m + et).
- rlwe::StatusOr<std::vector<Int>> GetErrorAndMessage(
- const Key& key, const Ciphertext& ciphertext) {
- Polynomial error_and_message_ntt(key.Len(), key.ModulusParams());
- Polynomial key_powers = key.Key();
- for (int i = 0; i < ciphertext.Len(); i++) {
- // Extract component i.
- RLWE_ASSIGN_OR_RETURN(Polynomial ci, ciphertext.Component(i));
- if (i > 1) {
- RLWE_RETURN_IF_ERROR(
- key_powers.MulInPlace(key.Key(), key.ModulusParams()));
- }
- // Beyond c0, multiply the exponentiated key in.
- if (i > 0) {
- RLWE_RETURN_IF_ERROR(ci.MulInPlace(key_powers, key.ModulusParams()));
- }
- RLWE_RETURN_IF_ERROR(
- error_and_message_ntt.AddInPlace(ci, key.ModulusParams()));
- }
- auto error_and_message =
- error_and_message_ntt.InverseNtt(key.NttParams(), key.ModulusParams());
- // Convert the integers mod q to integers.
- std::vector<Int> error_and_message_ints(error_and_message.size(), 0);
- for (int i = 0; i < error_and_message.size(); i++) {
- error_and_message_ints[i] =
- error_and_message[i].ExportInt(key.ModulusParams());
- if (error_and_message_ints[i] > (key.ModulusParams()->modulus >> 1)) {
- error_and_message_ints[i] =
- key.ModulusParams()->modulus - error_and_message_ints[i];
- }
- }
- return error_and_message_ints;
- }
- };
- TYPED_TEST_SUITE(ErrorParamsTest, rlwe::testing::ModularIntTypes);
- TYPED_TEST(ErrorParamsTest, CreateError) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- ASSERT_OK_AND_ASSIGN(auto context,
- rlwe::RlweContext<TypeParam>::Create(params));
- // large value for log_t
- const int log_t = context->GetModulusParams()->log_modulus;
- EXPECT_THAT(rlwe::ErrorParams<TypeParam>::Create(
- log_t, rlwe::testing::kDefaultVariance,
- context->GetModulusParams(), context->GetNttParams()),
- StatusIs(::absl::StatusCode::kInvalidArgument,
- HasSubstr(absl::StrCat(
- "The value log_t, ", log_t,
- ", must be smaller than log_modulus - 1, ",
- log_t - 1, "."))));
- }
- }
- TYPED_TEST(ErrorParamsTest, PlaintextError) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- ASSERT_OK_AND_ASSIGN(auto context,
- rlwe::RlweContext<TypeParam>::Create(params));
- // Randomly sample polynomials and expect l-infinity norm is bounded by
- // b_plaintext.
- for (int i = 0; i < kSamples; i++) {
- // Samples a polynomial with kLogT and kDefaultCoeffs.
- auto plaintext = rlwe::testing::SamplePlaintext<TypeParam>(
- context->GetN(), context->GetT());
- // Expect that the norm of the coefficients of the plaintext is less than
- // b_plaintext.
- double norm = this->ComputeNorm(plaintext);
- EXPECT_LT(norm, context->GetErrorParams()->B_plaintext());
- }
- }
- }
- TYPED_TEST(ErrorParamsTest, EncryptionError) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- ASSERT_OK_AND_ASSIGN(auto context,
- rlwe::RlweContext<TypeParam>::Create(params));
- ASSERT_OK_AND_ASSIGN(auto key, this->SampleKey(context.get()));
- // Randomly sample polynomials, decrypt, and compute the size of the result
- // before removing error.
- for (int i = 0; i < kSamples; i++) {
- // Expect that the norm of the coefficients of (m + et) is less than
- // b_encryption.
- auto plaintext = rlwe::testing::SamplePlaintext<TypeParam>(
- context->GetN(), context->GetT());
- ASSERT_OK_AND_ASSIGN(auto ciphertext,
- this->Encrypt(key, plaintext, context.get()));
- ASSERT_OK_AND_ASSIGN(auto error_and_message,
- this->GetErrorAndMessage(key, ciphertext));
- EXPECT_LT(this->ComputeNorm(error_and_message),
- context->GetErrorParams()->B_encryption());
- }
- }
- }
- TYPED_TEST(ErrorParamsTest, RelinearizationErrorScalesWithT) {
- for (const auto& params :
- rlwe::testing::ContextParameters<TypeParam>::Value()) {
- ASSERT_OK_AND_ASSIGN(auto context,
- rlwe::RlweContext<TypeParam>::Create(params));
- // Error scales by (T / logT) when all other constants are fixed.
- int small_decomposition_modulus = 1;
- int large_decomposition_modulus = 10;
- EXPECT_LT(
- context->GetErrorParams()->B_relinearize(small_decomposition_modulus),
- context->GetErrorParams()->B_relinearize(large_decomposition_modulus));
- }
- }
- } // namespace
- #endif // RLWE_ERROR_PARAMS_TEST_H_
|