prediction_manager_unittest.cc 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/prediction_manager.h"
  5. #include <map>
  6. #include <memory>
  7. #include <string>
  8. #include <utility>
  9. #include "base/base64.h"
  10. #include "base/command_line.h"
  11. #include "base/strings/string_number_conversions.h"
  12. #include "base/strings/stringprintf.h"
  13. #include "base/test/gtest_util.h"
  14. #include "base/test/metrics/histogram_tester.h"
  15. #include "base/test/scoped_feature_list.h"
  16. #include "base/test/task_environment.h"
  17. #include "base/time/time.h"
  18. #include "build/build_config.h"
  19. #include "build/chromeos_buildflags.h"
  20. #include "components/leveldb_proto/testing/fake_db.h"
  21. #include "components/optimization_guide/core/model_util.h"
  22. #include "components/optimization_guide/core/optimization_guide_features.h"
  23. #include "components/optimization_guide/core/optimization_guide_logger.h"
  24. #include "components/optimization_guide/core/optimization_guide_prefs.h"
  25. #include "components/optimization_guide/core/optimization_guide_store.h"
  26. #include "components/optimization_guide/core/optimization_guide_switches.h"
  27. #include "components/optimization_guide/core/optimization_guide_test_util.h"
  28. #include "components/optimization_guide/core/optimization_guide_util.h"
  29. #include "components/optimization_guide/core/optimization_target_model_observer.h"
  30. #include "components/optimization_guide/core/prediction_model_download_manager.h"
  31. #include "components/optimization_guide/core/prediction_model_fetcher.h"
  32. #include "components/optimization_guide/core/prediction_model_fetcher_impl.h"
  33. #include "components/optimization_guide/core/proto_database_provider_test_base.h"
  34. #include "components/optimization_guide/proto/hint_cache.pb.h"
  35. #include "components/optimization_guide/proto/models.pb.h"
  36. #include "components/prefs/testing_pref_service.h"
  37. #include "components/variations/scoped_variations_ids_provider.h"
  38. #include "services/network/public/cpp/shared_url_loader_factory.h"
  39. #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
  40. #include "services/network/test/test_network_connection_tracker.h"
  41. #include "services/network/test/test_url_loader_factory.h"
  42. #include "testing/gtest/include/gtest/gtest.h"
  43. #include "ui/base/page_transition_types.h"
  44. using leveldb_proto::test::FakeDB;
  45. namespace {
  46. // Retry delay is 2 minutes to allow for fetch retry delay + some random delay
  47. // to pass.
  48. constexpr int kTestFetchRetryDelaySecs = 60 * 2 + 62;
  49. // 24 hours + random fetch delay.
  50. constexpr int kUpdateFetchModelAndFeaturesTimeSecs = 24 * 60 * 60 + 62;
  51. } // namespace
  52. namespace optimization_guide {
  53. proto::PredictionModel CreatePredictionModel() {
  54. proto::PredictionModel prediction_model;
  55. proto::ModelInfo* model_info = prediction_model.mutable_model_info();
  56. model_info->set_version(1);
  57. model_info->set_optimization_target(
  58. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  59. model_info->add_supported_model_engine_versions(
  60. proto::ModelEngineVersion::MODEL_ENGINE_VERSION_TFLITE_2_8);
  61. prediction_model.mutable_model()->set_download_url(
  62. "https://example.com/model");
  63. return prediction_model;
  64. }
  65. std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse() {
  66. std::unique_ptr<proto::GetModelsResponse> get_models_response =
  67. std::make_unique<proto::GetModelsResponse>();
  68. proto::PredictionModel prediction_model = CreatePredictionModel();
  69. prediction_model.mutable_model_info()->set_version(2);
  70. *get_models_response->add_models() = std::move(prediction_model);
  71. return get_models_response;
  72. }
  73. class FakeOptimizationTargetModelObserver
  74. : public OptimizationTargetModelObserver {
  75. public:
  76. void OnModelUpdated(proto::OptimizationTarget optimization_target,
  77. const ModelInfo& model_info) override {
  78. last_received_models_.insert_or_assign(optimization_target, model_info);
  79. }
  80. absl::optional<ModelInfo> last_received_model_for_target(
  81. proto::OptimizationTarget optimization_target) const {
  82. auto model_it = last_received_models_.find(optimization_target);
  83. if (model_it == last_received_models_.end())
  84. return absl::nullopt;
  85. return model_it->second;
  86. }
  87. // Resets the state of the observer.
  88. void Reset() { last_received_models_.clear(); }
  89. private:
  90. base::flat_map<proto::OptimizationTarget, ModelInfo> last_received_models_;
  91. };
  92. class FakePredictionModelDownloadManager
  93. : public PredictionModelDownloadManager {
  94. public:
  95. explicit FakePredictionModelDownloadManager(
  96. const base::FilePath& models_dir_path,
  97. scoped_refptr<base::SequencedTaskRunner> task_runner)
  98. : PredictionModelDownloadManager(/*download_service=*/nullptr,
  99. models_dir_path,
  100. task_runner) {}
  101. ~FakePredictionModelDownloadManager() override = default;
  102. void StartDownload(const GURL& url,
  103. proto::OptimizationTarget optimization_target) override {
  104. last_requested_download_ = url;
  105. last_requested_optimization_target_ = optimization_target;
  106. }
  107. GURL last_requested_download() const { return last_requested_download_; }
  108. proto::OptimizationTarget last_requested_optimization_target() const {
  109. return last_requested_optimization_target_;
  110. }
  111. void CancelAllPendingDownloads() override { cancel_downloads_called_ = true; }
  112. bool cancel_downloads_called() const { return cancel_downloads_called_; }
  113. bool IsAvailableForDownloads() const override { return is_available_; }
  114. void SetAvailableForDownloads(bool is_available) {
  115. is_available_ = is_available;
  116. }
  117. private:
  118. GURL last_requested_download_;
  119. proto::OptimizationTarget last_requested_optimization_target_;
  120. bool cancel_downloads_called_ = false;
  121. bool is_available_ = true;
  122. };
  123. enum class PredictionModelFetcherEndState {
  124. kFetchFailed = 0,
  125. kFetchSuccessWithModels = 1,
  126. kFetchSuccessWithEmptyResponse = 2,
  127. };
  128. void RunGetModelsCallback(
  129. ModelsFetchedCallback callback,
  130. std::unique_ptr<proto::GetModelsResponse> get_models_response) {
  131. if (get_models_response) {
  132. std::move(callback).Run(std::move(get_models_response));
  133. return;
  134. }
  135. std::move(callback).Run(absl::nullopt);
  136. }
  137. // A mock class implementation of PredictionModelFetcherImpl.
  138. class TestPredictionModelFetcher : public PredictionModelFetcherImpl {
  139. public:
  140. TestPredictionModelFetcher(
  141. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  142. const GURL& optimization_guide_service_get_models_url,
  143. PredictionModelFetcherEndState fetch_state)
  144. : PredictionModelFetcherImpl(url_loader_factory,
  145. optimization_guide_service_get_models_url),
  146. fetch_state_(fetch_state) {}
  147. bool FetchOptimizationGuideServiceModels(
  148. const std::vector<proto::ModelInfo>& models_request_info,
  149. proto::RequestContext request_context,
  150. const std::string& locale,
  151. ModelsFetchedCallback models_fetched_callback) override {
  152. if (!ValidateModelsInfoForFetch(models_request_info)) {
  153. std::move(models_fetched_callback).Run(absl::nullopt);
  154. return false;
  155. }
  156. std::unique_ptr<proto::GetModelsResponse> get_models_response;
  157. locale_requested_ = locale;
  158. switch (fetch_state_) {
  159. case PredictionModelFetcherEndState::kFetchFailed:
  160. get_models_response = nullptr;
  161. break;
  162. case PredictionModelFetcherEndState::kFetchSuccessWithModels:
  163. models_fetched_ = true;
  164. get_models_response = BuildGetModelsResponse();
  165. break;
  166. case PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse:
  167. models_fetched_ = true;
  168. get_models_response = std::make_unique<proto::GetModelsResponse>();
  169. break;
  170. }
  171. base::ThreadTaskRunnerHandle::Get()->PostTask(
  172. FROM_HERE, base::BindOnce(&RunGetModelsCallback,
  173. std::move(models_fetched_callback),
  174. std::move(get_models_response)));
  175. return true;
  176. }
  177. bool ValidateModelsInfoForFetch(
  178. const std::vector<proto::ModelInfo>& models_request_info) {
  179. for (const auto& model_info : models_request_info) {
  180. if (model_info.supported_model_engine_versions_size() == 0 ||
  181. !proto::ModelEngineVersion_IsValid(
  182. model_info.supported_model_engine_versions(0))) {
  183. return false;
  184. }
  185. if (!model_info.has_optimization_target() ||
  186. !proto::OptimizationTarget_IsValid(
  187. model_info.optimization_target())) {
  188. return false;
  189. }
  190. if (check_expected_version_) {
  191. auto version_it =
  192. expected_version_.find(model_info.optimization_target());
  193. if (model_info.has_version() !=
  194. (version_it != expected_version_.end())) {
  195. return false;
  196. }
  197. if (model_info.has_version() &&
  198. model_info.version() != version_it->second) {
  199. return false;
  200. }
  201. }
  202. auto it = expected_metadata_.find(model_info.optimization_target());
  203. if (model_info.has_model_metadata() != (it != expected_metadata_.end()))
  204. return false;
  205. if (model_info.has_model_metadata()) {
  206. proto::Any expected_metadata = it->second;
  207. if (model_info.model_metadata().type_url() !=
  208. expected_metadata.type_url()) {
  209. return false;
  210. }
  211. if (model_info.model_metadata().value() != expected_metadata.value())
  212. return false;
  213. }
  214. }
  215. return true;
  216. }
  217. void SetExpectedModelMetadataForOptimizationTarget(
  218. proto::OptimizationTarget optimization_target,
  219. const proto::Any& model_metadata) {
  220. expected_metadata_[optimization_target] = model_metadata;
  221. }
  222. void SetExpectedVersionForOptimizationTarget(
  223. proto::OptimizationTarget optimization_target,
  224. int64_t version) {
  225. expected_version_[optimization_target] = version;
  226. }
  227. void SetCheckExpectedVersion() { check_expected_version_ = true; }
  228. void Reset() { models_fetched_ = false; }
  229. bool models_fetched() const { return models_fetched_; }
  230. std::string locale_requested() const { return locale_requested_; }
  231. private:
  232. bool models_fetched_ = false;
  233. bool check_expected_version_ = false;
  234. std::string locale_requested_;
  235. // The desired behavior of the TestPredictionModelFetcher.
  236. PredictionModelFetcherEndState fetch_state_;
  237. base::flat_map<proto::OptimizationTarget, proto::Any> expected_metadata_;
  238. base::flat_map<proto::OptimizationTarget, int64_t> expected_version_;
  239. };
  240. class TestOptimizationGuideStore : public OptimizationGuideStore {
  241. public:
  242. TestOptimizationGuideStore(
  243. std::unique_ptr<StoreEntryProtoDatabase> database,
  244. scoped_refptr<base::SequencedTaskRunner> store_task_runner)
  245. : OptimizationGuideStore(std::move(database),
  246. store_task_runner,
  247. nullptr) {}
  248. ~TestOptimizationGuideStore() override = default;
  249. void Initialize(bool purge_existing_data,
  250. base::OnceClosure callback) override {
  251. init_callback_ = std::move(callback);
  252. status_ = Status::kAvailable;
  253. }
  254. void RunInitCallback(bool load_models = true,
  255. bool have_models_in_store = true) {
  256. load_models_ = load_models;
  257. have_models_in_store_ = have_models_in_store;
  258. std::move(init_callback_).Run();
  259. }
  260. void LoadPredictionModel(const EntryKey& prediction_model_entry_key,
  261. PredictionModelLoadedCallback callback) override {
  262. model_loaded_ = true;
  263. if (load_models_) {
  264. std::move(callback).Run(
  265. std::make_unique<proto::PredictionModel>(CreatePredictionModel()));
  266. } else {
  267. std::move(callback).Run(nullptr);
  268. }
  269. }
  270. bool FindPredictionModelEntryKey(
  271. proto::OptimizationTarget optimization_target,
  272. OptimizationGuideStore::EntryKey* out_prediction_model_entry_key)
  273. override {
  274. if (optimization_target ==
  275. proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) {
  276. return false;
  277. }
  278. if (have_models_in_store_) {
  279. *out_prediction_model_entry_key =
  280. "4_" + base::NumberToString(static_cast<int>(optimization_target));
  281. return true;
  282. }
  283. return false;
  284. }
  285. void UpdatePredictionModels(
  286. std::unique_ptr<StoreUpdateData> prediction_models_update_data,
  287. base::OnceClosure callback) override {
  288. std::move(callback).Run();
  289. }
  290. bool WasModelLoaded() const { return model_loaded_; }
  291. private:
  292. base::OnceClosure init_callback_;
  293. bool model_loaded_ = false;
  294. bool load_models_ = true;
  295. bool have_models_in_store_ = true;
  296. };
  297. class TestPredictionManager : public PredictionManager {
  298. public:
  299. TestPredictionManager(
  300. base::WeakPtr<OptimizationGuideStore> model_and_features_store,
  301. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  302. PrefService* pref_service,
  303. ComponentUpdatesEnabledProvider component_updates_enabled_provider,
  304. bool off_the_record,
  305. const std::string& application_locale,
  306. const base::FilePath& models_dir_path)
  307. : PredictionManager(
  308. model_and_features_store,
  309. url_loader_factory,
  310. pref_service,
  311. off_the_record,
  312. application_locale,
  313. models_dir_path,
  314. &optimization_guide_logger_,
  315. /*background_download_service_provider=*/
  316. base::OnceCallback<download::BackgroundDownloadService*()>(),
  317. component_updates_enabled_provider) {}
  318. ~TestPredictionManager() override = default;
  319. private:
  320. OptimizationGuideLogger optimization_guide_logger_;
  321. };
  322. class PredictionManagerTestBase : public ProtoDatabaseProviderTestBase {
  323. public:
  324. using StoreEntry = proto::StoreEntry;
  325. using StoreEntryMap = std::map<OptimizationGuideStore::EntryKey, StoreEntry>;
  326. PredictionManagerTestBase() = default;
  327. ~PredictionManagerTestBase() override = default;
  328. PredictionManagerTestBase(const PredictionManagerTestBase&) = delete;
  329. PredictionManagerTestBase& operator=(const PredictionManagerTestBase&) =
  330. delete;
  331. void SetUp() override {
  332. ProtoDatabaseProviderTestBase::SetUp();
  333. pref_service_ = std::make_unique<TestingPrefServiceSimple>();
  334. prefs::RegisterProfilePrefs(pref_service_->registry());
  335. url_loader_factory_ =
  336. base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
  337. &test_url_loader_factory_);
  338. base::CommandLine::ForCurrentProcess()->AppendSwitch(
  339. switches::kDisableCheckingUserPermissionsForTesting);
  340. }
  341. void CreatePredictionManager() {
  342. if (prediction_manager_) {
  343. db_store_.clear();
  344. model_and_features_store_.reset();
  345. prediction_manager_.reset();
  346. }
  347. model_and_features_store_ = CreateModelAndHostModelFeaturesStore();
  348. prediction_manager_ = std::make_unique<TestPredictionManager>(
  349. model_and_features_store_->AsWeakPtr(), url_loader_factory_,
  350. pref_service_.get(),
  351. base::BindRepeating(
  352. &PredictionManagerTestBase::AreComponentUpdatesEnabled,
  353. base::Unretained(this)),
  354. false, "en-US", temp_dir());
  355. prediction_manager_->SetClockForTesting(task_environment_.GetMockClock());
  356. }
  357. std::unique_ptr<TestOptimizationGuideStore>
  358. CreateModelAndHostModelFeaturesStore() {
  359. // Setup the fake db and the class under test.
  360. auto db = std::make_unique<FakeDB<StoreEntry>>(&db_store_);
  361. return std::make_unique<TestOptimizationGuideStore>(
  362. std::move(db), task_environment_.GetMainThreadTaskRunner());
  363. }
  364. TestPredictionManager* prediction_manager() const {
  365. return prediction_manager_.get();
  366. }
  367. void TearDown() override { ProtoDatabaseProviderTestBase::TearDown(); }
  368. std::unique_ptr<TestPredictionModelFetcher> BuildTestPredictionModelFetcher(
  369. PredictionModelFetcherEndState end_state) {
  370. std::unique_ptr<TestPredictionModelFetcher> prediction_model_fetcher =
  371. std::make_unique<TestPredictionModelFetcher>(
  372. url_loader_factory_, GURL("https://hintsserver.com"), end_state);
  373. return prediction_model_fetcher;
  374. }
  375. void SetStoreInitialized(bool load_models = true,
  376. bool have_models_in_store = true) {
  377. models_and_features_store()->RunInitCallback(load_models,
  378. have_models_in_store);
  379. RunUntilIdle();
  380. // Move clock forward for any short delays added for the fetcher, until the
  381. // startup fetch could start.
  382. MoveClockForwardBy(base::Seconds(12));
  383. }
  384. void MoveClockForwardBy(base::TimeDelta time_delta) {
  385. task_environment_.FastForwardBy(time_delta);
  386. RunUntilIdle();
  387. }
  388. TestPredictionModelFetcher* prediction_model_fetcher() const {
  389. return static_cast<TestPredictionModelFetcher*>(
  390. prediction_manager()->prediction_model_fetcher());
  391. }
  392. FakePredictionModelDownloadManager* prediction_model_download_manager()
  393. const {
  394. return static_cast<FakePredictionModelDownloadManager*>(
  395. temp_dir(), prediction_manager()->prediction_model_download_manager());
  396. }
  397. TestOptimizationGuideStore* models_and_features_store() const {
  398. base::WeakPtr<OptimizationGuideStore> store =
  399. prediction_manager()->model_and_features_store();
  400. DCHECK(store);
  401. return static_cast<TestOptimizationGuideStore*>(store.get());
  402. }
  403. base::FilePath temp_dir() const { return temp_dir_.GetPath(); }
  404. TestingPrefServiceSimple* pref_service() const { return pref_service_.get(); }
  405. void RunUntilIdle() {
  406. task_environment_.RunUntilIdle();
  407. base::RunLoop().RunUntilIdle();
  408. }
  409. base::test::TaskEnvironment* task_environment() { return &task_environment_; }
  410. void SetComponentUpdatesPrefEnabled(bool enabled) {
  411. component_updates_enabled_ = enabled;
  412. }
  413. bool AreComponentUpdatesEnabled() const { return component_updates_enabled_; }
  414. protected:
  415. // |feature_list_| needs to be destroyed after |task_environment_|, to avoid
  416. // tsan flakes caused by other tasks running while |feature_list_| is
  417. // destroyed.
  418. base::test::ScopedFeatureList feature_list_;
  419. private:
  420. base::test::TaskEnvironment task_environment_{
  421. base::test::TaskEnvironment::MainThreadType::UI,
  422. base::test::TaskEnvironment::TimeSource::MOCK_TIME};
  423. StoreEntryMap db_store_;
  424. std::unique_ptr<TestOptimizationGuideStore> model_and_features_store_;
  425. std::unique_ptr<TestPredictionManager> prediction_manager_;
  426. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
  427. network::TestURLLoaderFactory test_url_loader_factory_;
  428. std::unique_ptr<TestingPrefServiceSimple> pref_service_;
  429. std::unique_ptr<TestingPrefServiceSimple> local_state_prefs_;
  430. bool component_updates_enabled_ = true;
  431. };
  432. class PredictionManagerRemoteFetchingDisabledTest
  433. : public PredictionManagerTestBase {
  434. public:
  435. PredictionManagerRemoteFetchingDisabledTest() {
  436. // This needs to be done before any tasks are run that might check if a
  437. // feature is enabled, to avoid tsan errors.
  438. feature_list_.InitAndDisableFeature(
  439. features::kRemoteOptimizationGuideFetching);
  440. }
  441. };
  442. TEST_F(PredictionManagerRemoteFetchingDisabledTest, RemoteFetchingDisabled) {
  443. CreatePredictionManager();
  444. prediction_manager()->SetPredictionModelFetcherForTesting(
  445. BuildTestPredictionModelFetcher(
  446. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  447. FakeOptimizationTargetModelObserver observer;
  448. prediction_manager()->AddObserverForOptimizationTargetModel(
  449. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  450. SetStoreInitialized();
  451. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  452. }
  453. class PredictionManagerModelDownloadingDisabledTest
  454. : public PredictionManagerTestBase {
  455. public:
  456. PredictionManagerModelDownloadingDisabledTest() {
  457. // This needs to be done before any tasks are run that might check if a
  458. // feature is enabled, to avoid tsan errors.
  459. feature_list_.InitAndDisableFeature(
  460. features::kOptimizationGuideModelDownloading);
  461. }
  462. };
  463. TEST_F(PredictionManagerModelDownloadingDisabledTest,
  464. ModelDownloadingDisabledShouldNotFetch) {
  465. CreatePredictionManager();
  466. prediction_manager()->SetPredictionModelFetcherForTesting(
  467. BuildTestPredictionModelFetcher(
  468. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  469. FakeOptimizationTargetModelObserver observer;
  470. prediction_manager()->AddObserverForOptimizationTargetModel(
  471. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  472. SetStoreInitialized();
  473. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  474. }
  475. class PredictionManagerTest : public PredictionManagerTestBase {
  476. public:
  477. PredictionManagerTest() {
  478. // This needs to be done before any tasks are run that might check if a
  479. // feature is enabled, to avoid tsan errors.
  480. feature_list_.InitWithFeatures(
  481. {features::kRemoteOptimizationGuideFetching,
  482. features::kOptimizationGuideModelDownloading},
  483. {});
  484. }
  485. private:
  486. variations::ScopedVariationsIdsProvider scoped_variations_ids_provider_{
  487. variations::VariationsIdsProvider::Mode::kUseSignedInState};
  488. };
  489. TEST_F(PredictionManagerTest, RemoteFetchingPrefDisabled) {
  490. SetComponentUpdatesPrefEnabled(false);
  491. CreatePredictionManager();
  492. prediction_manager()->SetPredictionModelFetcherForTesting(
  493. BuildTestPredictionModelFetcher(
  494. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  495. FakeOptimizationTargetModelObserver observer;
  496. prediction_manager()->AddObserverForOptimizationTargetModel(
  497. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  498. SetStoreInitialized();
  499. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  500. }
  501. TEST_F(PredictionManagerTest, AddObserverForOptimizationTargetModel) {
  502. base::HistogramTester histogram_tester;
  503. CreatePredictionManager();
  504. prediction_manager()->SetPredictionModelFetcherForTesting(
  505. BuildTestPredictionModelFetcher(
  506. PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse));
  507. proto::Any model_metadata;
  508. model_metadata.set_type_url(
  509. "type.googleapis.com/"
  510. "google.internal.chrome.optimizationguide.v1.PageEntitiesModelMetadata");
  511. prediction_model_fetcher()->SetExpectedModelMetadataForOptimizationTarget(
  512. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata);
  513. histogram_tester.ExpectTotalCount(
  514. "OptimizationGuide.PredictionManager.RegistrationTimeSinceServiceInit."
  515. "PainfulPageLoad",
  516. 0);
  517. histogram_tester.ExpectTotalCount(
  518. "OptimizationGuide.PredictionManager.FirstModelFetchSinceServiceInit", 0);
  519. FakeOptimizationTargetModelObserver observer;
  520. prediction_manager()->AddObserverForOptimizationTargetModel(
  521. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata, &observer);
  522. SetStoreInitialized(/* load_models= */ false,
  523. /* have_models_in_store= */ false);
  524. histogram_tester.ExpectUniqueSample(
  525. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration."
  526. "PainfulPageLoad",
  527. false, 1);
  528. histogram_tester.ExpectTotalCount(
  529. "OptimizationGuide.PredictionManager.RegistrationTimeSinceServiceInit."
  530. "PainfulPageLoad",
  531. 1);
  532. EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
  533. // Make sure the test histogram is recorded. We don't check for value here
  534. // since that is too much toil for someone whenever they add a new version.
  535. histogram_tester.ExpectTotalCount(
  536. "OptimizationGuide.PredictionManager.SupportedModelEngineVersion", 1);
  537. histogram_tester.ExpectTotalCount(
  538. "OptimizationGuide.PredictionManager.FirstModelFetchSinceServiceInit", 1);
  539. EXPECT_TRUE(prediction_manager()->GetRegisteredOptimizationTargets().contains(
  540. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD));
  541. EXPECT_FALSE(observer
  542. .last_received_model_for_target(
  543. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  544. .has_value());
  545. base::FilePath additional_file_path =
  546. temp_dir().AppendASCII("whatever").AppendASCII("additional_file.txt");
  547. proto::ModelInfo model_info;
  548. model_info.set_optimization_target(
  549. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  550. model_info.set_version(1);
  551. model_info.mutable_model_metadata()->set_type_url("sometypeurl");
  552. model_info.add_additional_files()->set_file_path(
  553. FilePathToString(additional_file_path));
  554. // An empty file path should be be ignored.
  555. model_info.add_additional_files()->set_file_path("");
  556. // Ensure observer is hooked up.
  557. {
  558. base::HistogramTester model_ready_histogram_tester;
  559. proto::PredictionModel model1;
  560. *model1.mutable_model_info() = model_info;
  561. model1.mutable_model()->set_download_url(
  562. FilePathToString(temp_dir().AppendASCII("whatever")));
  563. prediction_manager()->OnModelReady(model1);
  564. RunUntilIdle();
  565. absl::optional<ModelInfo> received_model =
  566. observer.last_received_model_for_target(
  567. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  568. EXPECT_EQ(received_model->GetModelMetadata()->type_url(), "sometypeurl");
  569. EXPECT_EQ(received_model->GetModelFilePath().BaseName().value(),
  570. FILE_PATH_LITERAL("whatever"));
  571. EXPECT_EQ(received_model->GetAdditionalFiles(),
  572. base::flat_set<base::FilePath>{additional_file_path});
  573. // Make sure we do not record the model available histogram again.
  574. model_ready_histogram_tester.ExpectTotalCount(
  575. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration."
  576. "PainfulPageLoad",
  577. 0);
  578. }
  579. // Reset fetcher and make sure version is sent in the new request and not
  580. // counted as re-loaded or updated.
  581. {
  582. base::HistogramTester histogram_tester2;
  583. prediction_model_fetcher()->Reset();
  584. prediction_model_fetcher()->SetCheckExpectedVersion();
  585. prediction_model_fetcher()->SetExpectedVersionForOptimizationTarget(
  586. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, 1);
  587. MoveClockForwardBy(base::Seconds(kUpdateFetchModelAndFeaturesTimeSecs));
  588. EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
  589. histogram_tester2.ExpectTotalCount(
  590. "OptimizationGuide.PredictionModelUpdateVersion.PainfulPageLoad", 0);
  591. histogram_tester2.ExpectTotalCount(
  592. "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0);
  593. histogram_tester2.ExpectTotalCount(
  594. "OptimizationGuide.PredictionModelRemoved.PainfulPageLoad", 0);
  595. }
  596. // Now remove and reset observer.
  597. prediction_manager()->RemoveObserverForOptimizationTargetModel(
  598. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &observer);
  599. observer.Reset();
  600. proto::PredictionModel model2;
  601. *model2.mutable_model_info() = model_info;
  602. model2.mutable_model_info()->set_version(2);
  603. model2.mutable_model()->set_download_url(
  604. FilePathToString(temp_dir().AppendASCII("whatever2")));
  605. prediction_manager()->OnModelReady(model2);
  606. RunUntilIdle();
  607. // Last received path should not have been updated since the observer was
  608. // removed.
  609. EXPECT_FALSE(observer
  610. .last_received_model_for_target(
  611. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  612. .has_value());
  613. }
  614. TEST_F(PredictionManagerTest,
  615. AddObserverForOptimizationTargetModelAddAnotherObserverForSameTarget) {
  616. // Fails under "threadsafe" mode.
  617. testing::GTEST_FLAG(death_test_style) = "fast";
  618. CreatePredictionManager();
  619. FakeOptimizationTargetModelObserver observer1;
  620. prediction_manager()->AddObserverForOptimizationTargetModel(
  621. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
  622. /*model_metadata=*/absl::nullopt, &observer1);
  623. SetStoreInitialized(/* load_models= */ false,
  624. /* have_models_in_store= */ false);
  625. proto::ModelInfo model_info;
  626. model_info.set_optimization_target(
  627. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  628. model_info.set_version(1);
  629. // Ensure observer is hooked up.
  630. proto::PredictionModel model1;
  631. *model1.mutable_model_info() = model_info;
  632. model1.mutable_model()->set_download_url(
  633. FilePathToString(temp_dir().AppendASCII("whatever")));
  634. prediction_manager()->OnModelReady(model1);
  635. RunUntilIdle();
  636. EXPECT_EQ(observer1
  637. .last_received_model_for_target(
  638. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  639. ->GetModelFilePath()
  640. .BaseName()
  641. .value(),
  642. FILE_PATH_LITERAL("whatever"));
  643. #if !BUILDFLAG(IS_WIN)
  644. // Do not run the DCHECK death test on Windows since there's some weird
  645. // behavior there.
  646. // Now, register a new observer - it should die.
  647. FakeOptimizationTargetModelObserver observer2;
  648. EXPECT_DCHECK_DEATH(
  649. prediction_manager()->AddObserverForOptimizationTargetModel(
  650. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
  651. /*model_metadata=*/absl::nullopt, &observer2));
  652. RunUntilIdle();
  653. #endif
  654. }
  655. // See crbug/1227996.
  656. #if !BUILDFLAG(IS_WIN)
  657. TEST_F(PredictionManagerTest,
  658. AddObserverForOptimizationTargetModelCommandLineOverride) {
  659. base::HistogramTester histogram_tester;
  660. optimization_guide::proto::Any metadata;
  661. metadata.set_type_url(
  662. "type.googleapis.com/"
  663. "google.internal.chrome.optimizationguide.v1.PageEntitiesModelMetadata");
  664. std::string encoded_metadata;
  665. metadata.SerializeToString(&encoded_metadata);
  666. base::Base64Encode(encoded_metadata, &encoded_metadata);
  667. base::CommandLine::ForCurrentProcess()->AppendSwitchASCII(
  668. switches::kModelOverride,
  669. base::StringPrintf("OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:%s:%s",
  670. kTestAbsoluteFilePath, encoded_metadata.c_str()));
  671. CreatePredictionManager();
  672. prediction_manager()->SetPredictionModelFetcherForTesting(
  673. BuildTestPredictionModelFetcher(
  674. PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse));
  675. proto::Any model_metadata;
  676. model_metadata.set_type_url(
  677. "type.googleapis.com/"
  678. "google.internal.chrome.optimizationguide.v1.PageEntitiesModelMetadata");
  679. prediction_model_fetcher()->SetExpectedModelMetadataForOptimizationTarget(
  680. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata);
  681. FakeOptimizationTargetModelObserver observer;
  682. prediction_manager()->AddObserverForOptimizationTargetModel(
  683. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata, &observer);
  684. SetStoreInitialized(/* load_models= */ false,
  685. /* have_models_in_store= */ false);
  686. // Make sure no models are fetched.
  687. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  688. // However, expect that the histogram for model engine version is recorded.
  689. // We don't check for value here since that is too much toil for someone
  690. // whenever they add a new version.
  691. histogram_tester.ExpectTotalCount(
  692. "OptimizationGuide.PredictionManager.SupportedModelEngineVersion", 1);
  693. EXPECT_TRUE(prediction_manager()->GetRegisteredOptimizationTargets().contains(
  694. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD));
  695. EXPECT_EQ(
  696. observer
  697. .last_received_model_for_target(
  698. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  699. ->GetModelMetadata()
  700. ->type_url(),
  701. "type.googleapis.com/"
  702. "google.internal.chrome.optimizationguide.v1.PageEntitiesModelMetadata");
  703. EXPECT_EQ(observer
  704. .last_received_model_for_target(
  705. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  706. ->GetModelFilePath()
  707. .value(),
  708. FILE_PATH_LITERAL(kTestAbsoluteFilePath));
  709. // Now reset observer. New model downloads should not update the observer.
  710. observer.Reset();
  711. proto::PredictionModel model;
  712. model.mutable_model_info()->set_optimization_target(
  713. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  714. model.mutable_model_info()->set_version(1);
  715. model.mutable_model()->set_download_url(
  716. FilePathToString(temp_dir().AppendASCII("whatever2")));
  717. prediction_manager()->OnModelReady(model);
  718. RunUntilIdle();
  719. // Last received path should not have been updated since the observer was
  720. // reset and override is in place.
  721. EXPECT_FALSE(observer
  722. .last_received_model_for_target(
  723. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  724. .has_value());
  725. }
  726. #endif
  727. TEST_F(PredictionManagerTest,
  728. NoPredictionModelForRegisteredOptimizationTarget) {
  729. base::HistogramTester histogram_tester;
  730. CreatePredictionManager();
  731. SetStoreInitialized(/*load_models=*/false, /*have_models_in_store=*/false);
  732. FakeOptimizationTargetModelObserver observer;
  733. prediction_manager()->AddObserverForOptimizationTargetModel(
  734. proto::OPTIMIZATION_TARGET_MODEL_VALIDATION, absl::nullopt, &observer);
  735. histogram_tester.ExpectUniqueSample(
  736. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration."
  737. "ModelValidation",
  738. false, 1);
  739. }
  740. TEST_F(PredictionManagerTest, UpdatePredictionModelsWithInvalidModel) {
  741. base::HistogramTester histogram_tester;
  742. CreatePredictionManager();
  743. FakeOptimizationTargetModelObserver observer;
  744. prediction_manager()->AddObserverForOptimizationTargetModel(
  745. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
  746. /*model_metadata=*/absl::nullopt, &observer);
  747. // Set invalid model with no download url.
  748. proto::PredictionModel model;
  749. model.mutable_model_info()->set_optimization_target(
  750. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  751. model.mutable_model_info()->set_version(3);
  752. model.mutable_model();
  753. prediction_manager()->OnModelReady(model);
  754. RunUntilIdle();
  755. histogram_tester.ExpectBucketCount("OptimizationGuide.IsPredictionModelValid",
  756. false, 1);
  757. histogram_tester.ExpectTotalCount(
  758. "OptimizationGuide.PredictionModelValidationLatency", 0);
  759. histogram_tester.ExpectTotalCount(
  760. "OptimizationGuide.PredictionModelUpdateVersion.PainfulPageLoad", 1);
  761. histogram_tester.ExpectTotalCount(
  762. "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0);
  763. histogram_tester.ExpectUniqueSample(
  764. "OptimizationGuide.PredictionModelRemoved.PainfulPageLoad", true, 1);
  765. }
  766. TEST_F(PredictionManagerTest, UpdateModelFileWithSameVersion) {
  767. base::HistogramTester histogram_tester;
  768. CreatePredictionManager();
  769. FakeOptimizationTargetModelObserver observer;
  770. prediction_manager()->AddObserverForOptimizationTargetModel(
  771. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD,
  772. /*model_metadata=*/absl::nullopt, &observer);
  773. proto::PredictionModel model;
  774. model.mutable_model_info()->set_optimization_target(
  775. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  776. model.mutable_model_info()->set_version(3);
  777. model.mutable_model()->set_download_url(
  778. FilePathToString(temp_dir().AppendASCII("whatever2")));
  779. prediction_manager()->OnModelReady(model);
  780. RunUntilIdle();
  781. EXPECT_TRUE(observer
  782. .last_received_model_for_target(
  783. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  784. .has_value());
  785. // Now reset the observer state.
  786. observer.Reset();
  787. // Send the same model again.
  788. prediction_manager()->OnModelReady(model);
  789. // The observer should not have received an update.
  790. EXPECT_FALSE(observer
  791. .last_received_model_for_target(
  792. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)
  793. .has_value());
  794. }
  795. TEST_F(PredictionManagerTest, DownloadManagerUnavailableShouldNotFetch) {
  796. base::HistogramTester histogram_tester;
  797. CreatePredictionManager();
  798. prediction_manager()->SetPredictionModelFetcherForTesting(
  799. BuildTestPredictionModelFetcher(
  800. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  801. prediction_manager()->SetPredictionModelDownloadManagerForTesting(
  802. std::make_unique<FakePredictionModelDownloadManager>(
  803. temp_dir(), task_environment()->GetMainThreadTaskRunner()));
  804. prediction_model_download_manager()->SetAvailableForDownloads(false);
  805. FakeOptimizationTargetModelObserver observer;
  806. prediction_manager()->AddObserverForOptimizationTargetModel(
  807. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  808. SetStoreInitialized(/*load_models=*/true, /*have_models_in_store=*/false);
  809. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  810. histogram_tester.ExpectUniqueSample(
  811. "OptimizationGuide.PredictionManager."
  812. "DownloadServiceAvailabilityBlockedFetch",
  813. true, 1);
  814. histogram_tester.ExpectUniqueSample(
  815. "OptimizationGuide.PredictionManager.ModelDeliveryEvents.PainfulPageLoad",
  816. ModelDeliveryEvent::kDownloadServiceUnavailable, 1);
  817. }
  818. TEST_F(PredictionManagerTest, UpdateModelWithDownloadUrl) {
  819. base::HistogramTester histogram_tester;
  820. CreatePredictionManager();
  821. prediction_manager()->SetPredictionModelFetcherForTesting(
  822. BuildTestPredictionModelFetcher(
  823. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  824. prediction_manager()->SetPredictionModelDownloadManagerForTesting(
  825. std::make_unique<FakePredictionModelDownloadManager>(
  826. temp_dir(), task_environment()->GetMainThreadTaskRunner()));
  827. FakeOptimizationTargetModelObserver observer;
  828. prediction_manager()->AddObserverForOptimizationTargetModel(
  829. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  830. SetStoreInitialized();
  831. EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
  832. EXPECT_TRUE(prediction_model_download_manager()->cancel_downloads_called());
  833. histogram_tester.ExpectTotalCount(
  834. "OptimizationGuide.PredictionManager.PredictionModelsStored", 0);
  835. histogram_tester.ExpectUniqueSample(
  836. "OptimizationGuide.PredictionManager."
  837. "DownloadServiceAvailabilityBlockedFetch",
  838. false, 1);
  839. histogram_tester.ExpectUniqueSample(
  840. "OptimizationGuide.PredictionManager.IsDownloadUrlValid.PainfulPageLoad",
  841. true, 1);
  842. EXPECT_EQ(prediction_model_download_manager()->last_requested_download(),
  843. GURL("https://example.com/model"));
  844. EXPECT_EQ(
  845. prediction_model_download_manager()->last_requested_optimization_target(),
  846. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  847. }
  848. TEST_F(PredictionManagerTest, UpdateModelForUnregisteredTargetOnModelReady) {
  849. base::HistogramTester histogram_tester;
  850. CreatePredictionManager();
  851. SetStoreInitialized();
  852. proto::PredictionModel model;
  853. model.mutable_model_info()->set_optimization_target(
  854. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD);
  855. model.mutable_model_info()->set_version(3);
  856. model.mutable_model()->set_download_url(
  857. FilePathToString(temp_dir().AppendASCII("whatever")));
  858. prediction_manager()->OnModelReady(model);
  859. histogram_tester.ExpectTotalCount(
  860. "OptimizationGuide.PredictionManager.PredictionModelsStored", 1);
  861. histogram_tester.ExpectTotalCount(
  862. "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0);
  863. // Now register the model.
  864. FakeOptimizationTargetModelObserver observer;
  865. prediction_manager()->AddObserverForOptimizationTargetModel(
  866. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  867. RunUntilIdle();
  868. histogram_tester.ExpectUniqueSample(
  869. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration."
  870. "PainfulPageLoad",
  871. true, 1);
  872. histogram_tester.ExpectTotalCount(
  873. "OptimizationGuide.PredictionManager.ModelDeliveryEvents.PainfulPageLoad",
  874. 2);
  875. histogram_tester.ExpectBucketCount(
  876. "OptimizationGuide.PredictionManager.ModelDeliveryEvents.PainfulPageLoad",
  877. ModelDeliveryEvent::kModelDownloaded, 1);
  878. histogram_tester.ExpectBucketCount(
  879. "OptimizationGuide.PredictionManager.ModelDeliveryEvents.PainfulPageLoad",
  880. ModelDeliveryEvent::kModelDelivered, 1);
  881. }
  882. TEST_F(PredictionManagerTest,
  883. StoreInitializedAfterOptimizationTargetRegistered) {
  884. base::HistogramTester histogram_tester;
  885. CreatePredictionManager();
  886. // Ensure that the fetch does not cause any models or features to load.
  887. prediction_manager()->SetPredictionModelFetcherForTesting(
  888. BuildTestPredictionModelFetcher(
  889. PredictionModelFetcherEndState::kFetchFailed));
  890. FakeOptimizationTargetModelObserver observer;
  891. prediction_manager()->AddObserverForOptimizationTargetModel(
  892. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  893. EXPECT_FALSE(models_and_features_store()->WasModelLoaded());
  894. SetStoreInitialized();
  895. EXPECT_TRUE(models_and_features_store()->WasModelLoaded());
  896. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  897. histogram_tester.ExpectUniqueSample(
  898. "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 1, 1);
  899. histogram_tester.ExpectUniqueSample(
  900. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration."
  901. "PainfulPageLoad",
  902. true, 1);
  903. }
  904. TEST_F(PredictionManagerTest,
  905. StoreInitializedBeforeOptimizationTargetRegistered) {
  906. base::HistogramTester histogram_tester;
  907. CreatePredictionManager();
  908. // Ensure that the fetch does not cause any models or features to load.
  909. prediction_manager()->SetPredictionModelFetcherForTesting(
  910. BuildTestPredictionModelFetcher(
  911. PredictionModelFetcherEndState::kFetchFailed));
  912. SetStoreInitialized();
  913. EXPECT_FALSE(models_and_features_store()->WasModelLoaded());
  914. FakeOptimizationTargetModelObserver observer;
  915. prediction_manager()->AddObserverForOptimizationTargetModel(
  916. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  917. RunUntilIdle();
  918. EXPECT_TRUE(models_and_features_store()->WasModelLoaded());
  919. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  920. histogram_tester.ExpectUniqueSample(
  921. "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 1, 1);
  922. histogram_tester.ExpectUniqueSample(
  923. "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration."
  924. "PainfulPageLoad",
  925. true, 1);
  926. }
  927. TEST_F(PredictionManagerTest, ModelFetcherTimerRetryDelay) {
  928. CreatePredictionManager();
  929. prediction_manager()->SetPredictionModelFetcherForTesting(
  930. BuildTestPredictionModelFetcher(
  931. PredictionModelFetcherEndState::kFetchFailed));
  932. prediction_manager()->SetPredictionModelDownloadManagerForTesting(
  933. std::make_unique<FakePredictionModelDownloadManager>(
  934. temp_dir(), task_environment()->GetMainThreadTaskRunner()));
  935. FakeOptimizationTargetModelObserver observer;
  936. prediction_manager()->AddObserverForOptimizationTargetModel(
  937. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  938. SetStoreInitialized();
  939. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  940. MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs));
  941. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  942. prediction_manager()->SetPredictionModelFetcherForTesting(
  943. BuildTestPredictionModelFetcher(
  944. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  945. MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs));
  946. EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
  947. }
  948. TEST_F(PredictionManagerTest, ModelFetcherTimerFetchSucceeds) {
  949. CreatePredictionManager();
  950. prediction_manager()->SetPredictionModelFetcherForTesting(
  951. BuildTestPredictionModelFetcher(
  952. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  953. prediction_manager()->SetPredictionModelDownloadManagerForTesting(
  954. std::make_unique<FakePredictionModelDownloadManager>(
  955. temp_dir(), task_environment()->GetMainThreadTaskRunner()));
  956. FakeOptimizationTargetModelObserver observer;
  957. prediction_manager()->AddObserverForOptimizationTargetModel(
  958. proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt, &observer);
  959. SetStoreInitialized();
  960. EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
  961. EXPECT_EQ("en-US", prediction_model_fetcher()->locale_requested());
  962. // Reset the prediction model fetcher to detect when the next fetch occurs.
  963. prediction_manager()->SetPredictionModelFetcherForTesting(
  964. BuildTestPredictionModelFetcher(
  965. PredictionModelFetcherEndState::kFetchSuccessWithModels));
  966. MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs));
  967. EXPECT_FALSE(prediction_model_fetcher()->models_fetched());
  968. MoveClockForwardBy(base::Seconds(kUpdateFetchModelAndFeaturesTimeSecs));
  969. EXPECT_TRUE(prediction_model_fetcher()->models_fetched());
  970. }
  971. } // namespace optimization_guide