page_visibility_model_executor_unittest.cc 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/page_visibility_model_executor.h"
  5. #include "base/containers/flat_map.h"
  6. #include "base/path_service.h"
  7. #include "base/task/thread_pool.h"
  8. #include "base/test/scoped_feature_list.h"
  9. #include "base/test/task_environment.h"
  10. #include "components/optimization_guide/core/optimization_guide_features.h"
  11. #include "components/optimization_guide/core/optimization_guide_model_provider.h"
  12. #include "components/optimization_guide/core/page_entities_model_executor.h"
  13. #include "components/optimization_guide/core/test_model_info_builder.h"
  14. #include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
  15. #include "components/optimization_guide/proto/models.pb.h"
  16. #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
  17. #include "testing/gmock/include/gmock/gmock.h"
  18. #include "testing/gtest/include/gtest/gtest.h"
  19. namespace optimization_guide {
  20. class ModelObserverTracker : public TestOptimizationGuideModelProvider {
  21. public:
  22. void AddObserverForOptimizationTargetModel(
  23. proto::OptimizationTarget target,
  24. const absl::optional<proto::Any>& model_metadata,
  25. OptimizationTargetModelObserver* observer) override {
  26. registered_model_metadata_.insert_or_assign(target, model_metadata);
  27. }
  28. bool DidRegisterForTarget(
  29. proto::OptimizationTarget target,
  30. absl::optional<proto::Any>* out_model_metadata) const {
  31. auto it = registered_model_metadata_.find(target);
  32. if (it == registered_model_metadata_.end())
  33. return false;
  34. *out_model_metadata = registered_model_metadata_.at(target);
  35. return true;
  36. }
  37. private:
  38. base::flat_map<proto::OptimizationTarget, absl::optional<proto::Any>>
  39. registered_model_metadata_;
  40. };
  41. class PageVisibilityModelExecutorTest : public testing::Test {
  42. public:
  43. PageVisibilityModelExecutorTest() {
  44. scoped_feature_list_.InitAndEnableFeature(
  45. features::kPageContentAnnotations);
  46. }
  47. ~PageVisibilityModelExecutorTest() override = default;
  48. void SetUp() override {
  49. model_observer_tracker_ = std::make_unique<ModelObserverTracker>();
  50. model_executor_ = std::make_unique<PageVisibilityModelExecutor>(
  51. model_observer_tracker_.get(),
  52. base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}),
  53. /*model_metadata=*/absl::nullopt);
  54. }
  55. void TearDown() override {
  56. model_executor_.reset();
  57. model_observer_tracker_.reset();
  58. RunUntilIdle();
  59. }
  60. void SendPageVisibilityModelToExecutor(
  61. const absl::optional<proto::Any>& model_metadata) {
  62. base::FilePath source_root_dir;
  63. base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
  64. base::FilePath model_file_path =
  65. source_root_dir.AppendASCII("components")
  66. .AppendASCII("test")
  67. .AppendASCII("data")
  68. .AppendASCII("optimization_guide")
  69. .AppendASCII("bert_page_topics_model.tflite");
  70. std::unique_ptr<ModelInfo> model_info =
  71. TestModelInfoBuilder()
  72. .SetModelFilePath(model_file_path)
  73. .SetModelMetadata(model_metadata)
  74. .Build();
  75. model_executor()->OnModelUpdated(proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY,
  76. *model_info);
  77. RunUntilIdle();
  78. }
  79. ModelObserverTracker* model_observer_tracker() const {
  80. return model_observer_tracker_.get();
  81. }
  82. PageVisibilityModelExecutor* model_executor() const {
  83. return model_executor_.get();
  84. }
  85. void RunUntilIdle() { task_environment_.RunUntilIdle(); }
  86. private:
  87. base::test::TaskEnvironment task_environment_;
  88. base::test::ScopedFeatureList scoped_feature_list_;
  89. std::unique_ptr<ModelObserverTracker> model_observer_tracker_;
  90. std::unique_ptr<PageVisibilityModelExecutor> model_executor_;
  91. };
  92. TEST_F(PageVisibilityModelExecutorTest, NoModelMetadataNoOutput) {
  93. // Note that |SendPageVisibilityModelToExecutor| is not called so no metadata
  94. // has been loaded.
  95. std::vector<tflite::task::core::Category> model_output = {
  96. {"VISIBILITY_HERE", 0.3},
  97. };
  98. absl::optional<double> score =
  99. model_executor()->ExtractContentVisibilityFromModelOutput(model_output);
  100. EXPECT_FALSE(score);
  101. }
  102. TEST_F(PageVisibilityModelExecutorTest, NoParamsNoOutput) {
  103. proto::PageTopicsModelMetadata model_metadata;
  104. model_metadata.set_version(123);
  105. proto::Any any_metadata;
  106. any_metadata.set_type_url(
  107. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  108. model_metadata.SerializeToString(any_metadata.mutable_value());
  109. SendPageVisibilityModelToExecutor(any_metadata);
  110. std::vector<tflite::task::core::Category> model_output = {
  111. {"VISIBILITY_HERE", 0.3},
  112. };
  113. absl::optional<double> score =
  114. model_executor()->ExtractContentVisibilityFromModelOutput(model_output);
  115. EXPECT_FALSE(score);
  116. }
  117. TEST_F(PageVisibilityModelExecutorTest, VisibilityNotEvaluated) {
  118. proto::PageTopicsModelMetadata model_metadata;
  119. model_metadata.set_version(123);
  120. model_metadata.mutable_output_postprocessing_params()
  121. ->mutable_visibility_params()
  122. ->set_category_name("VISIBILITY_HERE");
  123. proto::Any any_metadata;
  124. any_metadata.set_type_url(
  125. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  126. model_metadata.SerializeToString(any_metadata.mutable_value());
  127. SendPageVisibilityModelToExecutor(any_metadata);
  128. std::vector<tflite::task::core::Category> model_output = {
  129. {"something else", 0.3},
  130. };
  131. absl::optional<double> score =
  132. model_executor()->ExtractContentVisibilityFromModelOutput(model_output);
  133. ASSERT_TRUE(score);
  134. EXPECT_THAT(*score, testing::DoubleEq(-1));
  135. }
  136. TEST_F(PageVisibilityModelExecutorTest, SuccessCase) {
  137. proto::PageTopicsModelMetadata model_metadata;
  138. model_metadata.set_version(123);
  139. model_metadata.mutable_output_postprocessing_params()
  140. ->mutable_visibility_params()
  141. ->set_category_name("VISIBILITY_HERE");
  142. proto::Any any_metadata;
  143. any_metadata.set_type_url(
  144. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  145. model_metadata.SerializeToString(any_metadata.mutable_value());
  146. SendPageVisibilityModelToExecutor(any_metadata);
  147. std::vector<tflite::task::core::Category> model_output = {
  148. {"VISIBILITY_HERE", 0.3},
  149. {"0", 0.4},
  150. {"1", 0.5},
  151. };
  152. absl::optional<double> score =
  153. model_executor()->ExtractContentVisibilityFromModelOutput(model_output);
  154. ASSERT_TRUE(score);
  155. EXPECT_THAT(*score, testing::DoubleEq(0.7));
  156. }
  157. TEST_F(PageVisibilityModelExecutorTest,
  158. PostprocessCategoriesToBatchAnnotationResult) {
  159. proto::PageTopicsModelMetadata model_metadata;
  160. model_metadata.set_version(123);
  161. model_metadata.mutable_output_postprocessing_params()
  162. ->mutable_visibility_params()
  163. ->set_category_name("VISIBILITY_HERE");
  164. proto::Any any_metadata;
  165. any_metadata.set_type_url(
  166. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  167. model_metadata.SerializeToString(any_metadata.mutable_value());
  168. SendPageVisibilityModelToExecutor(any_metadata);
  169. std::vector<tflite::task::core::Category> model_output = {
  170. {"0", 0.3},
  171. {"1", 0.25},
  172. {"2", 0.4},
  173. {"3", 0.05},
  174. {"VISIBILITY_HERE", 0.4},
  175. };
  176. BatchAnnotationResult viz_result =
  177. BatchAnnotationResult::CreateEmptyAnnotationsResult("");
  178. model_executor()->PostprocessCategoriesToBatchAnnotationResult(
  179. base::BindOnce(
  180. [](BatchAnnotationResult* out_result,
  181. const BatchAnnotationResult& in_result) {
  182. *out_result = in_result;
  183. },
  184. &viz_result),
  185. AnnotationType::kContentVisibility, "input", model_output);
  186. EXPECT_EQ(viz_result,
  187. BatchAnnotationResult::CreateContentVisibilityResult("input", 0.6));
  188. }
  189. TEST_F(PageVisibilityModelExecutorTest,
  190. NullPostprocessCategoriesToBatchAnnotationResult) {
  191. proto::PageTopicsModelMetadata model_metadata;
  192. model_metadata.set_version(123);
  193. model_metadata.mutable_output_postprocessing_params()
  194. ->mutable_visibility_params()
  195. ->set_category_name("VISIBILITY_HERE");
  196. proto::Any any_metadata;
  197. any_metadata.set_type_url(
  198. "type.googleapis.com/com.foo.PageTopicsModelMetadata");
  199. model_metadata.SerializeToString(any_metadata.mutable_value());
  200. SendPageVisibilityModelToExecutor(any_metadata);
  201. BatchAnnotationResult viz_result =
  202. BatchAnnotationResult::CreateEmptyAnnotationsResult("");
  203. model_executor()->PostprocessCategoriesToBatchAnnotationResult(
  204. base::BindOnce(
  205. [](BatchAnnotationResult* out_result,
  206. const BatchAnnotationResult& in_result) {
  207. *out_result = in_result;
  208. },
  209. &viz_result),
  210. AnnotationType::kContentVisibility, "input", absl::nullopt);
  211. EXPECT_EQ(viz_result, BatchAnnotationResult::CreateContentVisibilityResult(
  212. "input", absl::nullopt));
  213. }
  214. } // namespace optimization_guide