binary_classifier_predictor_unittest.cc 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/binary_classifier_predictor.h"
  5. #include <memory>
  6. #include "base/bind.h"
  7. #include "base/callback_helpers.h"
  8. #include "base/feature_list.h"
  9. #include "base/metrics/field_trial_params.h"
  10. #include "base/test/scoped_feature_list.h"
  11. #include "components/assist_ranker/fake_ranker_model_loader.h"
  12. #include "components/assist_ranker/proto/ranker_model.pb.h"
  13. #include "components/assist_ranker/ranker_model.h"
  14. #include "testing/gtest/include/gtest/gtest.h"
  15. namespace assist_ranker {
  16. using ::assist_ranker::testing::FakeRankerModelLoader;
  17. class BinaryClassifierPredictorTest : public ::testing::Test {
  18. public:
  19. void SetUp() override;
  20. std::unique_ptr<BinaryClassifierPredictor> InitPredictor(
  21. std::unique_ptr<RankerModel> ranker_model,
  22. const PredictorConfig& config);
  23. // This model will return the value of |feature| as a prediction.
  24. GenericLogisticRegressionModel GetSimpleLogisticRegressionModel();
  25. PredictorConfig GetConfig();
  26. PredictorConfig GetConfig(float predictor_threshold_replacement);
  27. protected:
  28. const std::string feature_ = "feature";
  29. const float weight_ = 1.0;
  30. const float threshold_ = 0.5;
  31. base::test::ScopedFeatureList scoped_feature_list_;
  32. };
  33. void BinaryClassifierPredictorTest::SetUp() {
  34. ::testing::Test::SetUp();
  35. scoped_feature_list_.Init();
  36. }
  37. std::unique_ptr<BinaryClassifierPredictor>
  38. BinaryClassifierPredictorTest::InitPredictor(
  39. std::unique_ptr<RankerModel> ranker_model,
  40. const PredictorConfig& config) {
  41. std::unique_ptr<BinaryClassifierPredictor> predictor(
  42. new BinaryClassifierPredictor(config));
  43. auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
  44. base::BindRepeating(&BinaryClassifierPredictor::ValidateModel),
  45. base::BindRepeating(&BinaryClassifierPredictor::OnModelAvailable,
  46. base::Unretained(predictor.get())),
  47. std::move(ranker_model));
  48. predictor->LoadModel(std::move(fake_model_loader));
  49. return predictor;
  50. }
  51. const base::Feature kTestRankerQuery{"TestRankerQuery",
  52. base::FEATURE_ENABLED_BY_DEFAULT};
  53. const base::FeatureParam<std::string> kTestRankerUrl{
  54. &kTestRankerQuery, "url-param-name", "https://default.model.url"};
  55. PredictorConfig BinaryClassifierPredictorTest::GetConfig() {
  56. return GetConfig(kNoPredictThresholdReplacement);
  57. }
  58. PredictorConfig BinaryClassifierPredictorTest::GetConfig(
  59. float predictor_threshold_replacement) {
  60. PredictorConfig config("model_name", "logging_name", "uma_prefix", LOG_NONE,
  61. GetEmptyAllowlist(), &kTestRankerQuery,
  62. &kTestRankerUrl, predictor_threshold_replacement);
  63. return config;
  64. }
  65. GenericLogisticRegressionModel
  66. BinaryClassifierPredictorTest::GetSimpleLogisticRegressionModel() {
  67. GenericLogisticRegressionModel lr_model;
  68. lr_model.set_bias(-0.5);
  69. lr_model.set_threshold(threshold_);
  70. (*lr_model.mutable_weights())[feature_].set_scalar(weight_);
  71. return lr_model;
  72. }
  73. // TODO(hamelphi): Test BinaryClassifierPredictor::Create.
  74. TEST_F(BinaryClassifierPredictorTest, EmptyRankerModel) {
  75. auto ranker_model = std::make_unique<RankerModel>();
  76. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  77. EXPECT_FALSE(predictor->IsReady());
  78. RankerExample ranker_example;
  79. auto& features = *ranker_example.mutable_features();
  80. features[feature_].set_bool_value(true);
  81. bool bool_response;
  82. EXPECT_FALSE(predictor->Predict(ranker_example, &bool_response));
  83. float float_response;
  84. EXPECT_FALSE(predictor->PredictScore(ranker_example, &float_response));
  85. }
  86. TEST_F(BinaryClassifierPredictorTest, NoInferenceModuleForModel) {
  87. auto ranker_model = std::make_unique<RankerModel>();
  88. // TranslateRankerModel does not have an inference module. Validation will
  89. // fail.
  90. ranker_model->mutable_proto()
  91. ->mutable_translate()
  92. ->mutable_translate_logistic_regression_model()
  93. ->set_bias(1);
  94. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  95. EXPECT_FALSE(predictor->IsReady());
  96. RankerExample ranker_example;
  97. auto& features = *ranker_example.mutable_features();
  98. features[feature_].set_bool_value(true);
  99. bool bool_response;
  100. EXPECT_FALSE(predictor->Predict(ranker_example, &bool_response));
  101. float float_response;
  102. EXPECT_FALSE(predictor->PredictScore(ranker_example, &float_response));
  103. }
  104. TEST_F(BinaryClassifierPredictorTest, GenericLogisticRegressionModel) {
  105. auto ranker_model = std::make_unique<RankerModel>();
  106. *ranker_model->mutable_proto()->mutable_logistic_regression() =
  107. GetSimpleLogisticRegressionModel();
  108. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  109. EXPECT_TRUE(predictor->IsReady());
  110. RankerExample ranker_example;
  111. auto& features = *ranker_example.mutable_features();
  112. features[feature_].set_bool_value(true);
  113. bool bool_response;
  114. EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
  115. EXPECT_TRUE(bool_response);
  116. float float_response;
  117. EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
  118. EXPECT_GT(float_response, threshold_);
  119. features[feature_].set_bool_value(false);
  120. EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
  121. EXPECT_FALSE(bool_response);
  122. EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
  123. EXPECT_LT(float_response, threshold_);
  124. }
  125. TEST_F(BinaryClassifierPredictorTest,
  126. GenericLogisticRegressionPreprocessedModel) {
  127. auto ranker_model = std::make_unique<RankerModel>();
  128. auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
  129. glr = GetSimpleLogisticRegressionModel();
  130. glr.clear_weights();
  131. glr.set_is_preprocessed_model(true);
  132. (*glr.mutable_fullname_weights())[feature_] = weight_;
  133. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  134. EXPECT_TRUE(predictor->IsReady());
  135. RankerExample ranker_example;
  136. auto& features = *ranker_example.mutable_features();
  137. features[feature_].set_bool_value(true);
  138. bool bool_response;
  139. EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
  140. EXPECT_TRUE(bool_response);
  141. float float_response;
  142. EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
  143. EXPECT_GT(float_response, threshold_);
  144. features[feature_].set_bool_value(false);
  145. EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
  146. EXPECT_FALSE(bool_response);
  147. EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
  148. EXPECT_LT(float_response, threshold_);
  149. }
  150. TEST_F(BinaryClassifierPredictorTest,
  151. GenericLogisticRegressionPreprocessedModelReplacedThreshold) {
  152. auto ranker_model = std::make_unique<RankerModel>();
  153. auto& glr = *ranker_model->mutable_proto()->mutable_logistic_regression();
  154. glr = GetSimpleLogisticRegressionModel();
  155. glr.clear_weights();
  156. glr.set_is_preprocessed_model(true);
  157. (*glr.mutable_fullname_weights())[feature_] = weight_;
  158. float high_threshold = 0.9; // Some high threshold.
  159. auto predictor =
  160. InitPredictor(std::move(ranker_model), GetConfig(high_threshold));
  161. EXPECT_TRUE(predictor->IsReady());
  162. RankerExample ranker_example;
  163. auto& features = *ranker_example.mutable_features();
  164. features[feature_].set_bool_value(true);
  165. bool bool_response;
  166. EXPECT_TRUE(predictor->Predict(ranker_example, &bool_response));
  167. EXPECT_FALSE(bool_response);
  168. float float_response;
  169. EXPECT_TRUE(predictor->PredictScore(ranker_example, &float_response));
  170. EXPECT_GT(float_response, threshold_);
  171. EXPECT_LT(float_response, high_threshold);
  172. }
  173. } // namespace assist_ranker