classifier_predictor_unittest.cc 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/classifier_predictor.h"
  5. #include <memory>
  6. #include <string>
  7. #include <utility>
  8. #include <vector>
  9. #include "base/bind.h"
  10. #include "base/callback_helpers.h"
  11. #include "base/feature_list.h"
  12. #include "base/metrics/field_trial_params.h"
  13. #include "base/test/scoped_feature_list.h"
  14. #include "components/assist_ranker/example_preprocessing.h"
  15. #include "components/assist_ranker/fake_ranker_model_loader.h"
  16. #include "components/assist_ranker/nn_classifier_test_util.h"
  17. #include "components/assist_ranker/proto/ranker_model.pb.h"
  18. #include "components/assist_ranker/ranker_model.h"
  19. #include "testing/gmock/include/gmock/gmock-matchers.h"
  20. #include "testing/gtest/include/gtest/gtest.h"
  21. namespace assist_ranker {
  22. using ::assist_ranker::testing::FakeRankerModelLoader;
  23. using ::testing::FloatEq;
  24. // Preprocessor feature names.
  25. const char kFeatureName0[] = "feature_0";
  26. const char kFeatureName1[] = "feature_1";
  27. const char kFeatureExtra[] = "feature_extra";
  28. class ClassifierPredictorTest : public ::testing::Test {
  29. public:
  30. void SetUp() override;
  31. std::unique_ptr<ClassifierPredictor> InitPredictor(
  32. std::unique_ptr<RankerModel> ranker_model,
  33. const PredictorConfig& config);
  34. PredictorConfig GetConfig();
  35. protected:
  36. base::test::ScopedFeatureList scoped_feature_list_;
  37. };
  38. void ClassifierPredictorTest::SetUp() {
  39. ::testing::Test::SetUp();
  40. scoped_feature_list_.Init();
  41. }
  42. std::unique_ptr<ClassifierPredictor> ClassifierPredictorTest::InitPredictor(
  43. std::unique_ptr<RankerModel> ranker_model,
  44. const PredictorConfig& config) {
  45. std::unique_ptr<ClassifierPredictor> predictor(
  46. new ClassifierPredictor(config));
  47. auto fake_model_loader = std::make_unique<FakeRankerModelLoader>(
  48. base::BindRepeating(&ClassifierPredictor::ValidateModel),
  49. base::BindRepeating(&ClassifierPredictor::OnModelAvailable,
  50. base::Unretained(predictor.get())),
  51. std::move(ranker_model));
  52. predictor->LoadModel(std::move(fake_model_loader));
  53. return predictor;
  54. }
  55. const base::Feature kTestRankerQuery{"TestRankerQuery",
  56. base::FEATURE_ENABLED_BY_DEFAULT};
  57. const base::FeatureParam<std::string> kTestRankerUrl{
  58. &kTestRankerQuery, "url-param-name", "https://default.model.url"};
  59. PredictorConfig ClassifierPredictorTest::GetConfig() {
  60. return PredictorConfig("model_name", "logging_name", "uma_prefix", LOG_NONE,
  61. GetEmptyAllowlist(), &kTestRankerQuery,
  62. &kTestRankerUrl, 0);
  63. }
  64. TEST_F(ClassifierPredictorTest, EmptyRankerModel) {
  65. auto ranker_model = std::make_unique<RankerModel>();
  66. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  67. EXPECT_FALSE(predictor->IsReady());
  68. RankerExample ranker_example;
  69. auto& features = *ranker_example.mutable_features();
  70. features[kFeatureName0].set_bool_value(true);
  71. std::vector<float> response;
  72. EXPECT_FALSE(predictor->Predict(ranker_example, &response));
  73. }
  74. TEST_F(ClassifierPredictorTest, NoInferenceModuleForModel) {
  75. auto ranker_model = std::make_unique<RankerModel>();
  76. // TranslateRankerModel does not have an inference module. Validation will
  77. // fail.
  78. ranker_model->mutable_proto()
  79. ->mutable_translate()
  80. ->mutable_translate_logistic_regression_model()
  81. ->set_bias(1);
  82. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  83. EXPECT_FALSE(predictor->IsReady());
  84. RankerExample ranker_example;
  85. auto& features = *ranker_example.mutable_features();
  86. features[kFeatureName0].set_bool_value(true);
  87. std::vector<float> response;
  88. EXPECT_FALSE(predictor->Predict(ranker_example, &response));
  89. EXPECT_FALSE(predictor->Predict({0, 0}, &response));
  90. }
  91. TEST_F(ClassifierPredictorTest, PredictFeatureVector) {
  92. auto ranker_model = std::make_unique<RankerModel>();
  93. *ranker_model->mutable_proto()->mutable_nn_classifier() =
  94. nn_classifier::CreateXorClassifierModel();
  95. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  96. EXPECT_TRUE(predictor->IsReady());
  97. std::vector<float> response;
  98. // True responses.
  99. EXPECT_TRUE(predictor->Predict({0, 1}, &response));
  100. EXPECT_EQ(response.size(), 1u);
  101. EXPECT_THAT(response[0], FloatEq(2.8271765));
  102. EXPECT_TRUE(predictor->Predict({1, 0}, &response));
  103. EXPECT_EQ(response.size(), 1u);
  104. EXPECT_THAT(response[0], FloatEq(2.6790769));
  105. // False responses.
  106. EXPECT_TRUE(predictor->Predict({0, 0}, &response));
  107. EXPECT_EQ(response.size(), 1u);
  108. EXPECT_THAT(response[0], FloatEq(-2.7154054));
  109. EXPECT_TRUE(predictor->Predict({1, 1}, &response));
  110. EXPECT_EQ(response.size(), 1u);
  111. EXPECT_THAT(response[0], FloatEq(-3.1652793));
  112. }
  113. TEST_F(ClassifierPredictorTest, PredictRankerExampleNoPreprocessor) {
  114. auto ranker_model = std::make_unique<RankerModel>();
  115. *ranker_model->mutable_proto()->mutable_nn_classifier() =
  116. nn_classifier::CreateXorClassifierModel();
  117. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  118. EXPECT_TRUE(predictor->IsReady());
  119. // Prediction of RankerExample without preprocessor config should fail.
  120. std::vector<float> response;
  121. RankerExample example;
  122. EXPECT_FALSE(predictor->Predict(RankerExample(), &response));
  123. }
  124. TEST_F(ClassifierPredictorTest, PredictRankerExampleWithPreprocessor) {
  125. auto ranker_model = std::make_unique<RankerModel>();
  126. auto& model = *ranker_model->mutable_proto()->mutable_nn_classifier();
  127. model = nn_classifier::CreateXorClassifierModel();
  128. // Set up the preprocessor config with two features at feature vector
  129. // indices 0 and 1.
  130. auto& indices =
  131. *model.mutable_preprocessor_config()->mutable_feature_indices();
  132. indices[kFeatureName0] = 0;
  133. indices[kFeatureName1] = 1;
  134. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  135. EXPECT_TRUE(predictor->IsReady());
  136. // Prediction of RankerExample with preprocessor config should work.
  137. RankerExample example;
  138. auto& feature_map = *example.mutable_features();
  139. std::vector<float> response;
  140. // True responses.
  141. feature_map[kFeatureName0].set_float_value(0);
  142. feature_map[kFeatureName1].set_float_value(1);
  143. EXPECT_TRUE(predictor->Predict(example, &response));
  144. EXPECT_EQ(response.size(), 1u);
  145. EXPECT_THAT(response[0], FloatEq(2.8271765));
  146. feature_map[kFeatureName0].set_float_value(1);
  147. feature_map[kFeatureName1].set_float_value(0);
  148. EXPECT_TRUE(predictor->Predict(example, &response));
  149. EXPECT_EQ(response.size(), 1u);
  150. EXPECT_THAT(response[0], FloatEq(2.6790769));
  151. // False responses.
  152. feature_map[kFeatureName0].set_float_value(0);
  153. feature_map[kFeatureName1].set_float_value(0);
  154. EXPECT_TRUE(predictor->Predict(example, &response));
  155. EXPECT_EQ(response.size(), 1u);
  156. EXPECT_THAT(response[0], FloatEq(-2.7154054));
  157. feature_map[kFeatureName0].set_float_value(1);
  158. feature_map[kFeatureName1].set_float_value(1);
  159. EXPECT_TRUE(predictor->Predict(example, &response));
  160. EXPECT_EQ(response.size(), 1u);
  161. EXPECT_THAT(response[0], FloatEq(-3.1652793));
  162. // Check that extra features do not cause an error.
  163. feature_map[kFeatureName0].set_float_value(0);
  164. feature_map[kFeatureName1].set_float_value(1);
  165. feature_map[kFeatureExtra].set_float_value(1);
  166. EXPECT_TRUE(predictor->Predict(example, &response));
  167. EXPECT_EQ(response.size(), 1u);
  168. EXPECT_THAT(response[0], FloatEq(2.8271765));
  169. }
  170. TEST_F(ClassifierPredictorTest, PredictRankerExamplePreprocessorError) {
  171. auto ranker_model = std::make_unique<RankerModel>();
  172. auto& model = *ranker_model->mutable_proto()->mutable_nn_classifier();
  173. model = nn_classifier::CreateXorClassifierModel();
  174. // Set up the preprocessor config with two features at feature vector
  175. // indices 0 and 1.
  176. auto& config = *model.mutable_preprocessor_config();
  177. auto& indices = *config.mutable_feature_indices();
  178. indices[kFeatureName0] = 0;
  179. indices[kFeatureName1] = 1;
  180. // Zero normalizer will generate a preprocessing error.
  181. (*config.mutable_normalizers())[kFeatureName0] = 0;
  182. auto predictor = InitPredictor(std::move(ranker_model), GetConfig());
  183. EXPECT_TRUE(predictor->IsReady());
  184. // Prediction of RankerExample should fail due to preprocessing error.
  185. RankerExample example;
  186. auto& feature_map = *example.mutable_features();
  187. std::vector<float> response;
  188. feature_map[kFeatureName0].set_float_value(0);
  189. feature_map[kFeatureName1].set_float_value(1);
  190. EXPECT_FALSE(predictor->Predict(example, &response));
  191. }
  192. } // namespace assist_ranker