binary_classifier_predictor.h 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
  5. #define COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
  6. #include "components/assist_ranker/base_predictor.h"
  7. #include "components/assist_ranker/proto/ranker_example.pb.h"
  8. namespace base {
  9. class FilePath;
  10. }
  11. namespace network {
  12. class SharedURLLoaderFactory;
  13. }
  14. namespace assist_ranker {
  15. class GenericLogisticRegressionInference;
  16. // Predictor class for models that output a binary decision.
  17. class BinaryClassifierPredictor : public BasePredictor {
  18. public:
  19. BinaryClassifierPredictor(const BinaryClassifierPredictor&) = delete;
  20. BinaryClassifierPredictor& operator=(const BinaryClassifierPredictor&) =
  21. delete;
  22. ~BinaryClassifierPredictor() override;
  23. // Returns an new predictor instance with the given |config| and initialize
  24. // its model loader. The |request_context getter| is passed to the
  25. // predictor's model_loader which holds it as scoped_refptr.
  26. [[nodiscard]] static std::unique_ptr<BinaryClassifierPredictor> Create(
  27. const PredictorConfig& config,
  28. const base::FilePath& model_path,
  29. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
  30. // Fills in a boolean decision given a RankerExample. Returns false if a
  31. // prediction could not be made (e.g. the model is not loaded yet).
  32. [[nodiscard]] bool Predict(const RankerExample& example, bool* prediction);
  33. // Returns a score between 0 and 1. Returns false if a
  34. // prediction could not be made (e.g. the model is not loaded yet).
  35. [[nodiscard]] bool PredictScore(const RankerExample& example,
  36. float* prediction);
  37. // Validates that the loaded RankerModel is a valid BinaryClassifier model.
  38. static RankerModelStatus ValidateModel(const RankerModel& model);
  39. protected:
  40. // Instatiates the inference module.
  41. bool Initialize() override;
  42. private:
  43. friend class BinaryClassifierPredictorTest;
  44. BinaryClassifierPredictor(const PredictorConfig& config);
  45. // TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to
  46. // generalize to other models.
  47. std::unique_ptr<GenericLogisticRegressionInference> inference_module_;
  48. };
  49. } // namespace assist_ranker
  50. #endif // COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_