ranker_example_util_unittest.cc 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/assist_ranker/ranker_example_util.h"
  5. #include "testing/gmock/include/gmock/gmock.h"
  6. #include "testing/gtest/include/gtest/gtest.h"
  7. namespace assist_ranker {
  8. using ::testing::ElementsAreArray;
  9. class RankerExampleUtilTest : public ::testing::Test {
  10. protected:
  11. void SetUp() override {
  12. auto& features = *example_.mutable_features();
  13. features[bool_name_].set_bool_value(bool_value_);
  14. features[int32_name_].set_int32_value(int32_value_);
  15. features[float_name_].set_float_value(float_value_);
  16. features[one_hot_name_].set_string_value(one_hot_value_);
  17. }
  18. RankerExample example_;
  19. const std::string bool_name_ = "bool_feature";
  20. const bool bool_value_ = true;
  21. const std::string int32_name_ = "int32_feature";
  22. const int int32_value_ = 2;
  23. const std::string float_name_ = "float_feature";
  24. const float float_value_ = 3.0f;
  25. const std::string one_hot_name_ = "one_hot_feature";
  26. const std::string elem1_ = "elem1";
  27. const std::string elem2_ = "elem2";
  28. const std::string one_hot_value_ = elem1_;
  29. const float epsilon_ = 0.00000001f;
  30. };
  31. TEST_F(RankerExampleUtilTest, CheckFeature) {
  32. EXPECT_TRUE(SafeGetFeature(bool_name_, example_, nullptr));
  33. EXPECT_TRUE(SafeGetFeature(int32_name_, example_, nullptr));
  34. EXPECT_TRUE(SafeGetFeature(float_name_, example_, nullptr));
  35. EXPECT_TRUE(SafeGetFeature(one_hot_name_, example_, nullptr));
  36. EXPECT_FALSE(SafeGetFeature("", example_, nullptr));
  37. EXPECT_FALSE(SafeGetFeature("foo", example_, nullptr));
  38. }
  39. TEST_F(RankerExampleUtilTest, SafeGetFeature) {
  40. Feature feature;
  41. EXPECT_TRUE(SafeGetFeature(bool_name_, example_, &feature));
  42. EXPECT_TRUE(feature.bool_value());
  43. feature.Clear();
  44. EXPECT_TRUE(SafeGetFeature(int32_name_, example_, &feature));
  45. EXPECT_EQ(int32_value_, feature.int32_value());
  46. feature.Clear();
  47. EXPECT_TRUE(SafeGetFeature(float_name_, example_, &feature));
  48. EXPECT_NEAR(float_value_, feature.float_value(), epsilon_);
  49. feature.Clear();
  50. EXPECT_TRUE(SafeGetFeature(one_hot_name_, example_, &feature));
  51. EXPECT_EQ(one_hot_value_, feature.string_value());
  52. feature.Clear();
  53. EXPECT_FALSE(SafeGetFeature("", example_, &feature));
  54. EXPECT_FALSE(SafeGetFeature("foo", example_, &feature));
  55. }
  56. TEST_F(RankerExampleUtilTest, GetFeatureValueAsFloat) {
  57. float value;
  58. EXPECT_TRUE(GetFeatureValueAsFloat(bool_name_, example_, &value));
  59. EXPECT_NEAR(1.0f, value, epsilon_);
  60. EXPECT_TRUE(GetFeatureValueAsFloat(int32_name_, example_, &value));
  61. EXPECT_NEAR(2.0f, value, epsilon_);
  62. EXPECT_TRUE(GetFeatureValueAsFloat(float_name_, example_, &value));
  63. EXPECT_NEAR(3.0f, value, epsilon_);
  64. EXPECT_FALSE(GetFeatureValueAsFloat(one_hot_name_, example_, &value));
  65. // Value remains unchanged if GetFeatureValueAsFloat returns false.
  66. EXPECT_NEAR(3.0f, value, epsilon_);
  67. EXPECT_FALSE(GetFeatureValueAsFloat("", example_, &value));
  68. EXPECT_FALSE(GetFeatureValueAsFloat("foo", example_, &value));
  69. }
  70. TEST_F(RankerExampleUtilTest, GetOneHotValue) {
  71. std::string value;
  72. EXPECT_FALSE(GetOneHotValue(bool_name_, example_, &value));
  73. EXPECT_FALSE(GetOneHotValue(int32_name_, example_, &value));
  74. EXPECT_FALSE(GetOneHotValue(float_name_, example_, &value));
  75. EXPECT_TRUE(GetOneHotValue(one_hot_name_, example_, &value));
  76. EXPECT_EQ(one_hot_value_, value);
  77. EXPECT_FALSE(GetOneHotValue("", example_, &value));
  78. EXPECT_FALSE(GetOneHotValue("foo", example_, &value));
  79. }
  80. TEST_F(RankerExampleUtilTest, ScalarFeatureInt64Conversion) {
  81. Feature feature;
  82. int64_t int64_value;
  83. feature.set_bool_value(true);
  84. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  85. EXPECT_EQ(int64_value, 72057594037927937LL);
  86. feature.set_int32_value(std::numeric_limits<int32_t>::max());
  87. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  88. EXPECT_EQ(int64_value, 216172784261267455LL);
  89. feature.set_int32_value(std::numeric_limits<int32_t>::lowest());
  90. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  91. EXPECT_EQ(int64_value, 216172784261267456LL);
  92. feature.set_string_value("foo");
  93. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  94. EXPECT_EQ(int64_value, 288230377439557724LL);
  95. }
  96. TEST_F(RankerExampleUtilTest, FloatFeatureInt64Conversion) {
  97. Feature feature;
  98. int64_t int64_value;
  99. feature.set_float_value(std::numeric_limits<float>::epsilon());
  100. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  101. EXPECT_EQ(int64_value, 144115188948271104LL);
  102. feature.set_float_value(-std::numeric_limits<float>::epsilon());
  103. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  104. EXPECT_EQ(int64_value, 144115191095754752LL);
  105. feature.set_float_value(std::numeric_limits<float>::max());
  106. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  107. EXPECT_EQ(int64_value, 144115190214950911LL);
  108. feature.set_float_value(std::numeric_limits<float>::lowest());
  109. EXPECT_TRUE(FeatureToInt64(feature, &int64_value));
  110. EXPECT_EQ(int64_value, 144115192362434559LL);
  111. }
  112. TEST_F(RankerExampleUtilTest, StringListInt64Conversion) {
  113. Feature feature;
  114. int64_t int64_value;
  115. feature.mutable_string_list()->add_string_value("");
  116. feature.mutable_string_list()->add_string_value("TEST");
  117. EXPECT_TRUE(FeatureToInt64(feature, &int64_value, 1));
  118. EXPECT_EQ(int64_value, 360287974776690660LL);
  119. }
  120. TEST_F(RankerExampleUtilTest, HashExampleFeatureNames) {
  121. auto hashed_example = HashExampleFeatureNames(example_);
  122. // Hashed example has the same number of features.
  123. EXPECT_EQ(example_.features().size(), hashed_example.features().size());
  124. // But the feature names have changed.
  125. EXPECT_FALSE(SafeGetFeature(bool_name_, hashed_example, nullptr));
  126. EXPECT_FALSE(SafeGetFeature(int32_name_, hashed_example, nullptr));
  127. EXPECT_FALSE(SafeGetFeature(float_name_, hashed_example, nullptr));
  128. EXPECT_FALSE(SafeGetFeature(one_hot_name_, hashed_example, nullptr));
  129. EXPECT_TRUE(
  130. SafeGetFeature(HashFeatureName(bool_name_), hashed_example, nullptr));
  131. // Values have not changed.
  132. float float_value;
  133. EXPECT_TRUE(GetFeatureValueAsFloat(HashFeatureName(float_name_),
  134. hashed_example, &float_value));
  135. EXPECT_EQ(float_value_, float_value);
  136. std::string string_value;
  137. EXPECT_TRUE(GetOneHotValue(HashFeatureName(one_hot_name_), hashed_example,
  138. &string_value));
  139. EXPECT_EQ(one_hot_value_, string_value);
  140. }
  141. } // namespace assist_ranker