base_predictor.h 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_ASSIST_RANKER_BASE_PREDICTOR_H_
  5. #define COMPONENTS_ASSIST_RANKER_BASE_PREDICTOR_H_
  6. #include <memory>
  7. #include <string>
  8. #include "base/memory/weak_ptr.h"
  9. #include "components/assist_ranker/predictor_config.h"
  10. #include "components/assist_ranker/ranker_model_loader.h"
  11. #include "services/metrics/public/cpp/ukm_source_id.h"
  12. class GURL;
  13. namespace ukm {
  14. class UkmEntryBuilder;
  15. }
  16. namespace assist_ranker {
  17. // Value to use for when no prediction threshold replacement should be applied.
  18. // See |GetPredictThresholdReplacement| method.
  19. const float kNoPredictThresholdReplacement = 0.0;
  20. class Feature;
  21. class RankerExample;
  22. class RankerModel;
  23. // Predictors are objects that provide an interface for prediction, as well as
  24. // encapsulate the logic for loading the model and logging. Sub-classes of
  25. // BasePredictor implement an interface that depends on the nature of the
  26. // supported model. Subclasses of BasePredictor will also need to implement an
  27. // Initialize method that will be called once the model is available, and a
  28. // static validation function with the following signature:
  29. //
  30. // static RankerModelStatus ValidateModel(const RankerModel& model);
  31. class BasePredictor : public base::SupportsWeakPtr<BasePredictor> {
  32. public:
  33. BasePredictor(const PredictorConfig& config);
  34. BasePredictor(const BasePredictor&) = delete;
  35. BasePredictor& operator=(const BasePredictor&) = delete;
  36. virtual ~BasePredictor();
  37. // Returns true if the predictor is ready to make predictions.
  38. bool IsReady();
  39. // Returns true if the base::Feature associated with this model is enabled.
  40. bool is_query_enabled() const { return is_query_enabled_; }
  41. // Logs the features of |example| to UKM using the given source_id.
  42. void LogExampleToUkm(const RankerExample& example, ukm::SourceId source_id);
  43. // Returns the model URL.
  44. GURL GetModelUrl() const;
  45. // Returns the threshold to use for prediction, or
  46. // kNoPredictThresholdReplacement to leave it unchanged.
  47. float GetPredictThresholdReplacement() const;
  48. // Returns the model name.
  49. std::string GetModelName() const;
  50. protected:
  51. // Preprocessing applied to an example before prediction. The original
  52. // RankerExample is not modified, so it is safe to use it later for logging.
  53. RankerExample PreprocessExample(const RankerExample& example);
  54. // Called when the RankerModelLoader has finished loading the model. Returns
  55. // true only if the model was succesfully loaded and is ready to predict.
  56. virtual bool Initialize() = 0;
  57. // Loads a model and make it available for prediction.
  58. void LoadModel(std::unique_ptr<RankerModelLoader> model_loader);
  59. // Called once the model loader as succesfully loaded the model.
  60. void OnModelAvailable(std::unique_ptr<RankerModel> model);
  61. std::unique_ptr<RankerModelLoader> model_loader_;
  62. // The model used for prediction.
  63. std::unique_ptr<RankerModel> ranker_model_;
  64. private:
  65. void LogFeatureToUkm(const std::string& feature_name,
  66. const Feature& feature,
  67. ukm::UkmEntryBuilder* ukm_builder);
  68. bool is_ready_ = false;
  69. bool is_query_enabled_ = false;
  70. PredictorConfig config_;
  71. };
  72. } // namespace assist_ranker
  73. #endif // COMPONENTS_ASSIST_RANKER_BASE_PREDICTOR_H_