montgomery.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. // Copyright 2020 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "montgomery.h"
  15. #include "third_party/shell-encryption/base/shell_encryption_export.h"
  16. #include "third_party/shell-encryption/base/shell_encryption_export_template.h"
  17. #include "transcription.h"
  18. namespace rlwe {
  19. template <typename T>
  20. rlwe::StatusOr<std::unique_ptr<const MontgomeryIntParams<T>>>
  21. MontgomeryIntParams<T>::Create(Int modulus) {
  22. // Check that the modulus is smaller than max(Int) / 4.
  23. Int most_significant_bit = modulus >> (bitsize_int - 2);
  24. if (most_significant_bit != 0) {
  25. return absl::InvalidArgumentError(absl::StrCat(
  26. "The modulus should be less than 2^", (bitsize_int - 2), "."));
  27. }
  28. if ((modulus % 2) == 0) {
  29. return absl::InvalidArgumentError(
  30. absl::StrCat("The modulus should be odd."));
  31. }
  32. return absl::WrapUnique<const MontgomeryIntParams>(
  33. new MontgomeryIntParams(modulus));
  34. }
  35. // From Hacker's Delight.
  36. template <typename T>
  37. std::tuple<T, T> MontgomeryIntParams<T>::Inverses(BigInt modulus_bigint,
  38. BigInt r) {
  39. // Invariants
  40. // 1) sum = x * 2^w - y * modulus.
  41. // 2) sum is always a power of 2.
  42. // 3) modulus is odd.
  43. // 4) y is always even.
  44. // sum will decrease from 2^w to 2^0 = 1
  45. BigInt x = 1;
  46. BigInt y = 0;
  47. for (int i = bitsize_int; i > 0; i--) {
  48. // Ensure that x is even.
  49. if ((x & 1) == 1) {
  50. // If x is odd, make x even by adding modulus to x and changing the
  51. // value of y accordingly (y remains even).
  52. //
  53. // sum = x * 2^w - y * modulus
  54. // sum = (x + modulus) * 2^w - (y + 2^w) * modulus
  55. //
  56. // We can then divide the new values of x and y by 2 safely.
  57. x += modulus_bigint;
  58. y += r;
  59. }
  60. // Divide x and y by 2
  61. x >>= 1;
  62. y >>= 1;
  63. }
  64. // Return the inverses
  65. return std::make_tuple(static_cast<Int>(x), static_cast<Int>(y));
  66. }
  67. template <typename T>
  68. rlwe::StatusOr<MontgomeryInt<T>> MontgomeryInt<T>::ImportInt(
  69. Int n, const Params* params) {
  70. BigInt product = static_cast<BigInt>(params->r_mod_modulus_barrett) * n;
  71. Int result = static_cast<Int>(product >> Params::bitsize_int);
  72. result = n * params->r_mod_modulus - result * params->modulus;
  73. // The steps above produce an integer that is in the range [0, 2N).
  74. // We now reduce to the range [0, N).
  75. result -= (result >= params->modulus) ? params->modulus : 0;
  76. return MontgomeryInt(result);
  77. }
  78. template <typename T>
  79. MontgomeryInt<T> MontgomeryInt<T>::ImportZero(const Params* params) {
  80. return MontgomeryInt(params->Zero());
  81. }
  82. template <typename T>
  83. MontgomeryInt<T> MontgomeryInt<T>::ImportOne(const Params* params) {
  84. // 1 should be multiplied by r_mod_modulus; we load directly r_mod_modulus.
  85. return MontgomeryInt(static_cast<Int>(params->r_mod_modulus));
  86. }
  87. template <typename T>
  88. typename internal::BigInt<T>::value_type MontgomeryInt<T>::DivAndTruncate(
  89. BigInt dividend, BigInt divisor) {
  90. return dividend / divisor;
  91. }
  92. template <typename T>
  93. rlwe::StatusOr<std::string> MontgomeryInt<T>::Serialize(
  94. const Params* params) const {
  95. // Use transcription to transform all the LogModulus() bits of input into a
  96. // vector of unsigned char.
  97. RLWE_ASSIGN_OR_RETURN(
  98. auto v, (TranscribeBits<Int, Uint8>({this->n_}, params->log_modulus,
  99. params->log_modulus, 8)));
  100. // Return a string
  101. return std::string(std::make_move_iterator(v.begin()),
  102. std::make_move_iterator(v.end()));
  103. }
  104. template <typename T>
  105. rlwe::StatusOr<std::string> MontgomeryInt<T>::SerializeVector(
  106. const std::vector<MontgomeryInt>& coeffs, const Params* params) {
  107. if (coeffs.size() > kMaxNumCoeffs) {
  108. return absl::InvalidArgumentError(
  109. absl::StrCat("Number of coefficients, ", coeffs.size(),
  110. ", cannot be larger than ", kMaxNumCoeffs, "."));
  111. } else if (coeffs.empty()) {
  112. return absl::InvalidArgumentError("Cannot serialize an empty vector.");
  113. }
  114. // Bits required to represent modulus.
  115. int bit_size = params->log_modulus;
  116. // Extract the values
  117. std::vector<Int> coeffs_values;
  118. coeffs_values.reserve(coeffs.size());
  119. for (const auto& c : coeffs) {
  120. coeffs_values.push_back(c.n_);
  121. }
  122. // Use transcription to transform all the bit_size bits of input into a
  123. // vector of unsigned char.
  124. RLWE_ASSIGN_OR_RETURN(
  125. auto v,
  126. (TranscribeBits<Int, Uint8>(
  127. coeffs_values, coeffs_values.size() * bit_size, bit_size, 8)));
  128. // Return a string
  129. return std::string(std::make_move_iterator(v.begin()),
  130. std::make_move_iterator(v.end()));
  131. }
  132. template <typename T>
  133. rlwe::StatusOr<MontgomeryInt<T>> MontgomeryInt<T>::Deserialize(
  134. absl::string_view payload, const Params* params) {
  135. // Parse the string as unsigned char
  136. std::vector<Uint8> input(payload.begin(), payload.end());
  137. // Bits required to represent modulus.
  138. int bit_size = params->log_modulus;
  139. // Recover the coefficients from the input stream.
  140. RLWE_ASSIGN_OR_RETURN(auto coeffs_values, (TranscribeBits<Uint8, Int>(
  141. input, bit_size, 8, bit_size)));
  142. // There will be at least one coefficient in coeff_values because bit_size
  143. // is always expected to be positive.
  144. return MontgomeryInt(coeffs_values[0]);
  145. }
  146. template <typename T>
  147. rlwe::StatusOr<std::vector<MontgomeryInt<T>>>
  148. MontgomeryInt<T>::DeserializeVector(int num_coeffs,
  149. absl::string_view serialized,
  150. const Params* params) {
  151. if (num_coeffs < 0) {
  152. return absl::InvalidArgumentError(
  153. "Number of coefficients must be non-negative.");
  154. }
  155. if (num_coeffs > kMaxNumCoeffs) {
  156. return absl::InvalidArgumentError(
  157. absl::StrCat("Number of coefficients, ", num_coeffs, ", cannot be ",
  158. "larger than ", kMaxNumCoeffs, "."));
  159. }
  160. // Parse the string as unsigned char
  161. std::vector<Uint8> input(serialized.begin(), serialized.end());
  162. // Bits required to represent modulus.
  163. int bit_size = params->log_modulus;
  164. // Recover the coefficients from the input stream.
  165. RLWE_ASSIGN_OR_RETURN(
  166. auto coeffs_values,
  167. (TranscribeBits<Uint8, Int>(input, bit_size * num_coeffs, 8, bit_size)));
  168. // Check that the number of coefficients recovered is at least what is
  169. // expected.
  170. if (coeffs_values.size() < num_coeffs) {
  171. return absl::InvalidArgumentError("Given serialization is invalid.");
  172. }
  173. // Create a vector of Montgomery Int from the values.
  174. std::vector<MontgomeryInt> coeffs;
  175. coeffs.reserve(num_coeffs);
  176. for (int i = 0; i < num_coeffs; i++) {
  177. coeffs.push_back(MontgomeryInt(coeffs_values[i]));
  178. }
  179. return coeffs;
  180. }
  181. template <typename T>
  182. std::tuple<T, T> MontgomeryInt<T>::GetConstant(const Params* params) const {
  183. Int constant = ExportInt(params);
  184. Int constant_barrett = static_cast<Int>(
  185. (static_cast<BigInt>(constant) << params->bitsize_int) / params->modulus);
  186. return std::make_tuple(constant, constant_barrett);
  187. }
  188. template <typename T>
  189. rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchAdd(
  190. const std::vector<MontgomeryInt>& in1,
  191. const std::vector<MontgomeryInt>& in2, const Params* params) {
  192. std::vector<MontgomeryInt> out = in1;
  193. RLWE_RETURN_IF_ERROR(BatchAddInPlace(&out, in2, params));
  194. return out;
  195. }
  196. template <typename T>
  197. absl::Status MontgomeryInt<T>::BatchAddInPlace(
  198. std::vector<MontgomeryInt>* in1, const std::vector<MontgomeryInt>& in2,
  199. const Params* params) {
  200. // If the input vectors' sizes don't match, return an error.
  201. if (in1->size() != in2.size()) {
  202. return absl::InvalidArgumentError("Input vectors are not of same size");
  203. }
  204. int i = 0;
  205. // The remaining elements, if any, are added in place sequentially.
  206. for (; i < in1->size(); i++) {
  207. (*in1)[i].AddInPlace(in2[i], params);
  208. }
  209. return absl::OkStatus();
  210. }
  211. template <typename T>
  212. rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchAdd(
  213. const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
  214. const Params* params) {
  215. std::vector<MontgomeryInt> out = in1;
  216. RLWE_RETURN_IF_ERROR(BatchAddInPlace(&out, in2, params));
  217. return out;
  218. }
  219. template <typename T>
  220. absl::Status MontgomeryInt<T>::BatchAddInPlace(std::vector<MontgomeryInt>* in1,
  221. const MontgomeryInt& in2,
  222. const Params* params) {
  223. int i = 0;
  224. std::for_each(in1->begin() + i, in1->end(),
  225. [&in2 = in2, params](MontgomeryInt& coeff) {
  226. coeff.AddInPlace(in2, params);
  227. });
  228. return absl::OkStatus();
  229. }
  230. template <typename T>
  231. rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchSub(
  232. const std::vector<MontgomeryInt>& in1,
  233. const std::vector<MontgomeryInt>& in2, const Params* params) {
  234. std::vector<MontgomeryInt> out = in1;
  235. RLWE_RETURN_IF_ERROR(BatchSubInPlace(&out, in2, params));
  236. return out;
  237. }
  238. template <typename T>
  239. absl::Status MontgomeryInt<T>::BatchSubInPlace(
  240. std::vector<MontgomeryInt>* in1, const std::vector<MontgomeryInt>& in2,
  241. const Params* params) {
  242. // If the input vectors' sizes don't match, return an error.
  243. if (in1->size() != in2.size()) {
  244. return absl::InvalidArgumentError("Input vectors are not of same size");
  245. }
  246. int i = 0;
  247. for (; i < in1->size(); i++) {
  248. (*in1)[i].SubInPlace(in2[i], params);
  249. }
  250. return absl::OkStatus();
  251. }
  252. template <typename T>
  253. rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchSub(
  254. const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
  255. const Params* params) {
  256. std::vector<MontgomeryInt> out = in1;
  257. RLWE_RETURN_IF_ERROR(BatchSubInPlace(&out, in2, params));
  258. return out;
  259. }
  260. template <typename T>
  261. absl::Status MontgomeryInt<T>::BatchSubInPlace(std::vector<MontgomeryInt>* in1,
  262. const MontgomeryInt& in2,
  263. const Params* params) {
  264. int i = 0;
  265. std::for_each(in1->begin() + i, in1->end(),
  266. [&in2 = in2, params](MontgomeryInt& coeff) {
  267. coeff.SubInPlace(in2, params);
  268. });
  269. return absl::OkStatus();
  270. }
  271. template <typename T>
  272. rlwe::StatusOr<std::vector<MontgomeryInt<T>>>
  273. MontgomeryInt<T>::BatchMulConstant(const std::vector<MontgomeryInt>& in1,
  274. const std::vector<Int>& constant,
  275. const std::vector<Int>& constant_barrett,
  276. const Params* params) {
  277. std::vector<MontgomeryInt> out = in1;
  278. RLWE_RETURN_IF_ERROR(
  279. BatchMulConstantInPlace(&out, constant, constant_barrett, params));
  280. return out;
  281. }
  282. template <typename T>
  283. absl::Status MontgomeryInt<T>::BatchMulConstantInPlace(
  284. std::vector<MontgomeryInt>* in1, const std::vector<Int>& constant,
  285. const std::vector<Int>& constant_barrett, const Params* params) {
  286. // If the input vectors' sizes don't match, return an error.
  287. if (in1->size() != constant.size() ||
  288. constant.size() != constant_barrett.size()) {
  289. return absl::InvalidArgumentError("Input vectors are not of same size");
  290. }
  291. int i = 0;
  292. for (; i < in1->size(); i++) {
  293. (*in1)[i].MulConstantInPlace(constant[i], constant_barrett[i], params);
  294. }
  295. return absl::OkStatus();
  296. }
  297. template <typename T>
  298. rlwe::StatusOr<std::vector<MontgomeryInt<T>>>
  299. MontgomeryInt<T>::BatchMulConstant(const std::vector<MontgomeryInt>& in1,
  300. const Int& constant,
  301. const Int& constant_barrett,
  302. const Params* params) {
  303. std::vector<MontgomeryInt> out = in1;
  304. RLWE_RETURN_IF_ERROR(
  305. BatchMulConstantInPlace(&out, constant, constant_barrett, params));
  306. return out;
  307. }
  308. template <typename T>
  309. absl::Status MontgomeryInt<T>::BatchMulConstantInPlace(
  310. std::vector<MontgomeryInt>* in1, const Int& constant,
  311. const Int& constant_barrett, const Params* params) {
  312. int i = 0;
  313. for (; i < in1->size(); i++) {
  314. (*in1)[i].MulConstantInPlace(constant, constant_barrett, params);
  315. }
  316. return absl::OkStatus();
  317. }
  318. template <typename T>
  319. rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchMul(
  320. const std::vector<MontgomeryInt>& in1,
  321. const std::vector<MontgomeryInt>& in2, const Params* params) {
  322. std::vector<MontgomeryInt> out = in1;
  323. RLWE_RETURN_IF_ERROR(BatchMulInPlace(&out, in2, params));
  324. return out;
  325. }
  326. template <typename T>
  327. absl::Status MontgomeryInt<T>::BatchMulInPlace(
  328. std::vector<MontgomeryInt>* in1, const std::vector<MontgomeryInt>& in2,
  329. const Params* params) {
  330. // If the input vectors' sizes don't match, return an error.
  331. if (in1->size() != in2.size()) {
  332. return absl::InvalidArgumentError("Input vectors are not of same size");
  333. }
  334. int i = 0;
  335. for (; i < in1->size(); i++) {
  336. (*in1)[i].MulInPlace(in2[i], params);
  337. }
  338. return absl::OkStatus();
  339. }
  340. template <typename T>
  341. rlwe::StatusOr<std::vector<MontgomeryInt<T>>> MontgomeryInt<T>::BatchMul(
  342. const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
  343. const Params* params) {
  344. std::vector<MontgomeryInt> out = in1;
  345. RLWE_RETURN_IF_ERROR(BatchMulInPlace(&out, in2, params));
  346. return out;
  347. }
  348. template <typename T>
  349. absl::Status MontgomeryInt<T>::BatchMulInPlace(std::vector<MontgomeryInt>* in1,
  350. const MontgomeryInt& in2,
  351. const Params* params) {
  352. int i = 0;
  353. std::for_each(in1->begin() + i, in1->end(),
  354. [&in2 = in2, params](MontgomeryInt& coeff) {
  355. coeff.MulInPlace(in2, params);
  356. });
  357. return absl::OkStatus();
  358. }
  359. template <typename T>
  360. MontgomeryInt<T> MontgomeryInt<T>::ModExp(Int exponent,
  361. const Params* params) const {
  362. MontgomeryInt result = MontgomeryInt::ImportOne(params);
  363. MontgomeryInt base = *this;
  364. // Uses the bits of the exponent to gradually compute the result.
  365. // When bit k of the exponent is 1, the result is multiplied by
  366. // base^{2^k}.
  367. while (exponent > 0) {
  368. // If the current bit (bit k) is 1, multiply base^{2^k} into the result.
  369. if (exponent % 2 == 1) {
  370. result.MulInPlace(base, params);
  371. }
  372. // Update base from base^{2^k} to base^{2^{k+1}}.
  373. base.MulInPlace(base, params);
  374. exponent >>= 1;
  375. }
  376. return result;
  377. }
  378. template <typename T>
  379. MontgomeryInt<T> MontgomeryInt<T>::MultiplicativeInverse(
  380. const Params* params) const {
  381. return (*this).ModExp(static_cast<Int>(params->modulus - 2), params);
  382. }
  383. // Instantiations of MontgomeryInt and MontgomeryIntParams with specific
  384. // integral types.
  385. template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<Uint16>;
  386. template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<Uint32>;
  387. template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<Uint64>;
  388. template struct EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryIntParams<absl::uint128>;
  389. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<Uint16>;
  390. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<Uint32>;
  391. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<Uint64>;
  392. template class EXPORT_TEMPLATE_DEFINE(SHELL_ENCRYPTION_EXPORT) MontgomeryInt<absl::uint128>;
  393. } // namespace rlwe