prediction_manager.cc 36 KB


  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/prediction_manager.h"
  5. #include <memory>
  6. #include <utility>
  7. #include "base/callback.h"
  8. #include "base/containers/flat_map.h"
  9. #include "base/containers/flat_set.h"
  10. #include "base/containers/flat_tree.h"
  11. #include "base/metrics/histogram_functions.h"
  12. #include "base/metrics/histogram_macros.h"
  13. #include "base/metrics/histogram_macros_local.h"
  14. #include "base/observer_list.h"
  15. #include "base/path_service.h"
  16. #include "base/rand_util.h"
  17. #include "base/sequence_checker.h"
  18. #include "base/task/sequenced_task_runner.h"
  19. #include "base/task/thread_pool.h"
  20. #include "base/threading/thread_task_runner_handle.h"
  21. #include "base/time/default_clock.h"
  22. #include "base/time/time.h"
  23. #include "components/optimization_guide/core/model_info.h"
  24. #include "components/optimization_guide/core/model_util.h"
  25. #include "components/optimization_guide/core/optimization_guide_constants.h"
  26. #include "components/optimization_guide/core/optimization_guide_enums.h"
  27. #include "components/optimization_guide/core/optimization_guide_features.h"
  28. #include "components/optimization_guide/core/optimization_guide_logger.h"
  29. #include "components/optimization_guide/core/optimization_guide_permissions_util.h"
  30. #include "components/optimization_guide/core/optimization_guide_prefs.h"
  31. #include "components/optimization_guide/core/optimization_guide_store.h"
  32. #include "components/optimization_guide/core/optimization_guide_switches.h"
  33. #include "components/optimization_guide/core/optimization_guide_util.h"
  34. #include "components/optimization_guide/core/optimization_target_model_observer.h"
  35. #include "components/optimization_guide/core/prediction_model_download_manager.h"
  36. #include "components/optimization_guide/core/prediction_model_fetcher_impl.h"
  37. #include "components/optimization_guide/core/prediction_model_override.h"
  38. #include "components/optimization_guide/core/store_update_data.h"
  39. #include "components/optimization_guide/proto/models.pb.h"
  40. #include "components/prefs/pref_service.h"
  41. #include "mojo/public/cpp/bindings/remote.h"
  42. #include "services/network/public/cpp/shared_url_loader_factory.h"
  43. namespace optimization_guide {
  44. namespace {
  45. // Provide a random time delta in seconds before fetching models.
  46. base::TimeDelta RandomFetchDelay() {
  47. return base::Seconds(
  48. base::RandInt(features::PredictionModelFetchRandomMinDelaySecs(),
  49. features::PredictionModelFetchRandomMaxDelaySecs()));
  50. }
  51. // Util class for recording the state of a prediction model. The result is
  52. // recorded when it goes out of scope and its destructor is called.
  53. class ScopedPredictionManagerModelStatusRecorder {
  54. public:
  55. explicit ScopedPredictionManagerModelStatusRecorder(
  56. proto::OptimizationTarget optimization_target)
  57. : optimization_target_(optimization_target) {}
  58. ~ScopedPredictionManagerModelStatusRecorder() {
  59. DCHECK_NE(status_, PredictionManagerModelStatus::kUnknown);
  60. base::UmaHistogramEnumeration(
  61. "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus",
  62. status_);
  63. base::UmaHistogramEnumeration(
  64. "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." +
  65. GetStringNameForOptimizationTarget(optimization_target_),
  66. status_);
  67. }
  68. void set_status(PredictionManagerModelStatus status) { status_ = status; }
  69. private:
  70. PredictionManagerModelStatus status_ = PredictionManagerModelStatus::kUnknown;
  71. const proto::OptimizationTarget optimization_target_;
  72. };
  73. // Util class for recording the construction and validation of a prediction
  74. // model. The result is recorded when it goes out of scope and its destructor is
  75. // called.
  76. class ScopedPredictionModelConstructionAndValidationRecorder {
  77. public:
  78. explicit ScopedPredictionModelConstructionAndValidationRecorder(
  79. proto::OptimizationTarget optimization_target)
  80. : validation_start_time_(base::TimeTicks::Now()),
  81. optimization_target_(optimization_target) {}
  82. ~ScopedPredictionModelConstructionAndValidationRecorder() {
  83. base::UmaHistogramBoolean("OptimizationGuide.IsPredictionModelValid",
  84. is_valid_);
  85. base::UmaHistogramBoolean(
  86. "OptimizationGuide.IsPredictionModelValid." +
  87. GetStringNameForOptimizationTarget(optimization_target_),
  88. is_valid_);
  89. // Only record the timing if the model is valid and was able to be
  90. // constructed.
  91. if (is_valid_) {
  92. base::TimeDelta validation_latency =
  93. base::TimeTicks::Now() - validation_start_time_;
  94. base::UmaHistogramTimes(
  95. "OptimizationGuide.PredictionModelValidationLatency",
  96. validation_latency);
  97. base::UmaHistogramTimes(
  98. "OptimizationGuide.PredictionModelValidationLatency." +
  99. GetStringNameForOptimizationTarget(optimization_target_),
  100. validation_latency);
  101. }
  102. }
  103. void set_is_valid(bool is_valid) { is_valid_ = is_valid; }
  104. private:
  105. bool is_valid_ = true;
  106. const base::TimeTicks validation_start_time_;
  107. const proto::OptimizationTarget optimization_target_;
  108. };
  109. void RecordModelUpdateVersion(const proto::ModelInfo& model_info) {
  110. base::UmaHistogramSparse(
  111. "OptimizationGuide.PredictionModelUpdateVersion." +
  112. GetStringNameForOptimizationTarget(model_info.optimization_target()),
  113. model_info.version());
  114. }
  115. void RecordLifecycleState(proto::OptimizationTarget optimization_target,
  116. ModelDeliveryEvent event) {
  117. base::UmaHistogramEnumeration(
  118. "OptimizationGuide.PredictionManager.ModelDeliveryEvents." +
  119. GetStringNameForOptimizationTarget(optimization_target),
  120. event);
  121. }
  122. // Returns whether models should be fetched from the
  123. // remote Optimization Guide Service.
  124. bool ShouldFetchModels(bool off_the_record, bool component_updates_enabled) {
  125. return features::IsRemoteFetchingEnabled() && !off_the_record &&
  126. features::IsModelDownloadingEnabled() && component_updates_enabled;
  127. }
  128. // Returns whether the model metadata proto is on the server allowlist.
  129. bool IsModelMetadataTypeOnServerAllowlist(const proto::Any& model_metadata) {
  130. return model_metadata.type_url() ==
  131. "type.googleapis.com/"
  132. "google.internal.chrome.optimizationguide.v1."
  133. "PageEntitiesModelMetadata" ||
  134. model_metadata.type_url() ==
  135. "type.googleapis.com/"
  136. "google.internal.chrome.optimizationguide.v1."
  137. "PageTopicsModelMetadata" ||
  138. model_metadata.type_url() ==
  139. "type.googleapis.com/"
  140. "google.internal.chrome.optimizationguide.v1."
  141. "SegmentationModelMetadata" ||
  142. model_metadata.type_url() ==
  143. "type.googleapis.com/"
  144. "google.privacy.webpermissionpredictions.v1."
  145. "WebPermissionPredictionsModelMetadata";
  146. }
  147. void RecordModelAvailableAtRegistration(
  148. proto::OptimizationTarget optimization_target,
  149. bool model_available_at_registration) {
  150. base::UmaHistogramBoolean(
  151. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration." +
  152. GetStringNameForOptimizationTarget(optimization_target),
  153. model_available_at_registration);
  154. }
  155. } // namespace
  156. PredictionManager::PredictionManager(
  157. base::WeakPtr<OptimizationGuideStore> model_and_features_store,
  158. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  159. PrefService* pref_service,
  160. bool off_the_record,
  161. const std::string& application_locale,
  162. const base::FilePath& models_dir_path,
  163. OptimizationGuideLogger* optimization_guide_logger,
  164. BackgroundDownloadServiceProvider background_download_service_provider,
  165. ComponentUpdatesEnabledProvider component_updates_enabled_provider)
  166. : prediction_model_download_manager_(nullptr),
  167. model_and_features_store_(model_and_features_store),
  168. url_loader_factory_(url_loader_factory),
  169. optimization_guide_logger_(optimization_guide_logger),
  170. pref_service_(pref_service),
  171. component_updates_enabled_provider_(component_updates_enabled_provider),
  172. clock_(base::DefaultClock::GetInstance()),
  173. off_the_record_(off_the_record),
  174. application_locale_(application_locale),
  175. models_dir_path_(models_dir_path) {
  176. Initialize(std::move(background_download_service_provider));
  177. }
  178. PredictionManager::~PredictionManager() {
  179. if (prediction_model_download_manager_)
  180. prediction_model_download_manager_->RemoveObserver(this);
  181. }
  182. void PredictionManager::Initialize(
  183. BackgroundDownloadServiceProvider background_download_service_provider) {
  184. if (model_and_features_store_) {
  185. model_and_features_store_->Initialize(
  186. switches::ShouldPurgeModelAndFeaturesStoreOnStartup(),
  187. base::BindOnce(&PredictionManager::OnStoreInitialized,
  188. ui_weak_ptr_factory_.GetWeakPtr(),
  189. std::move(background_download_service_provider)));
  190. }
  191. }
  192. void PredictionManager::AddObserverForOptimizationTargetModel(
  193. proto::OptimizationTarget optimization_target,
  194. const absl::optional<proto::Any>& model_metadata,
  195. OptimizationTargetModelObserver* observer) {
  196. DCHECK(registered_observers_for_optimization_targets_.find(
  197. optimization_target) ==
  198. registered_observers_for_optimization_targets_.end());
  199. // As DCHECKS don't run in the wild, just do not register the observer if
  200. // something is already registered for the type. Otherwise, file reads may
  201. // blow up.
  202. if (registered_observers_for_optimization_targets_.find(
  203. optimization_target) !=
  204. registered_observers_for_optimization_targets_.end()) {
  205. DLOG(ERROR) << "Did not add observer for optimization target "
  206. << static_cast<int>(optimization_target)
  207. << " since an observer for the target was already registered ";
  208. return;
  209. }
  210. registered_observers_for_optimization_targets_[optimization_target]
  211. .AddObserver(observer);
  212. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  213. OPTIMIZATION_GUIDE_LOGGER(
  214. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  215. optimization_guide_logger_)
  216. << "Observer added for OptimizationTarget: " << optimization_target;
  217. }
  218. // Notify observer of existing model file path.
  219. auto model_it = optimization_target_model_info_map_.find(optimization_target);
  220. if (model_it != optimization_target_model_info_map_.end()) {
  221. observer->OnModelUpdated(optimization_target, *model_it->second);
  222. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  223. OPTIMIZATION_GUIDE_LOGGER(
  224. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  225. optimization_guide_logger_)
  226. << "OnModelFileUpdated for OptimizationTarget: "
  227. << optimization_target << "\nFile path: "
  228. << (*model_it->second).GetModelFilePath().AsUTF8Unsafe()
  229. << "\nHas metadata: " << (model_metadata ? "True" : "False");
  230. }
  231. RecordLifecycleState(optimization_target,
  232. ModelDeliveryEvent::kModelDeliveredAtRegistration);
  233. }
  234. base::UmaHistogramMediumTimes(
  235. "OptimizationGuide.PredictionManager.RegistrationTimeSinceServiceInit." +
  236. GetStringNameForOptimizationTarget(optimization_target),
  237. !init_time_.is_null() ? base::TimeTicks::Now() - init_time_
  238. : base::TimeDelta());
  239. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  240. if (registered_optimization_targets_and_metadata_.contains(
  241. optimization_target)) {
  242. return;
  243. }
  244. DCHECK(!model_metadata ||
  245. IsModelMetadataTypeOnServerAllowlist(*model_metadata));
  246. registered_optimization_targets_and_metadata_.emplace(optimization_target,
  247. model_metadata);
  248. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  249. OPTIMIZATION_GUIDE_LOGGER(
  250. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  251. optimization_guide_logger_)
  252. << "Registered new OptimizationTarget: " << optimization_target;
  253. }
  254. // Before loading/fetching models and features, the store must be ready.
  255. if (!store_is_ready_)
  256. return;
  257. // If no fetch is scheduled, maybe schedule one.
  258. if (!fetch_timer_.IsRunning())
  259. MaybeScheduleFirstModelFetch();
  260. // Otherwise, load prediction models for any newly registered targets.
  261. LoadPredictionModels({optimization_target});
  262. }
  263. void PredictionManager::RemoveObserverForOptimizationTargetModel(
  264. proto::OptimizationTarget optimization_target,
  265. OptimizationTargetModelObserver* observer) {
  266. auto observers_it =
  267. registered_observers_for_optimization_targets_.find(optimization_target);
  268. if (observers_it == registered_observers_for_optimization_targets_.end())
  269. return;
  270. observers_it->second.RemoveObserver(observer);
  271. }
  272. base::flat_set<proto::OptimizationTarget>
  273. PredictionManager::GetRegisteredOptimizationTargets() const {
  274. base::flat_set<proto::OptimizationTarget> optimization_targets;
  275. for (const auto& optimization_target_and_metadata :
  276. registered_optimization_targets_and_metadata_) {
  277. optimization_targets.insert(optimization_target_and_metadata.first);
  278. }
  279. return optimization_targets;
  280. }
  281. void PredictionManager::SetPredictionModelFetcherForTesting(
  282. std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher) {
  283. prediction_model_fetcher_ = std::move(prediction_model_fetcher);
  284. }
  285. void PredictionManager::SetPredictionModelDownloadManagerForTesting(
  286. std::unique_ptr<PredictionModelDownloadManager>
  287. prediction_model_download_manager) {
  288. prediction_model_download_manager_ =
  289. std::move(prediction_model_download_manager);
  290. }
  291. void PredictionManager::FetchModels(bool is_first_model_fetch) {
  292. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  293. // The histogram that gets recorded here is used for integration tests that
  294. // pass in a model override. For simplicity, we place the recording of this
  295. // histogram here rather than somewhere else earlier in the session
  296. // initialization flow since the model engine version needs to continuously be
  297. // updated for the fetch.
  298. proto::ModelInfo base_model_info;
  299. // There should only be one supported model engine version at a time.
  300. base_model_info.add_supported_model_engine_versions(
  301. proto::MODEL_ENGINE_VERSION_TFLITE_2_11);
  302. // This histogram is used for integration tests. Do not remove.
  303. // Update this to be 10000 if/when we exceed 100 model engine versions.
  304. LOCAL_HISTOGRAM_COUNTS_100(
  305. "OptimizationGuide.PredictionManager.SupportedModelEngineVersion",
  306. static_cast<int>(
  307. *base_model_info.supported_model_engine_versions().begin()));
  308. if (switches::IsModelOverridePresent())
  309. return;
  310. if (!ShouldFetchModels(off_the_record_, pref_service_))
  311. return;
  312. if (is_first_model_fetch) {
  313. DCHECK(!init_time_.is_null());
  314. base::UmaHistogramMediumTimes(
  315. "OptimizationGuide.PredictionManager.FirstModelFetchSinceServiceInit",
  316. base::TimeTicks::Now() - init_time_);
  317. }
  318. // Models should not be fetched if there are no optimization targets
  319. // registered.
  320. if (registered_optimization_targets_and_metadata_.empty())
  321. return;
  322. // We should have already created a prediction model download manager if we
  323. // initiated the fetching of models.
  324. DCHECK(prediction_model_download_manager_);
  325. if (prediction_model_download_manager_) {
  326. bool download_service_available =
  327. prediction_model_download_manager_->IsAvailableForDownloads();
  328. base::UmaHistogramBoolean(
  329. "OptimizationGuide.PredictionManager."
  330. "DownloadServiceAvailabilityBlockedFetch",
  331. !download_service_available);
  332. if (!download_service_available) {
  333. for (const auto& optimization_target_and_metadata :
  334. registered_optimization_targets_and_metadata_) {
  335. RecordLifecycleState(optimization_target_and_metadata.first,
  336. ModelDeliveryEvent::kDownloadServiceUnavailable);
  337. }
  338. // We cannot download any models from the server, so don't refresh them.
  339. return;
  340. }
  341. prediction_model_download_manager_->CancelAllPendingDownloads();
  342. }
  343. // NOTE: ALL PRECONDITIONS FOR THIS FUNCTION MUST BE CHECKED ABOVE THIS LINE.
  344. // It is assumed that if we proceed past here, that a fetch will at least be
  345. // attempted.
  346. if (!prediction_model_fetcher_) {
  347. prediction_model_fetcher_ = std::make_unique<PredictionModelFetcherImpl>(
  348. url_loader_factory_,
  349. features::GetOptimizationGuideServiceGetModelsURL());
  350. }
  351. std::vector<proto::ModelInfo> models_info = std::vector<proto::ModelInfo>();
  352. models_info.reserve(registered_optimization_targets_and_metadata_.size());
  353. // For now, we will fetch for all registered optimization targets.
  354. for (const auto& optimization_target_and_metadata :
  355. registered_optimization_targets_and_metadata_) {
  356. proto::ModelInfo model_info(base_model_info);
  357. model_info.set_optimization_target(optimization_target_and_metadata.first);
  358. if (optimization_target_and_metadata.second.has_value()) {
  359. *model_info.mutable_model_metadata() =
  360. *optimization_target_and_metadata.second;
  361. }
  362. auto model_it = optimization_target_model_info_map_.find(
  363. optimization_target_and_metadata.first);
  364. if (model_it != optimization_target_model_info_map_.end())
  365. model_info.set_version(model_it->second.get()->GetVersion());
  366. models_info.push_back(model_info);
  367. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  368. OPTIMIZATION_GUIDE_LOGGER(
  369. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  370. optimization_guide_logger_)
  371. << "Fetching models for Optimization Target "
  372. << model_info.optimization_target();
  373. }
  374. RecordLifecycleState(optimization_target_and_metadata.first,
  375. ModelDeliveryEvent::kGetModelsRequest);
  376. }
  377. bool fetch_initiated =
  378. prediction_model_fetcher_->FetchOptimizationGuideServiceModels(
  379. models_info, proto::CONTEXT_BATCH_UPDATE_MODELS, application_locale_,
  380. base::BindOnce(&PredictionManager::OnModelsFetched,
  381. ui_weak_ptr_factory_.GetWeakPtr(), models_info));
  382. if (fetch_initiated)
  383. SetLastModelFetchAttemptTime(clock_->Now());
  384. // Schedule the next fetch regardless since we may not have initiated a fetch
  385. // due to a network condition and trying in the next minute to see if that is
  386. // unblocked is only a timer firing and not an actual query to the server.
  387. ScheduleModelsFetch();
  388. }
  389. void PredictionManager::OnModelsFetched(
  390. const std::vector<proto::ModelInfo> models_request_info,
  391. absl::optional<std::unique_ptr<proto::GetModelsResponse>>
  392. get_models_response_data) {
  393. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  394. if (!get_models_response_data) {
  395. for (const auto& model_info : models_request_info) {
  396. RecordLifecycleState(model_info.optimization_target(),
  397. ModelDeliveryEvent::kGetModelsResponseFailure);
  398. }
  399. return;
  400. }
  401. SetLastModelFetchSuccessTime(clock_->Now());
  402. if ((*get_models_response_data)->models_size() > 0) {
  403. UpdatePredictionModels((*get_models_response_data)->models());
  404. }
  405. fetch_timer_.Stop();
  406. ScheduleModelsFetch();
  407. }
  408. void PredictionManager::UpdatePredictionModels(
  409. const google::protobuf::RepeatedPtrField<proto::PredictionModel>&
  410. prediction_models) {
  411. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  412. if (!model_and_features_store_)
  413. return;
  414. std::unique_ptr<StoreUpdateData> prediction_model_update_data =
  415. StoreUpdateData::CreatePredictionModelStoreUpdateData(
  416. clock_->Now() + features::StoredModelsValidDuration());
  417. bool has_models_to_update = false;
  418. for (const auto& model : prediction_models) {
  419. if (model.has_model() && !model.model().download_url().empty()) {
  420. // We should only be updating the store for on-the-record profiles and
  421. // after the store has been initialized.
  422. DCHECK(prediction_model_download_manager_);
  423. if (prediction_model_download_manager_) {
  424. GURL download_url(model.model().download_url());
  425. if (download_url.is_valid()) {
  426. prediction_model_download_manager_->StartDownload(
  427. download_url, model.model_info().optimization_target());
  428. }
  429. RecordLifecycleState(model.model_info().optimization_target(),
  430. download_url.is_valid()
  431. ? ModelDeliveryEvent::kDownloadServiceRequest
  432. : ModelDeliveryEvent::kDownloadURLInvalid);
  433. base::UmaHistogramBoolean(
  434. "OptimizationGuide.PredictionManager.IsDownloadUrlValid." +
  435. GetStringNameForOptimizationTarget(
  436. model.model_info().optimization_target()),
  437. download_url.is_valid());
  438. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  439. OPTIMIZATION_GUIDE_LOGGER(
  440. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  441. optimization_guide_logger_)
  442. << "Model download required for Optimization Target: "
  443. << model.model_info().optimization_target();
  444. }
  445. }
  446. // Skip over models that have a download URL since they will be updated
  447. // once the download has completed successfully.
  448. continue;
  449. }
  450. if (!model.has_model()) {
  451. // We already have this updated model, so don't update in store.
  452. continue;
  453. }
  454. has_models_to_update = true;
  455. // Storing the model regardless of whether the model is valid or not. Model
  456. // will be removed from store if it fails to load.
  457. prediction_model_update_data->CopyPredictionModelIntoUpdateData(model);
  458. RecordModelUpdateVersion(model.model_info());
  459. OnLoadPredictionModel(model.model_info().optimization_target(),
  460. /*record_availability_metrics=*/false,
  461. std::make_unique<proto::PredictionModel>(model));
  462. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  463. OPTIMIZATION_GUIDE_LOGGER(
  464. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  465. optimization_guide_logger_)
  466. << "Model Download Not Required for target: "
  467. << model.model_info().optimization_target() << "\nNew Version: "
  468. << base::NumberToString(model.model_info().version());
  469. }
  470. }
  471. if (has_models_to_update) {
  472. model_and_features_store_->UpdatePredictionModels(
  473. std::move(prediction_model_update_data),
  474. base::BindOnce(&PredictionManager::OnPredictionModelsStored,
  475. ui_weak_ptr_factory_.GetWeakPtr()));
  476. }
  477. }
  478. void PredictionManager::OnModelReady(const proto::PredictionModel& model) {
  479. if (switches::IsModelOverridePresent())
  480. return;
  481. if (!model_and_features_store_)
  482. return;
  483. DCHECK(model.model_info().has_version() &&
  484. model.model_info().has_optimization_target());
  485. RecordModelUpdateVersion(model.model_info());
  486. RecordLifecycleState(model.model_info().optimization_target(),
  487. ModelDeliveryEvent::kModelDownloaded);
  488. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  489. OPTIMIZATION_GUIDE_LOGGER(
  490. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  491. optimization_guide_logger_)
  492. << "Model Files Downloaded target: "
  493. << model.model_info().optimization_target()
  494. << "\nNew Version: " +
  495. base::NumberToString(model.model_info().version());
  496. }
  497. // Store the received model in the store.
  498. std::unique_ptr<StoreUpdateData> prediction_model_update_data =
  499. StoreUpdateData::CreatePredictionModelStoreUpdateData(
  500. clock_->Now() + features::StoredModelsValidDuration());
  501. prediction_model_update_data->CopyPredictionModelIntoUpdateData(model);
  502. model_and_features_store_->UpdatePredictionModels(
  503. std::move(prediction_model_update_data),
  504. base::BindOnce(&PredictionManager::OnPredictionModelsStored,
  505. ui_weak_ptr_factory_.GetWeakPtr()));
  506. if (registered_optimization_targets_and_metadata_.contains(
  507. model.model_info().optimization_target())) {
  508. OnLoadPredictionModel(model.model_info().optimization_target(),
  509. /*record_availability_metrics=*/false,
  510. std::make_unique<proto::PredictionModel>(model));
  511. }
  512. }
  513. void PredictionManager::OnModelDownloadStarted(
  514. proto::OptimizationTarget optimization_target) {
  515. RecordLifecycleState(optimization_target,
  516. ModelDeliveryEvent::kModelDownloadStarted);
  517. }
  518. void PredictionManager::OnModelDownloadFailed(
  519. proto::OptimizationTarget optimization_target) {
  520. RecordLifecycleState(optimization_target,
  521. ModelDeliveryEvent::kModelDownloadFailure);
  522. }
  523. void PredictionManager::NotifyObserversOfNewModel(
  524. proto::OptimizationTarget optimization_target,
  525. const ModelInfo& model_info) {
  526. auto observers_it =
  527. registered_observers_for_optimization_targets_.find(optimization_target);
  528. if (observers_it == registered_observers_for_optimization_targets_.end())
  529. return;
  530. RecordLifecycleState(optimization_target,
  531. ModelDeliveryEvent::kModelDelivered);
  532. for (auto& observer : observers_it->second) {
  533. observer.OnModelUpdated(optimization_target, model_info);
  534. if (optimization_guide_logger_->ShouldEnableDebugLogs()) {
  535. OPTIMIZATION_GUIDE_LOGGER(
  536. optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT,
  537. optimization_guide_logger_)
  538. << "OnModelFileUpdated for target: " << optimization_target
  539. << "\nFile path: " << model_info.GetModelFilePath().AsUTF8Unsafe()
  540. << "\nHas metadata: "
  541. << (model_info.GetModelMetadata() ? "True" : "False");
  542. }
  543. }
  544. }
  545. void PredictionManager::OnPredictionModelsStored() {
  546. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  547. LOCAL_HISTOGRAM_BOOLEAN(
  548. "OptimizationGuide.PredictionManager.PredictionModelsStored", true);
  549. }
  550. void PredictionManager::OnStoreInitialized(
  551. BackgroundDownloadServiceProvider background_download_service_provider) {
  552. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  553. store_is_ready_ = true;
  554. init_time_ = base::TimeTicks::Now();
  555. LOCAL_HISTOGRAM_BOOLEAN(
  556. "OptimizationGuide.PredictionManager.StoreInitialized", true);
  557. // Create the download manager here if we are allowed to.
  558. if (features::IsModelDownloadingEnabled() && !off_the_record_ &&
  559. !prediction_model_download_manager_) {
  560. prediction_model_download_manager_ =
  561. std::make_unique<PredictionModelDownloadManager>(
  562. background_download_service_provider
  563. ? std::move(background_download_service_provider).Run()
  564. : nullptr,
  565. models_dir_path_,
  566. base::ThreadPool::CreateSequencedTaskRunner(
  567. {base::MayBlock(), base::TaskPriority::BEST_EFFORT,
  568. base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN}));
  569. prediction_model_download_manager_->AddObserver(this);
  570. }
  571. // Purge any inactive models from the store.
  572. model_and_features_store_->PurgeInactiveModels();
  573. // Only load models if there are optimization targets registered.
  574. if (registered_optimization_targets_and_metadata_.empty())
  575. return;
  576. // The store is ready so start loading models for the registered optimization
  577. // targets.
  578. LoadPredictionModels(GetRegisteredOptimizationTargets());
  579. MaybeScheduleFirstModelFetch();
  580. }
  581. void PredictionManager::OnPredictionModelOverrideLoaded(
  582. proto::OptimizationTarget optimization_target,
  583. std::unique_ptr<proto::PredictionModel> prediction_model) {
  584. OnLoadPredictionModel(optimization_target,
  585. /*record_availability_metrics=*/false,
  586. std::move(prediction_model));
  587. RecordModelAvailableAtRegistration(optimization_target,
  588. prediction_model != nullptr);
  589. }
  590. void PredictionManager::LoadPredictionModels(
  591. const base::flat_set<proto::OptimizationTarget>& optimization_targets) {
  592. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  593. if (switches::IsModelOverridePresent()) {
  594. for (proto::OptimizationTarget optimization_target : optimization_targets) {
  595. BuildPredictionModelFromCommandLineForOptimizationTarget(
  596. optimization_target,
  597. base::BindOnce(&PredictionManager::OnPredictionModelOverrideLoaded,
  598. ui_weak_ptr_factory_.GetWeakPtr(),
  599. optimization_target));
  600. }
  601. return;
  602. }
  603. if (!model_and_features_store_)
  604. return;
  605. OptimizationGuideStore::EntryKey model_entry_key;
  606. for (const auto& optimization_target : optimization_targets) {
  607. // The prediction model for this optimization target has already been
  608. // loaded.
  609. bool model_stored_locally =
  610. model_and_features_store_->FindPredictionModelEntryKey(
  611. optimization_target, &model_entry_key);
  612. if (!model_stored_locally) {
  613. RecordModelAvailableAtRegistration(optimization_target,
  614. model_stored_locally);
  615. continue;
  616. }
  617. model_and_features_store_->LoadPredictionModel(
  618. model_entry_key,
  619. base::BindOnce(&PredictionManager::OnLoadPredictionModel,
  620. ui_weak_ptr_factory_.GetWeakPtr(), optimization_target,
  621. /*record_availability_metrics=*/true));
  622. }
  623. }
  624. void PredictionManager::OnLoadPredictionModel(
  625. proto::OptimizationTarget optimization_target,
  626. bool record_availability_metrics,
  627. std::unique_ptr<proto::PredictionModel> model) {
  628. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  629. if (!model) {
  630. if (record_availability_metrics) {
  631. RecordModelAvailableAtRegistration(optimization_target, false);
  632. }
  633. return;
  634. }
  635. bool success = ProcessAndStoreLoadedModel(*model);
  636. DCHECK_EQ(optimization_target, model->model_info().optimization_target());
  637. if (record_availability_metrics)
  638. RecordModelAvailableAtRegistration(optimization_target, success);
  639. OnProcessLoadedModel(*model, success);
  640. }
  641. void PredictionManager::OnProcessLoadedModel(
  642. const proto::PredictionModel& model,
  643. bool success) {
  644. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  645. if (success) {
  646. base::UmaHistogramSparse("OptimizationGuide.PredictionModelLoadedVersion." +
  647. GetStringNameForOptimizationTarget(
  648. model.model_info().optimization_target()),
  649. model.model_info().version());
  650. return;
  651. }
  652. // Remove model from store if it exists.
  653. OptimizationGuideStore::EntryKey model_entry_key;
  654. if (model_and_features_store_ &&
  655. model_and_features_store_->FindPredictionModelEntryKey(
  656. model.model_info().optimization_target(), &model_entry_key)) {
  657. LOCAL_HISTOGRAM_BOOLEAN("OptimizationGuide.PredictionModelRemoved." +
  658. GetStringNameForOptimizationTarget(
  659. model.model_info().optimization_target()),
  660. true);
  661. model_and_features_store_->RemovePredictionModelFromEntryKey(
  662. model_entry_key);
  663. }
  664. }
  665. bool PredictionManager::ProcessAndStoreLoadedModel(
  666. const proto::PredictionModel& model) {
  667. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  668. if (!model.model_info().has_optimization_target())
  669. return false;
  670. if (!model.model_info().has_version())
  671. return false;
  672. if (!model.has_model())
  673. return false;
  674. if (!registered_optimization_targets_and_metadata_.contains(
  675. model.model_info().optimization_target())) {
  676. return false;
  677. }
  678. ScopedPredictionModelConstructionAndValidationRecorder
  679. prediction_model_recorder(model.model_info().optimization_target());
  680. std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model);
  681. if (!model_info) {
  682. prediction_model_recorder.set_is_valid(false);
  683. return false;
  684. }
  685. proto::OptimizationTarget optimization_target =
  686. model.model_info().optimization_target();
  687. // See if we should update the loaded model.
  688. if (!ShouldUpdateStoredModelForTarget(optimization_target,
  689. model.model_info().version())) {
  690. return true;
  691. }
  692. // Update prediction model file if that is what we have loaded.
  693. if (model_info) {
  694. StoreLoadedModelInfo(optimization_target, std::move(model_info));
  695. }
  696. return true;
  697. }
  698. bool PredictionManager::ShouldUpdateStoredModelForTarget(
  699. proto::OptimizationTarget optimization_target,
  700. int64_t new_version) const {
  701. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  702. auto model_meta_it =
  703. optimization_target_model_info_map_.find(optimization_target);
  704. if (model_meta_it != optimization_target_model_info_map_.end())
  705. return model_meta_it->second->GetVersion() != new_version;
  706. return true;
  707. }
  708. void PredictionManager::StoreLoadedModelInfo(
  709. proto::OptimizationTarget optimization_target,
  710. std::unique_ptr<ModelInfo> model_info) {
  711. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  712. DCHECK(model_info);
  713. // Notify observers of new model file path.
  714. base::ThreadTaskRunnerHandle::Get()->PostTask(
  715. FROM_HERE, base::BindOnce(&PredictionManager::NotifyObserversOfNewModel,
  716. ui_weak_ptr_factory_.GetWeakPtr(),
  717. optimization_target, *model_info));
  718. optimization_target_model_info_map_.insert_or_assign(optimization_target,
  719. std::move(model_info));
  720. }
  721. void PredictionManager::MaybeScheduleFirstModelFetch() {
  722. if (!ShouldFetchModels(off_the_record_,
  723. component_updates_enabled_provider_.Run()))
  724. return;
  725. // Add a slight delay to allow the rest of the browser startup process to
  726. // finish up.
  727. fetch_timer_.Start(FROM_HERE, features::PredictionModelFetchStartupDelay(),
  728. base::BindOnce(&PredictionManager::FetchModels,
  729. base::Unretained(this), true));
  730. }
  731. base::Time PredictionManager::GetLastFetchAttemptTime() const {
  732. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  733. return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds(
  734. pref_service_->GetInt64(prefs::kModelAndFeaturesLastFetchAttempt)));
  735. }
  736. base::Time PredictionManager::GetLastFetchSuccessTime() const {
  737. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  738. return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds(
  739. pref_service_->GetInt64(prefs::kModelLastFetchSuccess)));
  740. }
  741. void PredictionManager::ScheduleModelsFetch() {
  742. DCHECK(!fetch_timer_.IsRunning());
  743. DCHECK(store_is_ready_);
  744. const base::TimeDelta time_until_update_time =
  745. GetLastFetchSuccessTime() + features::PredictionModelFetchInterval() -
  746. clock_->Now();
  747. const base::TimeDelta time_until_retry =
  748. GetLastFetchAttemptTime() + features::PredictionModelFetchRetryDelay() -
  749. clock_->Now();
  750. base::TimeDelta fetcher_delay =
  751. std::max(time_until_update_time, time_until_retry);
  752. if (fetcher_delay <= base::TimeDelta()) {
  753. fetch_timer_.Start(FROM_HERE, RandomFetchDelay(),
  754. base::BindOnce(&PredictionManager::FetchModels,
  755. base::Unretained(this), false));
  756. return;
  757. }
  758. fetch_timer_.Start(FROM_HERE, fetcher_delay, this,
  759. &PredictionManager::ScheduleModelsFetch);
  760. }
  761. void PredictionManager::SetLastModelFetchAttemptTime(
  762. base::Time last_attempt_time) {
  763. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  764. pref_service_->SetInt64(
  765. prefs::kModelAndFeaturesLastFetchAttempt,
  766. last_attempt_time.ToDeltaSinceWindowsEpoch().InMicroseconds());
  767. }
  768. void PredictionManager::SetLastModelFetchSuccessTime(
  769. base::Time last_success_time) {
  770. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  771. pref_service_->SetInt64(
  772. prefs::kModelLastFetchSuccess,
  773. last_success_time.ToDeltaSinceWindowsEpoch().InMicroseconds());
  774. }
  775. void PredictionManager::SetClockForTesting(const base::Clock* clock) {
  776. clock_ = clock;
  777. }
  778. void PredictionManager::OverrideTargetModelForTesting(
  779. proto::OptimizationTarget optimization_target,
  780. std::unique_ptr<ModelInfo> model_info) {
  781. if (!model_info) {
  782. return;
  783. }
  784. ModelInfo model_info_copy = *model_info;
  785. optimization_target_model_info_map_.insert_or_assign(optimization_target,
  786. std::move(model_info));
  787. NotifyObserversOfNewModel(optimization_target, model_info_copy);
  788. }
  789. } // namespace optimization_guide