classifier_predictor.cc 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/classifier_predictor.h"
  5. #include <memory>
  6. #include <utility>
  7. #include <vector>
  8. #include "base/bind.h"
  9. #include "base/callback_helpers.h"
  10. #include "base/files/file_path.h"
  11. #include "components/assist_ranker/example_preprocessing.h"
  12. #include "components/assist_ranker/nn_classifier.h"
  13. #include "components/assist_ranker/proto/ranker_model.pb.h"
  14. #include "components/assist_ranker/ranker_model.h"
  15. #include "components/assist_ranker/ranker_model_loader_impl.h"
  16. #include "services/network/public/cpp/shared_url_loader_factory.h"
  17. namespace assist_ranker {
  18. ClassifierPredictor::ClassifierPredictor(const PredictorConfig& config)
  19. : BasePredictor(config) {}
  20. ClassifierPredictor::~ClassifierPredictor() {}
  21. // static
  22. std::unique_ptr<ClassifierPredictor> ClassifierPredictor::Create(
  23. const PredictorConfig& config,
  24. const base::FilePath& model_path,
  25. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
  26. std::unique_ptr<ClassifierPredictor> predictor(
  27. new ClassifierPredictor(config));
  28. if (!predictor->is_query_enabled()) {
  29. DVLOG(1) << "Query disabled, bypassing model loading.";
  30. return predictor;
  31. }
  32. const GURL& model_url = predictor->GetModelUrl();
  33. DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
  34. DVLOG(1) << "Model URL: " << model_url;
  35. auto model_loader = std::make_unique<RankerModelLoaderImpl>(
  36. base::BindRepeating(&ClassifierPredictor::ValidateModel),
  37. base::BindRepeating(&ClassifierPredictor::OnModelAvailable,
  38. base::Unretained(predictor.get())),
  39. url_loader_factory, model_path, model_url, config.uma_prefix);
  40. predictor->LoadModel(std::move(model_loader));
  41. return predictor;
  42. }
  43. bool ClassifierPredictor::Predict(const std::vector<float>& features,
  44. std::vector<float>* prediction) {
  45. if (!IsReady()) {
  46. DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
  47. return false;
  48. }
  49. *prediction = nn_classifier::Inference(model_, features);
  50. return true;
  51. }
  52. bool ClassifierPredictor::Predict(RankerExample example,
  53. std::vector<float>* prediction) {
  54. if (!IsReady()) {
  55. DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
  56. return false;
  57. }
  58. if (!model_.has_preprocessor_config()) {
  59. DVLOG(1) << "No preprocessor config specified.";
  60. return false;
  61. }
  62. const int preprocessor_error =
  63. ExamplePreprocessor::Process(model_.preprocessor_config(), &example);
  64. // It is okay to ignore cases where there is an extra feature that is not in
  65. // the config.
  66. if (preprocessor_error != ExamplePreprocessor::kSuccess &&
  67. preprocessor_error != ExamplePreprocessor::kNoFeatureIndexFound) {
  68. DVLOG(1) << "Preprocessing error " << preprocessor_error;
  69. return false;
  70. }
  71. const auto& vec =
  72. example.features()
  73. .at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
  74. .float_list()
  75. .float_value();
  76. const std::vector<float> features(vec.begin(), vec.end());
  77. return Predict(features, prediction);
  78. }
  79. // static
  80. RankerModelStatus ClassifierPredictor::ValidateModel(const RankerModel& model) {
  81. if (model.proto().model_case() != RankerModelProto::kNnClassifier) {
  82. DVLOG(0) << "Model is incompatible.";
  83. return RankerModelStatus::INCOMPATIBLE;
  84. }
  85. return nn_classifier::Validate(model.proto().nn_classifier())
  86. ? RankerModelStatus::OK
  87. : RankerModelStatus::INCOMPATIBLE;
  88. }
  89. bool ClassifierPredictor::Initialize() {
  90. if (ranker_model_->proto().model_case() == RankerModelProto::kNnClassifier) {
  91. model_ = ranker_model_->proto().nn_classifier();
  92. return true;
  93. }
  94. DVLOG(0) << "Could not initialize inference module.";
  95. return false;
  96. }
  97. } // namespace assist_ranker