classifier_predictor.h 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_
  5. #define COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_
  6. #include <memory>
  7. #include <vector>
  8. #include "components/assist_ranker/base_predictor.h"
  9. #include "components/assist_ranker/nn_classifier.h"
  10. #include "components/assist_ranker/proto/ranker_example.pb.h"
  11. namespace base {
  12. class FilePath;
  13. }
  14. namespace network {
  15. class SharedURLLoaderFactory;
  16. }
  17. namespace assist_ranker {
  18. // Predictor class for single-layer neural network models.
  19. class ClassifierPredictor : public BasePredictor {
  20. public:
  21. ClassifierPredictor(const ClassifierPredictor&) = delete;
  22. ClassifierPredictor& operator=(const ClassifierPredictor&) = delete;
  23. ~ClassifierPredictor() override;
  24. // Returns an new predictor instance with the given |config| and initialize
  25. // its model loader. The |request_context getter| is passed to the
  26. // predictor's model_loader which holds it as scoped_refptr.
  27. [[nodiscard]] static std::unique_ptr<ClassifierPredictor> Create(
  28. const PredictorConfig& config,
  29. const base::FilePath& model_path,
  30. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
  31. // Performs inferencing on the specified RankerExample. The example is first
  32. // preprocessed using the model config. Returns false if a prediction could
  33. // not be made (e.g. the model is not loaded yet).
  34. [[nodiscard]] bool Predict(RankerExample example,
  35. std::vector<float>* prediction);
  36. // Performs inferencing on the specified feature vector. Returns false if
  37. // a prediction could not be made.
  38. [[nodiscard]] bool Predict(const std::vector<float>& features,
  39. std::vector<float>* prediction);
  40. // Validates that the loaded RankerModel is a valid BinaryClassifier model.
  41. static RankerModelStatus ValidateModel(const RankerModel& model);
  42. protected:
  43. // Instantiates the inference module.
  44. bool Initialize() override;
  45. private:
  46. friend class ClassifierPredictorTest;
  47. ClassifierPredictor(const PredictorConfig& config);
  48. NNClassifierModel model_;
  49. };
  50. } // namespace assist_ranker
  51. #endif // COMPONENTS_ASSIST_RANKER_CLASSIFIER_PREDICTOR_H_