model_provider_factory_impl.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_MODEL_PROVIDER_FACTORY_IMPL_H_
  5. #define COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_MODEL_PROVIDER_FACTORY_IMPL_H_
  6. #include <memory>
  7. #include "base/containers/flat_map.h"
  8. #include "base/no_destructor.h"
  9. #include "base/task/sequenced_task_runner.h"
  10. #include "components/segmentation_platform/public/model_provider.h"
  11. #include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
  12. namespace optimization_guide {
  13. class OptimizationGuideModelProvider;
  14. }
  15. namespace segmentation_platform {
  16. struct Config;
  17. class ModelProviderFactoryImpl : public ModelProviderFactory {
  18. public:
  19. ModelProviderFactoryImpl(
  20. optimization_guide::OptimizationGuideModelProvider*
  21. optimization_guide_provider,
  22. std::vector<std::unique_ptr<Config>>& configs,
  23. scoped_refptr<base::SequencedTaskRunner> background_task_runner);
  24. ~ModelProviderFactoryImpl() override;
  25. ModelProviderFactoryImpl(ModelProviderFactoryImpl&) = delete;
  26. ModelProviderFactoryImpl& operator=(ModelProviderFactoryImpl&) = delete;
  27. // ModelProviderFactory impl:
  28. std::unique_ptr<ModelProvider> CreateProvider(
  29. proto::SegmentId segment_id) override;
  30. std::unique_ptr<ModelProvider> CreateDefaultProvider(
  31. proto::SegmentId segment_id) override;
  32. private:
  33. raw_ptr<optimization_guide::OptimizationGuideModelProvider>
  34. optimization_guide_provider_;
  35. base::flat_map<proto::SegmentId, std::unique_ptr<ModelProvider>>
  36. default_models_;
  37. scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
  38. };
  39. // Used only in tests to override the default model.
  40. class TestDefaultModelOverride {
  41. public:
  42. static TestDefaultModelOverride& GetInstance();
  43. ~TestDefaultModelOverride();
  44. TestDefaultModelOverride(const TestDefaultModelOverride& client) = delete;
  45. TestDefaultModelOverride& operator=(const TestDefaultModelOverride& client) =
  46. delete;
  47. std::unique_ptr<ModelProvider> TakeOwnershipOfModelProvider(
  48. proto::SegmentId target);
  49. void SetModelForTesting(proto::SegmentId target,
  50. std::unique_ptr<ModelProvider>);
  51. private:
  52. friend class base::NoDestructor<TestDefaultModelOverride>;
  53. TestDefaultModelOverride();
  54. std::map<proto::SegmentId, std::unique_ptr<ModelProvider>> providers_;
  55. };
  56. } // namespace segmentation_platform
  57. #endif // COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_MODEL_PROVIDER_FACTORY_IMPL_H_