distributed_point_function.h 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893
  1. /*
  2. * Copyright 2021 Google LLC
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
  17. #define DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
  18. #include <glog/logging.h>
  19. #include <openssl/cipher.h>
  20. #include <memory>
  21. #include <type_traits>
  22. #include "absl/container/btree_map.h"
  23. #include "absl/container/flat_hash_map.h"
  24. #include "absl/meta/type_traits.h"
  25. #include "absl/status/statusor.h"
  26. #include "absl/strings/str_format.h"
  27. #include "dpf/aes_128_fixed_key_hash.h"
  28. #include "dpf/distributed_point_function.pb.h"
  29. #include "dpf/internal/proto_validator.h"
  30. #include "dpf/internal/value_type_helpers.h"
  31. namespace distributed_point_functions {
  32. // Type trait for all supported types. Used to provide meaningful error messages
  33. // in std::enable_if template guards.
  34. template <typename T>
  35. using is_supported_type = dpf_internal::is_supported_type<T>;
  36. template <typename T>
  37. constexpr bool is_supported_type_v = is_supported_type<T>::value;
  38. // Converts a given Value to the template parameter T.
  39. //
  40. // Returns INVALID_ARGUMENT if the conversion fails.
  41. template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
  42. absl::StatusOr<T> FromValue(const Value& value) {
  43. return dpf_internal::ValueTypeHelper<T>::FromValue(value);
  44. }
  45. // ToValue Converts the argument to a Value.
  46. template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
  47. Value ToValue(const T& input) {
  48. return dpf_internal::ValueTypeHelper<T>::ToValue(input);
  49. }
  50. // ToValueType<T> Returns a `ValueType` message describing T.
  51. template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
  52. ValueType ToValueType() {
  53. return dpf_internal::ValueTypeHelper<T>::ToValueType();
  54. }
  55. // Implements key generation and evaluation of distributed point functions.
  56. // A distributed point function (DPF) is parameterized by an index `alpha` and a
  57. // value `beta`. The key generation procedure produces two keys `k_a`, `k_b`.
  58. // Evaluating each key on any point `x` in the DPF domain results in an additive
  59. // secret share of `beta`, if `x == alpha`, and a share of 0 otherwise. This
  60. // class also supports *incremental* DPFs that can additionally be evaluated on
  61. // prefixes of points, resulting in different values `beta_i`for each prefix of
  62. // `alpha`.
  63. class DistributedPointFunction {
  64. public:
  65. // Creates a new instance of a distributed point function that can be
  66. // evaluated only at the output layer.
  67. //
  68. // Returns INVALID_ARGUMENT if the parameters are invalid.
  69. static absl::StatusOr<std::unique_ptr<DistributedPointFunction>> Create(
  70. const DpfParameters& parameters);
  71. // Creates a new instance of an *incremental* DPF that can be evaluated at
  72. // multiple layers. Each parameter set in `parameters` should specify the
  73. // domain size and element size at one of the layers to be evaluated, in
  74. // increasing domain size order. Element sizes must be non-decreasing.
  75. //
  76. // Returns INVALID_ARGUMENT if the parameters are invalid.
  77. static absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
  78. CreateIncremental(absl::Span<const DpfParameters> parameters);
  79. // DistributedPointFunction is neither copyable nor movable.
  80. DistributedPointFunction(const DistributedPointFunction&) = delete;
  81. DistributedPointFunction& operator=(const DistributedPointFunction&) = delete;
  82. // Converts the argument to a `Value` proto. Also registers the corresponding
  83. // value type with the DPF by calling `RegisterValueType<T>()`.
  84. template <typename T>
  85. absl::StatusOr<Value> ToValue(const T& in) {
  86. absl::Status status = RegisterValueType<T>();
  87. if (!status.ok()) {
  88. return status;
  89. }
  90. return distributed_point_functions::ToValue(in);
  91. }
  92. // Registers the template parameter type with this DPF. Note that it is rarely
  93. // necessary to call this function by hand: It is called by `Create` and
  94. // `CreateIncremental` for all unsigned integer types, including
  95. // absl::uint128, and on every call to ToValue<T>. Only call this function
  96. // when passing `Value`s created by other means than ToValue<T>.
  97. //
  98. // Returns OK on success and otherwise an INTERNAL status describing the
  99. // failure.
  100. template <typename T>
  101. absl::Status RegisterValueType() {
  102. return RegisterValueTypeImpl<T>(value_correction_functions_);
  103. }
  104. // Generates a pair of keys for a DPF that evaluates to `beta` when evaluated
  105. // `alpha`. The type of `beta` must match the ValueType passed in `parameters`
  106. // at construction.
  107. //
  108. // This function provides three overloads: One with `absl::uint128` for
  109. // `beta`, which implies the output type is a simple integer; One with a
  110. // `Value` proto for `beta`, which can be used for all supported value types;
  111. // And a templated version that computes the Value by calling ToValue<T> on
  112. // the argument.
  113. //
  114. // Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
  115. //
  116. // // Simple integer:
  117. // dpf->GenerateKeys(23, 42);
  118. //
  119. // // Explicit `Value` proto:
  120. // Value value;
  121. // value[1]->mutable_tuple->add_elements()
  122. // ->mutable_integer->set_value_uint64(12);
  123. // value[1]->mutable_tuple->add_elements()
  124. // ->mutable_integer->set_value_uint64(34);
  125. // // Must be called once before calling GenerateKeys for any type that is
  126. // // not a simple integer. The type should match the one in the
  127. // // DpfParameters passed at construction.
  128. // dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
  129. // dpf->GenerateKeys(23, value);
  130. //
  131. // // Templated version (no call to RegisterValueType needed):
  132. // dpf->GenerateKeys(23, Tuple<uint32_t, uint64_t>{12, 34});
  133. //
  134. // Returns INVALID_ARGUMENT if used on an incremental DPF with more
  135. // than one set of parameters, if `alpha` is outside of the domain specified
  136. // at construction, or if `beta` does not match the value type passed at
  137. // construction.
  138. // Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
  139. // for the type in the `DpfParameters` passed at construction.
  140. // Overload for simple integers.
  141. absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
  142. absl::uint128 beta) {
  143. return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
  144. }
  145. // Overload for explicit Value proto.
  146. absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
  147. Value beta) {
  148. return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
  149. }
  150. // Template for automatic conversion to Value proto. Disabled if the argument
  151. // is convertible to `absl::uint128` or `Value` to make overloading
  152. // unambiguous.
  153. template <typename T, typename = absl::enable_if_t<
  154. !std::is_convertible<T, absl::uint128>::value &&
  155. !std::is_convertible<T, Value>::value &&
  156. is_supported_type_v<T>>>
  157. absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
  158. const T& beta) {
  159. absl::StatusOr<Value> value = ToValue<T>(beta);
  160. if (!value.ok()) {
  161. return value.status();
  162. }
  163. return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&(*value), 1));
  164. }
  165. // Generates a pair of keys for an incremental DPF. For each parameter i
  166. // passed at construction, the DPF evaluates to `beta[i]` at the lowest
  167. // `parameters_[i].log_domain_size()` bits of `alpha`.
  168. //
  169. // Similar to `GenerateKeys`, supports three overloads: One for simple
  170. // integers, passed as an `absl::Span<const absl::uint128>`; One for a span of
  171. // `Value` protos; And a variadic function template that automatically
  172. // converts the passed arguments to a vector of `Value`s.
  173. //
  174. // Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
  175. //
  176. // // Simple integers:
  177. // std::vector<absl::uint128> beta{123, 456};
  178. // dpf->GenerateKeysIncremental(23, beta);
  179. //
  180. // // Explicit Value protos:
  181. // std::vector<Value> beta(2);
  182. // value[0]->mutable_integer()->set_value_uint128(42);
  183. // value[1]->mutable_tuple->add_elements()
  184. // ->mutable_integer->set_value_uint64(12);
  185. // value[1]->mutable_tuple->add_elements()
  186. // ->mutable_integer->set_value_uint64(34);
  187. // // Must be called once before calling GenerateKeys for any type that is
  188. // // not a simple integer. The type should match the one in the
  189. // // DpfParameters passed at construction.
  190. // dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
  191. // dpf->GenerateKeysIncremental(23, beta);
  192. //
  193. // // Templated version (equivalent to the one above):
  194. // dpf->GenerateKeysIncremental(23, 42, Tuple<uint32_t, uint64_t>{12, 34}));
  195. //
  196. // Returns INVALID_ARGUMENT if `beta.size() != parameters_.size()`, if `alpha`
  197. // is outside of the domain specified at construction, or if `beta` does not
  198. // match the element type passed at construction.
  199. // Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
  200. // for all types in the `DpfParameters` passed at construction.
  201. // Overload for simple integers.
  202. absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
  203. absl::uint128 alpha, absl::Span<const absl::uint128> beta) {
  204. std::vector<Value> values(beta.size());
  205. for (int i = 0; i < static_cast<int>(beta.size()); ++i) {
  206. absl::StatusOr<Value> value = ToValue(beta[i]);
  207. if (!value.ok()) {
  208. return value.status();
  209. }
  210. values[i] = std::move(*value);
  211. }
  212. return GenerateKeysIncremental(alpha, values);
  213. }
  214. // Overload for Value protos.
  215. absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
  216. absl::uint128 alpha, absl::Span<const Value> beta);
  217. // Variadic template version. Disabled if the first argument is convertible to
  218. // a span of `absl::uint128`s or `Value`s to make overloading unambiguous.
  219. template <
  220. typename T0, typename... Tn,
  221. typename = absl::enable_if_t<
  222. !std::is_convertible<T0, absl::Span<const Value>>::value &&
  223. !std::is_convertible<T0, absl::Span<const absl::uint128>>::value &&
  224. absl::conjunction<is_supported_type<T0>,
  225. is_supported_type<Tn>...>::value>>
  226. absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
  227. absl::uint128 alpha, T0&& beta_0, Tn&&... beta_n);
  228. // Returns an `EvaluationContext` for incrementally evaluating the given
  229. // DpfKey.
  230. //
  231. // Returns INVALID_ARGUMENT if `key` doesn't match the parameters given at
  232. // construction.
  233. absl::StatusOr<EvaluationContext> CreateEvaluationContext(DpfKey key) const;
  234. // Evaluates the given `hierarchy_level` of the DPF under all `prefixes`
  235. // passed to this function. If `prefixes` is empty, evaluation starts from the
  236. // seed of `ctx.key`. Otherwise, each element of `prefixes` must fit in the
  237. // domain size of `ctx.previous_hierarchy_level`. Further, `prefixes` may only
  238. // contain extensions of the prefixes passed in the previous call. For
  239. // example, in the following sequence of calls, for each element p2 of
  240. // `prefixes2`, there must be an element p1 of `prefixes1` such that p1 is a
  241. // prefix of p2:
  242. //
  243. // DPF_ASSIGN_OR_RETURN(std::unique_ptr<EvaluationContext> ctx,
  244. // dpf->CreateEvaluationContext(key));
  245. // using T0 = ...;
  246. // DPF_ASSIGN_OR_RETURN(std::vector<T0> evaluations0,
  247. // dpf->EvaluateUntil(0, {}, *ctx));
  248. //
  249. // std::vector<absl::uint128> prefixes1 = ...;
  250. // using T1 = ...;
  251. // DPF_ASSIGN_OR_RETURN(std::vector<T1> evaluations1,
  252. // dpf->EvaluateUntil(1, prefixes1, *ctx));
  253. // ...
  254. // std::vector<absl::uint128> prefixes2 = ...;
  255. // using T2 = ...;
  256. // DPF_ASSIGN_OR_RETURN(std::vector<T2> evaluations2,
  257. // dpf->EvaluateUntil(3, prefixes2, *ctx));
  258. //
  259. // The prefixes are read from the lowest-order bits of the corresponding
  260. // absl::uint128. The number of bits used for each prefix depends on the
  261. // output domain size of the previously evaluated hierarchy level. For
  262. // example, if `ctx` was last evaluated on a hierarchy level with output
  263. // domain size 2**20, then the 20 lowest-order bits of each element in
  264. // `prefixes` are used.
  265. //
  266. // Returns `INVALID_ARGUMENT` if
  267. // - any element of `prefixes` is larger than the next hierarchy level's
  268. // log_domain_size,
  269. // - `prefixes` contains elements that are not extensions of previous
  270. // prefixes, or
  271. // - the bit-size of T doesn't match the next hierarchy level's
  272. // element_bitsize.
  273. template <typename T>
  274. absl::StatusOr<std::vector<T>> EvaluateUntil(
  275. int hierarchy_level, absl::Span<const absl::uint128> prefixes,
  276. EvaluationContext& ctx) const;
  277. template <typename T>
  278. absl::StatusOr<std::vector<T>> EvaluateNext(
  279. absl::Span<const absl::uint128> prefixes, EvaluationContext& ctx) const {
  280. if (prefixes.empty()) {
  281. return EvaluateUntil<T>(0, prefixes, ctx);
  282. } else {
  283. return EvaluateUntil<T>(ctx.previous_hierarchy_level() + 1, prefixes,
  284. ctx);
  285. }
  286. }
  287. // Evaluates a single key at one or multiple points, up to the given
  288. // hierarchy_level. Each element of `evaluation_points` must be within the
  289. // domain of this DPF at `hierarchy_level`.
  290. //
  291. // Example:
  292. //
  293. // DpfKey key = ...;
  294. // std::vector<absl::uint128> evaluation_points = {1, 23, 42};
  295. // // Evaluate `key` on {1, 23, 42}.
  296. // DPF_ASSIGN_OR_RETURN(std::vector<T> result,
  297. // dpf->EvaluateAt(key, 0, evaluation_points);
  298. //
  299. // Returns INVALID_ARGUMENT if `key` is malformed, or if `hierarchy_level` or
  300. // any element of `evaluation_points` is out of range.
  301. template <typename T>
  302. absl::StatusOr<std::vector<T>> EvaluateAt(
  303. const DpfKey& key, int hierarchy_level,
  304. absl::Span<const absl::uint128> evaluation_points) const;
  305. // Returns the DpfParameters of this DPF.
  306. inline absl::Span<const DpfParameters> parameters() const {
  307. return parameters_;
  308. }
  309. private:
  310. // BitVector is a vector of bools. Allows for faster access times than
  311. // std::vector<bool>, as well as inlining if the size is small.
  312. using BitVector =
  313. absl::InlinedVector<bool,
  314. std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>;
  315. // Seeds and control bits resulting from a DPF expansion. This type is
  316. // returned by `ExpandSeeds` and `ExpandAndUpdateContext`.
  317. struct DpfExpansion {
  318. std::vector<absl::uint128> seeds;
  319. BitVector control_bits;
  320. };
  321. // A function for computing value corrections. Used as return type in
  322. // `GetValueCorrectionFunction`.
  323. using ValueCorrectionFunction = absl::StatusOr<std::vector<Value>> (*)(
  324. absl::string_view, absl::string_view, int block_index, const Value&,
  325. bool);
  326. // Private constructor, called by `CreateIncremental`.
  327. DistributedPointFunction(
  328. std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
  329. std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left,
  330. Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value,
  331. absl::flat_hash_map<std::string, ValueCorrectionFunction>
  332. value_correction_functions);
  333. // Computes the value correction for the given `hierarchy_level`, `seeds`,
  334. // index `alpha` and value `beta`. If `invert` is true, the individual values
  335. // in the returned block are multiplied element-wise by -1. Expands `seeds`
  336. // using `prg_ctx_value_`, then calls the function returned by
  337. // `GetValueCorrectionFunction(parameters_[hierarchy_level])` to obtain the
  338. // value correction words.
  339. //
  340. // Returns multiple values in the case of packing, and a single Value
  341. // otherwise.
  342. //
  343. // Returns INTERNAL in case the PRG expansion fails, and UNIMPLEMENTED if
  344. // `element_bitsize` is not supported.
  345. absl::StatusOr<std::vector<Value>> ComputeValueCorrection(
  346. int hierarchy_level, absl::Span<const absl::uint128> seeds,
  347. absl::uint128 alpha, const Value& beta, bool invert) const;
  348. // Expands the PRG seeds at the next `tree_level` for an incremental DPF with
  349. // index `alpha` and values `beta`, updates `seeds` and `control_bits`, and
  350. // writes the next correction word to `keys`. Called from
  351. // `GenerateKeysIncremental`.
  352. absl::Status GenerateNext(int tree_level, absl::uint128 alpha,
  353. absl::Span<const Value> beta,
  354. absl::Span<absl::uint128> seeds,
  355. absl::Span<bool> control_bits,
  356. absl::Span<DpfKey> keys) const;
  357. // Computes the tree index (representing a path in the FSS tree) from the
  358. // given `domain_index` and `hierarchy_level`. Does NOT check whether the
  359. // given domain index fits in the domain at `hierarchy_level`.
  360. absl::uint128 DomainToTreeIndex(absl::uint128 domain_index,
  361. int hierarchy_level) const;
  362. // Computes the block index (pointing to an element in a batched 128-bit
  363. // block) from the given `domain_index` and `hierarchy_level`. Does NOT check
  364. // whether the given domain index fits in the domain at `hierarchy_level`.
  365. int DomainToBlockIndex(absl::uint128 domain_index, int hierarchy_level) const;
  366. // Performs DPF evaluation of the given `partial_evaluations` using
  367. // prg_ctx_left_ or prg_ctx_right_, and the given `correction_words`. At each
  368. // level `l < correction_words.size()`, the evaluation for the i-th seed in
  369. // `partial_evaluations` continues along the left or right path depending on
  370. // the l-th most significant bit among the lowest `correction_words.size()`
  371. // bits of `paths[i]`.
  372. //
  373. // Returns INVALID_ARGUMENT if the input sizes don't match.
  374. // Returns INTERNAL in case of OpenSSL errors.
  375. absl::StatusOr<DpfExpansion> EvaluateSeeds(
  376. DpfExpansion partial_evaluations, absl::Span<const absl::uint128> paths,
  377. absl::Span<const CorrectionWord* const> correction_words) const;
  378. // Performs DPF expansion of the given `partial_evaluations` using
  379. // prg_ctx_left_ and prg_ctx_right_, and the given `correction_words`. In more
  380. // detail, each of the partial evaluations is subjected to a full subtree
  381. // expansion of `correction_words.size()` levels, and the concatenated result
  382. // is provided in the response. The result contains
  383. // `(partial_evaluations.size() * (2^correction_words.size())` evaluations in
  384. // a single `DpfExpansion`.
  385. //
  386. // Returns INTERNAL in case of OpenSSL errors.
  387. absl::StatusOr<DpfExpansion> ExpandSeeds(
  388. const DpfExpansion& partial_evaluations,
  389. absl::Span<const CorrectionWord* const> correction_words) const;
  390. // Computes partial evaluations of the paths to `prefixes` to be used as the
  391. // starting point of the expansion of `ctx`. If `update_ctx == true`, saves
  392. // the partial evaluations of `ctx.previous_hierarchy_level` to `ctx` and sets
  393. // `ctx.partial_evaluations_level` to `ctx.previous_hierarchy_level`.
  394. // Called by `ExpandAndUpdateContext`.
  395. //
  396. // Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
  397. // `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
  398. // duplicate seeds.
  399. absl::StatusOr<DpfExpansion> ComputePartialEvaluations(
  400. absl::Span<const absl::uint128> prefixes, bool update_ctx,
  401. EvaluationContext& ctx) const;
  402. // Extracts the seeds for the given `prefixes` from `ctx` and expands them as
  403. // far as needed for the next hierarchy level. Returns the result as a
  404. // `DpfExpansion`. Called by `EvaluateUntil`, where the expanded seeds are
  405. // corrected to obtain output values.
  406. // After expansion, `ctx.hierarchy_level()` is increased. If this isn't the
  407. // last expansion, the expanded seeds are also saved in `ctx` for the next
  408. // expansion.
  409. //
  410. // Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
  411. // `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
  412. // duplicate seeds. Returns INTERNAL in case of OpenSSL errors.
  413. absl::StatusOr<DpfExpansion> ExpandAndUpdateContext(
  414. int hierarchy_level, absl::Span<const absl::uint128> prefixes,
  415. EvaluationContext& ctx) const;
  416. // Compute output PRG value of expanded seeds using prg_ctx_value_.
  417. // Returns blocks_needed_[hierarchy_level] * expansion.seeds.size() blocks,
  418. // where every blocks_needed_[hierarchy_level] correspond to the hash of an
  419. // input seed.
  420. //
  421. // Returns INTERNAL in case of OpenSSL errors.
  422. absl::StatusOr<std::vector<absl::uint128>> HashExpandedSeeds(
  423. int hierarchy_level, absl::Span<const absl::uint128> expansion) const;
  424. // Deterministically serializes the given value_type.
  425. //
  426. // Returns OK on success and INTERNAL in case serialization fails.
  427. static absl::StatusOr<std::string> SerializeValueTypeDeterministically(
  428. const ValueType& value_type);
  429. // Returns the value correction function for the given parameters.
  430. // For all value types except unsigned integers, these functions have to be
  431. // first registered using RegisterValueType<T>.
  432. //
  433. // Returns UNIMPLEMENTED if no matching function was registered.
  434. absl::StatusOr<ValueCorrectionFunction> GetValueCorrectionFunction(
  435. const DpfParameters& parameters) const;
  436. // Static implementation of RegisterValueType<T>, so we can call it from
  437. // `Create`.
  438. template <typename T>
  439. static absl::Status RegisterValueTypeImpl(
  440. absl::flat_hash_map<std::string, ValueCorrectionFunction>&
  441. value_correction_functions);
  442. // Used to validate DpfParameters, DpfKey and EvaluationContext protos.
  443. const std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
  444. // DP parameters passed to the factory function. Contains the domain size and
  445. // element size for hierarchy level of the incremental DPF. Owned by
  446. // proto_validator_.
  447. const absl::Span<const DpfParameters> parameters_;
  448. // Number of levels in the evaluation tree. This is always less than or equal
  449. // to the largest log_domain_size in parameters_.
  450. const int tree_levels_needed_;
  451. // Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements
  452. // of parameters_).
  453. const absl::flat_hash_map<int, int>& tree_to_hierarchy_;
  454. // The inverse of tree_to_hierarchy_.
  455. const std::vector<int>& hierarchy_to_tree_;
  456. // Cached numbers of AES blocks needed for value correction at each hierarchy
  457. // level.
  458. const std::vector<int> blocks_needed_;
  459. // Pseudorandom generator used for seed expansion (left and right), and value
  460. // correction. The PRG G(x) for hierarchy level i is defined as the
  461. // concatenation of
  462. //
  463. // H_left(x), H_right(x), H_value(x + 0), ..., H_value(x + k-1)
  464. //
  465. // where k is equal to blocks_needed_[i], and H_*(x) is the evaluation of
  466. // prg_*_ on input x.
  467. const Aes128FixedKeyHash prg_left_;
  468. const Aes128FixedKeyHash prg_right_;
  469. const Aes128FixedKeyHash prg_value_;
  470. // Maps serialized `ValueType` messages to the correct value correction
  471. // functions. Map values are instantiations of
  472. // `dpf_internal::ComputeValueCorrectionFor`. Relies on protobuf's
  473. // deterministic serialization feature. This has the caveat that messages with
  474. // unknown fields are not supported. However, as long as `ValueType` consists
  475. // of a single `oneof` field, this is fine, since we either know the value
  476. // type and have deterministic serialization because the `ValueType` can only
  477. // contain one field, or we don't know the type and wouldn't be able to
  478. // correct values for it anyway.
  479. absl::flat_hash_map<std::string, ValueCorrectionFunction>
  480. value_correction_functions_;
  481. };
  482. //========================//
  483. // Implementation Details //
  484. //========================//
  485. template <typename T>
  486. absl::Status DistributedPointFunction::RegisterValueTypeImpl(
  487. absl::flat_hash_map<std::string, ValueCorrectionFunction>&
  488. value_correction_functions) {
  489. ValueType value_type = ToValueType<T>();
  490. absl::StatusOr<std::string> serialized_value_type =
  491. SerializeValueTypeDeterministically(value_type);
  492. if (!serialized_value_type.ok()) {
  493. return serialized_value_type.status();
  494. }
  495. value_correction_functions[*serialized_value_type] =
  496. dpf_internal::ComputeValueCorrectionFor<T>;
  497. return absl::OkStatus();
  498. }
  499. template <typename T0, typename... Tn, typename /*= absl::enable_if_t<...>*/>
  500. absl::StatusOr<std::pair<DpfKey, DpfKey>>
  501. DistributedPointFunction::GenerateKeysIncremental(absl::uint128 alpha,
  502. T0&& beta_0, Tn&&... beta_n) {
  503. // Convert the first element of beta. We need to treat it separately to be
  504. // able to check its type in the enable_if above.
  505. absl::StatusOr<Value> value = ToValue(beta_0);
  506. if (!value.ok()) {
  507. return value.status();
  508. }
  509. std::vector<Value> values = {std::move(*value)};
  510. values.reserve(1 + sizeof...(beta_n));
  511. // Convert all values in the parameter pack, stopping at the first error.
  512. absl::Status status = absl::OkStatus();
  513. // We create an unused std::tuple<Tn...> here, because its braced-initializer
  514. // list constructor allows us to operate on beta_n in a well-defined order. In
  515. // C++17, this could be replaced by a fold expression instead.
  516. std::tuple<Tn...>{[this, &status, &values, &value](auto&& beta_i) -> Tn {
  517. if (status.ok()) {
  518. value = this->ToValue(beta_i);
  519. if (value.ok()) {
  520. values.push_back(std::move(*value));
  521. } else {
  522. status = value.status();
  523. }
  524. }
  525. return Tn{};
  526. }(beta_n)...};
  527. // Return if there was an error during conversion, otherwise generate keys.
  528. if (!status.ok()) {
  529. return status;
  530. }
  531. return GenerateKeysIncremental(alpha, values);
  532. }
  533. template <typename T>
  534. absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
  535. int hierarchy_level, absl::Span<const absl::uint128> prefixes,
  536. EvaluationContext& ctx) const {
  537. absl::Status status = proto_validator_->ValidateEvaluationContext(ctx);
  538. if (!status.ok()) {
  539. return status;
  540. }
  541. if (hierarchy_level < 0 ||
  542. hierarchy_level >= static_cast<int>(parameters_.size())) {
  543. return absl::InvalidArgumentError(
  544. "`hierarchy_level` must be non-negative and less than "
  545. "parameters_.size()");
  546. }
  547. absl::StatusOr<bool> types_are_equal = dpf_internal::ValueTypesAreEqual(
  548. ToValueType<T>(), parameters_[hierarchy_level].value_type());
  549. if (!types_are_equal.ok()) {
  550. return types_are_equal.status();
  551. } else if (!*types_are_equal) {
  552. return absl::InvalidArgumentError(
  553. "Value type T doesn't match parameters at `hierarchy_level`");
  554. }
  555. if (hierarchy_level <= ctx.previous_hierarchy_level()) {
  556. return absl::InvalidArgumentError(
  557. "`hierarchy_level` must be greater than "
  558. "`ctx.previous_hierarchy_level`");
  559. }
  560. if ((ctx.previous_hierarchy_level() < 0) != (prefixes.empty())) {
  561. return absl::InvalidArgumentError(
  562. "`prefixes` must be empty if and only if this is the first call with "
  563. "`ctx`.");
  564. }
  565. int previous_log_domain_size = 0;
  566. int previous_hierarchy_level = ctx.previous_hierarchy_level();
  567. if (!prefixes.empty()) {
  568. DCHECK(ctx.previous_hierarchy_level() >= 0);
  569. previous_log_domain_size =
  570. parameters_[previous_hierarchy_level].log_domain_size();
  571. for (absl::uint128 prefix : prefixes) {
  572. if (previous_log_domain_size < 128 &&
  573. prefix >= (absl::uint128{1} << previous_log_domain_size)) {
  574. return absl::InvalidArgumentError(
  575. absl::StrFormat("Index %d out of range for hierarchy level %d",
  576. prefix, previous_hierarchy_level));
  577. }
  578. }
  579. }
  580. int64_t prefixes_size = static_cast<int64_t>(prefixes.size());
  581. int log_domain_size = parameters_[hierarchy_level].log_domain_size();
  582. if (log_domain_size - previous_log_domain_size > 62) {
  583. return absl::InvalidArgumentError(
  584. "Output size would be larger than 2**62. Please evaluate fewer "
  585. "hierarchy levels at once.");
  586. }
  587. // The `prefixes` passed in by the caller refer to the domain of the previous
  588. // hierarchy level. However, because we batch multiple elements of type T in a
  589. // single uint128 block, multiple prefixes can actually refer to the same
  590. // block in the FSS evaluation tree. On a high level, our approach is as
  591. // follows:
  592. //
  593. // 1. Split up each element of `prefixes` into a tree index, pointing to a
  594. // block in the FSS tree, and a block index, pointing to an element of type
  595. // T in that block.
  596. //
  597. // 2. Compute a list of unique `tree_indices`, and for each original prefix,
  598. // remember the position of the corresponding tree index in `tree_indices`.
  599. //
  600. // 3. After expanding the unique `tree_indices`, use the positions saved in
  601. // Step (2) together with the corresponding block index to retrieve the
  602. // expanded values for each prefix, and return them in the same order as
  603. // `prefixes`.
  604. //
  605. // `tree_indices` holds the unique tree indices from `prefixes`, to be passed
  606. // to `ExpandAndUpdateContext`.
  607. std::vector<absl::uint128> tree_indices;
  608. tree_indices.reserve(prefixes_size);
  609. // `tree_indices_inverse` is the inverse of `tree_indices`, used for
  610. // deduplicating and constructing `prefix_map`. Use a btree_map because we
  611. // expect `prefixes` (and thus `tree_indices`) to be sorted.
  612. absl::btree_map<absl::uint128, int64_t> tree_indices_inverse;
  613. // `prefix_map` maps each i < prefixes.size() to an element of `tree_indices`
  614. // and a block index. Used to select which elements to return after the
  615. // expansion, to ensure the result is ordered the same way as `prefixes`.
  616. std::vector<std::pair<int64_t, int>> prefix_map;
  617. prefix_map.reserve(prefixes_size);
  618. for (int64_t i = 0; i < prefixes_size; ++i) {
  619. absl::uint128 tree_index =
  620. DomainToTreeIndex(prefixes[i], previous_hierarchy_level);
  621. int block_index = DomainToBlockIndex(prefixes[i], previous_hierarchy_level);
  622. // Check if `tree_index` already exists in `tree_indices`.
  623. auto previous_size = tree_indices_inverse.size();
  624. auto it = tree_indices_inverse.try_emplace(tree_indices_inverse.end(),
  625. tree_index, tree_indices.size());
  626. if (tree_indices_inverse.size() > previous_size) {
  627. tree_indices.push_back(tree_index);
  628. }
  629. prefix_map.push_back(std::make_pair(it->second, block_index));
  630. }
  631. // Perform expansion of unique `tree_indices`.
  632. absl::StatusOr<DpfExpansion> expansion =
  633. ExpandAndUpdateContext(hierarchy_level, tree_indices, ctx);
  634. if (!expansion.ok()) {
  635. return expansion.status();
  636. }
  637. // Hash the expanded seeds.
  638. absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
  639. HashExpandedSeeds(hierarchy_level, expansion->seeds);
  640. if (!hashed_expansion.ok()) {
  641. return hashed_expansion.status();
  642. }
  643. // Get output correction word from `ctx`.
  644. constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
  645. const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
  646. if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
  647. value_correction =
  648. &(ctx.key()
  649. .correction_words(hierarchy_to_tree_[hierarchy_level])
  650. .value_correction());
  651. } else {
  652. // Last level value correction is stored in an extra proto field, since we
  653. // have one less correction word than tree levels.
  654. value_correction = &(ctx.key().last_level_value_correction());
  655. }
  656. // Split output correction into elements of type T.
  657. absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
  658. dpf_internal::ValuesToArray<T>(*value_correction);
  659. if (!correction_ints.ok()) {
  660. return correction_ints.status();
  661. }
  662. // Compute value corrections for each block in `expanded_seeds`. We have to
  663. // account for the fact that blocks might not be full (i.e., have less than
  664. // elements_per_block elements).
  665. const int corrected_elements_per_block =
  666. 1 << (parameters_[hierarchy_level].log_domain_size() -
  667. hierarchy_to_tree_[hierarchy_level]);
  668. const auto expansion_size = static_cast<int64_t>(expansion->seeds.size());
  669. const int blocks_needed = blocks_needed_[hierarchy_level];
  670. DCHECK(corrected_elements_per_block <= elements_per_block);
  671. std::vector<T> corrected_expansion(expansion_size *
  672. corrected_elements_per_block);
  673. for (int64_t i = 0; i < expansion_size; ++i) {
  674. std::array<T, elements_per_block> current_elements =
  675. dpf_internal::ConvertBytesToArrayOf<T>(
  676. absl::string_view(reinterpret_cast<const char*>(
  677. &(*hashed_expansion)[i * blocks_needed]),
  678. blocks_needed * sizeof(absl::uint128)));
  679. for (int j = 0; j < corrected_elements_per_block; ++j) {
  680. if (expansion->control_bits[i]) {
  681. current_elements[j] += (*correction_ints)[j];
  682. }
  683. if (ctx.key().party() == 1) {
  684. current_elements[j] = -current_elements[j];
  685. }
  686. corrected_expansion[i * corrected_elements_per_block + j] =
  687. current_elements[j];
  688. }
  689. }
  690. // Compute the number of outputs we will have. For each prefix, we will have a
  691. // full expansion from the previous heirarchy level to the current heirarchy
  692. // level.
  693. DCHECK(log_domain_size - previous_log_domain_size < 63);
  694. int64_t outputs_per_prefix = int64_t{1}
  695. << (log_domain_size - previous_log_domain_size);
  696. if (prefixes.empty()) {
  697. // If prefixes is empty (i.e., this is the first evaluation of `ctx`), just
  698. // return the expansion.
  699. DCHECK(static_cast<int>(corrected_expansion.size()) == outputs_per_prefix);
  700. return corrected_expansion;
  701. } else {
  702. // Otherwise, only return elements under `prefixes`.
  703. int blocks_per_tree_prefix = expansion->seeds.size() / tree_indices.size();
  704. std::vector<T> result(prefixes_size * outputs_per_prefix);
  705. for (int64_t i = 0; i < prefixes_size; ++i) {
  706. int64_t prefix_expansion_start =
  707. prefix_map[i].first * blocks_per_tree_prefix *
  708. corrected_elements_per_block +
  709. prefix_map[i].second * outputs_per_prefix;
  710. std::copy_n(&corrected_expansion[prefix_expansion_start],
  711. outputs_per_prefix, &result[i * outputs_per_prefix]);
  712. }
  713. return result;
  714. }
  715. }
  716. template <typename T>
  717. absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAt(
  718. const DpfKey& key, int hierarchy_level,
  719. absl::Span<const absl::uint128> evaluation_points) const {
  720. auto num_evaluation_points = static_cast<int64_t>(evaluation_points.size());
  721. if (hierarchy_level < 0) {
  722. return absl::InvalidArgumentError("`hierarchy_level` must be non-negative");
  723. }
  724. if (hierarchy_level >= static_cast<int>(parameters_.size())) {
  725. return absl::InvalidArgumentError(
  726. "`hierarchy_level` must be less than the number of parameters passed "
  727. "at construction");
  728. }
  729. absl::Status status = proto_validator_->ValidateDpfKey(key);
  730. if (!status.ok()) {
  731. return status;
  732. }
  733. // Get output correction word from `key`.
  734. constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
  735. const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
  736. if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
  737. value_correction =
  738. &(key.correction_words(hierarchy_to_tree_[hierarchy_level])
  739. .value_correction());
  740. } else {
  741. // Last level value correction is stored in an extra proto field, since we
  742. // have one less correction word than tree levels.
  743. value_correction = &(key.last_level_value_correction());
  744. }
  745. // Split output correction into elements of type T, and save it in
  746. // correction_ints.
  747. absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
  748. dpf_internal::ValuesToArray<T>(*value_correction);
  749. if (!correction_ints.ok()) {
  750. return correction_ints.status();
  751. }
  752. // Split up evaluation_points into tree indices and block indices, if we're
  753. // operating on a packed type. Otherwise set `tree_indices` to
  754. // `evaluation_points`.
  755. std::vector<absl::uint128> maybe_recomputed_tree_indices(0);
  756. absl::Span<const absl::uint128> tree_indices;
  757. if (elements_per_block > 1) {
  758. maybe_recomputed_tree_indices.reserve(num_evaluation_points);
  759. for (int64_t i = 0; i < num_evaluation_points; ++i) {
  760. maybe_recomputed_tree_indices.push_back(
  761. DomainToTreeIndex(evaluation_points[i], hierarchy_level));
  762. }
  763. tree_indices = absl::MakeConstSpan(maybe_recomputed_tree_indices);
  764. } else {
  765. // This avoids copying the evaluation points when elements_per_block == 1.
  766. tree_indices = evaluation_points;
  767. }
  768. // Extract seed and party for DPF evaluation.
  769. absl::uint128 seed = absl::MakeUint128(key.seed().high(), key.seed().low());
  770. bool party = key.party();
  771. DpfExpansion inputs;
  772. inputs.seeds.resize(num_evaluation_points, seed);
  773. inputs.control_bits.resize(num_evaluation_points, party);
  774. // Evaluate DPFs.
  775. const int stop_level = hierarchy_to_tree_[hierarchy_level];
  776. auto correction_words =
  777. absl::MakeConstSpan(key.correction_words()).subspan(0, stop_level);
  778. absl::StatusOr<DpfExpansion> evaluated_inputs =
  779. EvaluateSeeds(std::move(inputs), tree_indices, correction_words);
  780. if (!evaluated_inputs.ok()) {
  781. return evaluated_inputs.status();
  782. }
  783. DCHECK(static_cast<int64_t>(evaluated_inputs->seeds.size()) ==
  784. num_evaluation_points);
  785. // Hash DPF evaluations.
  786. absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
  787. HashExpandedSeeds(hierarchy_level, evaluated_inputs->seeds);
  788. if (!hashed_expansion.ok()) {
  789. return hashed_expansion.status();
  790. }
  791. // Perform value correction.
  792. std::vector<T> result;
  793. result.reserve(num_evaluation_points);
  794. const int blocks_needed = blocks_needed_[hierarchy_level];
  795. for (int64_t i = 0; i < num_evaluation_points; ++i) {
  796. std::array<T, elements_per_block> current_elements =
  797. dpf_internal::ConvertBytesToArrayOf<T>(
  798. absl::string_view(reinterpret_cast<const char*>(
  799. &(*hashed_expansion)[i * blocks_needed]),
  800. blocks_needed * sizeof(absl::uint128)));
  801. int block_index = 0;
  802. if (elements_per_block > 1) {
  803. block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level);
  804. }
  805. result.push_back(current_elements[block_index]);
  806. if (evaluated_inputs->control_bits[i]) {
  807. result[i] += (*correction_ints)[block_index];
  808. }
  809. if (party == 1) {
  810. result[i] = -result[i];
  811. }
  812. }
  813. return result;
  814. }
  815. } // namespace distributed_point_functions
  816. #endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_