123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804 |
- // Copyright 2021 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/optimization_guide/core/page_topics_model_executor.h"
- #include "base/containers/flat_map.h"
- #include "base/files/file_util.h"
- #include "base/files/scoped_temp_dir.h"
- #include "base/path_service.h"
- #include "base/task/thread_pool.h"
- #include "base/test/metrics/histogram_tester.h"
- #include "base/test/scoped_feature_list.h"
- #include "base/test/task_environment.h"
- #include "components/optimization_guide/core/optimization_guide_features.h"
- #include "components/optimization_guide/core/optimization_guide_model_provider.h"
- #include "components/optimization_guide/core/page_entities_model_executor.h"
- #include "components/optimization_guide/core/test_model_info_builder.h"
- #include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
- #include "components/optimization_guide/proto/models.pb.h"
- #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
- #include "components/optimization_guide/proto/page_topics_override_list.pb.h"
- #include "testing/gmock/include/gmock/gmock.h"
- #include "testing/gtest/include/gtest/gtest.h"
- #include "third_party/zlib/google/compression_utils.h"
- namespace optimization_guide {
- class ModelObserverTracker : public TestOptimizationGuideModelProvider {
- public:
- void AddObserverForOptimizationTargetModel(
- proto::OptimizationTarget target,
- const absl::optional<proto::Any>& model_metadata,
- OptimizationTargetModelObserver* observer) override {
- registered_model_metadata_.insert_or_assign(target, model_metadata);
- }
- bool DidRegisterForTarget(
- proto::OptimizationTarget target,
- absl::optional<proto::Any>* out_model_metadata) const {
- auto it = registered_model_metadata_.find(target);
- if (it == registered_model_metadata_.end())
- return false;
- *out_model_metadata = registered_model_metadata_.at(target);
- return true;
- }
- private:
- base::flat_map<proto::OptimizationTarget, absl::optional<proto::Any>>
- registered_model_metadata_;
- };
- class TestPageTopicsModelExecutor : public PageTopicsModelExecutor {
- public:
- TestPageTopicsModelExecutor(
- OptimizationGuideModelProvider* model_provider,
- scoped_refptr<base::SequencedTaskRunner> background_task_runner,
- const absl::optional<proto::Any>& model_metadata)
- : PageTopicsModelExecutor(model_provider,
- background_task_runner,
- model_metadata) {}
- ~TestPageTopicsModelExecutor() override = default;
- void ExecuteModelWithInput(ExecutionCallback callback,
- const std::string& input) override {
- inputs_.push_back(input);
- std::move(callback).Run(absl::nullopt);
- }
- const std::vector<std::string>& inputs() const { return inputs_; }
- private:
- std::vector<std::string> inputs_;
- };
- class PageTopicsModelExecutorTest : public testing::Test {
- public:
- PageTopicsModelExecutorTest() {
- scoped_feature_list_.InitWithFeatures(
- {features::kPageContentAnnotations},
- {features::kPreventLongRunningPredictionModels});
- }
- ~PageTopicsModelExecutorTest() override = default;
- void SetUp() override {
- model_observer_tracker_ = std::make_unique<ModelObserverTracker>();
- model_executor_ = std::make_unique<TestPageTopicsModelExecutor>(
- model_observer_tracker_.get(),
- base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}),
- /*model_metadata=*/absl::nullopt);
- }
- void TearDown() override {
- model_executor_.reset();
- model_observer_tracker_.reset();
- RunUntilIdle();
- }
- void SendPageTopicsModelToExecutor(
- const absl::optional<proto::Any>& model_metadata) {
- base::FilePath source_root_dir;
- base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
- base::FilePath model_file_path =
- source_root_dir.AppendASCII("components")
- .AppendASCII("test")
- .AppendASCII("data")
- .AppendASCII("optimization_guide")
- .AppendASCII("bert_page_topics_model.tflite");
- std::unique_ptr<ModelInfo> model_info =
- TestModelInfoBuilder()
- .SetModelFilePath(model_file_path)
- .SetModelMetadata(model_metadata)
- .Build();
- model_executor()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
- *model_info);
- RunUntilIdle();
- }
- ModelObserverTracker* model_observer_tracker() const {
- return model_observer_tracker_.get();
- }
- TestPageTopicsModelExecutor* model_executor() const {
- return model_executor_.get();
- }
- void RunUntilIdle() { task_environment_.RunUntilIdle(); }
- private:
- base::test::TaskEnvironment task_environment_;
- base::test::ScopedFeatureList scoped_feature_list_;
- std::unique_ptr<ModelObserverTracker> model_observer_tracker_;
- std::unique_ptr<TestPageTopicsModelExecutor> model_executor_;
- };
- TEST_F(
- PageTopicsModelExecutorTest,
- GetContentModelAnnotationsFromOutputNonNumericAndLowWeightCategoriesPruned) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- auto* category_params = model_metadata.mutable_output_postprocessing_params()
- ->mutable_category_params();
- category_params->set_max_categories(4);
- category_params->set_min_none_weight(0.8);
- category_params->set_min_category_weight(0.01);
- category_params->set_min_normalized_weight_within_top_n(0.1);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- std::vector<tflite::task::core::Category> model_output = {
- {"0", 0.0001}, {"1", 0.1}, {"not an int", 0.9}, {"2", 0.2}, {"3", 0.3},
- };
- absl::optional<std::vector<WeightedIdentifier>> categories =
- model_executor()->ExtractCategoriesFromModelOutput(model_output);
- ASSERT_TRUE(categories);
- EXPECT_THAT(*categories,
- testing::UnorderedElementsAre(WeightedIdentifier(1, 0.1),
- WeightedIdentifier(2, 0.2),
- WeightedIdentifier(3, 0.3)));
- }
- TEST_F(PageTopicsModelExecutorTest,
- GetContentModelAnnotationsFromOutputNoneWeightTooStrong) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- auto* category_params = model_metadata.mutable_output_postprocessing_params()
- ->mutable_category_params();
- category_params->set_max_categories(4);
- category_params->set_min_none_weight(0.1);
- category_params->set_min_category_weight(0.01);
- category_params->set_min_normalized_weight_within_top_n(0.1);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- std::vector<tflite::task::core::Category> model_output = {
- {"-2", 0.9999},
- {"0", 0.3},
- {"1", 0.2},
- };
- absl::optional<std::vector<WeightedIdentifier>> categories =
- model_executor()->ExtractCategoriesFromModelOutput(model_output);
- EXPECT_FALSE(categories);
- }
- TEST_F(PageTopicsModelExecutorTest,
- GetContentModelAnnotationsFromOutputNoneInTopButNotStrongSoPruned) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- auto* category_params = model_metadata.mutable_output_postprocessing_params()
- ->mutable_category_params();
- category_params->set_max_categories(4);
- category_params->set_min_none_weight(0.8);
- category_params->set_min_category_weight(0.01);
- category_params->set_min_normalized_weight_within_top_n(0.1);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- std::vector<tflite::task::core::Category> model_output = {
- {"-2", 0.1}, {"0", 0.3}, {"1", 0.2}, {"2", 0.4}, {"3", 0.05},
- };
- absl::optional<std::vector<WeightedIdentifier>> categories =
- model_executor()->ExtractCategoriesFromModelOutput(model_output);
- ASSERT_TRUE(categories);
- EXPECT_THAT(*categories,
- testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3),
- WeightedIdentifier(1, 0.2),
- WeightedIdentifier(2, 0.4)));
- }
- TEST_F(PageTopicsModelExecutorTest,
- GetContentModelAnnotationsFromOutputPrunedAfterNormalization) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- auto* category_params = model_metadata.mutable_output_postprocessing_params()
- ->mutable_category_params();
- category_params->set_max_categories(4);
- category_params->set_min_none_weight(0.8);
- category_params->set_min_category_weight(0.01);
- category_params->set_min_normalized_weight_within_top_n(0.25);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- std::vector<tflite::task::core::Category> model_output = {
- {"0", 0.3},
- {"1", 0.25},
- {"2", 0.4},
- {"3", 0.05},
- };
- absl::optional<std::vector<WeightedIdentifier>> categories =
- model_executor()->ExtractCategoriesFromModelOutput(model_output);
- ASSERT_TRUE(categories);
- EXPECT_THAT(*categories,
- testing::UnorderedElementsAre(WeightedIdentifier(0, 0.3),
- WeightedIdentifier(1, 0.25),
- WeightedIdentifier(2, 0.4)));
- }
- TEST_F(PageTopicsModelExecutorTest,
- PostprocessCategoriesToBatchAnnotationResult) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- auto* category_params = model_metadata.mutable_output_postprocessing_params()
- ->mutable_category_params();
- category_params->set_max_categories(4);
- category_params->set_min_none_weight(0.8);
- category_params->set_min_category_weight(0.01);
- category_params->set_min_normalized_weight_within_top_n(0.25);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- std::vector<tflite::task::core::Category> model_output = {
- {"0", 0.3},
- {"1", 0.25},
- {"2", 0.4},
- {"3", 0.05},
- };
- BatchAnnotationResult topics_result =
- BatchAnnotationResult::CreateEmptyAnnotationsResult("");
- model_executor()->PostprocessCategoriesToBatchAnnotationResult(
- base::BindOnce(
- [](BatchAnnotationResult* out_result,
- const BatchAnnotationResult& in_result) {
- *out_result = in_result;
- },
- &topics_result),
- AnnotationType::kPageTopics, "input", model_output);
- EXPECT_EQ(topics_result, BatchAnnotationResult::CreatePageTopicsResult(
- "input", std::vector<WeightedIdentifier>{
- WeightedIdentifier(0, 0.3),
- WeightedIdentifier(1, 0.25),
- WeightedIdentifier(2, 0.4),
- }));
- }
- // Regression test for crbug.com/1303304.
- TEST_F(PageTopicsModelExecutorTest, NoneCategoryBelowMinWeight) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- auto* category_params = model_metadata.mutable_output_postprocessing_params()
- ->mutable_category_params();
- category_params->set_max_categories(4);
- category_params->set_min_none_weight(0.8);
- category_params->set_min_category_weight(0.01);
- category_params->set_min_normalized_weight_within_top_n(0.25);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- std::vector<tflite::task::core::Category> model_output = {
- {"-2", 0.001}, {"0", 0.001}, {"1", 0.25}, {"2", 0.4}, {"3", 0.05},
- };
- BatchAnnotationResult topics_result =
- BatchAnnotationResult::CreateEmptyAnnotationsResult("");
- model_executor()->PostprocessCategoriesToBatchAnnotationResult(
- base::BindOnce(
- [](BatchAnnotationResult* out_result,
- const BatchAnnotationResult& in_result) {
- *out_result = in_result;
- },
- &topics_result),
- AnnotationType::kPageTopics, "input", model_output);
- EXPECT_EQ(topics_result, BatchAnnotationResult::CreatePageTopicsResult(
- "input", std::vector<WeightedIdentifier>{
- WeightedIdentifier(1, 0.25),
- WeightedIdentifier(2, 0.4),
- }));
- }
- TEST_F(PageTopicsModelExecutorTest,
- NullPostprocessCategoriesToBatchAnnotationResult) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- SendPageTopicsModelToExecutor(any_metadata);
- BatchAnnotationResult topics_result =
- BatchAnnotationResult::CreateEmptyAnnotationsResult("");
- model_executor()->PostprocessCategoriesToBatchAnnotationResult(
- base::BindOnce(
- [](BatchAnnotationResult* out_result,
- const BatchAnnotationResult& in_result) {
- *out_result = in_result;
- },
- &topics_result),
- AnnotationType::kPageTopics, "", absl::nullopt);
- EXPECT_EQ(topics_result,
- BatchAnnotationResult::CreatePageTopicsResult("", absl::nullopt));
- }
- TEST_F(PageTopicsModelExecutorTest, HostPreprocessing) {
- std::vector<std::pair<std::string, std::string>> tests = {
- {"www.chromium.org", "chromium org"},
- {"foo-bar.com", "foo bar com"},
- {"foo_bar.com", "foo bar com"},
- {"cats.co.uk", "cats co uk"},
- {"cats+dogs.com", "cats dogs com"},
- {"www.foo-bar_.baz.com", "foo bar baz com"},
- {"www.foo-bar-baz.com", "foo bar baz com"},
- {"WwW.LOWER-CASE.com", "lower case com"},
- };
- for (const auto& test : tests) {
- std::string raw_host = test.first;
- std::string processed_host = test.second;
- std::string got_input;
- // The callback is run synchronously in this test.
- model_executor()->ExecuteOnSingleInput(
- AnnotationType::kPageTopics, raw_host,
- base::BindOnce(
- [](std::string* got_input_out,
- const BatchAnnotationResult& result) {
- EXPECT_EQ(result.type(), AnnotationType::kPageTopics);
- *got_input_out = result.input();
- },
- &got_input));
- EXPECT_EQ(raw_host, got_input);
- EXPECT_EQ(processed_host, model_executor()->inputs().back());
- }
- }
- class PageTopicsModelExecutorOverrideListTest
- : public PageTopicsModelExecutorTest {
- public:
- PageTopicsModelExecutorOverrideListTest() = default;
- ~PageTopicsModelExecutorOverrideListTest() override = default;
- void SetUp() override {
- PageTopicsModelExecutorTest::SetUp();
- ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
- }
- base::FilePath WriteToTempFile(const std::string& base_name,
- const std::string& contents) {
- base::FilePath abs_path = temp_dir_.GetPath().AppendASCII(base_name);
- EXPECT_TRUE(base::WriteFile(abs_path, contents));
- return abs_path;
- }
- std::string Compress(const std::string& data) {
- std::string compressed;
- EXPECT_TRUE(compression::GzipCompress(data, &compressed));
- return compressed;
- }
- void SendModelWithAdditionalFilesToExecutor(
- const base::flat_set<base::FilePath>& additional_files) {
- proto::PageTopicsModelMetadata model_metadata;
- model_metadata.set_version(123);
- proto::Any any_metadata;
- any_metadata.set_type_url(
- "type.googleapis.com/com.foo.PageTopicsModelMetadata");
- model_metadata.SerializeToString(any_metadata.mutable_value());
- base::FilePath source_root_dir;
- base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
- base::FilePath model_file_path =
- source_root_dir.AppendASCII("components")
- .AppendASCII("test")
- .AppendASCII("data")
- .AppendASCII("optimization_guide")
- // These tests don't need a valid model to execute as we don't care
- // about the model output or execution.
- .AppendASCII("model_doesnt_exist.tflite");
- std::unique_ptr<ModelInfo> model_info =
- TestModelInfoBuilder()
- .SetModelFilePath(model_file_path)
- .SetModelMetadata(any_metadata)
- .SetAdditionalFiles(additional_files)
- .Build();
- model_executor()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_TOPICS_V2,
- *model_info);
- RunUntilIdle();
- }
- const base::FilePath& temp_file_path() const { return temp_dir_.GetPath(); }
- private:
- base::ScopedTempDir temp_dir_;
- };
- TEST_F(PageTopicsModelExecutorOverrideListTest, NoAdditionalFiles) {
- base::HistogramTester histogram_tester;
- SendModelWithAdditionalFilesToExecutor({});
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", false, 1);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, WrongAdditionalFileName) {
- base::HistogramTester histogram_tester;
- base::FilePath add_file =
- WriteToTempFile("tsil_eidrrevo.pb.gz", "file contents");
- SendModelWithAdditionalFilesToExecutor({add_file});
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", false, 1);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, FileDoesntExist) {
- base::HistogramTester histogram_tester;
- base::FilePath doesnt_exist = temp_file_path().Append(
- base::FilePath(FILE_PATH_LITERAL("override_list.pb.gz")));
- SendModelWithAdditionalFilesToExecutor({doesnt_exist});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::DoNothing(), std::vector<std::string>{"inputs"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kCouldNotReadFile=*/2, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", 0);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, BadGzip) {
- base::HistogramTester histogram_tester;
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", std::string());
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::DoNothing(), std::vector<std::string>{"inputs"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kCouldNotUncompressFile=*/3, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", 0);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, BadProto) {
- base::HistogramTester histogram_tester;
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress("bad protobuf"));
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::DoNothing(), std::vector<std::string>{"inputs"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kCouldNotUnmarshalProtobuf=*/4, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- histogram_tester.ExpectTotalCount(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", 0);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, SuccessCase) {
- base::HistogramTester histogram_tester;
- proto::PageTopicsOverrideList override_list;
- proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
- entry->set_domain("input com");
- entry->mutable_topics()->add_topic_ids(1337);
- std::string enc_pb;
- ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::BindOnce([](const std::vector<BatchAnnotationResult>& results) {
- ASSERT_EQ(results.size(), 1U);
- EXPECT_EQ(results[0].input(), "www.input.com");
- EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
- ASSERT_TRUE(results[0].topics());
- EXPECT_EQ(*results[0].topics(), (std::vector<WeightedIdentifier>{
- WeightedIdentifier(1337, 1.0),
- }));
- }),
- std::vector<std::string>{"www.input.com"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kSuccess=*/1, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 1);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, InputNotInOverride) {
- base::HistogramTester histogram_tester;
- proto::PageTopicsOverrideList override_list;
- proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
- entry->set_domain("other");
- entry->mutable_topics()->add_topic_ids(1337);
- std::string enc_pb;
- ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::BindOnce([](const std::vector<BatchAnnotationResult>& results) {
- ASSERT_EQ(results.size(), 1U);
- EXPECT_EQ(results[0].input(), "input");
- EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
- // The passed model file isn't valid so we don't expect an output
- // here.
- EXPECT_FALSE(results[0].topics());
- }),
- std::vector<std::string>{"input"}, AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kSuccess=*/1, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", false, 1);
- }
- // Regression test for crbug.com/1321808.
- TEST_F(PageTopicsModelExecutorOverrideListTest, KeepsOrdering) {
- base::HistogramTester histogram_tester;
- proto::PageTopicsOverrideList override_list;
- proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
- entry->set_domain("in list");
- entry->mutable_topics()->add_topic_ids(1337);
- std::string enc_pb;
- ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::BindOnce([](const std::vector<BatchAnnotationResult>& results) {
- ASSERT_EQ(results.size(), 2U);
- EXPECT_EQ(results[0].input(), "not in list");
- EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
- EXPECT_FALSE(results[0].topics());
- EXPECT_EQ(results[1].input(), "in list");
- EXPECT_EQ(results[1].type(), AnnotationType::kPageTopics);
- EXPECT_TRUE(results[1].topics());
- }),
- std::vector<std::string>{"not in list", "in list"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kSuccess=*/1, 1);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, ModelUnloadsOverrideList) {
- base::HistogramTester histogram_tester;
- proto::PageTopicsOverrideList override_list;
- proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
- entry->set_domain("input");
- entry->mutable_topics()->add_topic_ids(1337);
- std::string enc_pb;
- ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
- SendModelWithAdditionalFilesToExecutor({add_file});
- {
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::DoNothing(), std::vector<std::string>{"input"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 1);
- }
- // Request the model to be unloaded, which should also unload the override
- // list.
- model_executor()->UnloadModel();
- // Retry an execution and check that the UMA reports the override list being
- // loaded twice.
- {
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::DoNothing(), std::vector<std::string>{"input"},
- AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 2);
- }
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kSuccess=*/1, 2);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 1);
- }
- TEST_F(PageTopicsModelExecutorOverrideListTest, NewModelUnloadsOverrideList) {
- base::HistogramTester histogram_tester;
- {
- proto::PageTopicsOverrideList override_list;
- proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
- entry->set_domain("input");
- entry->mutable_topics()->add_topic_ids(1337);
- std::string enc_pb;
- ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::BindOnce(
- [](const std::vector<BatchAnnotationResult>& results) {
- ASSERT_EQ(results.size(), 1U);
- EXPECT_EQ(results[0].input(), "input");
- EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
- ASSERT_TRUE(results[0].topics());
- EXPECT_EQ(*results[0].topics(),
- (std::vector<WeightedIdentifier>{
- WeightedIdentifier(1337, 1.0),
- }));
- }),
- std::vector<std::string>{"input"}, AnnotationType::kPageTopics));
- run_loop.Run();
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 1);
- }
- // Retry an execution and check that the UMA reports the override list being
- // loaded twice, and that the topics are now different.
- {
- proto::PageTopicsOverrideList override_list;
- proto::PageTopicsOverrideEntry* entry = override_list.add_entries();
- entry->set_domain("input");
- entry->mutable_topics()->add_topic_ids(7331);
- std::string enc_pb;
- ASSERT_TRUE(override_list.SerializeToString(&enc_pb));
- base::FilePath add_file =
- WriteToTempFile("override_list.pb.gz", Compress(enc_pb));
- SendModelWithAdditionalFilesToExecutor({add_file});
- base::RunLoop run_loop;
- model_executor()->ExecuteJob(
- run_loop.QuitClosure(),
- std::make_unique<PageContentAnnotationJob>(
- base::BindOnce(
- [](const std::vector<BatchAnnotationResult>& results) {
- ASSERT_EQ(results.size(), 1U);
- EXPECT_EQ(results[0].input(), "input");
- EXPECT_EQ(results[0].type(), AnnotationType::kPageTopics);
- ASSERT_TRUE(results[0].topics());
- EXPECT_EQ(*results[0].topics(),
- (std::vector<WeightedIdentifier>{
- WeightedIdentifier(7331, 1.0),
- }));
- }),
- std::vector<std::string>{"input"}, AnnotationType::kPageTopics));
- run_loop.Run();
- }
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.FileLoadResult",
- /*OverrideListFileLoadResult::kSuccess=*/1, 2);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.GotFile", true, 2);
- histogram_tester.ExpectUniqueSample(
- "OptimizationGuide.PageTopicsOverrideList.UsedOverride", true, 2);
- }
- } // namespace optimization_guide
|