page_topics_model_executor.h 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_TOPICS_MODEL_EXECUTOR_H_
  5. #define COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_TOPICS_MODEL_EXECUTOR_H_
  6. #include <string>
  7. #include <unordered_map>
  8. #include <vector>
  9. #include "base/callback.h"
  10. #include "base/files/file_path.h"
  11. #include "base/memory/weak_ptr.h"
  12. #include "base/sequence_checker.h"
  13. #include "components/optimization_guide/core/bert_model_handler.h"
  14. #include "components/optimization_guide/core/page_content_annotation_job.h"
  15. #include "components/optimization_guide/core/page_content_annotation_job_executor.h"
  16. #include "components/optimization_guide/core/page_content_annotations_common.h"
  17. #include "third_party/abseil-cpp/absl/types/optional.h"
  18. namespace optimization_guide {
  19. // A BERT-based mode executor for page topics annotations. All the derived
  20. // functionality of this class is exclusive to the UI thread, but may post
  21. // things to the background task runner.
  22. class PageTopicsModelExecutor : public PageContentAnnotationJobExecutor,
  23. public BertModelHandler {
  24. public:
  25. PageTopicsModelExecutor(
  26. OptimizationGuideModelProvider* model_provider,
  27. scoped_refptr<base::SequencedTaskRunner> background_task_runner,
  28. const absl::optional<proto::Any>& model_metadata);
  29. ~PageTopicsModelExecutor() override;
  30. // PageContentAnnotationJobExecutor:
  31. void ExecuteJob(base::OnceClosure on_job_complete_callback,
  32. std::unique_ptr<PageContentAnnotationJob> job) override;
  33. void ExecuteOnSingleInput(
  34. AnnotationType annotation_type,
  35. const std::string& raw_input,
  36. base::OnceCallback<void(const BatchAnnotationResult&)> callback) override;
  37. // BertModelHandler:
  38. void UnloadModel() override;
  39. void OnModelUpdated(proto::OptimizationTarget optimization_target,
  40. const ModelInfo& model_info) override;
  41. // Creates a BatchAnnotationResult from the output of the model, calling
  42. // |ExtractCategoriesFromModelOutput| in the process.
  43. // Public for testing.
  44. void PostprocessCategoriesToBatchAnnotationResult(
  45. base::OnceCallback<void(const BatchAnnotationResult&)> callback,
  46. AnnotationType annotation_type,
  47. const std::string& raw_input,
  48. const absl::optional<std::vector<tflite::task::core::Category>>& output);
  49. // Extracts the scored categories from the output of the model.
  50. // Public for testing.
  51. absl::optional<std::vector<WeightedIdentifier>>
  52. ExtractCategoriesFromModelOutput(
  53. const std::vector<tflite::task::core::Category>& model_output) const;
  54. private:
  55. void OnOverrideListLoadAttemptDone(
  56. base::OnceClosure on_job_complete_callback,
  57. std::unique_ptr<PageContentAnnotationJob> job,
  58. absl::optional<
  59. std::unordered_map<std::string, std::vector<WeightedIdentifier>>>
  60. override_list);
  61. // Does the required preprocessing on a input domain.
  62. static std::string PreprocessHost(const std::string& host);
  63. scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
  64. // Set whenever a valid override list file is passed along with the model file
  65. // update. This will be reset if the provided file is deemed invalid on the
  66. // first attempted load.
  67. // Used on the UI thread.
  68. absl::optional<base::FilePath> override_list_file_path_;
  69. // Set whenever an override list file is available and the model file is
  70. // loaded into memory. Reset whenever the model file is unloaded.
  71. // Used on the UI thread. Lookups in this mapping should have |PreprocessHost|
  72. // applied first.
  73. absl::optional<
  74. std::unordered_map<std::string, std::vector<WeightedIdentifier>>>
  75. override_list_;
  76. SEQUENCE_CHECKER(sequence_checker_);
  77. base::WeakPtrFactory<PageTopicsModelExecutor> weak_ptr_factory_{this};
  78. };
  79. } // namespace optimization_guide
  80. #endif // COMPONENTS_OPTIMIZATION_GUIDE_CORE_PAGE_TOPICS_MODEL_EXECUTOR_H_