123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- // Copyright 2018 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/assist_ranker/classifier_predictor.h"
- #include <memory>
- #include <utility>
- #include <vector>
- #include "base/bind.h"
- #include "base/callback_helpers.h"
- #include "base/files/file_path.h"
- #include "components/assist_ranker/example_preprocessing.h"
- #include "components/assist_ranker/nn_classifier.h"
- #include "components/assist_ranker/proto/ranker_model.pb.h"
- #include "components/assist_ranker/ranker_model.h"
- #include "components/assist_ranker/ranker_model_loader_impl.h"
- #include "services/network/public/cpp/shared_url_loader_factory.h"
- namespace assist_ranker {
- ClassifierPredictor::ClassifierPredictor(const PredictorConfig& config)
- : BasePredictor(config) {}
- ClassifierPredictor::~ClassifierPredictor() {}
- // static
- std::unique_ptr<ClassifierPredictor> ClassifierPredictor::Create(
- const PredictorConfig& config,
- const base::FilePath& model_path,
- scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory) {
- std::unique_ptr<ClassifierPredictor> predictor(
- new ClassifierPredictor(config));
- if (!predictor->is_query_enabled()) {
- DVLOG(1) << "Query disabled, bypassing model loading.";
- return predictor;
- }
- const GURL& model_url = predictor->GetModelUrl();
- DVLOG(1) << "Creating predictor instance for " << predictor->GetModelName();
- DVLOG(1) << "Model URL: " << model_url;
- auto model_loader = std::make_unique<RankerModelLoaderImpl>(
- base::BindRepeating(&ClassifierPredictor::ValidateModel),
- base::BindRepeating(&ClassifierPredictor::OnModelAvailable,
- base::Unretained(predictor.get())),
- url_loader_factory, model_path, model_url, config.uma_prefix);
- predictor->LoadModel(std::move(model_loader));
- return predictor;
- }
- bool ClassifierPredictor::Predict(const std::vector<float>& features,
- std::vector<float>* prediction) {
- if (!IsReady()) {
- DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
- return false;
- }
- *prediction = nn_classifier::Inference(model_, features);
- return true;
- }
- bool ClassifierPredictor::Predict(RankerExample example,
- std::vector<float>* prediction) {
- if (!IsReady()) {
- DVLOG(1) << "Predictor " << GetModelName() << " not ready for prediction.";
- return false;
- }
- if (!model_.has_preprocessor_config()) {
- DVLOG(1) << "No preprocessor config specified.";
- return false;
- }
- const int preprocessor_error =
- ExamplePreprocessor::Process(model_.preprocessor_config(), &example);
- // It is okay to ignore cases where there is an extra feature that is not in
- // the config.
- if (preprocessor_error != ExamplePreprocessor::kSuccess &&
- preprocessor_error != ExamplePreprocessor::kNoFeatureIndexFound) {
- DVLOG(1) << "Preprocessing error " << preprocessor_error;
- return false;
- }
- const auto& vec =
- example.features()
- .at(assist_ranker::ExamplePreprocessor::kVectorizedFeatureDefaultName)
- .float_list()
- .float_value();
- const std::vector<float> features(vec.begin(), vec.end());
- return Predict(features, prediction);
- }
- // static
- RankerModelStatus ClassifierPredictor::ValidateModel(const RankerModel& model) {
- if (model.proto().model_case() != RankerModelProto::kNnClassifier) {
- DVLOG(0) << "Model is incompatible.";
- return RankerModelStatus::INCOMPATIBLE;
- }
- return nn_classifier::Validate(model.proto().nn_classifier())
- ? RankerModelStatus::OK
- : RankerModelStatus::INCOMPATIBLE;
- }
- bool ClassifierPredictor::Initialize() {
- if (ranker_model_->proto().model_case() == RankerModelProto::kNnClassifier) {
- model_ = ranker_model_->proto().nn_classifier();
- return true;
- }
- DVLOG(0) << "Could not initialize inference module.";
- return false;
- }
- } // namespace assist_ranker
|