123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893 |
- /*
- * Copyright 2021 Google LLC
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
- #define DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
- #include <glog/logging.h>
- #include <openssl/cipher.h>
- #include <memory>
- #include <type_traits>
- #include "absl/container/btree_map.h"
- #include "absl/container/flat_hash_map.h"
- #include "absl/meta/type_traits.h"
- #include "absl/status/statusor.h"
- #include "absl/strings/str_format.h"
- #include "dpf/aes_128_fixed_key_hash.h"
- #include "dpf/distributed_point_function.pb.h"
- #include "dpf/internal/proto_validator.h"
- #include "dpf/internal/value_type_helpers.h"
- namespace distributed_point_functions {
- // Type trait for all supported types. Used to provide meaningful error messages
- // in std::enable_if template guards.
- template <typename T>
- using is_supported_type = dpf_internal::is_supported_type<T>;
- template <typename T>
- constexpr bool is_supported_type_v = is_supported_type<T>::value;
- // Converts a given Value to the template parameter T.
- //
- // Returns INVALID_ARGUMENT if the conversion fails.
- template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
- absl::StatusOr<T> FromValue(const Value& value) {
- return dpf_internal::ValueTypeHelper<T>::FromValue(value);
- }
- // ToValue Converts the argument to a Value.
- template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
- Value ToValue(const T& input) {
- return dpf_internal::ValueTypeHelper<T>::ToValue(input);
- }
- // ToValueType<T> Returns a `ValueType` message describing T.
- template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
- ValueType ToValueType() {
- return dpf_internal::ValueTypeHelper<T>::ToValueType();
- }
- // Implements key generation and evaluation of distributed point functions.
- // A distributed point function (DPF) is parameterized by an index `alpha` and a
- // value `beta`. The key generation procedure produces two keys `k_a`, `k_b`.
- // Evaluating each key on any point `x` in the DPF domain results in an additive
- // secret share of `beta`, if `x == alpha`, and a share of 0 otherwise. This
- // class also supports *incremental* DPFs that can additionally be evaluated on
- // prefixes of points, resulting in different values `beta_i`for each prefix of
- // `alpha`.
- class DistributedPointFunction {
- public:
- // Creates a new instance of a distributed point function that can be
- // evaluated only at the output layer.
- //
- // Returns INVALID_ARGUMENT if the parameters are invalid.
- static absl::StatusOr<std::unique_ptr<DistributedPointFunction>> Create(
- const DpfParameters& parameters);
- // Creates a new instance of an *incremental* DPF that can be evaluated at
- // multiple layers. Each parameter set in `parameters` should specify the
- // domain size and element size at one of the layers to be evaluated, in
- // increasing domain size order. Element sizes must be non-decreasing.
- //
- // Returns INVALID_ARGUMENT if the parameters are invalid.
- static absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
- CreateIncremental(absl::Span<const DpfParameters> parameters);
- // DistributedPointFunction is neither copyable nor movable.
- DistributedPointFunction(const DistributedPointFunction&) = delete;
- DistributedPointFunction& operator=(const DistributedPointFunction&) = delete;
- // Converts the argument to a `Value` proto. Also registers the corresponding
- // value type with the DPF by calling `RegisterValueType<T>()`.
- template <typename T>
- absl::StatusOr<Value> ToValue(const T& in) {
- absl::Status status = RegisterValueType<T>();
- if (!status.ok()) {
- return status;
- }
- return distributed_point_functions::ToValue(in);
- }
- // Registers the template parameter type with this DPF. Note that it is rarely
- // necessary to call this function by hand: It is called by `Create` and
- // `CreateIncremental` for all unsigned integer types, including
- // absl::uint128, and on every call to ToValue<T>. Only call this function
- // when passing `Value`s created by other means than ToValue<T>.
- //
- // Returns OK on success and otherwise an INTERNAL status describing the
- // failure.
- template <typename T>
- absl::Status RegisterValueType() {
- return RegisterValueTypeImpl<T>(value_correction_functions_);
- }
- // Generates a pair of keys for a DPF that evaluates to `beta` when evaluated
- // `alpha`. The type of `beta` must match the ValueType passed in `parameters`
- // at construction.
- //
- // This function provides three overloads: One with `absl::uint128` for
- // `beta`, which implies the output type is a simple integer; One with a
- // `Value` proto for `beta`, which can be used for all supported value types;
- // And a templated version that computes the Value by calling ToValue<T> on
- // the argument.
- //
- // Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
- //
- // // Simple integer:
- // dpf->GenerateKeys(23, 42);
- //
- // // Explicit `Value` proto:
- // Value value;
- // value[1]->mutable_tuple->add_elements()
- // ->mutable_integer->set_value_uint64(12);
- // value[1]->mutable_tuple->add_elements()
- // ->mutable_integer->set_value_uint64(34);
- // // Must be called once before calling GenerateKeys for any type that is
- // // not a simple integer. The type should match the one in the
- // // DpfParameters passed at construction.
- // dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
- // dpf->GenerateKeys(23, value);
- //
- // // Templated version (no call to RegisterValueType needed):
- // dpf->GenerateKeys(23, Tuple<uint32_t, uint64_t>{12, 34});
- //
- // Returns INVALID_ARGUMENT if used on an incremental DPF with more
- // than one set of parameters, if `alpha` is outside of the domain specified
- // at construction, or if `beta` does not match the value type passed at
- // construction.
- // Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
- // for the type in the `DpfParameters` passed at construction.
- // Overload for simple integers.
- absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
- absl::uint128 beta) {
- return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
- }
- // Overload for explicit Value proto.
- absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
- Value beta) {
- return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
- }
- // Template for automatic conversion to Value proto. Disabled if the argument
- // is convertible to `absl::uint128` or `Value` to make overloading
- // unambiguous.
- template <typename T, typename = absl::enable_if_t<
- !std::is_convertible<T, absl::uint128>::value &&
- !std::is_convertible<T, Value>::value &&
- is_supported_type_v<T>>>
- absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
- const T& beta) {
- absl::StatusOr<Value> value = ToValue<T>(beta);
- if (!value.ok()) {
- return value.status();
- }
- return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&(*value), 1));
- }
- // Generates a pair of keys for an incremental DPF. For each parameter i
- // passed at construction, the DPF evaluates to `beta[i]` at the lowest
- // `parameters_[i].log_domain_size()` bits of `alpha`.
- //
- // Similar to `GenerateKeys`, supports three overloads: One for simple
- // integers, passed as an `absl::Span<const absl::uint128>`; One for a span of
- // `Value` protos; And a variadic function template that automatically
- // converts the passed arguments to a vector of `Value`s.
- //
- // Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
- //
- // // Simple integers:
- // std::vector<absl::uint128> beta{123, 456};
- // dpf->GenerateKeysIncremental(23, beta);
- //
- // // Explicit Value protos:
- // std::vector<Value> beta(2);
- // value[0]->mutable_integer()->set_value_uint128(42);
- // value[1]->mutable_tuple->add_elements()
- // ->mutable_integer->set_value_uint64(12);
- // value[1]->mutable_tuple->add_elements()
- // ->mutable_integer->set_value_uint64(34);
- // // Must be called once before calling GenerateKeys for any type that is
- // // not a simple integer. The type should match the one in the
- // // DpfParameters passed at construction.
- // dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
- // dpf->GenerateKeysIncremental(23, beta);
- //
- // // Templated version (equivalent to the one above):
- // dpf->GenerateKeysIncremental(23, 42, Tuple<uint32_t, uint64_t>{12, 34}));
- //
- // Returns INVALID_ARGUMENT if `beta.size() != parameters_.size()`, if `alpha`
- // is outside of the domain specified at construction, or if `beta` does not
- // match the element type passed at construction.
- // Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
- // for all types in the `DpfParameters` passed at construction.
- // Overload for simple integers.
- absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
- absl::uint128 alpha, absl::Span<const absl::uint128> beta) {
- std::vector<Value> values(beta.size());
- for (int i = 0; i < static_cast<int>(beta.size()); ++i) {
- absl::StatusOr<Value> value = ToValue(beta[i]);
- if (!value.ok()) {
- return value.status();
- }
- values[i] = std::move(*value);
- }
- return GenerateKeysIncremental(alpha, values);
- }
- // Overload for Value protos.
- absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
- absl::uint128 alpha, absl::Span<const Value> beta);
- // Variadic template version. Disabled if the first argument is convertible to
- // a span of `absl::uint128`s or `Value`s to make overloading unambiguous.
- template <
- typename T0, typename... Tn,
- typename = absl::enable_if_t<
- !std::is_convertible<T0, absl::Span<const Value>>::value &&
- !std::is_convertible<T0, absl::Span<const absl::uint128>>::value &&
- absl::conjunction<is_supported_type<T0>,
- is_supported_type<Tn>...>::value>>
- absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
- absl::uint128 alpha, T0&& beta_0, Tn&&... beta_n);
- // Returns an `EvaluationContext` for incrementally evaluating the given
- // DpfKey.
- //
- // Returns INVALID_ARGUMENT if `key` doesn't match the parameters given at
- // construction.
- absl::StatusOr<EvaluationContext> CreateEvaluationContext(DpfKey key) const;
- // Evaluates the given `hierarchy_level` of the DPF under all `prefixes`
- // passed to this function. If `prefixes` is empty, evaluation starts from the
- // seed of `ctx.key`. Otherwise, each element of `prefixes` must fit in the
- // domain size of `ctx.previous_hierarchy_level`. Further, `prefixes` may only
- // contain extensions of the prefixes passed in the previous call. For
- // example, in the following sequence of calls, for each element p2 of
- // `prefixes2`, there must be an element p1 of `prefixes1` such that p1 is a
- // prefix of p2:
- //
- // DPF_ASSIGN_OR_RETURN(std::unique_ptr<EvaluationContext> ctx,
- // dpf->CreateEvaluationContext(key));
- // using T0 = ...;
- // DPF_ASSIGN_OR_RETURN(std::vector<T0> evaluations0,
- // dpf->EvaluateUntil(0, {}, *ctx));
- //
- // std::vector<absl::uint128> prefixes1 = ...;
- // using T1 = ...;
- // DPF_ASSIGN_OR_RETURN(std::vector<T1> evaluations1,
- // dpf->EvaluateUntil(1, prefixes1, *ctx));
- // ...
- // std::vector<absl::uint128> prefixes2 = ...;
- // using T2 = ...;
- // DPF_ASSIGN_OR_RETURN(std::vector<T2> evaluations2,
- // dpf->EvaluateUntil(3, prefixes2, *ctx));
- //
- // The prefixes are read from the lowest-order bits of the corresponding
- // absl::uint128. The number of bits used for each prefix depends on the
- // output domain size of the previously evaluated hierarchy level. For
- // example, if `ctx` was last evaluated on a hierarchy level with output
- // domain size 2**20, then the 20 lowest-order bits of each element in
- // `prefixes` are used.
- //
- // Returns `INVALID_ARGUMENT` if
- // - any element of `prefixes` is larger than the next hierarchy level's
- // log_domain_size,
- // - `prefixes` contains elements that are not extensions of previous
- // prefixes, or
- // - the bit-size of T doesn't match the next hierarchy level's
- // element_bitsize.
- template <typename T>
- absl::StatusOr<std::vector<T>> EvaluateUntil(
- int hierarchy_level, absl::Span<const absl::uint128> prefixes,
- EvaluationContext& ctx) const;
- template <typename T>
- absl::StatusOr<std::vector<T>> EvaluateNext(
- absl::Span<const absl::uint128> prefixes, EvaluationContext& ctx) const {
- if (prefixes.empty()) {
- return EvaluateUntil<T>(0, prefixes, ctx);
- } else {
- return EvaluateUntil<T>(ctx.previous_hierarchy_level() + 1, prefixes,
- ctx);
- }
- }
- // Evaluates a single key at one or multiple points, up to the given
- // hierarchy_level. Each element of `evaluation_points` must be within the
- // domain of this DPF at `hierarchy_level`.
- //
- // Example:
- //
- // DpfKey key = ...;
- // std::vector<absl::uint128> evaluation_points = {1, 23, 42};
- // // Evaluate `key` on {1, 23, 42}.
- // DPF_ASSIGN_OR_RETURN(std::vector<T> result,
- // dpf->EvaluateAt(key, 0, evaluation_points);
- //
- // Returns INVALID_ARGUMENT if `key` is malformed, or if `hierarchy_level` or
- // any element of `evaluation_points` is out of range.
- template <typename T>
- absl::StatusOr<std::vector<T>> EvaluateAt(
- const DpfKey& key, int hierarchy_level,
- absl::Span<const absl::uint128> evaluation_points) const;
- // Returns the DpfParameters of this DPF.
- inline absl::Span<const DpfParameters> parameters() const {
- return parameters_;
- }
- private:
- // BitVector is a vector of bools. Allows for faster access times than
- // std::vector<bool>, as well as inlining if the size is small.
- using BitVector =
- absl::InlinedVector<bool,
- std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>;
- // Seeds and control bits resulting from a DPF expansion. This type is
- // returned by `ExpandSeeds` and `ExpandAndUpdateContext`.
- struct DpfExpansion {
- std::vector<absl::uint128> seeds;
- BitVector control_bits;
- };
- // A function for computing value corrections. Used as return type in
- // `GetValueCorrectionFunction`.
- using ValueCorrectionFunction = absl::StatusOr<std::vector<Value>> (*)(
- absl::string_view, absl::string_view, int block_index, const Value&,
- bool);
- // Private constructor, called by `CreateIncremental`.
- DistributedPointFunction(
- std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
- std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left,
- Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value,
- absl::flat_hash_map<std::string, ValueCorrectionFunction>
- value_correction_functions);
- // Computes the value correction for the given `hierarchy_level`, `seeds`,
- // index `alpha` and value `beta`. If `invert` is true, the individual values
- // in the returned block are multiplied element-wise by -1. Expands `seeds`
- // using `prg_ctx_value_`, then calls the function returned by
- // `GetValueCorrectionFunction(parameters_[hierarchy_level])` to obtain the
- // value correction words.
- //
- // Returns multiple values in the case of packing, and a single Value
- // otherwise.
- //
- // Returns INTERNAL in case the PRG expansion fails, and UNIMPLEMENTED if
- // `element_bitsize` is not supported.
- absl::StatusOr<std::vector<Value>> ComputeValueCorrection(
- int hierarchy_level, absl::Span<const absl::uint128> seeds,
- absl::uint128 alpha, const Value& beta, bool invert) const;
- // Expands the PRG seeds at the next `tree_level` for an incremental DPF with
- // index `alpha` and values `beta`, updates `seeds` and `control_bits`, and
- // writes the next correction word to `keys`. Called from
- // `GenerateKeysIncremental`.
- absl::Status GenerateNext(int tree_level, absl::uint128 alpha,
- absl::Span<const Value> beta,
- absl::Span<absl::uint128> seeds,
- absl::Span<bool> control_bits,
- absl::Span<DpfKey> keys) const;
- // Computes the tree index (representing a path in the FSS tree) from the
- // given `domain_index` and `hierarchy_level`. Does NOT check whether the
- // given domain index fits in the domain at `hierarchy_level`.
- absl::uint128 DomainToTreeIndex(absl::uint128 domain_index,
- int hierarchy_level) const;
- // Computes the block index (pointing to an element in a batched 128-bit
- // block) from the given `domain_index` and `hierarchy_level`. Does NOT check
- // whether the given domain index fits in the domain at `hierarchy_level`.
- int DomainToBlockIndex(absl::uint128 domain_index, int hierarchy_level) const;
- // Performs DPF evaluation of the given `partial_evaluations` using
- // prg_ctx_left_ or prg_ctx_right_, and the given `correction_words`. At each
- // level `l < correction_words.size()`, the evaluation for the i-th seed in
- // `partial_evaluations` continues along the left or right path depending on
- // the l-th most significant bit among the lowest `correction_words.size()`
- // bits of `paths[i]`.
- //
- // Returns INVALID_ARGUMENT if the input sizes don't match.
- // Returns INTERNAL in case of OpenSSL errors.
- absl::StatusOr<DpfExpansion> EvaluateSeeds(
- DpfExpansion partial_evaluations, absl::Span<const absl::uint128> paths,
- absl::Span<const CorrectionWord* const> correction_words) const;
- // Performs DPF expansion of the given `partial_evaluations` using
- // prg_ctx_left_ and prg_ctx_right_, and the given `correction_words`. In more
- // detail, each of the partial evaluations is subjected to a full subtree
- // expansion of `correction_words.size()` levels, and the concatenated result
- // is provided in the response. The result contains
- // `(partial_evaluations.size() * (2^correction_words.size())` evaluations in
- // a single `DpfExpansion`.
- //
- // Returns INTERNAL in case of OpenSSL errors.
- absl::StatusOr<DpfExpansion> ExpandSeeds(
- const DpfExpansion& partial_evaluations,
- absl::Span<const CorrectionWord* const> correction_words) const;
- // Computes partial evaluations of the paths to `prefixes` to be used as the
- // starting point of the expansion of `ctx`. If `update_ctx == true`, saves
- // the partial evaluations of `ctx.previous_hierarchy_level` to `ctx` and sets
- // `ctx.partial_evaluations_level` to `ctx.previous_hierarchy_level`.
- // Called by `ExpandAndUpdateContext`.
- //
- // Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
- // `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
- // duplicate seeds.
- absl::StatusOr<DpfExpansion> ComputePartialEvaluations(
- absl::Span<const absl::uint128> prefixes, bool update_ctx,
- EvaluationContext& ctx) const;
- // Extracts the seeds for the given `prefixes` from `ctx` and expands them as
- // far as needed for the next hierarchy level. Returns the result as a
- // `DpfExpansion`. Called by `EvaluateUntil`, where the expanded seeds are
- // corrected to obtain output values.
- // After expansion, `ctx.hierarchy_level()` is increased. If this isn't the
- // last expansion, the expanded seeds are also saved in `ctx` for the next
- // expansion.
- //
- // Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
- // `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
- // duplicate seeds. Returns INTERNAL in case of OpenSSL errors.
- absl::StatusOr<DpfExpansion> ExpandAndUpdateContext(
- int hierarchy_level, absl::Span<const absl::uint128> prefixes,
- EvaluationContext& ctx) const;
- // Compute output PRG value of expanded seeds using prg_ctx_value_.
- // Returns blocks_needed_[hierarchy_level] * expansion.seeds.size() blocks,
- // where every blocks_needed_[hierarchy_level] correspond to the hash of an
- // input seed.
- //
- // Returns INTERNAL in case of OpenSSL errors.
- absl::StatusOr<std::vector<absl::uint128>> HashExpandedSeeds(
- int hierarchy_level, absl::Span<const absl::uint128> expansion) const;
- // Deterministically serializes the given value_type.
- //
- // Returns OK on success and INTERNAL in case serialization fails.
- static absl::StatusOr<std::string> SerializeValueTypeDeterministically(
- const ValueType& value_type);
- // Returns the value correction function for the given parameters.
- // For all value types except unsigned integers, these functions have to be
- // first registered using RegisterValueType<T>.
- //
- // Returns UNIMPLEMENTED if no matching function was registered.
- absl::StatusOr<ValueCorrectionFunction> GetValueCorrectionFunction(
- const DpfParameters& parameters) const;
- // Static implementation of RegisterValueType<T>, so we can call it from
- // `Create`.
- template <typename T>
- static absl::Status RegisterValueTypeImpl(
- absl::flat_hash_map<std::string, ValueCorrectionFunction>&
- value_correction_functions);
- // Used to validate DpfParameters, DpfKey and EvaluationContext protos.
- const std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
- // DP parameters passed to the factory function. Contains the domain size and
- // element size for hierarchy level of the incremental DPF. Owned by
- // proto_validator_.
- const absl::Span<const DpfParameters> parameters_;
- // Number of levels in the evaluation tree. This is always less than or equal
- // to the largest log_domain_size in parameters_.
- const int tree_levels_needed_;
- // Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements
- // of parameters_).
- const absl::flat_hash_map<int, int>& tree_to_hierarchy_;
- // The inverse of tree_to_hierarchy_.
- const std::vector<int>& hierarchy_to_tree_;
- // Cached numbers of AES blocks needed for value correction at each hierarchy
- // level.
- const std::vector<int> blocks_needed_;
- // Pseudorandom generator used for seed expansion (left and right), and value
- // correction. The PRG G(x) for hierarchy level i is defined as the
- // concatenation of
- //
- // H_left(x), H_right(x), H_value(x + 0), ..., H_value(x + k-1)
- //
- // where k is equal to blocks_needed_[i], and H_*(x) is the evaluation of
- // prg_*_ on input x.
- const Aes128FixedKeyHash prg_left_;
- const Aes128FixedKeyHash prg_right_;
- const Aes128FixedKeyHash prg_value_;
- // Maps serialized `ValueType` messages to the correct value correction
- // functions. Map values are instantiations of
- // `dpf_internal::ComputeValueCorrectionFor`. Relies on protobuf's
- // deterministic serialization feature. This has the caveat that messages with
- // unknown fields are not supported. However, as long as `ValueType` consists
- // of a single `oneof` field, this is fine, since we either know the value
- // type and have deterministic serialization because the `ValueType` can only
- // contain one field, or we don't know the type and wouldn't be able to
- // correct values for it anyway.
- absl::flat_hash_map<std::string, ValueCorrectionFunction>
- value_correction_functions_;
- };
- //========================//
- // Implementation Details //
- //========================//
- template <typename T>
- absl::Status DistributedPointFunction::RegisterValueTypeImpl(
- absl::flat_hash_map<std::string, ValueCorrectionFunction>&
- value_correction_functions) {
- ValueType value_type = ToValueType<T>();
- absl::StatusOr<std::string> serialized_value_type =
- SerializeValueTypeDeterministically(value_type);
- if (!serialized_value_type.ok()) {
- return serialized_value_type.status();
- }
- value_correction_functions[*serialized_value_type] =
- dpf_internal::ComputeValueCorrectionFor<T>;
- return absl::OkStatus();
- }
- template <typename T0, typename... Tn, typename /*= absl::enable_if_t<...>*/>
- absl::StatusOr<std::pair<DpfKey, DpfKey>>
- DistributedPointFunction::GenerateKeysIncremental(absl::uint128 alpha,
- T0&& beta_0, Tn&&... beta_n) {
- // Convert the first element of beta. We need to treat it separately to be
- // able to check its type in the enable_if above.
- absl::StatusOr<Value> value = ToValue(beta_0);
- if (!value.ok()) {
- return value.status();
- }
- std::vector<Value> values = {std::move(*value)};
- values.reserve(1 + sizeof...(beta_n));
- // Convert all values in the parameter pack, stopping at the first error.
- absl::Status status = absl::OkStatus();
- // We create an unused std::tuple<Tn...> here, because its braced-initializer
- // list constructor allows us to operate on beta_n in a well-defined order. In
- // C++17, this could be replaced by a fold expression instead.
- std::tuple<Tn...>{[this, &status, &values, &value](auto&& beta_i) -> Tn {
- if (status.ok()) {
- value = this->ToValue(beta_i);
- if (value.ok()) {
- values.push_back(std::move(*value));
- } else {
- status = value.status();
- }
- }
- return Tn{};
- }(beta_n)...};
- // Return if there was an error during conversion, otherwise generate keys.
- if (!status.ok()) {
- return status;
- }
- return GenerateKeysIncremental(alpha, values);
- }
- template <typename T>
- absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
- int hierarchy_level, absl::Span<const absl::uint128> prefixes,
- EvaluationContext& ctx) const {
- absl::Status status = proto_validator_->ValidateEvaluationContext(ctx);
- if (!status.ok()) {
- return status;
- }
- if (hierarchy_level < 0 ||
- hierarchy_level >= static_cast<int>(parameters_.size())) {
- return absl::InvalidArgumentError(
- "`hierarchy_level` must be non-negative and less than "
- "parameters_.size()");
- }
- absl::StatusOr<bool> types_are_equal = dpf_internal::ValueTypesAreEqual(
- ToValueType<T>(), parameters_[hierarchy_level].value_type());
- if (!types_are_equal.ok()) {
- return types_are_equal.status();
- } else if (!*types_are_equal) {
- return absl::InvalidArgumentError(
- "Value type T doesn't match parameters at `hierarchy_level`");
- }
- if (hierarchy_level <= ctx.previous_hierarchy_level()) {
- return absl::InvalidArgumentError(
- "`hierarchy_level` must be greater than "
- "`ctx.previous_hierarchy_level`");
- }
- if ((ctx.previous_hierarchy_level() < 0) != (prefixes.empty())) {
- return absl::InvalidArgumentError(
- "`prefixes` must be empty if and only if this is the first call with "
- "`ctx`.");
- }
- int previous_log_domain_size = 0;
- int previous_hierarchy_level = ctx.previous_hierarchy_level();
- if (!prefixes.empty()) {
- DCHECK(ctx.previous_hierarchy_level() >= 0);
- previous_log_domain_size =
- parameters_[previous_hierarchy_level].log_domain_size();
- for (absl::uint128 prefix : prefixes) {
- if (previous_log_domain_size < 128 &&
- prefix >= (absl::uint128{1} << previous_log_domain_size)) {
- return absl::InvalidArgumentError(
- absl::StrFormat("Index %d out of range for hierarchy level %d",
- prefix, previous_hierarchy_level));
- }
- }
- }
- int64_t prefixes_size = static_cast<int64_t>(prefixes.size());
- int log_domain_size = parameters_[hierarchy_level].log_domain_size();
- if (log_domain_size - previous_log_domain_size > 62) {
- return absl::InvalidArgumentError(
- "Output size would be larger than 2**62. Please evaluate fewer "
- "hierarchy levels at once.");
- }
- // The `prefixes` passed in by the caller refer to the domain of the previous
- // hierarchy level. However, because we batch multiple elements of type T in a
- // single uint128 block, multiple prefixes can actually refer to the same
- // block in the FSS evaluation tree. On a high level, our approach is as
- // follows:
- //
- // 1. Split up each element of `prefixes` into a tree index, pointing to a
- // block in the FSS tree, and a block index, pointing to an element of type
- // T in that block.
- //
- // 2. Compute a list of unique `tree_indices`, and for each original prefix,
- // remember the position of the corresponding tree index in `tree_indices`.
- //
- // 3. After expanding the unique `tree_indices`, use the positions saved in
- // Step (2) together with the corresponding block index to retrieve the
- // expanded values for each prefix, and return them in the same order as
- // `prefixes`.
- //
- // `tree_indices` holds the unique tree indices from `prefixes`, to be passed
- // to `ExpandAndUpdateContext`.
- std::vector<absl::uint128> tree_indices;
- tree_indices.reserve(prefixes_size);
- // `tree_indices_inverse` is the inverse of `tree_indices`, used for
- // deduplicating and constructing `prefix_map`. Use a btree_map because we
- // expect `prefixes` (and thus `tree_indices`) to be sorted.
- absl::btree_map<absl::uint128, int64_t> tree_indices_inverse;
- // `prefix_map` maps each i < prefixes.size() to an element of `tree_indices`
- // and a block index. Used to select which elements to return after the
- // expansion, to ensure the result is ordered the same way as `prefixes`.
- std::vector<std::pair<int64_t, int>> prefix_map;
- prefix_map.reserve(prefixes_size);
- for (int64_t i = 0; i < prefixes_size; ++i) {
- absl::uint128 tree_index =
- DomainToTreeIndex(prefixes[i], previous_hierarchy_level);
- int block_index = DomainToBlockIndex(prefixes[i], previous_hierarchy_level);
- // Check if `tree_index` already exists in `tree_indices`.
- auto previous_size = tree_indices_inverse.size();
- auto it = tree_indices_inverse.try_emplace(tree_indices_inverse.end(),
- tree_index, tree_indices.size());
- if (tree_indices_inverse.size() > previous_size) {
- tree_indices.push_back(tree_index);
- }
- prefix_map.push_back(std::make_pair(it->second, block_index));
- }
- // Perform expansion of unique `tree_indices`.
- absl::StatusOr<DpfExpansion> expansion =
- ExpandAndUpdateContext(hierarchy_level, tree_indices, ctx);
- if (!expansion.ok()) {
- return expansion.status();
- }
- // Hash the expanded seeds.
- absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
- HashExpandedSeeds(hierarchy_level, expansion->seeds);
- if (!hashed_expansion.ok()) {
- return hashed_expansion.status();
- }
- // Get output correction word from `ctx`.
- constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
- const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
- if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
- value_correction =
- &(ctx.key()
- .correction_words(hierarchy_to_tree_[hierarchy_level])
- .value_correction());
- } else {
- // Last level value correction is stored in an extra proto field, since we
- // have one less correction word than tree levels.
- value_correction = &(ctx.key().last_level_value_correction());
- }
- // Split output correction into elements of type T.
- absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
- dpf_internal::ValuesToArray<T>(*value_correction);
- if (!correction_ints.ok()) {
- return correction_ints.status();
- }
- // Compute value corrections for each block in `expanded_seeds`. We have to
- // account for the fact that blocks might not be full (i.e., have less than
- // elements_per_block elements).
- const int corrected_elements_per_block =
- 1 << (parameters_[hierarchy_level].log_domain_size() -
- hierarchy_to_tree_[hierarchy_level]);
- const auto expansion_size = static_cast<int64_t>(expansion->seeds.size());
- const int blocks_needed = blocks_needed_[hierarchy_level];
- DCHECK(corrected_elements_per_block <= elements_per_block);
- std::vector<T> corrected_expansion(expansion_size *
- corrected_elements_per_block);
- for (int64_t i = 0; i < expansion_size; ++i) {
- std::array<T, elements_per_block> current_elements =
- dpf_internal::ConvertBytesToArrayOf<T>(
- absl::string_view(reinterpret_cast<const char*>(
- &(*hashed_expansion)[i * blocks_needed]),
- blocks_needed * sizeof(absl::uint128)));
- for (int j = 0; j < corrected_elements_per_block; ++j) {
- if (expansion->control_bits[i]) {
- current_elements[j] += (*correction_ints)[j];
- }
- if (ctx.key().party() == 1) {
- current_elements[j] = -current_elements[j];
- }
- corrected_expansion[i * corrected_elements_per_block + j] =
- current_elements[j];
- }
- }
- // Compute the number of outputs we will have. For each prefix, we will have a
- // full expansion from the previous heirarchy level to the current heirarchy
- // level.
- DCHECK(log_domain_size - previous_log_domain_size < 63);
- int64_t outputs_per_prefix = int64_t{1}
- << (log_domain_size - previous_log_domain_size);
- if (prefixes.empty()) {
- // If prefixes is empty (i.e., this is the first evaluation of `ctx`), just
- // return the expansion.
- DCHECK(static_cast<int>(corrected_expansion.size()) == outputs_per_prefix);
- return corrected_expansion;
- } else {
- // Otherwise, only return elements under `prefixes`.
- int blocks_per_tree_prefix = expansion->seeds.size() / tree_indices.size();
- std::vector<T> result(prefixes_size * outputs_per_prefix);
- for (int64_t i = 0; i < prefixes_size; ++i) {
- int64_t prefix_expansion_start =
- prefix_map[i].first * blocks_per_tree_prefix *
- corrected_elements_per_block +
- prefix_map[i].second * outputs_per_prefix;
- std::copy_n(&corrected_expansion[prefix_expansion_start],
- outputs_per_prefix, &result[i * outputs_per_prefix]);
- }
- return result;
- }
- }
- template <typename T>
- absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAt(
- const DpfKey& key, int hierarchy_level,
- absl::Span<const absl::uint128> evaluation_points) const {
- auto num_evaluation_points = static_cast<int64_t>(evaluation_points.size());
- if (hierarchy_level < 0) {
- return absl::InvalidArgumentError("`hierarchy_level` must be non-negative");
- }
- if (hierarchy_level >= static_cast<int>(parameters_.size())) {
- return absl::InvalidArgumentError(
- "`hierarchy_level` must be less than the number of parameters passed "
- "at construction");
- }
- absl::Status status = proto_validator_->ValidateDpfKey(key);
- if (!status.ok()) {
- return status;
- }
- // Get output correction word from `key`.
- constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
- const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
- if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
- value_correction =
- &(key.correction_words(hierarchy_to_tree_[hierarchy_level])
- .value_correction());
- } else {
- // Last level value correction is stored in an extra proto field, since we
- // have one less correction word than tree levels.
- value_correction = &(key.last_level_value_correction());
- }
- // Split output correction into elements of type T, and save it in
- // correction_ints.
- absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
- dpf_internal::ValuesToArray<T>(*value_correction);
- if (!correction_ints.ok()) {
- return correction_ints.status();
- }
- // Split up evaluation_points into tree indices and block indices, if we're
- // operating on a packed type. Otherwise set `tree_indices` to
- // `evaluation_points`.
- std::vector<absl::uint128> maybe_recomputed_tree_indices(0);
- absl::Span<const absl::uint128> tree_indices;
- if (elements_per_block > 1) {
- maybe_recomputed_tree_indices.reserve(num_evaluation_points);
- for (int64_t i = 0; i < num_evaluation_points; ++i) {
- maybe_recomputed_tree_indices.push_back(
- DomainToTreeIndex(evaluation_points[i], hierarchy_level));
- }
- tree_indices = absl::MakeConstSpan(maybe_recomputed_tree_indices);
- } else {
- // This avoids copying the evaluation points when elements_per_block == 1.
- tree_indices = evaluation_points;
- }
- // Extract seed and party for DPF evaluation.
- absl::uint128 seed = absl::MakeUint128(key.seed().high(), key.seed().low());
- bool party = key.party();
- DpfExpansion inputs;
- inputs.seeds.resize(num_evaluation_points, seed);
- inputs.control_bits.resize(num_evaluation_points, party);
- // Evaluate DPFs.
- const int stop_level = hierarchy_to_tree_[hierarchy_level];
- auto correction_words =
- absl::MakeConstSpan(key.correction_words()).subspan(0, stop_level);
- absl::StatusOr<DpfExpansion> evaluated_inputs =
- EvaluateSeeds(std::move(inputs), tree_indices, correction_words);
- if (!evaluated_inputs.ok()) {
- return evaluated_inputs.status();
- }
- DCHECK(static_cast<int64_t>(evaluated_inputs->seeds.size()) ==
- num_evaluation_points);
- // Hash DPF evaluations.
- absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
- HashExpandedSeeds(hierarchy_level, evaluated_inputs->seeds);
- if (!hashed_expansion.ok()) {
- return hashed_expansion.status();
- }
- // Perform value correction.
- std::vector<T> result;
- result.reserve(num_evaluation_points);
- const int blocks_needed = blocks_needed_[hierarchy_level];
- for (int64_t i = 0; i < num_evaluation_points; ++i) {
- std::array<T, elements_per_block> current_elements =
- dpf_internal::ConvertBytesToArrayOf<T>(
- absl::string_view(reinterpret_cast<const char*>(
- &(*hashed_expansion)[i * blocks_needed]),
- blocks_needed * sizeof(absl::uint128)));
- int block_index = 0;
- if (elements_per_block > 1) {
- block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level);
- }
- result.push_back(current_elements[block_index]);
- if (evaluated_inputs->control_bits[i]) {
- result[i] += (*correction_ints)[block_index];
- }
- if (party == 1) {
- result[i] = -result[i];
- }
- }
- return result;
- }
- } // namespace distributed_point_functions
- #endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
|