123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- // Copyright 2017 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #ifndef COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
- #define COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
- #include "components/assist_ranker/base_predictor.h"
- #include "components/assist_ranker/proto/ranker_example.pb.h"
- namespace base {
- class FilePath;
- }
- namespace network {
- class SharedURLLoaderFactory;
- }
- namespace assist_ranker {
- class GenericLogisticRegressionInference;
- // Predictor class for models that output a binary decision.
- class BinaryClassifierPredictor : public BasePredictor {
- public:
- BinaryClassifierPredictor(const BinaryClassifierPredictor&) = delete;
- BinaryClassifierPredictor& operator=(const BinaryClassifierPredictor&) =
- delete;
- ~BinaryClassifierPredictor() override;
- // Returns an new predictor instance with the given |config| and initialize
- // its model loader. The |request_context getter| is passed to the
- // predictor's model_loader which holds it as scoped_refptr.
- [[nodiscard]] static std::unique_ptr<BinaryClassifierPredictor> Create(
- const PredictorConfig& config,
- const base::FilePath& model_path,
- scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory);
- // Fills in a boolean decision given a RankerExample. Returns false if a
- // prediction could not be made (e.g. the model is not loaded yet).
- [[nodiscard]] bool Predict(const RankerExample& example, bool* prediction);
- // Returns a score between 0 and 1. Returns false if a
- // prediction could not be made (e.g. the model is not loaded yet).
- [[nodiscard]] bool PredictScore(const RankerExample& example,
- float* prediction);
- // Validates that the loaded RankerModel is a valid BinaryClassifier model.
- static RankerModelStatus ValidateModel(const RankerModel& model);
- protected:
- // Instatiates the inference module.
- bool Initialize() override;
- private:
- friend class BinaryClassifierPredictorTest;
- BinaryClassifierPredictor(const PredictorConfig& config);
- // TODO(hamelphi): Use an abstract BinaryClassifierInferenceModule in order to
- // generalize to other models.
- std::unique_ptr<GenericLogisticRegressionInference> inference_module_;
- };
- } // namespace assist_ranker
- #endif // COMPONENTS_ASSIST_RANKER_BINARY_CLASSIFIER_PREDICTOR_H_
|