montgomery_test.cc 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. /*
  2. * Copyright 2017 Google LLC.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. *
  7. * https://www.apache.org/licenses/LICENSE-2.0
  8. *
  9. * Unless required by applicable law or agreed to in writing, software
  10. * distributed under the License is distributed on an "AS IS" BASIS,
  11. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. * See the License for the specific language governing permissions and
  13. * limitations under the License.
  14. */
  15. #include "montgomery.h"
  16. #include <cstdint>
  17. #include <limits>
  18. #include <list>
  19. #include <memory>
  20. #include <random>
  21. #include <gmock/gmock.h>
  22. #include <gtest/gtest.h>
  23. #include "absl/numeric/int128.h"
  24. #include "constants.h"
  25. #include "serialization.pb.h"
  26. #include "status_macros.h"
  27. #include "testing/parameters.h"
  28. #include "testing/status_matchers.h"
  29. #include "testing/status_testing.h"
  30. #include "testing/testing_prng.h"
  31. #include "third_party/shell-encryption/base/shell_encryption_export.h"
  32. #include "third_party/shell-encryption/base/shell_encryption_export_template.h"
  33. namespace rlwe {
  34. namespace {
  35. // Random uniform distribution for Uint64.
  36. std::uniform_int_distribution<Uint64> uniform_uint64;
  37. using ::rlwe::testing::StatusIs;
  38. using ::testing::HasSubstr;
  39. const Uint64 kTestingRounds = 10;
  40. const size_t kExhaustiveTest = 1000;
  41. // Generate a random integer of a specified number of bits.
  42. template <class TypeParam>
  43. TypeParam GenerateRandom(unsigned int* seed) {
  44. std::minstd_rand random(*seed);
  45. *seed += 1;
  46. return static_cast<TypeParam>(uniform_uint64(random));
  47. }
  48. // Specialization for absl::uint128 and uint256.
  49. template <>
  50. absl::uint128 GenerateRandom(unsigned int* seed) {
  51. Uint64 hi = GenerateRandom<Uint64>(seed);
  52. Uint64 lo = GenerateRandom<Uint64>(seed);
  53. return absl::MakeUint128(hi, lo);
  54. }
  55. template <>
  56. uint256 GenerateRandom(unsigned int* seed) {
  57. absl::uint128 hi = GenerateRandom<absl::uint128>(seed);
  58. absl::uint128 lo = GenerateRandom<absl::uint128>(seed);
  59. return uint256(hi, lo);
  60. }
  61. template <typename T>
  62. class EXPORT_TEMPLATE_DECLARE(SHELL_ENCRYPTION_EXPORT) MontgomeryTest : public ::testing::Test {};
  63. TYPED_TEST_SUITE(MontgomeryTest, testing::ModularIntTypes);
  64. TYPED_TEST(MontgomeryTest, ModulusTooLarge) {
  65. using Int = typename TypeParam::Int;
  66. unsigned int seed = 0;
  67. Int modulus;
  68. for (Int i = 0; i < kTestingRounds; ++i) {
  69. // Sample an invalid odd modulus in (max(Int)/2, max(Int)).
  70. modulus =
  71. (std::numeric_limits<Int>::max() / 2) +
  72. (GenerateRandom<Int>(&seed) % (std::numeric_limits<Int>::max() / 2));
  73. modulus |= 1; // Ensure that the modulus is odd.
  74. EXPECT_THAT(
  75. MontgomeryIntParams<Int>::Create(modulus),
  76. StatusIs(absl::StatusCode::kInvalidArgument,
  77. HasSubstr(absl::StrCat("The modulus should be less than 2^",
  78. (sizeof(Int) * 8 - 2), "."))));
  79. // Sample an even modulus in the allowed range.
  80. modulus =
  81. (GenerateRandom<Int>(&seed) % (std::numeric_limits<Int>::max() / 8))
  82. << 1;
  83. EXPECT_THAT(
  84. MontgomeryIntParams<Int>::Create(modulus),
  85. StatusIs(absl::StatusCode::kInvalidArgument,
  86. HasSubstr(absl::StrCat("The modulus should be odd."))));
  87. }
  88. }
  89. // Verifies that the MontgomeryIntParams code computes the inverse modulus.
  90. TYPED_TEST(MontgomeryTest, ParamsInvModulus) {
  91. using Int = typename TypeParam::Int;
  92. using BigInt = typename internal::BigInt<Int>::value_type;
  93. for (const auto& params :
  94. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  95. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  96. TypeParam::Params::Create(params.modulus));
  97. EXPECT_EQ(modulus_params->r * modulus_params->inv_r -
  98. static_cast<BigInt>(modulus_params->modulus) *
  99. modulus_params->inv_modulus,
  100. 1);
  101. }
  102. }
  103. // Verifies that numbers can be imported and exported properly.
  104. TYPED_TEST(MontgomeryTest, ImportExportInt) {
  105. using Int = typename TypeParam::Int;
  106. unsigned int seed = 0;
  107. for (const auto& params :
  108. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  109. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  110. TypeParam::Params::Create(params.modulus));
  111. for (Int i = 0; i < kTestingRounds; ++i) {
  112. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  113. ASSERT_OK_AND_ASSIGN(auto m,
  114. TypeParam::ImportInt(a, modulus_params.get()));
  115. Int after = m.ExportInt(modulus_params.get());
  116. EXPECT_EQ(after, a);
  117. }
  118. }
  119. }
  120. // Verifies that numbers can be added correctly.
  121. TYPED_TEST(MontgomeryTest, AddSub) {
  122. using Int = typename TypeParam::Int;
  123. // Test over a selection of the possible input space.
  124. unsigned int seed = 0;
  125. for (const auto& params :
  126. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  127. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  128. TypeParam::Params::Create(params.modulus));
  129. for (int i = 0; i < kTestingRounds; i++) {
  130. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  131. Int b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  132. ASSERT_OK_AND_ASSIGN(TypeParam ma,
  133. TypeParam::ImportInt(a, modulus_params.get()));
  134. ASSERT_OK_AND_ASSIGN(TypeParam mb,
  135. TypeParam::ImportInt(b, modulus_params.get()));
  136. TypeParam mc = ma.Add(mb, modulus_params.get());
  137. Int c = mc.ExportInt(modulus_params.get());
  138. Int expected = (a + b) % modulus_params->modulus;
  139. EXPECT_EQ(expected, c);
  140. TypeParam md = ma.Sub(mb, modulus_params.get());
  141. Int d = md.ExportInt(modulus_params.get());
  142. Int expected2 =
  143. (a + modulus_params->modulus - b) % modulus_params->modulus;
  144. EXPECT_EQ(expected2, d);
  145. }
  146. }
  147. }
  148. TYPED_TEST(MontgomeryTest, InlineAddSub) {
  149. using Int = typename TypeParam::Int;
  150. // Test over a selection of the possible input space.
  151. unsigned int seed = 0;
  152. for (const auto& params :
  153. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  154. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  155. TypeParam::Params::Create(params.modulus));
  156. for (int i = 0; i < kTestingRounds; i++) {
  157. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  158. Int b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  159. ASSERT_OK_AND_ASSIGN(TypeParam ma,
  160. TypeParam::ImportInt(a, modulus_params.get()));
  161. ASSERT_OK_AND_ASSIGN(TypeParam mb,
  162. TypeParam::ImportInt(b, modulus_params.get()));
  163. ma.AddInPlace(mb, modulus_params.get());
  164. Int c = ma.ExportInt(modulus_params.get());
  165. Int expected = (a + b) % modulus_params->modulus;
  166. EXPECT_EQ(expected, c);
  167. ASSERT_OK_AND_ASSIGN(ma, TypeParam::ImportInt(a, modulus_params.get()));
  168. ma.SubInPlace(mb, modulus_params.get());
  169. Int d = ma.ExportInt(modulus_params.get());
  170. Int expected2 =
  171. (a + modulus_params->modulus - b) % modulus_params->modulus;
  172. EXPECT_EQ(expected2, d);
  173. }
  174. }
  175. }
  176. // Verifies that equality functions properly.
  177. TYPED_TEST(MontgomeryTest, Equality) {
  178. using Int = typename TypeParam::Int;
  179. // Test over a selection of the possible input space.
  180. unsigned int seed = 0;
  181. for (const auto& params :
  182. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  183. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  184. TypeParam::Params::Create(params.modulus));
  185. for (int i = 0; i < kTestingRounds; i++) {
  186. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  187. Int b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  188. while (b == a) {
  189. b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  190. }
  191. ASSERT_OK_AND_ASSIGN(auto ma1,
  192. TypeParam::ImportInt(a, modulus_params.get()));
  193. ASSERT_OK_AND_ASSIGN(auto ma2,
  194. TypeParam::ImportInt(a, modulus_params.get()));
  195. ASSERT_OK_AND_ASSIGN(auto mb1,
  196. TypeParam::ImportInt(b, modulus_params.get()));
  197. ASSERT_OK_AND_ASSIGN(auto mb2,
  198. TypeParam::ImportInt(b, modulus_params.get()));
  199. EXPECT_TRUE(ma1 == ma2);
  200. EXPECT_TRUE(ma2 == ma1);
  201. EXPECT_FALSE(ma1 != ma2);
  202. EXPECT_FALSE(ma2 != ma1);
  203. EXPECT_TRUE(mb1 == mb2);
  204. EXPECT_TRUE(mb2 == mb1);
  205. EXPECT_FALSE(mb1 != mb2);
  206. EXPECT_FALSE(mb2 != mb1);
  207. EXPECT_TRUE(ma1 != mb1);
  208. EXPECT_TRUE(mb1 != ma1);
  209. EXPECT_FALSE(ma1 == mb1);
  210. EXPECT_FALSE(mb1 == ma1);
  211. }
  212. }
  213. }
  214. // Verifies that numbers can be negated correctly.
  215. TYPED_TEST(MontgomeryTest, Negate) {
  216. using Int = typename TypeParam::Int;
  217. for (const auto& params :
  218. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  219. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  220. TypeParam::Params::Create(params.modulus));
  221. for (Int i = 0; i < 4 * kNewhopeModulus; i++) {
  222. ASSERT_OK_AND_ASSIGN(auto mi,
  223. TypeParam::ImportInt(i, modulus_params.get()));
  224. EXPECT_EQ(0, mi.Add(mi.Negate(modulus_params.get()), modulus_params.get())
  225. .ExportInt(modulus_params.get()));
  226. }
  227. }
  228. }
  229. // Verifies that repeated addition works properly.
  230. TYPED_TEST(MontgomeryTest, AddRepeatedly) {
  231. using Int = typename TypeParam::Int;
  232. // Test over a selection of the possible input space.
  233. unsigned int seed = 0;
  234. for (const auto& params :
  235. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  236. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  237. TypeParam::Params::Create(params.modulus));
  238. for (int i = 0; i < kTestingRounds; i++) {
  239. Int sum = 0;
  240. Int diff = 0;
  241. TypeParam mont_sum = TypeParam::ImportZero(modulus_params.get());
  242. TypeParam mont_diff = TypeParam::ImportZero(modulus_params.get());
  243. for (int j = 0; j < 1000; j++) {
  244. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  245. sum = (sum + a) % modulus_params->modulus;
  246. ASSERT_OK_AND_ASSIGN(auto ma,
  247. TypeParam::ImportInt(a, modulus_params.get()));
  248. mont_sum = mont_sum.Add(ma, modulus_params.get());
  249. diff = (diff + modulus_params->modulus - a) % modulus_params->modulus;
  250. mont_diff = mont_diff.Sub(ma, modulus_params.get());
  251. }
  252. EXPECT_EQ(sum, mont_sum.ExportInt(modulus_params.get()));
  253. EXPECT_EQ(diff, mont_diff.ExportInt(modulus_params.get()));
  254. }
  255. }
  256. }
  257. // Verifies that numbers can be multiplied correctly.
  258. TYPED_TEST(MontgomeryTest, Multiply) {
  259. using Int = typename TypeParam::Int;
  260. using BigInt = typename internal::BigInt<Int>::value_type;
  261. // Test over a selection of the possible input space.
  262. unsigned int seed = 0;
  263. for (const auto& params :
  264. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  265. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  266. TypeParam::Params::Create(params.modulus));
  267. // Test over many random values.
  268. for (int i = 0; i < kTestingRounds; i++) {
  269. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  270. Int b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  271. ASSERT_OK_AND_ASSIGN(auto ma,
  272. TypeParam::ImportInt(a, modulus_params.get()));
  273. ASSERT_OK_AND_ASSIGN(auto mb,
  274. TypeParam::ImportInt(b, modulus_params.get()));
  275. TypeParam mc = ma.Mul(mb, modulus_params.get());
  276. Int c = mc.ExportInt(modulus_params.get());
  277. Int expected =
  278. static_cast<Int>((static_cast<BigInt>(a) * static_cast<BigInt>(b)) %
  279. static_cast<BigInt>(modulus_params->modulus));
  280. EXPECT_EQ(expected, c);
  281. }
  282. // Test the multiplication of the maximum values together.
  283. Int a = modulus_params->modulus - 1;
  284. ASSERT_OK_AND_ASSIGN(auto ma,
  285. TypeParam::ImportInt(a, modulus_params.get()));
  286. TypeParam mb = ma.Mul(ma, modulus_params.get());
  287. Int b = mb.ExportInt(modulus_params.get());
  288. EXPECT_EQ(1, b);
  289. }
  290. }
  291. TYPED_TEST(MontgomeryTest, MulInPlace) {
  292. using Int = typename TypeParam::Int;
  293. using BigInt = typename internal::BigInt<Int>::value_type;
  294. // Test over a selection of the possible input space.
  295. unsigned int seed = 0;
  296. for (const auto& params :
  297. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  298. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  299. TypeParam::Params::Create(params.modulus));
  300. // Test over many random values.
  301. for (int i = 0; i < kTestingRounds; i++) {
  302. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  303. Int b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  304. ASSERT_OK_AND_ASSIGN(auto ma,
  305. TypeParam::ImportInt(a, modulus_params.get()));
  306. ASSERT_OK_AND_ASSIGN(auto mb,
  307. TypeParam::ImportInt(b, modulus_params.get()));
  308. ma.MulInPlace(mb, modulus_params.get());
  309. Int c = ma.ExportInt(modulus_params.get());
  310. Int expected =
  311. static_cast<Int>((static_cast<BigInt>(a) * static_cast<BigInt>(b)) %
  312. static_cast<BigInt>(modulus_params->modulus));
  313. EXPECT_EQ(expected, c);
  314. }
  315. }
  316. }
  317. TYPED_TEST(MontgomeryTest, MulConstantInPlace) {
  318. using Int = typename TypeParam::Int;
  319. // Test over a selection of the possible input space.
  320. unsigned int seed = 0;
  321. for (const auto& params :
  322. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  323. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  324. TypeParam::Params::Create(params.modulus));
  325. // Test over many random values.
  326. for (int i = 0; i < kTestingRounds; i++) {
  327. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  328. Int b = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  329. ASSERT_OK_AND_ASSIGN(auto ma,
  330. TypeParam::ImportInt(a, modulus_params.get()));
  331. auto ma_clone = ma;
  332. ASSERT_OK_AND_ASSIGN(auto mb,
  333. TypeParam::ImportInt(b, modulus_params.get()));
  334. auto constants_tuple = mb.GetConstant(modulus_params.get());
  335. auto constant = std::get<0>(constants_tuple);
  336. auto constant_barrett = std::get<1>(constants_tuple);
  337. ma.MulInPlace(mb, modulus_params.get());
  338. ma_clone.MulConstantInPlace(constant, constant_barrett,
  339. modulus_params.get());
  340. EXPECT_EQ(ma.ExportInt(modulus_params.get()),
  341. ma_clone.ExportInt(modulus_params.get()));
  342. }
  343. }
  344. }
  345. // Verifies that repeated addition works properly.
  346. TYPED_TEST(MontgomeryTest, MultiplyRepeatedly) {
  347. using Int = typename TypeParam::Int;
  348. using BigInt = typename internal::BigInt<Int>::value_type;
  349. // Test over a selection of the possible input space.
  350. unsigned int seed = 0;
  351. for (const auto& params :
  352. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  353. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  354. TypeParam::Params::Create(params.modulus));
  355. for (int i = 0; i < kTestingRounds; i++) {
  356. BigInt prod = 1;
  357. TypeParam mont_prod = TypeParam::ImportOne(modulus_params.get());
  358. for (int j = 0; j < 1000; j++) {
  359. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  360. prod = (prod * static_cast<BigInt>(a)) %
  361. static_cast<BigInt>(modulus_params->modulus);
  362. ASSERT_OK_AND_ASSIGN(auto ma,
  363. TypeParam::ImportInt(a, modulus_params.get()));
  364. mont_prod = mont_prod.Mul(ma, modulus_params.get());
  365. }
  366. EXPECT_EQ(static_cast<Int>(prod),
  367. mont_prod.ExportInt(modulus_params.get()));
  368. }
  369. }
  370. }
  371. // Test the entire space for a small modulus.
  372. TYPED_TEST(MontgomeryTest, SmallModulus) {
  373. using Int = typename TypeParam::Int;
  374. using BigInt = typename internal::BigInt<Int>::value_type;
  375. for (const auto& params :
  376. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  377. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  378. TypeParam::Params::Create(params.modulus));
  379. const BigInt modulus = static_cast<BigInt>(modulus_params->modulus);
  380. for (Int a = 0; a < kExhaustiveTest; a++) {
  381. Int b = a + 1;
  382. BigInt a_BigInt = static_cast<BigInt>(a);
  383. BigInt b_BigInt = static_cast<BigInt>(b);
  384. ASSERT_OK_AND_ASSIGN(auto ma,
  385. TypeParam::ImportInt(a, modulus_params.get()));
  386. ASSERT_OK_AND_ASSIGN(auto mb,
  387. TypeParam::ImportInt(b, modulus_params.get()));
  388. TypeParam mc = ma.Add(mb, modulus_params.get());
  389. // Equality.
  390. if (a_BigInt % modulus == b_BigInt % modulus) {
  391. EXPECT_TRUE(ma == mb);
  392. EXPECT_FALSE(ma != mb);
  393. } else {
  394. EXPECT_TRUE(ma != mb);
  395. EXPECT_FALSE(ma == mb);
  396. }
  397. // Addition.
  398. EXPECT_EQ(static_cast<Int>((a_BigInt + b_BigInt) % modulus),
  399. mc.ExportInt(modulus_params.get()));
  400. // Negation.
  401. EXPECT_EQ(
  402. static_cast<Int>((2 * modulus - a_BigInt) % modulus),
  403. (ma.Negate(modulus_params.get())).ExportInt(modulus_params.get()));
  404. EXPECT_EQ(
  405. static_cast<Int>((2 * modulus - b_BigInt) % modulus),
  406. (mb.Negate(modulus_params.get())).ExportInt(modulus_params.get()));
  407. EXPECT_EQ(
  408. static_cast<Int>((4 * modulus - a_BigInt - b_BigInt) % modulus),
  409. (mc.Negate(modulus_params.get())).ExportInt(modulus_params.get()));
  410. // Subtraction.
  411. EXPECT_EQ(
  412. static_cast<Int>((2 * modulus - a_BigInt + b_BigInt) % modulus),
  413. (mb.Sub(ma, modulus_params.get()).ExportInt(modulus_params.get())));
  414. EXPECT_EQ(
  415. static_cast<Int>((2 * modulus - b_BigInt + a_BigInt) % modulus),
  416. (ma.Sub(mb, modulus_params.get()).ExportInt(modulus_params.get())));
  417. // Multiplication and commutativity.
  418. EXPECT_EQ(
  419. static_cast<Int>((a_BigInt * b_BigInt) % modulus),
  420. (ma.Mul(mb, modulus_params.get())).ExportInt(modulus_params.get()));
  421. EXPECT_EQ(
  422. static_cast<Int>((a_BigInt * b_BigInt) % modulus),
  423. (mb.Mul(ma, modulus_params.get())).ExportInt(modulus_params.get()));
  424. }
  425. }
  426. }
  427. TYPED_TEST(MontgomeryTest, ModExpModulus) {
  428. using Int = typename TypeParam::Int;
  429. using BigInt = typename internal::BigInt<Int>::value_type;
  430. // Test over a selection of the possible input space.
  431. unsigned int seed = 0;
  432. for (const auto& params :
  433. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  434. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  435. TypeParam::Params::Create(params.modulus));
  436. const BigInt modulus = static_cast<BigInt>(modulus_params->modulus);
  437. for (int i = 0; i < kTestingRounds; i++) {
  438. BigInt expected = 1;
  439. Int base = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  440. for (Int exp = 0; exp < kExhaustiveTest; exp++) {
  441. ASSERT_OK_AND_ASSIGN(auto base_m,
  442. TypeParam::ImportInt(base, modulus_params.get()));
  443. auto actual_m = base_m.ModExp(exp, modulus_params.get());
  444. Int actual = actual_m.ExportInt(modulus_params.get());
  445. ASSERT_EQ(actual, expected);
  446. expected *= static_cast<BigInt>(base);
  447. expected %= modulus;
  448. }
  449. }
  450. }
  451. }
  452. TYPED_TEST(MontgomeryTest, InverseModulus) {
  453. using Int = typename TypeParam::Int;
  454. // Test over a selection of the possible input space.
  455. unsigned int seed = 0;
  456. for (const auto& params :
  457. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  458. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  459. TypeParam::Params::Create(params.modulus));
  460. for (int i = 0; i < kTestingRounds; i++) {
  461. Int a = GenerateRandom<Int>(&seed) % modulus_params->modulus;
  462. ASSERT_OK_AND_ASSIGN(auto a_m,
  463. TypeParam::ImportInt(a, modulus_params.get()));
  464. TypeParam inv = a_m.MultiplicativeInverse(modulus_params.get());
  465. ASSERT_EQ(
  466. (a_m.Mul(inv, modulus_params.get())).ExportInt(modulus_params.get()),
  467. 1);
  468. }
  469. }
  470. }
  471. TYPED_TEST(MontgomeryTest, Serialization) {
  472. using Int = typename TypeParam::Int;
  473. for (const auto& params :
  474. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  475. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  476. TypeParam::Params::Create(params.modulus));
  477. for (Int i = 0; i < kExhaustiveTest; i++) {
  478. Int input_int = i % modulus_params->modulus;
  479. // Serialize and ensure the byte length is as expected.
  480. ASSERT_OK_AND_ASSIGN(
  481. auto int_value,
  482. TypeParam::ImportInt(input_int, modulus_params.get()));
  483. ASSERT_OK_AND_ASSIGN(std::string serialized,
  484. int_value.Serialize(modulus_params.get()));
  485. EXPECT_EQ(serialized.length(), modulus_params->SerializedSize());
  486. // Ensure that deserialization works properly.
  487. ASSERT_OK_AND_ASSIGN(
  488. auto int_deserialized,
  489. TypeParam::Deserialize(serialized, modulus_params.get()));
  490. EXPECT_EQ(int_deserialized, int_value);
  491. // Ensure that that any bit beyond bit the serialized bit length can be
  492. // wiped out without issue. That is, ensure that the bit size is accurate.
  493. serialized[serialized.length() - 1] &=
  494. (static_cast<Uint8>(1)
  495. << (modulus_params->log_modulus - 8 * (serialized.length() - 1))) -
  496. 1;
  497. ASSERT_OK_AND_ASSIGN(
  498. auto int_deserialized2,
  499. TypeParam::Deserialize(serialized, modulus_params.get()));
  500. EXPECT_EQ(int_deserialized2, int_value);
  501. }
  502. }
  503. }
  504. TYPED_TEST(MontgomeryTest, ExceedMaxNumCoeffVectorSerialization) {
  505. int num_coeffs = kMaxNumCoeffs + 1;
  506. for (const auto& params :
  507. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  508. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  509. TypeParam::Params::Create(params.modulus));
  510. std::vector<TypeParam> coeffs;
  511. for (int i = 0; i < num_coeffs; ++i) {
  512. coeffs.push_back(TypeParam::ImportOne(modulus_params.get()));
  513. }
  514. EXPECT_THAT(TypeParam::SerializeVector(coeffs, modulus_params.get()),
  515. StatusIs(absl::StatusCode::kInvalidArgument,
  516. HasSubstr(absl::StrCat(
  517. "Number of coefficients, ", num_coeffs,
  518. ", cannot be larger than ", kMaxNumCoeffs, "."))));
  519. }
  520. }
  521. TYPED_TEST(MontgomeryTest, EmptyVectorSerialization) {
  522. for (const auto& params :
  523. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  524. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  525. TypeParam::Params::Create(params.modulus));
  526. std::vector<TypeParam> coeffs;
  527. EXPECT_THAT(TypeParam::SerializeVector(coeffs, modulus_params.get()),
  528. StatusIs(absl::StatusCode::kInvalidArgument,
  529. HasSubstr("Cannot serialize an empty vector")));
  530. }
  531. }
  532. TYPED_TEST(MontgomeryTest, VectorSerialization) {
  533. // Prng to generate random values
  534. auto prng = absl::make_unique<rlwe::testing::TestingPrng>(0);
  535. for (const auto& params :
  536. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  537. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  538. TypeParam::Params::Create(params.modulus));
  539. for (int num_coeffs = 3; num_coeffs <= 25; ++num_coeffs) {
  540. std::vector<TypeParam> coeffs;
  541. coeffs.reserve(num_coeffs);
  542. for (int i = 0; i < num_coeffs; ++i) {
  543. ASSERT_OK_AND_ASSIGN(
  544. auto int_value,
  545. TypeParam::ImportRandom(prng.get(), modulus_params.get()));
  546. coeffs.push_back(int_value);
  547. }
  548. ASSERT_OK_AND_ASSIGN(
  549. std::string serialized,
  550. TypeParam::SerializeVector(coeffs, modulus_params.get()));
  551. int expected_size = (num_coeffs * modulus_params->log_modulus + 7) / 8;
  552. EXPECT_EQ(serialized.size(), expected_size);
  553. ASSERT_OK_AND_ASSIGN(auto deserialized,
  554. TypeParam::DeserializeVector(num_coeffs, serialized,
  555. modulus_params.get()));
  556. EXPECT_EQ(deserialized.size(), num_coeffs);
  557. for (int i = 0; i < num_coeffs; ++i) {
  558. EXPECT_EQ(coeffs[i], deserialized[i]);
  559. }
  560. }
  561. }
  562. }
  563. TYPED_TEST(MontgomeryTest, ExceedMaxNumCoeffVectorDeserialization) {
  564. int num_coeffs = kMaxNumCoeffs + 1;
  565. for (const auto& params :
  566. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  567. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  568. TypeParam::Params::Create(params.modulus));
  569. EXPECT_THAT(TypeParam::DeserializeVector(num_coeffs, std::string(),
  570. modulus_params.get()),
  571. StatusIs(absl::StatusCode::kInvalidArgument,
  572. HasSubstr(absl::StrCat(
  573. "Number of coefficients, ", num_coeffs,
  574. ", cannot be larger than ", kMaxNumCoeffs, "."))));
  575. }
  576. }
  577. TYPED_TEST(MontgomeryTest, NegativeVectorDeserialization) {
  578. int num_coeffs = -1;
  579. for (const auto& params :
  580. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  581. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  582. TypeParam::Params::Create(params.modulus));
  583. EXPECT_THAT(
  584. TypeParam::DeserializeVector(num_coeffs, std::string(),
  585. modulus_params.get()),
  586. StatusIs(absl::StatusCode::kInvalidArgument,
  587. HasSubstr("Number of coefficients must be non-negative.")));
  588. }
  589. }
  590. TYPED_TEST(MontgomeryTest, ImportRandomWithPrngWithSameKeys) {
  591. unsigned seed = 0;
  592. unsigned int seed_prng = GenerateRandom<unsigned int>(&seed);
  593. for (const auto& params :
  594. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  595. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  596. TypeParam::Params::Create(params.modulus));
  597. auto prng1 = absl::make_unique<rlwe::testing::TestingPrng>(seed_prng);
  598. auto prng2 = absl::make_unique<rlwe::testing::TestingPrng>(seed_prng);
  599. ASSERT_OK_AND_ASSIGN(
  600. auto r1, TypeParam::ImportRandom(prng1.get(), modulus_params.get()));
  601. ASSERT_OK_AND_ASSIGN(
  602. auto r2, TypeParam::ImportRandom(prng2.get(), modulus_params.get()));
  603. EXPECT_EQ(r1, r2);
  604. }
  605. }
  606. TYPED_TEST(MontgomeryTest, ImportRandomWithPrngWithDifferentKeys) {
  607. unsigned seed = 0;
  608. unsigned int seed_prng1 = GenerateRandom<unsigned int>(&seed);
  609. unsigned int seed_prng2 = seed_prng1 + 1; // Different seed
  610. for (const auto& params :
  611. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  612. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  613. TypeParam::Params::Create(params.modulus));
  614. auto prng1 = absl::make_unique<rlwe::testing::TestingPrng>(seed_prng1);
  615. auto prng2 = absl::make_unique<rlwe::testing::TestingPrng>(seed_prng2);
  616. ASSERT_OK_AND_ASSIGN(
  617. auto r1, TypeParam::ImportRandom(prng1.get(), modulus_params.get()));
  618. ASSERT_OK_AND_ASSIGN(
  619. auto r2, TypeParam::ImportRandom(prng2.get(), modulus_params.get()));
  620. EXPECT_NE(r1, r2);
  621. }
  622. }
  623. // Verifies that Barrett reductions functions properly.
  624. TYPED_TEST(MontgomeryTest, VerifyBarrett) {
  625. using Int = typename TypeParam::Int;
  626. using BigInt = typename internal::BigInt<Int>::value_type;
  627. for (const auto& params :
  628. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  629. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  630. TypeParam::Params::Create(params.modulus));
  631. // Test over a selection of the possible input space.
  632. for (unsigned int j = 0; j < kTestingRounds; j++) {
  633. unsigned int seed = j;
  634. for (int i = 0; i < kTestingRounds; i++) {
  635. // Verify Barrett reduction up to max(Int).
  636. Int a = modulus_params->modulus +
  637. (GenerateRandom<Int>(&seed) %
  638. (std::numeric_limits<Int>::max() - modulus_params->modulus));
  639. EXPECT_EQ(modulus_params->BarrettReduce(a), a % params.modulus);
  640. }
  641. }
  642. }
  643. }
  644. TYPED_TEST(MontgomeryTest, BatchOperations) {
  645. using Int = typename TypeParam::Int;
  646. unsigned int seed = 0;
  647. unsigned int seed_prng = GenerateRandom<unsigned int>(&seed);
  648. auto prng = absl::make_unique<rlwe::testing::TestingPrng>(seed_prng);
  649. for (const auto& params :
  650. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  651. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  652. TypeParam::Params::Create(params.modulus));
  653. for (size_t length : {1, 2, 7, 32, 500, 1024}) {
  654. std::vector<TypeParam> a, b;
  655. std::vector<Int> b_constant, b_constant_barrett;
  656. std::vector<TypeParam> expected_add, expected_sub, expected_mul;
  657. TypeParam scalar =
  658. TypeParam::ImportRandom(prng.get(), modulus_params.get())
  659. .value();
  660. auto scalar_constants_tuple = scalar.GetConstant(modulus_params.get());
  661. auto scalar_constant = std::get<0>(scalar_constants_tuple);
  662. auto scalar_constant_barrett = std::get<1>(scalar_constants_tuple);
  663. std::vector<TypeParam> expected_add_scalar, expected_sub_scalar,
  664. expected_mul_scalar;
  665. for (size_t i = 0; i < length; i++) {
  666. a.push_back(TypeParam::ImportRandom(prng.get(), modulus_params.get())
  667. .value());
  668. b.push_back(TypeParam::ImportRandom(prng.get(), modulus_params.get())
  669. .value());
  670. auto constants_tuple = b[i].GetConstant(modulus_params.get());
  671. auto constant = std::get<0>(constants_tuple);
  672. auto constant_barrett = std::get<1>(constants_tuple);
  673. b_constant.push_back(constant);
  674. b_constant_barrett.push_back(constant_barrett);
  675. expected_add.push_back(a[i].Add(b[i], modulus_params.get()));
  676. expected_sub.push_back(a[i].Sub(b[i], modulus_params.get()));
  677. expected_mul.push_back(a[i].Mul(b[i], modulus_params.get()));
  678. expected_add_scalar.push_back(a[i].Add(scalar, modulus_params.get()));
  679. expected_sub_scalar.push_back(a[i].Sub(scalar, modulus_params.get()));
  680. expected_mul_scalar.push_back(a[i].Mul(scalar, modulus_params.get()));
  681. }
  682. ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> add,
  683. TypeParam::BatchAdd(a, b, modulus_params.get()));
  684. ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> sub,
  685. TypeParam::BatchSub(a, b, modulus_params.get()));
  686. ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> mul,
  687. TypeParam::BatchMul(a, b, modulus_params.get()));
  688. ASSERT_OK_AND_ASSIGN(
  689. std::vector<TypeParam> mul_constant,
  690. TypeParam::BatchMulConstant(a, b_constant, b_constant_barrett,
  691. modulus_params.get()));
  692. ASSERT_OK_AND_ASSIGN(
  693. std::vector<TypeParam> add_scalar,
  694. TypeParam::BatchAdd(a, scalar, modulus_params.get()));
  695. ASSERT_OK_AND_ASSIGN(
  696. std::vector<TypeParam> sub_scalar,
  697. TypeParam::BatchSub(a, scalar, modulus_params.get()));
  698. ASSERT_OK_AND_ASSIGN(
  699. std::vector<TypeParam> mul_scalar,
  700. TypeParam::BatchMul(a, scalar, modulus_params.get()));
  701. ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> mul_scalar_constant,
  702. TypeParam::BatchMulConstant(a, scalar_constant,
  703. scalar_constant_barrett,
  704. modulus_params.get()));
  705. EXPECT_EQ(add.size(), expected_add.size());
  706. EXPECT_EQ(sub.size(), expected_sub.size());
  707. EXPECT_EQ(mul.size(), expected_mul.size());
  708. EXPECT_EQ(mul_constant.size(), expected_mul.size());
  709. EXPECT_EQ(add_scalar.size(), expected_add_scalar.size());
  710. EXPECT_EQ(sub_scalar.size(), expected_sub_scalar.size());
  711. EXPECT_EQ(mul_scalar.size(), expected_mul_scalar.size());
  712. EXPECT_EQ(mul_scalar_constant.size(), expected_mul_scalar.size());
  713. for (size_t i = 0; i < length; i++) {
  714. EXPECT_EQ(add[i].ExportInt(modulus_params.get()),
  715. expected_add[i].ExportInt(modulus_params.get()));
  716. EXPECT_EQ(sub[i].ExportInt(modulus_params.get()),
  717. expected_sub[i].ExportInt(modulus_params.get()));
  718. EXPECT_EQ(mul[i].ExportInt(modulus_params.get()),
  719. expected_mul[i].ExportInt(modulus_params.get()));
  720. EXPECT_EQ(mul_constant[i].ExportInt(modulus_params.get()),
  721. expected_mul[i].ExportInt(modulus_params.get()));
  722. EXPECT_EQ(add_scalar[i].ExportInt(modulus_params.get()),
  723. expected_add_scalar[i].ExportInt(modulus_params.get()));
  724. EXPECT_EQ(sub_scalar[i].ExportInt(modulus_params.get()),
  725. expected_sub_scalar[i].ExportInt(modulus_params.get()));
  726. EXPECT_EQ(mul_scalar[i].ExportInt(modulus_params.get()),
  727. expected_mul_scalar[i].ExportInt(modulus_params.get()));
  728. EXPECT_EQ(mul_scalar_constant[i].ExportInt(modulus_params.get()),
  729. expected_mul_scalar[i].ExportInt(modulus_params.get()));
  730. }
  731. }
  732. }
  733. }
  734. TYPED_TEST(MontgomeryTest, BatchOperationsFailsWithVectorsOfDifferentSize) {
  735. using Int = typename TypeParam::Int;
  736. for (const auto& params :
  737. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  738. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  739. TypeParam::Params::Create(params.modulus));
  740. for (size_t length_a : {1, 2, 7, 32, 500, 1024}) {
  741. for (size_t length_b : {1, 2, 7, 32, 500, 1024}) {
  742. if (length_a != length_b) {
  743. std::vector<TypeParam> a(length_a,
  744. TypeParam::ImportZero(modulus_params.get()));
  745. std::vector<Int> a_constant(length_a, static_cast<Int>(0));
  746. std::vector<TypeParam> b(length_b,
  747. TypeParam::ImportZero(modulus_params.get()));
  748. std::vector<Int> b_constant(length_b, static_cast<Int>(0));
  749. EXPECT_THAT(
  750. TypeParam::BatchAdd(a, b, modulus_params.get()),
  751. StatusIs(absl::StatusCode::kInvalidArgument,
  752. HasSubstr("Input vectors are not of same size")));
  753. EXPECT_THAT(
  754. TypeParam::BatchAddInPlace(&a, b, modulus_params.get()),
  755. StatusIs(absl::StatusCode::kInvalidArgument,
  756. HasSubstr("Input vectors are not of same size")));
  757. EXPECT_THAT(
  758. TypeParam::BatchSub(a, b, modulus_params.get()),
  759. StatusIs(absl::StatusCode::kInvalidArgument,
  760. HasSubstr("Input vectors are not of same size")));
  761. EXPECT_THAT(
  762. TypeParam::BatchSubInPlace(&a, b, modulus_params.get()),
  763. StatusIs(absl::StatusCode::kInvalidArgument,
  764. HasSubstr("Input vectors are not of same size")));
  765. EXPECT_THAT(
  766. TypeParam::BatchMul(a, b, modulus_params.get()),
  767. StatusIs(absl::StatusCode::kInvalidArgument,
  768. HasSubstr("Input vectors are not of same size")));
  769. EXPECT_THAT(
  770. TypeParam::BatchMulInPlace(&a, b, modulus_params.get()),
  771. StatusIs(absl::StatusCode::kInvalidArgument,
  772. HasSubstr("Input vectors are not of same size")));
  773. EXPECT_THAT(
  774. TypeParam::BatchMulConstant(a, b_constant, b_constant,
  775. modulus_params.get()),
  776. StatusIs(absl::StatusCode::kInvalidArgument,
  777. HasSubstr("Input vectors are not of same size")));
  778. EXPECT_THAT(
  779. TypeParam::BatchMulConstantInPlace(&a, b_constant, b_constant,
  780. modulus_params.get()),
  781. StatusIs(absl::StatusCode::kInvalidArgument,
  782. HasSubstr("Input vectors are not of same size")));
  783. EXPECT_THAT(
  784. TypeParam::BatchMulConstantInPlace(&a, a_constant, b_constant,
  785. modulus_params.get()),
  786. StatusIs(absl::StatusCode::kInvalidArgument,
  787. HasSubstr("Input vectors are not of same size")));
  788. }
  789. }
  790. }
  791. }
  792. }
  793. // This PRNG tests templating with a Prng that does not inherit from SecurePrng.
  794. class FakePrng {
  795. public:
  796. StatusOr<Uint8> Rand8() { return 0; }
  797. StatusOr<Uint64> Rand64() { return 0; }
  798. };
  799. TYPED_TEST(MontgomeryTest, PrngTemplateParameterizationWorks) {
  800. for (const auto& params :
  801. rlwe::testing::ContextParameters<TypeParam>::Value()) {
  802. ASSERT_OK_AND_ASSIGN(auto modulus_params,
  803. TypeParam::Params::Create(params.modulus));
  804. FakePrng prng;
  805. ASSERT_OK(TypeParam::ImportRandom(&prng, modulus_params.get()));
  806. }
  807. }
  808. } // namespace
  809. } // namespace rlwe