model_provider.h 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_
  5. #define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_
  6. #include "base/callback.h"
  7. #include "base/task/sequenced_task_runner.h"
  8. #include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
  9. #include "third_party/abseil-cpp/absl/types/optional.h"
  10. namespace segmentation_platform {
  11. namespace proto {
  12. class SegmentationModelMetadata;
  13. } // namespace proto
  14. // Interface used by the segmentation platform to get model and metadata for a
  15. // single optimization target.
  16. class ModelProvider {
  17. public:
  18. using ModelUpdatedCallback = base::RepeatingCallback<
  19. void(proto::SegmentId, proto::SegmentationModelMetadata, int64_t)>;
  20. using ExecutionCallback =
  21. base::OnceCallback<void(const absl::optional<float>&)>;
  22. explicit ModelProvider(proto::SegmentId segment_id);
  23. virtual ~ModelProvider();
  24. ModelProvider(ModelProvider&) = delete;
  25. ModelProvider& operator=(ModelProvider&) = delete;
  26. // Implementation should return metadata that will be used to execute model.
  27. // The metadata provided should define the number of features needed by the
  28. // ExecuteModelWithInput() method. Starts a fetch request for the model for
  29. // optimization target. The `model_updated_callback` can be called multiple
  30. // times when new models are available for the optimization target.
  31. virtual void InitAndFetchModel(
  32. const ModelUpdatedCallback& model_updated_callback) = 0;
  33. // Executes the latest model available, with the given inputs and returns
  34. // result via `callback`. Should be called only after InitAndFetchModel()
  35. // otherwise returns absl::nullopt. Implementation could be a heuristic or
  36. // model execution to return a result. The inputs to this method are the
  37. // computed tensors based on the features provided in the latest call to
  38. // `model_updated_callback`. The result is a float score with the probability
  39. // of positive result. Also see `discrete_mapping` field in the
  40. // `SegmentationModelMetadata` for how the score will be used to determine the
  41. // segment.
  42. virtual void ExecuteModelWithInput(const std::vector<float>& inputs,
  43. ExecutionCallback callback) = 0;
  44. // Returns true if a model is available.
  45. virtual bool ModelAvailable() = 0;
  46. protected:
  47. const proto::SegmentId segment_id_;
  48. };
  49. // Interface used by segmentation platform to create ModelProvider(s).
  50. class ModelProviderFactory {
  51. public:
  52. virtual ~ModelProviderFactory();
  53. // Creates a model provider for the given `segment_id`.
  54. virtual std::unique_ptr<ModelProvider> CreateProvider(proto::SegmentId) = 0;
  55. // Creates a default model provider to be used when the original provider did
  56. // not provide a model. Returns `nullptr` when a default provider is not
  57. // available.
  58. // TODO(crbug.com/1346389): This method should be moved to Config after
  59. // migrating all the tests that use this.
  60. virtual std::unique_ptr<ModelProvider> CreateDefaultProvider(
  61. proto::SegmentId) = 0;
  62. };
  63. } // namespace segmentation_platform
  64. #endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_MODEL_PROVIDER_H_