page_visibility_model_executor.cc 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/page_visibility_model_executor.h"
  5. #include "components/optimization_guide/core/optimization_guide_model_provider.h"
  6. #include "components/optimization_guide/proto/models.pb.h"
  7. #include "components/optimization_guide/proto/page_topics_model_metadata.pb.h"
  8. namespace optimization_guide {
  9. PageVisibilityModelExecutor::PageVisibilityModelExecutor(
  10. OptimizationGuideModelProvider* model_provider,
  11. scoped_refptr<base::SequencedTaskRunner> background_task_runner,
  12. const absl::optional<proto::Any>& model_metadata)
  13. : BertModelHandler(model_provider,
  14. background_task_runner,
  15. proto::OPTIMIZATION_TARGET_PAGE_VISIBILITY,
  16. model_metadata) {
  17. SetShouldUnloadModelOnComplete(false);
  18. }
  19. PageVisibilityModelExecutor::~PageVisibilityModelExecutor() = default;
  20. void PageVisibilityModelExecutor::ExecuteOnSingleInput(
  21. AnnotationType annotation_type,
  22. const std::string& input,
  23. base::OnceCallback<void(const BatchAnnotationResult&)> callback) {
  24. ExecuteModelWithInput(
  25. base::BindOnce(&PageVisibilityModelExecutor::
  26. PostprocessCategoriesToBatchAnnotationResult,
  27. weak_ptr_factory_.GetWeakPtr(), std::move(callback),
  28. annotation_type, input),
  29. input);
  30. }
  31. void PageVisibilityModelExecutor::PostprocessCategoriesToBatchAnnotationResult(
  32. base::OnceCallback<void(const BatchAnnotationResult&)> callback,
  33. AnnotationType annotation_type,
  34. const std::string& input,
  35. const absl::optional<std::vector<tflite::task::core::Category>>& output) {
  36. DCHECK_EQ(annotation_type, AnnotationType::kContentVisibility);
  37. absl::optional<double> visibility_score;
  38. if (output) {
  39. visibility_score = ExtractContentVisibilityFromModelOutput(*output);
  40. }
  41. std::move(callback).Run(BatchAnnotationResult::CreateContentVisibilityResult(
  42. input, visibility_score));
  43. }
  44. absl::optional<double>
  45. PageVisibilityModelExecutor::ExtractContentVisibilityFromModelOutput(
  46. const std::vector<tflite::task::core::Category>& model_output) const {
  47. absl::optional<proto::PageTopicsModelMetadata> model_metadata =
  48. ParsedSupportedFeaturesForLoadedModel<proto::PageTopicsModelMetadata>();
  49. if (!model_metadata) {
  50. return absl::nullopt;
  51. }
  52. if (!model_metadata->output_postprocessing_params().has_visibility_params()) {
  53. return absl::nullopt;
  54. }
  55. if (!model_metadata->output_postprocessing_params()
  56. .visibility_params()
  57. .has_category_name()) {
  58. return absl::nullopt;
  59. }
  60. std::string visibility_category_name =
  61. model_metadata->output_postprocessing_params()
  62. .visibility_params()
  63. .category_name();
  64. for (const auto& category : model_output) {
  65. if (category.class_name == visibility_category_name) {
  66. return 1.0 - category.score;
  67. }
  68. }
  69. // -1 is a sentinel value that means the visibility of the page was not
  70. // evaluated.
  71. return -1.0;
  72. }
  73. } // namespace optimization_guide