one_hot_unittest.cc 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/one_hot.h"
  5. #include "testing/gtest/include/gtest/gtest.h"
  6. namespace media {
  7. namespace learning {
  8. class OneHotTest : public testing::Test {
  9. public:
  10. OneHotTest() {}
  11. };
  12. TEST_F(OneHotTest, EmptyLearningTaskWorks) {
  13. LearningTask empty_task("EmptyTask", LearningTask::Model::kExtraTrees, {},
  14. LearningTask::ValueDescription({"target"}));
  15. TrainingData empty_training_data;
  16. OneHotConverter one_hot(empty_task, empty_training_data);
  17. EXPECT_EQ(one_hot.converted_task().feature_descriptions.size(), 0u);
  18. }
  19. TEST_F(OneHotTest, SimpleConversionWorks) {
  20. LearningTask task("SimpleTask", LearningTask::Model::kExtraTrees,
  21. {{"feature1", LearningTask::Ordering::kUnordered}},
  22. LearningTask::ValueDescription({"target"}));
  23. TrainingData training_data;
  24. training_data.push_back({{FeatureValue("abc")}, TargetValue(0)});
  25. training_data.push_back({{FeatureValue("def")}, TargetValue(1)});
  26. training_data.push_back({{FeatureValue("ghi")}, TargetValue(2)});
  27. // Push a duplicate as the last one.
  28. training_data.push_back({{FeatureValue("def")}, TargetValue(3)});
  29. OneHotConverter one_hot(task, training_data);
  30. // There should be one feature for each distinct value in features[0].
  31. const size_t adjusted_feature_size = 3u;
  32. EXPECT_EQ(one_hot.converted_task().feature_descriptions.size(),
  33. adjusted_feature_size);
  34. EXPECT_EQ(one_hot.converted_task().feature_descriptions[0].ordering,
  35. LearningTask::Ordering::kNumeric);
  36. EXPECT_EQ(one_hot.converted_task().feature_descriptions[1].ordering,
  37. LearningTask::Ordering::kNumeric);
  38. EXPECT_EQ(one_hot.converted_task().feature_descriptions[2].ordering,
  39. LearningTask::Ordering::kNumeric);
  40. TrainingData converted_training_data = one_hot.Convert(training_data);
  41. EXPECT_EQ(converted_training_data.size(), training_data.size());
  42. // Exactly one feature should be 1.
  43. for (size_t i = 0; i < converted_training_data.size(); i++) {
  44. EXPECT_EQ(converted_training_data[i].features[0].value() +
  45. converted_training_data[i].features[1].value() +
  46. converted_training_data[i].features[2].value(),
  47. 1);
  48. }
  49. // Each of the first three training examples should have distinct vectors.
  50. for (size_t f = 0; f < adjusted_feature_size; f++) {
  51. int num_ones = 0;
  52. // 3u is the number of distinct examples. [3] is a duplicate.
  53. for (size_t i = 0; i < 3u; i++)
  54. num_ones += converted_training_data[i].features[f].value();
  55. EXPECT_EQ(num_ones, 1);
  56. }
  57. // The features of examples 1 and 3 should be the same.
  58. for (size_t f = 0; f < adjusted_feature_size; f++) {
  59. EXPECT_EQ(converted_training_data[1].features[f],
  60. converted_training_data[3].features[f]);
  61. }
  62. // Converting each feature vector should result in the same one as before.
  63. for (size_t f = 0; f < adjusted_feature_size; f++) {
  64. FeatureVector converted_feature_vector =
  65. one_hot.Convert(training_data[f].features);
  66. EXPECT_EQ(converted_feature_vector, converted_training_data[f].features);
  67. }
  68. }
  69. TEST_F(OneHotTest, NumericsAreNotConverted) {
  70. LearningTask task("SimpleTask", LearningTask::Model::kExtraTrees,
  71. {{"feature1", LearningTask::Ordering::kNumeric}},
  72. LearningTask::ValueDescription({"target"}));
  73. OneHotConverter one_hot(task, TrainingData());
  74. EXPECT_EQ(one_hot.converted_task().feature_descriptions.size(), 1u);
  75. EXPECT_EQ(one_hot.converted_task().feature_descriptions[0].ordering,
  76. LearningTask::Ordering::kNumeric);
  77. TrainingData training_data;
  78. training_data.push_back({{FeatureValue(5)}, TargetValue(0)});
  79. TrainingData converted_training_data = one_hot.Convert(training_data);
  80. EXPECT_EQ(converted_training_data[0], training_data[0]);
  81. FeatureVector converted_feature_vector =
  82. one_hot.Convert(training_data[0].features);
  83. EXPECT_EQ(converted_feature_vector, training_data[0].features);
  84. }
  85. TEST_F(OneHotTest, UnknownValuesAreZeroHot) {
  86. LearningTask task("SimpleTask", LearningTask::Model::kExtraTrees,
  87. {{"feature1", LearningTask::Ordering::kUnordered}},
  88. LearningTask::ValueDescription({"target"}));
  89. TrainingData training_data;
  90. training_data.push_back({{FeatureValue("abc")}, TargetValue(0)});
  91. training_data.push_back({{FeatureValue("def")}, TargetValue(1)});
  92. training_data.push_back({{FeatureValue("ghi")}, TargetValue(2)});
  93. OneHotConverter one_hot(task, training_data);
  94. // Send in an unknown value, and see if it becomes {0, 0, 0}.
  95. FeatureVector converted_feature_vector =
  96. one_hot.Convert(FeatureVector({FeatureValue("jkl")}));
  97. EXPECT_EQ(converted_feature_vector.size(), 3u);
  98. for (size_t i = 0; i < converted_feature_vector.size(); i++)
  99. EXPECT_EQ(converted_feature_vector[i], FeatureValue(0));
  100. }
  101. } // namespace learning
  102. } // namespace media