123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904 |
- // Copyright 2019 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/optimization_guide/core/prediction_manager.h"
- #include <memory>
- #include <utility>
- #include "base/callback.h"
- #include "base/containers/flat_map.h"
- #include "base/containers/flat_set.h"
- #include "base/containers/flat_tree.h"
- #include "base/metrics/histogram_functions.h"
- #include "base/metrics/histogram_macros.h"
- #include "base/metrics/histogram_macros_local.h"
- #include "base/observer_list.h"
- #include "base/path_service.h"
- #include "base/rand_util.h"
- #include "base/sequence_checker.h"
- #include "base/task/sequenced_task_runner.h"
- #include "base/task/thread_pool.h"
- #include "base/threading/thread_task_runner_handle.h"
- #include "base/time/default_clock.h"
- #include "base/time/time.h"
- #include "components/optimization_guide/core/model_info.h"
- #include "components/optimization_guide/core/model_util.h"
- #include "components/optimization_guide/core/optimization_guide_constants.h"
- #include "components/optimization_guide/core/optimization_guide_enums.h"
- #include "components/optimization_guide/core/optimization_guide_features.h"
- #include "components/optimization_guide/core/optimization_guide_logger.h"
- #include "components/optimization_guide/core/optimization_guide_permissions_util.h"
- #include "components/optimization_guide/core/optimization_guide_prefs.h"
- #include "components/optimization_guide/core/optimization_guide_store.h"
- #include "components/optimization_guide/core/optimization_guide_switches.h"
- #include "components/optimization_guide/core/optimization_guide_util.h"
- #include "components/optimization_guide/core/optimization_target_model_observer.h"
- #include "components/optimization_guide/core/prediction_model_download_manager.h"
- #include "components/optimization_guide/core/prediction_model_fetcher_impl.h"
- #include "components/optimization_guide/core/prediction_model_override.h"
- #include "components/optimization_guide/core/store_update_data.h"
- #include "components/optimization_guide/proto/models.pb.h"
- #include "components/prefs/pref_service.h"
- #include "mojo/public/cpp/bindings/remote.h"
- #include "services/network/public/cpp/shared_url_loader_factory.h"
- namespace optimization_guide {
- namespace {
- // Provide a random time delta in seconds before fetching models.
- base::TimeDelta RandomFetchDelay() {
- return base::Seconds(
- base::RandInt(features::PredictionModelFetchRandomMinDelaySecs(),
- features::PredictionModelFetchRandomMaxDelaySecs()));
- }
- // Util class for recording the state of a prediction model. The result is
- // recorded when it goes out of scope and its destructor is called.
- class ScopedPredictionManagerModelStatusRecorder {
- public:
- explicit ScopedPredictionManagerModelStatusRecorder(
- proto::OptimizationTarget optimization_target)
- : optimization_target_(optimization_target) {}
- ~ScopedPredictionManagerModelStatusRecorder() {
- DCHECK_NE(status_, PredictionManagerModelStatus::kUnknown);
- base::UmaHistogramEnumeration(
- "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus",
- status_);
- base::UmaHistogramEnumeration(
- "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." +
- GetStringNameForOptimizationTarget(optimization_target_),
- status_);
- }
- void set_status(PredictionManagerModelStatus status) { status_ = status; }
- private:
- PredictionManagerModelStatus status_ = PredictionManagerModelStatus::kUnknown;
- const proto::OptimizationTarget optimization_target_;
- };
- // Util class for recording the construction and validation of a prediction
- // model. The result is recorded when it goes out of scope and its destructor is
- // called.
- class ScopedPredictionModelConstructionAndValidationRecorder {
- public:
- explicit ScopedPredictionModelConstructionAndValidationRecorder(
- proto::OptimizationTarget optimization_target)
- : validation_start_time_(base::TimeTicks::Now()),
- optimization_target_(optimization_target) {}
- ~ScopedPredictionModelConstructionAndValidationRecorder() {
- base::UmaHistogramBoolean("OptimizationGuide.IsPredictionModelValid",
- is_valid_);
- base::UmaHistogramBoolean(
- "OptimizationGuide.IsPredictionModelValid." +
- GetStringNameForOptimizationTarget(optimization_target_),
- is_valid_);
- // Only record the timing if the model is valid and was able to be
- // constructed.
- if (is_valid_) {
- base::TimeDelta validation_latency =
- base::TimeTicks::Now() - validation_start_time_;
- base::UmaHistogramTimes(
- "OptimizationGuide.PredictionModelValidationLatency",
- validation_latency);
- base::UmaHistogramTimes(
- "OptimizationGuide.PredictionModelValidationLatency." +
- GetStringNameForOptimizationTarget(optimization_target_),
- validation_latency);
- }
- }
- void set_is_valid(bool is_valid) { is_valid_ = is_valid; }
- private:
- bool is_valid_ = true;
- const base::TimeTicks validation_start_time_;
- const proto::OptimizationTarget optimization_target_;
- };
- void RecordModelUpdateVersion(const proto::ModelInfo& model_info) {
- base::UmaHistogramSparse(
- "OptimizationGuide.PredictionModelUpdateVersion." +
- GetStringNameForOptimizationTarget(model_info.optimization_target()),
- model_info.version());
- }
- void RecordLifecycleState(proto::OptimizationTarget optimization_target,
- ModelDeliveryEvent event) {
- base::UmaHistogramEnumeration(
- "OptimizationGuide.PredictionManager.ModelDeliveryEvents." +
- GetStringNameForOptimizationTarget(optimization_target),
- event);
- }
- // Returns whether models should be fetched from the
- // remote Optimization Guide Service.
- bool ShouldFetchModels(bool off_the_record, bool component_updates_enabled) {
- return features::IsRemoteFetchingEnabled() && !off_the_record &&
- features::IsModelDownloadingEnabled() && component_updates_enabled;
- }
- // Returns whether the model metadata proto is on the server allowlist.
- bool IsModelMetadataTypeOnServerAllowlist(const proto::Any& model_metadata) {
- return model_metadata.type_url() ==
- "type.googleapis.com/"
- "google.internal.chrome.optimizationguide.v1."
- "PageEntitiesModelMetadata" ||
- model_metadata.type_url() ==
- "type.googleapis.com/"
- "google.internal.chrome.optimizationguide.v1."
- "PageTopicsModelMetadata" ||
- model_metadata.type_url() ==
- "type.googleapis.com/"
- "google.internal.chrome.optimizationguide.v1."
- "SegmentationModelMetadata" ||
- model_metadata.type_url() ==
- "type.googleapis.com/"
- "google.privacy.webpermissionpredictions.v1."
- "WebPermissionPredictionsModelMetadata";
- }
- void RecordModelAvailableAtRegistration(
- proto::OptimizationTarget optimization_target,
- bool model_available_at_registration) {
- base::UmaHistogramBoolean(
- "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration." +
- GetStringNameForOptimizationTarget(optimization_target),
- model_available_at_registration);
- }
- } // namespace
- PredictionManager::PredictionManager(
- base::WeakPtr<OptimizationGuideStore> model_and_features_store,
- scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
- PrefService* pref_service,
- bool off_the_record,
- const std::string& application_locale,
- const base::FilePath& models_dir_path,
- OptimizationGuideLogger* optimization_guide_logger,
- BackgroundDownloadServiceProvider background_download_service_provider,
- ComponentUpdatesEnabledProvider component_updates_enabled_provider)
- : prediction_model_download_manager_(nullptr),
- model_and_features_store_(model_and_features_store),
- url_loader_factory_(url_loader_factory),
- optimization_guide_logger_(optimization_guide_logger),
- pref_service_(pref_service),
- component_updates_enabled_provider_(component_updates_enabled_provider),
- clock_(base::DefaultClock::GetInstance()),
- off_the_record_(off_the_record),
- application_locale_(application_locale),
- models_dir_path_(models_dir_path) {
- Initialize(std::move(background_download_service_provider));
- }
- PredictionManager::~PredictionManager() {
- if (prediction_model_download_manager_)
- prediction_model_download_manager_->RemoveObserver(this);
- }
- void PredictionManager::Initialize(
- BackgroundDownloadServiceProvider background_download_service_provider) {
- if (model_and_features_store_) {
- model_and_features_store_->Initialize(
- switches::ShouldPurgeModelAndFeaturesStoreOnStartup(),
- base::BindOnce(&PredictionManager::OnStoreInitialized,
- ui_weak_ptr_factory_.GetWeakPtr(),
- std::move(background_download_service_provider)));
- }
- }
- void PredictionManager::AddObserverForOptimizationTargetModel(
- proto::OptimizationTarget optimization_target,
- const absl::optional<proto::Any>& model_metadata,
- OptimizationTargetModelObserver* observer) {
- DCHECK(registered_observers_for_optimization_targets_.find(
- optimization_target) ==
- registered_observers_for_optimization_targets_.end());
- // As DCHECKS don't run in the wild, just do not register the observer if
- // something is already registered for the type. Otherwise, file reads may
- // blow up.
- if (registered_observers_for_optimization_targets_.find(
- optimization_target) !=
- registered_observers_for_optimization_targets_.end()) {
- DLOG(ERROR) << "Did not add observer for optimization target "
- << static_cast<int>(optimization_target)
- << " since an observer for the target was already registered ";
- return;
- }
- registered_observers_for_optimization_targets_[optimization_target]
- .AddObserver(observer);
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "Observer added for OptimizationTarget: " << optimization_target;
- }
- // Notify observer of existing model file path.
- auto model_it = optimization_target_model_info_map_.find(optimization_target);
- if (model_it != optimization_target_model_info_map_.end()) {
- observer->OnModelUpdated(optimization_target, *model_it->second);
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "OnModelFileUpdated for OptimizationTarget: "
- << optimization_target << "\nFile path: "
- << (*model_it->second).GetModelFilePath().AsUTF8Unsafe()
- << "\nHas metadata: " << (model_metadata ? "True" : "False");
- }
- RecordLifecycleState(optimization_target,
- ModelDeliveryEvent::kModelDeliveredAtRegistration);
- }
- base::UmaHistogramMediumTimes(
- "OptimizationGuide.PredictionManager.RegistrationTimeSinceServiceInit." +
- GetStringNameForOptimizationTarget(optimization_target),
- !init_time_.is_null() ? base::TimeTicks::Now() - init_time_
- : base::TimeDelta());
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (registered_optimization_targets_and_metadata_.contains(
- optimization_target)) {
- return;
- }
- DCHECK(!model_metadata ||
- IsModelMetadataTypeOnServerAllowlist(*model_metadata));
- registered_optimization_targets_and_metadata_.emplace(optimization_target,
- model_metadata);
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "Registered new OptimizationTarget: " << optimization_target;
- }
- // Before loading/fetching models and features, the store must be ready.
- if (!store_is_ready_)
- return;
- // If no fetch is scheduled, maybe schedule one.
- if (!fetch_timer_.IsRunning())
- MaybeScheduleFirstModelFetch();
- // Otherwise, load prediction models for any newly registered targets.
- LoadPredictionModels({optimization_target});
- }
- void PredictionManager::RemoveObserverForOptimizationTargetModel(
- proto::OptimizationTarget optimization_target,
- OptimizationTargetModelObserver* observer) {
- auto observers_it =
- registered_observers_for_optimization_targets_.find(optimization_target);
- if (observers_it == registered_observers_for_optimization_targets_.end())
- return;
- observers_it->second.RemoveObserver(observer);
- }
- base::flat_set<proto::OptimizationTarget>
- PredictionManager::GetRegisteredOptimizationTargets() const {
- base::flat_set<proto::OptimizationTarget> optimization_targets;
- for (const auto& optimization_target_and_metadata :
- registered_optimization_targets_and_metadata_) {
- optimization_targets.insert(optimization_target_and_metadata.first);
- }
- return optimization_targets;
- }
- void PredictionManager::SetPredictionModelFetcherForTesting(
- std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher) {
- prediction_model_fetcher_ = std::move(prediction_model_fetcher);
- }
- void PredictionManager::SetPredictionModelDownloadManagerForTesting(
- std::unique_ptr<PredictionModelDownloadManager>
- prediction_model_download_manager) {
- prediction_model_download_manager_ =
- std::move(prediction_model_download_manager);
- }
- void PredictionManager::FetchModels(bool is_first_model_fetch) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- // The histogram that gets recorded here is used for integration tests that
- // pass in a model override. For simplicity, we place the recording of this
- // histogram here rather than somewhere else earlier in the session
- // initialization flow since the model engine version needs to continuously be
- // updated for the fetch.
- proto::ModelInfo base_model_info;
- // There should only be one supported model engine version at a time.
- base_model_info.add_supported_model_engine_versions(
- proto::MODEL_ENGINE_VERSION_TFLITE_2_11);
- // This histogram is used for integration tests. Do not remove.
- // Update this to be 10000 if/when we exceed 100 model engine versions.
- LOCAL_HISTOGRAM_COUNTS_100(
- "OptimizationGuide.PredictionManager.SupportedModelEngineVersion",
- static_cast<int>(
- *base_model_info.supported_model_engine_versions().begin()));
- if (switches::IsModelOverridePresent())
- return;
- if (!ShouldFetchModels(off_the_record_, pref_service_))
- return;
- if (is_first_model_fetch) {
- DCHECK(!init_time_.is_null());
- base::UmaHistogramMediumTimes(
- "OptimizationGuide.PredictionManager.FirstModelFetchSinceServiceInit",
- base::TimeTicks::Now() - init_time_);
- }
- // Models should not be fetched if there are no optimization targets
- // registered.
- if (registered_optimization_targets_and_metadata_.empty())
- return;
- // We should have already created a prediction model download manager if we
- // initiated the fetching of models.
- DCHECK(prediction_model_download_manager_);
- if (prediction_model_download_manager_) {
- bool download_service_available =
- prediction_model_download_manager_->IsAvailableForDownloads();
- base::UmaHistogramBoolean(
- "OptimizationGuide.PredictionManager."
- "DownloadServiceAvailabilityBlockedFetch",
- !download_service_available);
- if (!download_service_available) {
- for (const auto& optimization_target_and_metadata :
- registered_optimization_targets_and_metadata_) {
- RecordLifecycleState(optimization_target_and_metadata.first,
- ModelDeliveryEvent::kDownloadServiceUnavailable);
- }
- // We cannot download any models from the server, so don't refresh them.
- return;
- }
- prediction_model_download_manager_->CancelAllPendingDownloads();
- }
- // NOTE: ALL PRECONDITIONS FOR THIS FUNCTION MUST BE CHECKED ABOVE THIS LINE.
- // It is assumed that if we proceed past here, that a fetch will at least be
- // attempted.
- if (!prediction_model_fetcher_) {
- prediction_model_fetcher_ = std::make_unique<PredictionModelFetcherImpl>(
- url_loader_factory_,
- features::GetOptimizationGuideServiceGetModelsURL());
- }
- std::vector<proto::ModelInfo> models_info = std::vector<proto::ModelInfo>();
- models_info.reserve(registered_optimization_targets_and_metadata_.size());
- // For now, we will fetch for all registered optimization targets.
- for (const auto& optimization_target_and_metadata :
- registered_optimization_targets_and_metadata_) {
- proto::ModelInfo model_info(base_model_info);
- model_info.set_optimization_target(optimization_target_and_metadata.first);
- if (optimization_target_and_metadata.second.has_value()) {
- *model_info.mutable_model_metadata() =
- *optimization_target_and_metadata.second;
- }
- auto model_it = optimization_target_model_info_map_.find(
- optimization_target_and_metadata.first);
- if (model_it != optimization_target_model_info_map_.end())
- model_info.set_version(model_it->second.get()->GetVersion());
- models_info.push_back(model_info);
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "Fetching models for Optimization Target "
- << model_info.optimization_target();
- }
- RecordLifecycleState(optimization_target_and_metadata.first,
- ModelDeliveryEvent::kGetModelsRequest);
- }
- bool fetch_initiated =
- prediction_model_fetcher_->FetchOptimizationGuideServiceModels(
- models_info, proto::CONTEXT_BATCH_UPDATE_MODELS, application_locale_,
- base::BindOnce(&PredictionManager::OnModelsFetched,
- ui_weak_ptr_factory_.GetWeakPtr(), models_info));
- if (fetch_initiated)
- SetLastModelFetchAttemptTime(clock_->Now());
- // Schedule the next fetch regardless since we may not have initiated a fetch
- // due to a network condition and trying in the next minute to see if that is
- // unblocked is only a timer firing and not an actual query to the server.
- ScheduleModelsFetch();
- }
- void PredictionManager::OnModelsFetched(
- const std::vector<proto::ModelInfo> models_request_info,
- absl::optional<std::unique_ptr<proto::GetModelsResponse>>
- get_models_response_data) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (!get_models_response_data) {
- for (const auto& model_info : models_request_info) {
- RecordLifecycleState(model_info.optimization_target(),
- ModelDeliveryEvent::kGetModelsResponseFailure);
- }
- return;
- }
- SetLastModelFetchSuccessTime(clock_->Now());
- if ((*get_models_response_data)->models_size() > 0) {
- UpdatePredictionModels((*get_models_response_data)->models());
- }
- fetch_timer_.Stop();
- ScheduleModelsFetch();
- }
- void PredictionManager::UpdatePredictionModels(
- const google::protobuf::RepeatedPtrField<proto::PredictionModel>&
- prediction_models) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (!model_and_features_store_)
- return;
- std::unique_ptr<StoreUpdateData> prediction_model_update_data =
- StoreUpdateData::CreatePredictionModelStoreUpdateData(
- clock_->Now() + features::StoredModelsValidDuration());
- bool has_models_to_update = false;
- for (const auto& model : prediction_models) {
- if (model.has_model() && !model.model().download_url().empty()) {
- // We should only be updating the store for on-the-record profiles and
- // after the store has been initialized.
- DCHECK(prediction_model_download_manager_);
- if (prediction_model_download_manager_) {
- GURL download_url(model.model().download_url());
- if (download_url.is_valid()) {
- prediction_model_download_manager_->StartDownload(
- download_url, model.model_info().optimization_target());
- }
- RecordLifecycleState(model.model_info().optimization_target(),
- download_url.is_valid()
- ? ModelDeliveryEvent::kDownloadServiceRequest
- : ModelDeliveryEvent::kDownloadURLInvalid);
- base::UmaHistogramBoolean(
- "OptimizationGuide.PredictionManager.IsDownloadUrlValid." +
- GetStringNameForOptimizationTarget(
- model.model_info().optimization_target()),
- download_url.is_valid());
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "Model download required for Optimization Target: "
- << model.model_info().optimization_target();
- }
- }
- // Skip over models that have a download URL since they will be updated
- // once the download has completed successfully.
- continue;
- }
- if (!model.has_model()) {
- // We already have this updated model, so don't update in store.
- continue;
- }
- has_models_to_update = true;
- // Storing the model regardless of whether the model is valid or not. Model
- // will be removed from store if it fails to load.
- prediction_model_update_data->CopyPredictionModelIntoUpdateData(model);
- RecordModelUpdateVersion(model.model_info());
- OnLoadPredictionModel(model.model_info().optimization_target(),
- /*record_availability_metrics=*/false,
- std::make_unique<proto::PredictionModel>(model));
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "Model Download Not Required for target: "
- << model.model_info().optimization_target() << "\nNew Version: "
- << base::NumberToString(model.model_info().version());
- }
- }
- if (has_models_to_update) {
- model_and_features_store_->UpdatePredictionModels(
- std::move(prediction_model_update_data),
- base::BindOnce(&PredictionManager::OnPredictionModelsStored,
- ui_weak_ptr_factory_.GetWeakPtr()));
- }
- }
- void PredictionManager::OnModelReady(const proto::PredictionModel& model) {
- if (switches::IsModelOverridePresent())
- return;
- if (!model_and_features_store_)
- return;
- DCHECK(model.model_info().has_version() &&
- model.model_info().has_optimization_target());
- RecordModelUpdateVersion(model.model_info());
- RecordLifecycleState(model.model_info().optimization_target(),
- ModelDeliveryEvent::kModelDownloaded);
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "Model Files Downloaded target: "
- << model.model_info().optimization_target()
- << "\nNew Version: " +
- base::NumberToString(model.model_info().version());
- }
- // Store the received model in the store.
- std::unique_ptr<StoreUpdateData> prediction_model_update_data =
- StoreUpdateData::CreatePredictionModelStoreUpdateData(
- clock_->Now() + features::StoredModelsValidDuration());
- prediction_model_update_data->CopyPredictionModelIntoUpdateData(model);
- model_and_features_store_->UpdatePredictionModels(
- std::move(prediction_model_update_data),
- base::BindOnce(&PredictionManager::OnPredictionModelsStored,
- ui_weak_ptr_factory_.GetWeakPtr()));
- if (registered_optimization_targets_and_metadata_.contains(
- model.model_info().optimization_target())) {
- OnLoadPredictionModel(model.model_info().optimization_target(),
- /*record_availability_metrics=*/false,
- std::make_unique<proto::PredictionModel>(model));
- }
- }
- void PredictionManager::OnModelDownloadStarted(
- proto::OptimizationTarget optimization_target) {
- RecordLifecycleState(optimization_target,
- ModelDeliveryEvent::kModelDownloadStarted);
- }
- void PredictionManager::OnModelDownloadFailed(
- proto::OptimizationTarget optimization_target) {
- RecordLifecycleState(optimization_target,
- ModelDeliveryEvent::kModelDownloadFailure);
- }
- void PredictionManager::NotifyObserversOfNewModel(
- proto::OptimizationTarget optimization_target,
- const ModelInfo& model_info) {
- auto observers_it =
- registered_observers_for_optimization_targets_.find(optimization_target);
- if (observers_it == registered_observers_for_optimization_targets_.end())
- return;
- RecordLifecycleState(optimization_target,
- ModelDeliveryEvent::kModelDelivered);
- for (auto& observer : observers_it->second) {
- observer.OnModelUpdated(optimization_target, model_info);
- if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
- OPTIMIZATION_GUIDE_LOGGER(
- optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
- optimization_guide_logger_)
- << "OnModelFileUpdated for target: " << optimization_target
- << "\nFile path: " << model_info.GetModelFilePath().AsUTF8Unsafe()
- << "\nHas metadata: "
- << (model_info.GetModelMetadata() ? "True" : "False");
- }
- }
- }
- void PredictionManager::OnPredictionModelsStored() {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- LOCAL_HISTOGRAM_BOOLEAN(
- "OptimizationGuide.PredictionManager.PredictionModelsStored", true);
- }
- void PredictionManager::OnStoreInitialized(
- BackgroundDownloadServiceProvider background_download_service_provider) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- store_is_ready_ = true;
- init_time_ = base::TimeTicks::Now();
- LOCAL_HISTOGRAM_BOOLEAN(
- "OptimizationGuide.PredictionManager.StoreInitialized", true);
- // Create the download manager here if we are allowed to.
- if (features::IsModelDownloadingEnabled() && !off_the_record_ &&
- !prediction_model_download_manager_) {
- prediction_model_download_manager_ =
- std::make_unique<PredictionModelDownloadManager>(
- background_download_service_provider
- ? std::move(background_download_service_provider).Run()
- : nullptr,
- models_dir_path_,
- base::ThreadPool::CreateSequencedTaskRunner(
- {base::MayBlock(), base::TaskPriority::BEST_EFFORT,
- base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}));
- prediction_model_download_manager_->AddObserver(this);
- }
- // Purge any inactive models from the store.
- model_and_features_store_->PurgeInactiveModels();
- // Only load models if there are optimization targets registered.
- if (registered_optimization_targets_and_metadata_.empty())
- return;
- // The store is ready so start loading models for the registered optimization
- // targets.
- LoadPredictionModels(GetRegisteredOptimizationTargets());
- MaybeScheduleFirstModelFetch();
- }
- void PredictionManager::OnPredictionModelOverrideLoaded(
- proto::OptimizationTarget optimization_target,
- std::unique_ptr<proto::PredictionModel> prediction_model) {
- OnLoadPredictionModel(optimization_target,
- /*record_availability_metrics=*/false,
- std::move(prediction_model));
- RecordModelAvailableAtRegistration(optimization_target,
- prediction_model != nullptr);
- }
- void PredictionManager::LoadPredictionModels(
- const base::flat_set<proto::OptimizationTarget>& optimization_targets) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (switches::IsModelOverridePresent()) {
- for (proto::OptimizationTarget optimization_target : optimization_targets) {
- BuildPredictionModelFromCommandLineForOptimizationTarget(
- optimization_target,
- base::BindOnce(&PredictionManager::OnPredictionModelOverrideLoaded,
- ui_weak_ptr_factory_.GetWeakPtr(),
- optimization_target));
- }
- return;
- }
- if (!model_and_features_store_)
- return;
- OptimizationGuideStore::EntryKey model_entry_key;
- for (const auto& optimization_target : optimization_targets) {
- // The prediction model for this optimization target has already been
- // loaded.
- bool model_stored_locally =
- model_and_features_store_->FindPredictionModelEntryKey(
- optimization_target, &model_entry_key);
- if (!model_stored_locally) {
- RecordModelAvailableAtRegistration(optimization_target,
- model_stored_locally);
- continue;
- }
- model_and_features_store_->LoadPredictionModel(
- model_entry_key,
- base::BindOnce(&PredictionManager::OnLoadPredictionModel,
- ui_weak_ptr_factory_.GetWeakPtr(), optimization_target,
- /*record_availability_metrics=*/true));
- }
- }
- void PredictionManager::OnLoadPredictionModel(
- proto::OptimizationTarget optimization_target,
- bool record_availability_metrics,
- std::unique_ptr<proto::PredictionModel> model) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (!model) {
- if (record_availability_metrics) {
- RecordModelAvailableAtRegistration(optimization_target, false);
- }
- return;
- }
- bool success = ProcessAndStoreLoadedModel(*model);
- DCHECK_EQ(optimization_target, model->model_info().optimization_target());
- if (record_availability_metrics)
- RecordModelAvailableAtRegistration(optimization_target, success);
- OnProcessLoadedModel(*model, success);
- }
- void PredictionManager::OnProcessLoadedModel(
- const proto::PredictionModel& model,
- bool success) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (success) {
- base::UmaHistogramSparse("OptimizationGuide.PredictionModelLoadedVersion." +
- GetStringNameForOptimizationTarget(
- model.model_info().optimization_target()),
- model.model_info().version());
- return;
- }
- // Remove model from store if it exists.
- OptimizationGuideStore::EntryKey model_entry_key;
- if (model_and_features_store_ &&
- model_and_features_store_->FindPredictionModelEntryKey(
- model.model_info().optimization_target(), &model_entry_key)) {
- LOCAL_HISTOGRAM_BOOLEAN("OptimizationGuide.PredictionModelRemoved." +
- GetStringNameForOptimizationTarget(
- model.model_info().optimization_target()),
- true);
- model_and_features_store_->RemovePredictionModelFromEntryKey(
- model_entry_key);
- }
- }
- bool PredictionManager::ProcessAndStoreLoadedModel(
- const proto::PredictionModel& model) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- if (!model.model_info().has_optimization_target())
- return false;
- if (!model.model_info().has_version())
- return false;
- if (!model.has_model())
- return false;
- if (!registered_optimization_targets_and_metadata_.contains(
- model.model_info().optimization_target())) {
- return false;
- }
- ScopedPredictionModelConstructionAndValidationRecorder
- prediction_model_recorder(model.model_info().optimization_target());
- std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model);
- if (!model_info) {
- prediction_model_recorder.set_is_valid(false);
- return false;
- }
- proto::OptimizationTarget optimization_target =
- model.model_info().optimization_target();
- // See if we should update the loaded model.
- if (!ShouldUpdateStoredModelForTarget(optimization_target,
- model.model_info().version())) {
- return true;
- }
- // Update prediction model file if that is what we have loaded.
- if (model_info) {
- StoreLoadedModelInfo(optimization_target, std::move(model_info));
- }
- return true;
- }
- bool PredictionManager::ShouldUpdateStoredModelForTarget(
- proto::OptimizationTarget optimization_target,
- int64_t new_version) const {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- auto model_meta_it =
- optimization_target_model_info_map_.find(optimization_target);
- if (model_meta_it != optimization_target_model_info_map_.end())
- return model_meta_it->second->GetVersion() != new_version;
- return true;
- }
- void PredictionManager::StoreLoadedModelInfo(
- proto::OptimizationTarget optimization_target,
- std::unique_ptr<ModelInfo> model_info) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- DCHECK(model_info);
- // Notify observers of new model file path.
- base::ThreadTaskRunnerHandle::Get()->PostTask(
- FROM_HERE, base::BindOnce(&PredictionManager::NotifyObserversOfNewModel,
- ui_weak_ptr_factory_.GetWeakPtr(),
- optimization_target, *model_info));
- optimization_target_model_info_map_.insert_or_assign(optimization_target,
- std::move(model_info));
- }
- void PredictionManager::MaybeScheduleFirstModelFetch() {
- if (!ShouldFetchModels(off_the_record_,
- component_updates_enabled_provider_.Run()))
- return;
- // Add a slight delay to allow the rest of the browser startup process to
- // finish up.
- fetch_timer_.Start(FROM_HERE, features::PredictionModelFetchStartupDelay(),
- base::BindOnce(&PredictionManager::FetchModels,
- base::Unretained(this), true));
- }
- base::Time PredictionManager::GetLastFetchAttemptTime() const {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds(
- pref_service_->GetInt64(prefs::kModelAndFeaturesLastFetchAttempt)));
- }
- base::Time PredictionManager::GetLastFetchSuccessTime() const {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds(
- pref_service_->GetInt64(prefs::kModelLastFetchSuccess)));
- }
- void PredictionManager::ScheduleModelsFetch() {
- DCHECK(!fetch_timer_.IsRunning());
- DCHECK(store_is_ready_);
- const base::TimeDelta time_until_update_time =
- GetLastFetchSuccessTime() + features::PredictionModelFetchInterval() -
- clock_->Now();
- const base::TimeDelta time_until_retry =
- GetLastFetchAttemptTime() + features::PredictionModelFetchRetryDelay() -
- clock_->Now();
- base::TimeDelta fetcher_delay =
- std::max(time_until_update_time, time_until_retry);
- if (fetcher_delay <= base::TimeDelta()) {
- fetch_timer_.Start(FROM_HERE, RandomFetchDelay(),
- base::BindOnce(&PredictionManager::FetchModels,
- base::Unretained(this), false));
- return;
- }
- fetch_timer_.Start(FROM_HERE, fetcher_delay, this,
- &PredictionManager::ScheduleModelsFetch);
- }
- void PredictionManager::SetLastModelFetchAttemptTime(
- base::Time last_attempt_time) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- pref_service_->SetInt64(
- prefs::kModelAndFeaturesLastFetchAttempt,
- last_attempt_time.ToDeltaSinceWindowsEpoch().InMicroseconds());
- }
- void PredictionManager::SetLastModelFetchSuccessTime(
- base::Time last_success_time) {
- DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
- pref_service_->SetInt64(
- prefs::kModelLastFetchSuccess,
- last_success_time.ToDeltaSinceWindowsEpoch().InMicroseconds());
- }
- void PredictionManager::SetClockForTesting(const base::Clock* clock) {
- clock_ = clock;
- }
- void PredictionManager::OverrideTargetModelForTesting(
- proto::OptimizationTarget optimization_target,
- std::unique_ptr<ModelInfo> model_info) {
- if (!model_info) {
- return;
- }
- ModelInfo model_info_copy = *model_info;
- optimization_target_model_info_map_.insert_or_assign(optimization_target,
- std::move(model_info));
- NotifyObserversOfNewModel(optimization_target, model_info_copy);
- }
- } // namespace optimization_guide
|