1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- // Copyright 2022 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #ifndef COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_MODEL_PROVIDER_FACTORY_IMPL_H_
- #define COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_MODEL_PROVIDER_FACTORY_IMPL_H_
- #include <memory>
- #include "base/containers/flat_map.h"
- #include "base/no_destructor.h"
- #include "base/task/sequenced_task_runner.h"
- #include "components/segmentation_platform/public/model_provider.h"
- #include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
- namespace optimization_guide {
- class OptimizationGuideModelProvider;
- }
- namespace segmentation_platform {
- struct Config;
- class ModelProviderFactoryImpl : public ModelProviderFactory {
- public:
- ModelProviderFactoryImpl(
- optimization_guide::OptimizationGuideModelProvider*
- optimization_guide_provider,
- std::vector<std::unique_ptr<Config>>& configs,
- scoped_refptr<base::SequencedTaskRunner> background_task_runner);
- ~ModelProviderFactoryImpl() override;
- ModelProviderFactoryImpl(ModelProviderFactoryImpl&) = delete;
- ModelProviderFactoryImpl& operator=(ModelProviderFactoryImpl&) = delete;
- // ModelProviderFactory impl:
- std::unique_ptr<ModelProvider> CreateProvider(
- proto::SegmentId segment_id) override;
- std::unique_ptr<ModelProvider> CreateDefaultProvider(
- proto::SegmentId segment_id) override;
- private:
- raw_ptr<optimization_guide::OptimizationGuideModelProvider>
- optimization_guide_provider_;
- base::flat_map<proto::SegmentId, std::unique_ptr<ModelProvider>>
- default_models_;
- scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
- };
- // Used only in tests to override the default model.
- class TestDefaultModelOverride {
- public:
- static TestDefaultModelOverride& GetInstance();
- ~TestDefaultModelOverride();
- TestDefaultModelOverride(const TestDefaultModelOverride& client) = delete;
- TestDefaultModelOverride& operator=(const TestDefaultModelOverride& client) =
- delete;
- std::unique_ptr<ModelProvider> TakeOwnershipOfModelProvider(
- proto::SegmentId target);
- void SetModelForTesting(proto::SegmentId target,
- std::unique_ptr<ModelProvider>);
- private:
- friend class base::NoDestructor<TestDefaultModelOverride>;
- TestDefaultModelOverride();
- std::map<proto::SegmentId, std::unique_ptr<ModelProvider>> providers_;
- };
- } // namespace segmentation_platform
- #endif // COMPONENTS_SEGMENTATION_PLATFORM_EMBEDDER_MODEL_PROVIDER_FACTORY_IMPL_H_
|