page_topics_model_executor_unittest.cc 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/page_topics_model_executor.h"
  5. #include "base/containers/flat_map.h"
  6. #include "base/files/file_util.h"
  7. #include "base/files/scoped_temp_dir.h"
  8. #include "base/path_service.h"
  9. #include "base/task/thread_pool.h"
  10. #include "base/test/metrics/histogram_tester.h"
  11. #include "base/test/scoped_feature_list.h"
  12. #include "base/test/task_environment.h"
  13. #include "components/optimization_guide/core/optimization_guide_features.h"
  14. #include "components/optimization_guide/core/optimization_guide_model_provider.h"
  15. #include "components/optimization_guide/core/page_entities_model_executor.h"
  16. #include "components/optimization_guide/core/test_model_info_builder.h"
  17. #include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
  18. #include "components/optimization_guide/proto/models.pb.h"
  19. #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
  20. #include "components/optimization_guide/proto/page_topics_override_list.pb.h"
  21. #include "testing/gmock/include/gmock/gmock.h"
  22. #include "testing/gtest/include/gtest/gtest.h"
  23. #include "third_party/zlib/google/compression_utils.h"
  24. namespace optimization_guide {
  25. class ModelObserverTracker : public TestOptimizationGuideModelProvider {
  26. public:
  27. void AddObserverForOptimizationTargetModel(
  28. proto::OptimizationTarget target,
  29. const absl::optional<proto::Any>& model_metadata,
  30. OptimizationTargetModelObserver* observer) override {
  31. registered_model_metadata_.insert_or_assign(target, model_metadata);
  32. }
  33. bool DidRegisterForTarget(
  34. proto::OptimizationTarget target,
  35. absl::optional<proto::Any>* out_model_metadata) const {
  36. auto it = registered_model_metadata_.find(target);
  37. if (it == registered_model_metadata_.end())
  38. return false;
  39. *out_model_metadata = registered_model_metadata_.at(target);
  40. return true;
  41. }
  42. private:
  43. base::flat_map<proto::OptimizationTarget, absl::optional<proto::Any>>
  44. registered_model_metadata_;
  45. };
  46. class TestPageTopicsModelExecutor : public PageTopicsModelExecutor {
  47. public:
  48. TestPageTopicsModelExecutor(
  49. OptimizationGuideModelProvider* model_provider,
  50. scoped_refptr<base::SequencedTaskRunner> background_task_runner,
  51. const absl::optional<proto::Any>& model_metadata)
  52. : PageTopicsModelExecutor(model_provider,
  53. background_task_runner,
  54. model_metadata) {}
  55. ~TestPageTopicsModelExecutor() override = default;
  56. void ExecuteModelWithInput(ExecutionCallback callback,
  57. const std::string& input) override {
  58. inputs_.push_back(input);
  59. std::move(callback).Run(absl::nullopt);
  60. }
  61. const std::vector<std::string>& inputs() const { return inputs_; }
  62. private:
  63. std::vector<std::string> inputs_;
  64. };
  65. class PageTopicsModelExecutorTest : public testing::Test {
  66. public:
  67. PageTopicsModelExecutorTest() {
  68. scoped_feature_list_.InitWithFeatures(
  69. {features::kPageContentAnnotations},
  70. {features::kPreventLongRunningPredictionModels});
  71. }
  72. ~PageTopicsModelExecutorTest() override = default;
  73. void SetUp() override {
  74. model_observer_tracker_ = std::make_unique<ModelObserverTracker>();
  75. model_executor_ = std::make_unique<TestPageTopicsModelExecutor>(
  76. model_observer_tracker_.get(),
  77. base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}),
  78. /*model_metadata=*/absl::nullopt);
  79. }
  80. void TearDown() override {
  81. model_executor_.reset();
  82. model_observer_tracker_.reset();
  83. RunUntilIdle();
  84. }
  85. void SendPageTopicsModelToExecutor(
  86. const absl::optional<proto::Any>& model_metadata) {
  87. base::FilePath source_root_dir;
  88. base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
  89. base::FilePath model_file_path =
  90. source_root_dir.AppendASCII("components")
  91. .AppendASCII("test")
  92. .AppendASCII("data")
  93. .AppendASCII("optimization_guide")
  94. .AppendASCII("bert_page_topics_model.tflite");
  95. std::unique_ptr<ModelInfo> model_info =
  96. TestModelInfoBuilder()
  97. .SetModelFilePath(model_file_path)
  98. .SetModelMetadata(model_metadata)
  99. .Build();
  100. model_executor()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
  101. *model_info);
  102. RunUntilIdle();
  103. }
  104. ModelObserverTracker* model_observer_tracker() const {
  105. return model_observer_tracker_.get();
  106. }
  107. TestPageTopicsModelExecutor* model_executor() const {
  108. return model_executor_.get();
  109. }
  110. void RunUntilIdle() { task_environment_.RunUntilIdle(); }
  111. private:
  112. base::test::TaskEnvironment task_environment_;
  113. base::test::ScopedFeatureList scoped_feature_list_;
  114. std::unique_ptr<ModelObserverTracker> model_observer_tracker_;
  115. std::unique_ptr<TestPageTopicsModelExecutor> model_executor_;
  116. };
  117. TEST_F(
  118. PageTopicsModelExecutorTest,
  119. GetContentModelAnnotationsFromOutputNonNumericAndLowWeightCategoriesPruned) {
  120. proto::PageTopicsModelMetadata model_metadata;
  121. model_metadata.set_version(123);
  122. auto* category_params = model_metadata.mutable_output_postprocessing_params()
  123. ->mutable_category_params();
  124. category_params->set_max_categories(4);
  125. category_params->set_min_none_weight(0.8);
  126. category_params->set_min_category_weight(0.01);
  127. category_params->set_min_normalized_weight_within_top_n(0.1);
  128. proto::Any any_metadata;
  129. any_metadata.set_type_url(
  130. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  131. model_metadata.SerializeToString(any_metadata.mutable_value());
  132. SendPageTopicsModelToExecutor(any_metadata);
  133. std::vector<tflite::task::core::Category> model_output = {
  134. {"0", 0.0001}, {"1", 0.1}, {"not an int", 0.9}, {"2", 0.2}, {"3", 0.3},
  135. };
  136. absl::optional<std::vector<WeightedIdentifier>> categories =
  137. model_executor()->ExtractCategoriesFromModelOutput(model_output);
  138. ASSERT_TRUE(categories);
  139. EXPECT_THAT(*categories,
  140. testing::UnorderedElementsAre(WeightedIdentifier(1, 0.1),
  141. WeightedIdentifier(2, 0.2),
  142. WeightedIdentifier(3, 0.3)));
  143. }
  144. TEST_F(PageTopicsModelExecutorTest,
  145. GetContentModelAnnotationsFromOutputNoneWeightTooStrong) {
  146. proto::PageTopicsModelMetadata model_metadata;
  147. model_metadata.set_version(123);
  148. auto* category_params = model_metadata.mutable_output_postprocessing_params()
  149. ->mutable_category_params();
  150. category_params->set_max_categories(4);
  151. category_params->set_min_none_weight(0.1);
  152. category_params->set_min_category_weight(0.01);
  153. category_params->set_min_normalized_weight_within_top_n(0.1);
  154. proto::Any any_metadata;
  155. any_metadata.set_type_url(
  156. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  157. model_metadata.SerializeToString(any_metadata.mutable_value());
  158. SendPageTopicsModelToExecutor(any_metadata);
  159. std::vector<tflite::task::core::Category> model_output = {
  160. {"-2", 0.9999},
  161. {"0", 0.3},
  162. {"1", 0.2},
  163. };
  164. absl::optional<std::vector<WeightedIdentifier>> categories =
  165. model_executor()->ExtractCategoriesFromModelOutput(model_output);
  166. EXPECT_FALSE(categories);
  167. }
  168. TEST_F(PageTopicsModelExecutorTest,
  169. GetContentModelAnnotationsFromOutputNoneInTopButNotStrongSoPruned) {
  170. proto::PageTopicsModelMetadata model_metadata;
  171. model_metadata.set_version(123);
  172. auto* category_params = model_metadata.mutable_output_postprocessing_params()
  173. ->mutable_category_params();
  174. category_params->set_max_categories(4);
  175. category_params->set_min_none_weight(0.8);
  176. category_params->set_min_category_weight(0.01);
  177. category_params->set_min_normalized_weight_within_top_n(0.1);
  178. proto::Any any_metadata;
  179. any_metadata.set_type_url(
  180. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  181. model_metadata.SerializeToString(any_metadata.mutable_value());
  182. SendPageTopicsModelToExecutor(any_metadata);
  183. std::vector<tflite::task::core::Category> model_output = {
  184. {"-2", 0.1}, {"0", 0.3}, {"1", 0.2}, {"2", 0.4}, {"3", 0.05},
  185. };
  186. absl::optional<std::vector<WeightedIdentifier>> categories =
  187. model_executor()->ExtractCategoriesFromModelOutput(model_output);
  188. ASSERT_TRUE(categories);
  189. EXPECT_THAT(*categories,
  190. testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3),
  191. WeightedIdentifier(1, 0.2),
  192. WeightedIdentifier(2, 0.4)));
  193. }
  194. TEST_F(PageTopicsModelExecutorTest,
  195. GetContentModelAnnotationsFromOutputPrunedAfterNormalization) {
  196. proto::PageTopicsModelMetadata model_metadata;
  197. model_metadata.set_version(123);
  198. auto* category_params = model_metadata.mutable_output_postprocessing_params()
  199. ->mutable_category_params();
  200. category_params->set_max_categories(4);
  201. category_params->set_min_none_weight(0.8);
  202. category_params->set_min_category_weight(0.01);
  203. category_params->set_min_normalized_weight_within_top_n(0.25);
  204. proto::Any any_metadata;
  205. any_metadata.set_type_url(
  206. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  207. model_metadata.SerializeToString(any_metadata.mutable_value());
  208. SendPageTopicsModelToExecutor(any_metadata);
  209. std::vector<tflite::task::core::Category> model_output = {
  210. {"0", 0.3},
  211. {"1", 0.25},
  212. {"2", 0.4},
  213. {"3", 0.05},
  214. };
  215. absl::optional<std::vector<WeightedIdentifier>> categories =
  216. model_executor()->ExtractCategoriesFromModelOutput(model_output);
  217. ASSERT_TRUE(categories);
  218. EXPECT_THAT(*categories,
  219. testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3),
  220. WeightedIdentifier(1, 0.25),
  221. WeightedIdentifier(2, 0.4)));
  222. }
  223. TEST_F(PageTopicsModelExecutorTest,
  224. PostprocessCategoriesToBatchAnnotationResult) {
  225. proto::PageTopicsModelMetadata model_metadata;
  226. model_metadata.set_version(123);
  227. auto* category_params = model_metadata.mutable_output_postprocessing_params()
  228. ->mutable_category_params();
  229. category_params->set_max_categories(4);
  230. category_params->set_min_none_weight(0.8);
  231. category_params->set_min_category_weight(0.01);
  232. category_params->set_min_normalized_weight_within_top_n(0.25);
  233. proto::Any any_metadata;
  234. any_metadata.set_type_url(
  235. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  236. model_metadata.SerializeToString(any_metadata.mutable_value());
  237. SendPageTopicsModelToExecutor(any_metadata);
  238. std::vector<tflite::task::core::Category> model_output = {
  239. {"0", 0.3},
  240. {"1", 0.25},
  241. {"2", 0.4},
  242. {"3", 0.05},
  243. };
  244. BatchAnnotationResult topics_result =
  245. BatchAnnotationResult::CreateEmptyAnnotationsResult("");
  246. model_executor()->PostprocessCategoriesToBatchAnnotationResult(
  247. base::BindOnce(
  248. [](BatchAnnotationResult* out_result,
  249. const BatchAnnotationResult& in_result) {
  250. *out_result = in_result;
  251. },
  252. &topics_result),
  253. AnnotationType::kPageTopics, "input", model_output);
  254. EXPECT_EQ(topics_result, BatchAnnotationResult::CreatePageTopicsResult(
  255. "input", std::vector<WeightedIdentifier>{
  256. WeightedIdentifier(0, 0.3),
  257. WeightedIdentifier(1, 0.25),
  258. WeightedIdentifier(2, 0.4),
  259. }));
  260. }
  261. // Regression test for crbug.com/1303304.
  262. TEST_F(PageTopicsModelExecutorTest, NoneCategoryBelowMinWeight) {
  263. proto::PageTopicsModelMetadata model_metadata;
  264. model_metadata.set_version(123);
  265. auto* category_params = model_metadata.mutable_output_postprocessing_params()
  266. ->mutable_category_params();
  267. category_params->set_max_categories(4);
  268. category_params->set_min_none_weight(0.8);
  269. category_params->set_min_category_weight(0.01);
  270. category_params->set_min_normalized_weight_within_top_n(0.25);
  271. proto::Any any_metadata;
  272. any_metadata.set_type_url(
  273. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  274. model_metadata.SerializeToString(any_metadata.mutable_value());
  275. SendPageTopicsModelToExecutor(any_metadata);
  276. std::vector<tflite::task::core::Category> model_output = {
  277. {"-2", 0.001}, {"0", 0.001}, {"1", 0.25}, {"2", 0.4}, {"3", 0.05},
  278. };
  279. BatchAnnotationResult topics_result =
  280. BatchAnnotationResult::CreateEmptyAnnotationsResult("");
  281. model_executor()->PostprocessCategoriesToBatchAnnotationResult(
  282. base::BindOnce(
  283. [](BatchAnnotationResult* out_result,
  284. const BatchAnnotationResult& in_result) {
  285. *out_result = in_result;
  286. },
  287. &topics_result),
  288. AnnotationType::kPageTopics, "input", model_output);
  289. EXPECT_EQ(topics_result, BatchAnnotationResult::CreatePageTopicsResult(
  290. "input", std::vector<WeightedIdentifier>{
  291. WeightedIdentifier(1, 0.25),
  292. WeightedIdentifier(2, 0.4),
  293. }));
  294. }
  295. TEST_F(PageTopicsModelExecutorTest,
  296. NullPostprocessCategoriesToBatchAnnotationResult) {
  297. proto::PageTopicsModelMetadata model_metadata;
  298. model_metadata.set_version(123);
  299. proto::Any any_metadata;
  300. any_metadata.set_type_url(
  301. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  302. model_metadata.SerializeToString(any_metadata.mutable_value());
  303. SendPageTopicsModelToExecutor(any_metadata);
  304. BatchAnnotationResult topics_result =
  305. BatchAnnotationResult::CreateEmptyAnnotationsResult("");
  306. model_executor()->PostprocessCategoriesToBatchAnnotationResult(
  307. base::BindOnce(
  308. [](BatchAnnotationResult* out_result,
  309. const BatchAnnotationResult& in_result) {
  310. *out_result = in_result;
  311. },
  312. &topics_result),
  313. AnnotationType::kPageTopics, "", absl::nullopt);
  314. EXPECT_EQ(topics_result,
  315. BatchAnnotationResult::CreatePageTopicsResult("", absl::nullopt));
  316. }
  317. TEST_F(PageTopicsModelExecutorTest, HostPreprocessing) {
  318. std::vector<std::pair<std::string, std::string>> tests = {
  319. {"www.chromium.org", "chromium org"},
  320. {"foo-bar.com", "foo bar com"},
  321. {"foo_bar.com", "foo bar com"},
  322. {"cats.co.uk", "cats co uk"},
  323. {"cats+dogs.com", "cats dogs com"},
  324. {"www.foo-bar_.baz.com", "foo bar baz com"},
  325. {"www.foo-bar-baz.com", "foo bar baz com"},
  326. {"WwW.LOWER-CASE.com", "lower case com"},
  327. };
  328. for (const auto& test : tests) {
  329. std::string raw_host = test.first;
  330. std::string processed_host = test.second;
  331. std::string got_input;
  332. // The callback is run synchronously in this test.
  333. model_executor()->ExecuteOnSingleInput(
  334. AnnotationType::kPageTopics, raw_host,
  335. base::BindOnce(
  336. [](std::string* got_input_out,
  337. const BatchAnnotationResult& result) {
  338. EXPECT_EQ(result.type(), AnnotationType::kPageTopics);
  339. *got_input_out = result.input();
  340. },
  341. &got_input));
  342. EXPECT_EQ(raw_host, got_input);
  343. EXPECT_EQ(processed_host, model_executor()->inputs().back());
  344. }
  345. }
  346. class PageTopicsModelExecutorOverrideListTest
  347. : public PageTopicsModelExecutorTest {
  348. public:
  349. PageTopicsModelExecutorOverrideListTest() = default;
  350. ~PageTopicsModelExecutorOverrideListTest() override = default;
  351. void SetUp() override {
  352. PageTopicsModelExecutorTest::SetUp();
  353. ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
  354. }
  355. base::FilePath WriteToTempFile(const std::string& base_name,
  356. const std::string& contents) {
  357. base::FilePath abs_path = temp_dir_.GetPath().AppendASCII(base_name);
  358. EXPECT_TRUE(base::WriteFile(abs_path, contents));
  359. return abs_path;
  360. }
  361. std::string Compress(const std::string& data) {
  362. std::string compressed;
  363. EXPECT_TRUE(compression::GzipCompress(data, &compressed));
  364. return compressed;
  365. }
  366. void SendModelWithAdditionalFilesToExecutor(
  367. const base::flat_set<base::FilePath>& additional_files) {
  368. proto::PageTopicsModelMetadata model_metadata;
  369. model_metadata.set_version(123);
  370. proto::Any any_metadata;
  371. any_metadata.set_type_url(
  372. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  373. model_metadata.SerializeToString(any_metadata.mutable_value());
  374. base::FilePath source_root_dir;
  375. base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
  376. base::FilePath model_file_path =
  377. source_root_dir.AppendASCII("components")
  378. .AppendASCII("test")
  379. .AppendASCII("data")
  380. .AppendASCII("optimization_guide")
  381. // These tests don't need a valid model to execute as we don't care
  382. // about the model output or execution.
  383. .AppendASCII("model_doesnt_exist.tflite");
  384. std::unique_ptr<ModelInfo> model_info =
  385. TestModelInfoBuilder()
  386. .SetModelFilePath(model_file_path)
  387. .SetModelMetadata(any_metadata)
  388. .SetAdditionalFiles(additional_files)
  389. .Build();
  390. model_executor()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
  391. *model_info);
  392. RunUntilIdle();
  393. }
  394. const base::FilePath& temp_file_path() const { return temp_dir_.GetPath(); }
  395. private:
  396. base::ScopedTempDir temp_dir_;
  397. };
  398. TEST_F(PageTopicsModelExecutorOverrideListTest, NoAdditionalFiles) {
  399. base::HistogramTester histogram_tester;
  400. SendModelWithAdditionalFilesToExecutor({});
  401. histogram_tester.ExpectUniqueSample(
  402. "OptimizationGuide.PageTopicsOverrideList.GotFile", false, 1);
  403. }
  404. TEST_F(PageTopicsModelExecutorOverrideListTest, WrongAdditionalFileName) {
  405. base::HistogramTester histogram_tester;
  406. base::FilePath add_file =
  407. WriteToTempFile("tsil_eidrrevo.pb.gz", "file contents");
  408. SendModelWithAdditionalFilesToExecutor({add_file});
  409. histogram_tester.ExpectUniqueSample(
  410. "OptimizationGuide.PageTopicsOverrideList.GotFile", false, 1);
  411. }
  412. TEST_F(PageTopicsModelExecutorOverrideListTest, FileDoesntExist) {
  413. base::HistogramTester histogram_tester;
  414. base::FilePath doesnt_exist = temp_file_path().Append(
  415. base::FilePath(FILE_PATH_LITERAL("override_list.pb.gz")));
  416. SendModelWithAdditionalFilesToExecutor({doesnt_exist});
  417. base::RunLoop run_loop;
  418. model_executor()->ExecuteJob(
  419. run_loop.QuitClosure(),
  420. std::make_unique<PageContentAnnotationJob>(
  421. base::DoNothing(), std::vector<std::string>{"inputs"},
  422. AnnotationType::kPageTopics));
  423. run_loop.Run();
  424. histogram_tester.ExpectUniqueSample(
  425. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  426. /*OverrideListFileLoadResult::kCouldNotReadFile=*/2, 1);
  427. histogram_tester.ExpectUniqueSample(
  428. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  429. histogram_tester.ExpectTotalCount(
  430. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", 0);
  431. }
  432. TEST_F(PageTopicsModelExecutorOverrideListTest, BadGzip) {
  433. base::HistogramTester histogram_tester;
  434. base::FilePath add_file =
  435. WriteToTempFile("override_list.pb.gz", std::string());
  436. SendModelWithAdditionalFilesToExecutor({add_file});
  437. base::RunLoop run_loop;
  438. model_executor()->ExecuteJob(
  439. run_loop.QuitClosure(),
  440. std::make_unique<PageContentAnnotationJob>(
  441. base::DoNothing(), std::vector<std::string>{"inputs"},
  442. AnnotationType::kPageTopics));
  443. run_loop.Run();
  444. histogram_tester.ExpectUniqueSample(
  445. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  446. /*OverrideListFileLoadResult::kCouldNotUncompressFile=*/3, 1);
  447. histogram_tester.ExpectUniqueSample(
  448. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  449. histogram_tester.ExpectTotalCount(
  450. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", 0);
  451. }
  452. TEST_F(PageTopicsModelExecutorOverrideListTest, BadProto) {
  453. base::HistogramTester histogram_tester;
  454. base::FilePath add_file =
  455. WriteToTempFile("override_list.pb.gz", Compress("bad protobuf"));
  456. SendModelWithAdditionalFilesToExecutor({add_file});
  457. base::RunLoop run_loop;
  458. model_executor()->ExecuteJob(
  459. run_loop.QuitClosure(),
  460. std::make_unique<PageContentAnnotationJob>(
  461. base::DoNothing(), std::vector<std::string>{"inputs"},
  462. AnnotationType::kPageTopics));
  463. run_loop.Run();
  464. histogram_tester.ExpectUniqueSample(
  465. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  466. /*OverrideListFileLoadResult::kCouldNotUnmarshalProtobuf=*/4, 1);
  467. histogram_tester.ExpectUniqueSample(
  468. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  469. histogram_tester.ExpectTotalCount(
  470. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", 0);
  471. }
  472. TEST_F(PageTopicsModelExecutorOverrideListTest, SuccessCase) {
  473. base::HistogramTester histogram_tester;
  474. proto::PageTopicsOverrideList override_list;
  475. proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
  476. entry->set_domain("input com");
  477. entry->mutable_topics()->add_topic_ids(1337);
  478. std::string enc_pb;
  479. ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
  480. base::FilePath add_file =
  481. WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
  482. SendModelWithAdditionalFilesToExecutor({add_file});
  483. base::RunLoop run_loop;
  484. model_executor()->ExecuteJob(
  485. run_loop.QuitClosure(),
  486. std::make_unique<PageContentAnnotationJob>(
  487. base::BindOnce([](const std::vector<BatchAnnotationResult>& results) {
  488. ASSERT_EQ(results.size(), 1U);
  489. EXPECT_EQ(results[0].input(), "www.input.com");
  490. EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
  491. ASSERT_TRUE(results[0].topics());
  492. EXPECT_EQ(*results[0].topics(), (std::vector<WeightedIdentifier>{
  493. WeightedIdentifier(1337, 1.0),
  494. }));
  495. }),
  496. std::vector<std::string>{"www.input.com"},
  497. AnnotationType::kPageTopics));
  498. run_loop.Run();
  499. histogram_tester.ExpectUniqueSample(
  500. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  501. /*OverrideListFileLoadResult::kSuccess=*/1, 1);
  502. histogram_tester.ExpectUniqueSample(
  503. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  504. histogram_tester.ExpectUniqueSample(
  505. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 1);
  506. }
  507. TEST_F(PageTopicsModelExecutorOverrideListTest, InputNotInOverride) {
  508. base::HistogramTester histogram_tester;
  509. proto::PageTopicsOverrideList override_list;
  510. proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
  511. entry->set_domain("other");
  512. entry->mutable_topics()->add_topic_ids(1337);
  513. std::string enc_pb;
  514. ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
  515. base::FilePath add_file =
  516. WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
  517. SendModelWithAdditionalFilesToExecutor({add_file});
  518. base::RunLoop run_loop;
  519. model_executor()->ExecuteJob(
  520. run_loop.QuitClosure(),
  521. std::make_unique<PageContentAnnotationJob>(
  522. base::BindOnce([](const std::vector<BatchAnnotationResult>& results) {
  523. ASSERT_EQ(results.size(), 1U);
  524. EXPECT_EQ(results[0].input(), "input");
  525. EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
  526. // The passed model file isn't valid so we don't expect an output
  527. // here.
  528. EXPECT_FALSE(results[0].topics());
  529. }),
  530. std::vector<std::string>{"input"}, AnnotationType::kPageTopics));
  531. run_loop.Run();
  532. histogram_tester.ExpectUniqueSample(
  533. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  534. /*OverrideListFileLoadResult::kSuccess=*/1, 1);
  535. histogram_tester.ExpectUniqueSample(
  536. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  537. histogram_tester.ExpectUniqueSample(
  538. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", false, 1);
  539. }
  540. // Regression test for crbug.com/1321808.
  541. TEST_F(PageTopicsModelExecutorOverrideListTest, KeepsOrdering) {
  542. base::HistogramTester histogram_tester;
  543. proto::PageTopicsOverrideList override_list;
  544. proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
  545. entry->set_domain("in list");
  546. entry->mutable_topics()->add_topic_ids(1337);
  547. std::string enc_pb;
  548. ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
  549. base::FilePath add_file =
  550. WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
  551. SendModelWithAdditionalFilesToExecutor({add_file});
  552. base::RunLoop run_loop;
  553. model_executor()->ExecuteJob(
  554. run_loop.QuitClosure(),
  555. std::make_unique<PageContentAnnotationJob>(
  556. base::BindOnce([](const std::vector<BatchAnnotationResult>& results) {
  557. ASSERT_EQ(results.size(), 2U);
  558. EXPECT_EQ(results[0].input(), "not in list");
  559. EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
  560. EXPECT_FALSE(results[0].topics());
  561. EXPECT_EQ(results[1].input(), "in list");
  562. EXPECT_EQ(results[1].type(), AnnotationType::kPageTopics);
  563. EXPECT_TRUE(results[1].topics());
  564. }),
  565. std::vector<std::string>{"not in list", "in list"},
  566. AnnotationType::kPageTopics));
  567. run_loop.Run();
  568. histogram_tester.ExpectUniqueSample(
  569. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  570. /*OverrideListFileLoadResult::kSuccess=*/1, 1);
  571. histogram_tester.ExpectUniqueSample(
  572. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  573. }
  574. TEST_F(PageTopicsModelExecutorOverrideListTest, ModelUnloadsOverrideList) {
  575. base::HistogramTester histogram_tester;
  576. proto::PageTopicsOverrideList override_list;
  577. proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
  578. entry->set_domain("input");
  579. entry->mutable_topics()->add_topic_ids(1337);
  580. std::string enc_pb;
  581. ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
  582. base::FilePath add_file =
  583. WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
  584. SendModelWithAdditionalFilesToExecutor({add_file});
  585. {
  586. base::RunLoop run_loop;
  587. model_executor()->ExecuteJob(
  588. run_loop.QuitClosure(),
  589. std::make_unique<PageContentAnnotationJob>(
  590. base::DoNothing(), std::vector<std::string>{"input"},
  591. AnnotationType::kPageTopics));
  592. run_loop.Run();
  593. histogram_tester.ExpectUniqueSample(
  594. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 1);
  595. }
  596. // Request the model to be unloaded, which should also unload the override
  597. // list.
  598. model_executor()->UnloadModel();
  599. // Retry an execution and check that the UMA reports the override list being
  600. // loaded twice.
  601. {
  602. base::RunLoop run_loop;
  603. model_executor()->ExecuteJob(
  604. run_loop.QuitClosure(),
  605. std::make_unique<PageContentAnnotationJob>(
  606. base::DoNothing(), std::vector<std::string>{"input"},
  607. AnnotationType::kPageTopics));
  608. run_loop.Run();
  609. histogram_tester.ExpectUniqueSample(
  610. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 2);
  611. }
  612. histogram_tester.ExpectUniqueSample(
  613. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  614. /*OverrideListFileLoadResult::kSuccess=*/1, 2);
  615. histogram_tester.ExpectUniqueSample(
  616. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
  617. }
  618. TEST_F(PageTopicsModelExecutorOverrideListTest, NewModelUnloadsOverrideList) {
  619. base::HistogramTester histogram_tester;
  620. {
  621. proto::PageTopicsOverrideList override_list;
  622. proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
  623. entry->set_domain("input");
  624. entry->mutable_topics()->add_topic_ids(1337);
  625. std::string enc_pb;
  626. ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
  627. base::FilePath add_file =
  628. WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
  629. SendModelWithAdditionalFilesToExecutor({add_file});
  630. base::RunLoop run_loop;
  631. model_executor()->ExecuteJob(
  632. run_loop.QuitClosure(),
  633. std::make_unique<PageContentAnnotationJob>(
  634. base::BindOnce(
  635. [](const std::vector<BatchAnnotationResult>& results) {
  636. ASSERT_EQ(results.size(), 1U);
  637. EXPECT_EQ(results[0].input(), "input");
  638. EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
  639. ASSERT_TRUE(results[0].topics());
  640. EXPECT_EQ(*results[0].topics(),
  641. (std::vector<WeightedIdentifier>{
  642. WeightedIdentifier(1337, 1.0),
  643. }));
  644. }),
  645. std::vector<std::string>{"input"}, AnnotationType::kPageTopics));
  646. run_loop.Run();
  647. histogram_tester.ExpectUniqueSample(
  648. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 1);
  649. }
  650. // Retry an execution and check that the UMA reports the override list being
  651. // loaded twice, and that the topics are now different.
  652. {
  653. proto::PageTopicsOverrideList override_list;
  654. proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
  655. entry->set_domain("input");
  656. entry->mutable_topics()->add_topic_ids(7331);
  657. std::string enc_pb;
  658. ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
  659. base::FilePath add_file =
  660. WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
  661. SendModelWithAdditionalFilesToExecutor({add_file});
  662. base::RunLoop run_loop;
  663. model_executor()->ExecuteJob(
  664. run_loop.QuitClosure(),
  665. std::make_unique<PageContentAnnotationJob>(
  666. base::BindOnce(
  667. [](const std::vector<BatchAnnotationResult>& results) {
  668. ASSERT_EQ(results.size(), 1U);
  669. EXPECT_EQ(results[0].input(), "input");
  670. EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
  671. ASSERT_TRUE(results[0].topics());
  672. EXPECT_EQ(*results[0].topics(),
  673. (std::vector<WeightedIdentifier>{
  674. WeightedIdentifier(7331, 1.0),
  675. }));
  676. }),
  677. std::vector<std::string>{"input"}, AnnotationType::kPageTopics));
  678. run_loop.Run();
  679. }
  680. histogram_tester.ExpectUniqueSample(
  681. "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
  682. /*OverrideListFileLoadResult::kSuccess=*/1, 2);
  683. histogram_tester.ExpectUniqueSample(
  684. "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 2);
  685. histogram_tester.ExpectUniqueSample(
  686. "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 2);
  687. }
  688. } // namespace optimization_guide