123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- // Copyright 2017 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/assist_ranker/binary_classifier_predictor.h"
- #include <memory>
- #include "base/bind.h"
- #include "base/callback_helpers.h"
- #include "base/feature_list.h"
- #include "base/metrics/field_trial_params.h"
- #include "base/test/scoped_feature_list.h"
- #include "components/assist_ranker/fake_ranker_model_loader.h"
- #include "components/assist_ranker/proto/ranker_model.pb.h"
- #include "components/assist_ranker/ranker_model.h"
- #include "testing/gtest/include/gtest/gtest.h"
- namespace assist_ranker {
- using ::assist_ranker::testing::FakeRankerModelLoader;
- class BinaryClassifierPredictorTest : public ::testing::Test {
- public:
- void SetUp() override;
- std::unique_ptr<BinaryClassifierPredictor> InitPredictor(
- std::unique_ptr<RankerModel> ranker_model,
- const PredictorConfig& config);
- // This model will return the value of |feature| as a prediction.
- GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
- PredictorConfig GetConfig();
- PredictorConfig GetConfig(float predictor_threshold_replacement);
- protected:
- const std::string feature_ = "feature";
- const float weight_ = 1.0;
- const float threshold_ = 0.5;
- base::test::ScopedFeatureList scoped_feature_list_;
- };
- void BinaryClassifierPredictorTest::SetUp() {
- ::testing::Test::SetUp();
- scoped_feature_list_.Init();
- }
- std::unique_ptr<BinaryClassifierPredictor>
- BinaryClassifierPredictorTest::InitPredictor(
- std::unique_ptr<RankerModel> ranker_model,
- const PredictorConfig& config) {
- std::unique_ptr<BinaryClassifierPredictor> predictor(
- new BinaryClassifierPredictor(config));
- auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
- base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
- base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
- base::Unretained(predictor.get())),
- std::move(ranker_model));
- predictor->LoadModel(std::move(fake_model_loader));
- return predictor;
- }
- const base::Feature kTestRankerQuery{"TestRankerQuery",
- base::FEATURE_ENABLED_BY_DEFAULT};
- const base::FeatureParam<std::string> kTestRankerUrl{
- &kTestRankerQuery, "url-param-name", "https://default.model.url"};
- PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
- return GetConfig(kNoPredictThresholdReplacement);
- }
- PredictorConfig BinaryClassifierPredictorTest::GetConfig(
- float predictor_threshold_replacement) {
- PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
- GetEmptyAllowlist(), &kTestRankerQuery,
- &kTestRankerUrl, predictor_threshold_replacement);
- return config;
- }
- GenericLogisticRegressionModel
- BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
- GenericLogisticRegressionModel lr_model;
- lr_model.set_bias(-0.5);
- lr_model.set_threshold(threshold_);
- (*lr_model.mutable_weights())[feature_].set_scalar(weight_);
- return lr_model;
- }
- // TODO(hamelphi): Test BinaryClassifierPredictor::Create.
- TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) {
- auto ranker_model = std::make_unique<RankerModel>();
- auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
- EXPECT_FALSE(predictor->IsReady());
- RankerExample ranker_example;
- auto& features = *ranker_example.mutable_features();
- features[feature_].set_bool_value(true);
- bool bool_response;
- EXPECT_FALSE(predictor->Predict(ranker_example, &bool_response));
- float float_response;
- EXPECT_FALSE(predictor->PredictScore(ranker_example, &float_response));
- }
- TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) {
- auto ranker_model = std::make_unique<RankerModel>();
- // TranslateRankerModel does not have an inference module. Validation will
- // fail.
- ranker_model->mutable_proto()
- ->mutable_translate()
- ->mutable_translate_logistic_regression_model()
- ->set_bias(1);
- auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
- EXPECT_FALSE(predictor->IsReady());
- RankerExample ranker_example;
- auto& features = *ranker_example.mutable_features();
- features[feature_].set_bool_value(true);
- bool bool_response;
- EXPECT_FALSE(predictor->Predict(ranker_example, &bool_response));
- float float_response;
- EXPECT_FALSE(predictor->PredictScore(ranker_example, &float_response));
- }
- TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
- auto ranker_model = std::make_unique<RankerModel>();
- *ranker_model->mutable_proto()->mutable_logistic_regression() =
- GetSimpleLogisticRegressionModel();
- auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
- EXPECT_TRUE(predictor->IsReady());
- RankerExample ranker_example;
- auto& features = *ranker_example.mutable_features();
- features[feature_].set_bool_value(true);
- bool bool_response;
- EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
- EXPECT_TRUE(bool_response);
- float float_response;
- EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
- EXPECT_GT(float_response, threshold_);
- features[feature_].set_bool_value(false);
- EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
- EXPECT_FALSE(bool_response);
- EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
- EXPECT_LT(float_response, threshold_);
- }
- TEST_F(BinaryClassifierPredictorTest,
- GenericLogisticRegressionPreprocessedModel) {
- auto ranker_model = std::make_unique<RankerModel>();
- auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
- glr = GetSimpleLogisticRegressionModel();
- glr.clear_weights();
- glr.set_is_preprocessed_model(true);
- (*glr.mutable_fullname_weights())[feature_] = weight_;
- auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
- EXPECT_TRUE(predictor->IsReady());
- RankerExample ranker_example;
- auto& features = *ranker_example.mutable_features();
- features[feature_].set_bool_value(true);
- bool bool_response;
- EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
- EXPECT_TRUE(bool_response);
- float float_response;
- EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
- EXPECT_GT(float_response, threshold_);
- features[feature_].set_bool_value(false);
- EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
- EXPECT_FALSE(bool_response);
- EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
- EXPECT_LT(float_response, threshold_);
- }
- TEST_F(BinaryClassifierPredictorTest,
- GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
- auto ranker_model = std::make_unique<RankerModel>();
- auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
- glr = GetSimpleLogisticRegressionModel();
- glr.clear_weights();
- glr.set_is_preprocessed_model(true);
- (*glr.mutable_fullname_weights())[feature_] = weight_;
- float high_threshold = 0.9; // Some high threshold.
- auto predictor =
- InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
- EXPECT_TRUE(predictor->IsReady());
- RankerExample ranker_example;
- auto& features = *ranker_example.mutable_features();
- features[feature_].set_bool_value(true);
- bool bool_response;
- EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
- EXPECT_FALSE(bool_response);
- float float_response;
- EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
- EXPECT_GT(float_response, threshold_);
- EXPECT_LT(float_response, high_threshold);
- }
- } // namespace assist_ranker
|