generic_logistic_regression_inference_unittest.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/generic_logistic_regression_inference.h"
  5. #include "components/assist_ranker/example_preprocessing.h"
  6. #include "testing/gtest/include/gtest/gtest.h"
  7. #include "third_party/protobuf/src/google/protobuf/map.h"
  8. namespace assist_ranker {
  9. using ::google::protobuf::Map;
  10. class GenericLogisticRegressionInferenceTest : public testing::Test {
  11. protected:
  12. GenericLogisticRegressionModel GetProto() {
  13. GenericLogisticRegressionModel proto;
  14. proto.set_bias(bias_);
  15. proto.set_threshold(threshold_);
  16. auto& weights = *proto.mutable_weights();
  17. weights[scalar1_name_].set_scalar(scalar1_weight_);
  18. weights[scalar2_name_].set_scalar(scalar2_weight_);
  19. weights[scalar3_name_].set_scalar(scalar3_weight_);
  20. auto* one_hot_feat = weights[one_hot_name_].mutable_one_hot();
  21. one_hot_feat->set_default_weight(one_hot_default_weight_);
  22. (*one_hot_feat->mutable_weights())[one_hot_elem1_name_] =
  23. one_hot_elem1_weight_;
  24. (*one_hot_feat->mutable_weights())[one_hot_elem2_name_] =
  25. one_hot_elem2_weight_;
  26. (*one_hot_feat->mutable_weights())[one_hot_elem3_name_] =
  27. one_hot_elem3_weight_;
  28. SparseWeights* sparse_feat = weights[sparse_name_].mutable_sparse();
  29. sparse_feat->set_default_weight(sparse_default_weight_);
  30. (*sparse_feat->mutable_weights())[sparse_elem1_name_] =
  31. sparse_elem1_weight_;
  32. (*sparse_feat->mutable_weights())[sparse_elem2_name_] =
  33. sparse_elem2_weight_;
  34. BucketizedWeights* bucketized_feat =
  35. weights[bucketized_name_].mutable_bucketized();
  36. bucketized_feat->set_default_weight(bucketization_default_weight_);
  37. for (const float boundary : bucketization_boundaries_) {
  38. bucketized_feat->add_boundaries(boundary);
  39. }
  40. for (const float weight : bucketization_weights_) {
  41. bucketized_feat->add_weights(weight);
  42. }
  43. return proto;
  44. }
  45. const std::string scalar1_name_ = "scalar_feature1";
  46. const std::string scalar2_name_ = "scalar_feature2";
  47. const std::string scalar3_name_ = "scalar_feature3";
  48. const std::string one_hot_name_ = "one_hot_feature";
  49. const std::string one_hot_elem1_name_ = "one_hot_elem1";
  50. const std::string one_hot_elem2_name_ = "one_hot_elem2";
  51. const std::string one_hot_elem3_name_ = "one_hot_elem3";
  52. const float bias_ = 1.5f;
  53. const float threshold_ = 0.6f;
  54. const float scalar1_weight_ = 0.8f;
  55. const float scalar2_weight_ = -2.4f;
  56. const float scalar3_weight_ = 0.01f;
  57. const float one_hot_elem1_weight_ = -1.0f;
  58. const float one_hot_elem2_weight_ = 5.0f;
  59. const float one_hot_elem3_weight_ = -1.5f;
  60. const float one_hot_default_weight_ = 10.0f;
  61. const float epsilon_ = 0.001f;
  62. const std::string sparse_name_ = "sparse_feature";
  63. const std::string sparse_elem1_name_ = "sparse_elem1";
  64. const std::string sparse_elem2_name_ = "sparse_elem2";
  65. const float sparse_elem1_weight_ = -2.2f;
  66. const float sparse_elem2_weight_ = 3.1f;
  67. const float sparse_default_weight_ = 4.4f;
  68. const std::string bucketized_name_ = "bucketized_feature";
  69. const float bucketization_boundaries_[2] = {0.3f, 0.7f};
  70. const float bucketization_weights_[3] = {-1.0f, 1.0f, 3.0f};
  71. const float bucketization_default_weight_ = -3.3f;
  72. };
  73. TEST_F(GenericLogisticRegressionInferenceTest, BaseTest) {
  74. auto predictor = GenericLogisticRegressionInference(GetProto());
  75. RankerExample example;
  76. auto& features = *example.mutable_features();
  77. features[scalar1_name_].set_bool_value(true);
  78. features[scalar2_name_].set_int32_value(42);
  79. features[scalar3_name_].set_float_value(0.666f);
  80. features[one_hot_name_].set_string_value(one_hot_elem1_name_);
  81. float score = predictor.PredictScore(example);
  82. float expected_score =
  83. Sigmoid(bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
  84. 0.666f * scalar3_weight_ + one_hot_elem1_weight_);
  85. EXPECT_NEAR(expected_score, score, epsilon_);
  86. EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
  87. }
  88. TEST_F(GenericLogisticRegressionInferenceTest, UnknownElement) {
  89. RankerExample example;
  90. auto& features = *example.mutable_features();
  91. features[one_hot_name_].set_string_value("Unknown element");
  92. auto predictor = GenericLogisticRegressionInference(GetProto());
  93. float score = predictor.PredictScore(example);
  94. float expected_score = Sigmoid(bias_ + one_hot_default_weight_);
  95. EXPECT_NEAR(expected_score, score, epsilon_);
  96. }
  97. TEST_F(GenericLogisticRegressionInferenceTest, MissingFeatures) {
  98. RankerExample example;
  99. auto predictor = GenericLogisticRegressionInference(GetProto());
  100. float score = predictor.PredictScore(example);
  101. // Missing features will use default weights for one_hot features and drop
  102. // scalar features.
  103. float expected_score = Sigmoid(bias_ + one_hot_default_weight_);
  104. EXPECT_NEAR(expected_score, score, epsilon_);
  105. }
  106. TEST_F(GenericLogisticRegressionInferenceTest, UnknownFeatures) {
  107. RankerExample example;
  108. auto& features = *example.mutable_features();
  109. features["foo1"].set_bool_value(true);
  110. features["foo2"].set_int32_value(42);
  111. features["foo3"].set_float_value(0.666f);
  112. features["foo4"].set_string_value(one_hot_elem1_name_);
  113. // All features except this one will be ignored.
  114. features[one_hot_name_].set_string_value(one_hot_elem2_name_);
  115. auto predictor = GenericLogisticRegressionInference(GetProto());
  116. float score = predictor.PredictScore(example);
  117. // Unknown features will be ignored.
  118. float expected_score = Sigmoid(bias_ + one_hot_elem2_weight_);
  119. EXPECT_NEAR(expected_score, score, epsilon_);
  120. }
  121. TEST_F(GenericLogisticRegressionInferenceTest, Threshold) {
  122. // In this test, we calculate the score for a given example and set the model
  123. // threshold to this value. We then add a feature to the example that should
  124. // tip the score slightly on either side of the treshold and verify that the
  125. // decision is as expected.
  126. auto proto = GetProto();
  127. auto threshold_calculator = GenericLogisticRegressionInference(proto);
  128. RankerExample example;
  129. auto& features = *example.mutable_features();
  130. features[scalar1_name_].set_bool_value(true);
  131. features[scalar2_name_].set_int32_value(2);
  132. features[one_hot_name_].set_string_value(one_hot_elem1_name_);
  133. float threshold = threshold_calculator.PredictScore(example);
  134. proto.set_threshold(threshold);
  135. // Setting the model with the calculated threshold.
  136. auto predictor = GenericLogisticRegressionInference(proto);
  137. // Adding small positive contribution from scalar3 to tip the decision the
  138. // positive side of the threshold.
  139. features[scalar3_name_].set_float_value(0.01f);
  140. float score = predictor.PredictScore(example);
  141. // The score is now greater than, but still near the threshold. The
  142. // decision should be positive.
  143. EXPECT_LT(threshold, score);
  144. EXPECT_NEAR(threshold, score, epsilon_);
  145. EXPECT_TRUE(predictor.Predict(example));
  146. // A small negative contribution from scalar3 should tip the decision the
  147. // other way.
  148. features[scalar3_name_].set_float_value(-0.01f);
  149. score = predictor.PredictScore(example);
  150. EXPECT_GT(threshold, score);
  151. EXPECT_NEAR(threshold, score, epsilon_);
  152. EXPECT_FALSE(predictor.Predict(example));
  153. }
  154. TEST_F(GenericLogisticRegressionInferenceTest, NoThreshold) {
  155. auto proto = GetProto();
  156. // When no threshold is specified, we use the default of 0.5.
  157. proto.clear_threshold();
  158. auto predictor = GenericLogisticRegressionInference(proto);
  159. RankerExample example;
  160. auto& features = *example.mutable_features();
  161. // one_hot_elem3 exactly balances the bias, so we expect the pre-sigmoid score
  162. // to be zero, and the post-sigmoid score to be 0.5 if this is the only active
  163. // feature.
  164. features[one_hot_name_].set_string_value(one_hot_elem3_name_);
  165. float score = predictor.PredictScore(example);
  166. EXPECT_NEAR(0.5f, score, epsilon_);
  167. // Adding small contribution from scalar3 to tip the decision on one side or
  168. // the other of the threshold.
  169. features[scalar3_name_].set_float_value(0.01f);
  170. score = predictor.PredictScore(example);
  171. // The score is now greater than, but still near 0.5. The decision should be
  172. // positive.
  173. EXPECT_LT(0.5f, score);
  174. EXPECT_NEAR(0.5f, score, epsilon_);
  175. EXPECT_TRUE(predictor.Predict(example));
  176. features[scalar3_name_].set_float_value(-0.01f);
  177. score = predictor.PredictScore(example);
  178. // The score is now lower than, but near 0.5. The decision should be
  179. // negative.
  180. EXPECT_GT(0.5f, score);
  181. EXPECT_NEAR(0.5f, score, epsilon_);
  182. EXPECT_FALSE(predictor.Predict(example));
  183. }
  184. TEST_F(GenericLogisticRegressionInferenceTest, PreprossessedModel) {
  185. GenericLogisticRegressionModel proto = GetProto();
  186. proto.set_is_preprocessed_model(true);
  187. // Clear the weights to make sure the inference is done by fullname_weights.
  188. proto.clear_weights();
  189. // Build fullname weights.
  190. Map<std::string, float>& weights = *proto.mutable_fullname_weights();
  191. weights[scalar1_name_] = scalar1_weight_;
  192. weights[scalar2_name_] = scalar2_weight_;
  193. weights[scalar3_name_] = scalar3_weight_;
  194. weights[ExamplePreprocessor::FeatureFullname(
  195. one_hot_name_, one_hot_elem1_name_)] = one_hot_elem1_weight_;
  196. weights[ExamplePreprocessor::FeatureFullname(
  197. one_hot_name_, one_hot_elem2_name_)] = one_hot_elem2_weight_;
  198. weights[ExamplePreprocessor::FeatureFullname(
  199. one_hot_name_, one_hot_elem3_name_)] = one_hot_elem3_weight_;
  200. weights[ExamplePreprocessor::FeatureFullname(
  201. sparse_name_, sparse_elem1_name_)] = sparse_elem1_weight_;
  202. weights[ExamplePreprocessor::FeatureFullname(
  203. sparse_name_, sparse_elem2_name_)] = sparse_elem2_weight_;
  204. weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "0")] =
  205. bucketization_weights_[0];
  206. weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "1")] =
  207. bucketization_weights_[1];
  208. weights[ExamplePreprocessor::FeatureFullname(bucketized_name_, "2")] =
  209. bucketization_weights_[2];
  210. weights[ExamplePreprocessor::FeatureFullname(
  211. ExamplePreprocessor::kMissingFeatureDefaultName, one_hot_name_)] =
  212. one_hot_default_weight_;
  213. weights[ExamplePreprocessor::FeatureFullname(
  214. ExamplePreprocessor::kMissingFeatureDefaultName, sparse_name_)] =
  215. sparse_default_weight_;
  216. weights[ExamplePreprocessor::FeatureFullname(
  217. ExamplePreprocessor::kMissingFeatureDefaultName, bucketized_name_)] =
  218. bucketization_default_weight_;
  219. // Build preprocessor_config.
  220. ExamplePreprocessorConfig& config = *proto.mutable_preprocessor_config();
  221. config.add_missing_features(one_hot_name_);
  222. config.add_missing_features(sparse_name_);
  223. config.add_missing_features(bucketized_name_);
  224. (*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
  225. bucketization_boundaries_[0]);
  226. (*config.mutable_bucketizers())[bucketized_name_].add_boundaries(
  227. bucketization_boundaries_[1]);
  228. auto predictor = GenericLogisticRegressionInference(proto);
  229. // Build example.
  230. RankerExample example;
  231. Map<std::string, Feature>& features = *example.mutable_features();
  232. features[scalar1_name_].set_bool_value(true);
  233. features[scalar2_name_].set_int32_value(42);
  234. features[scalar3_name_].set_float_value(0.666f);
  235. features[one_hot_name_].set_string_value(one_hot_elem1_name_);
  236. features[sparse_name_].mutable_string_list()->add_string_value(
  237. sparse_elem1_name_);
  238. features[sparse_name_].mutable_string_list()->add_string_value(
  239. sparse_elem2_name_);
  240. features[bucketized_name_].set_float_value(0.98f);
  241. // Inference.
  242. float score = predictor.PredictScore(example);
  243. float expected_score = Sigmoid(
  244. bias_ + 1.0f * scalar1_weight_ + 42.0f * scalar2_weight_ +
  245. 0.666f * scalar3_weight_ + one_hot_elem1_weight_ + sparse_elem1_weight_ +
  246. sparse_elem2_weight_ + bucketization_weights_[2]);
  247. EXPECT_NEAR(expected_score, score, epsilon_);
  248. EXPECT_EQ(expected_score >= threshold_, predictor.Predict(example));
  249. }
  250. } // namespace assist_ranker