random_tree_trainer_unittest.cc 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/random_tree_trainer.h"
  5. #include "base/bind.h"
  6. #include "base/run_loop.h"
  7. #include "base/test/task_environment.h"
  8. #include "media/learning/impl/test_random_number_generator.h"
  9. #include "testing/gtest/include/gtest/gtest.h"
  10. namespace media {
  11. namespace learning {
  12. class RandomTreeTest : public testing::TestWithParam<LearningTask::Ordering> {
  13. public:
  14. RandomTreeTest()
  15. : rng_(0),
  16. trainer_(&rng_),
  17. ordering_(GetParam()) {}
  18. // Set up |task_| to have |n| features with the given ordering.
  19. void SetupFeatures(size_t n) {
  20. for (size_t i = 0; i < n; i++) {
  21. LearningTask::ValueDescription desc;
  22. desc.ordering = ordering_;
  23. task_.feature_descriptions.push_back(desc);
  24. }
  25. }
  26. std::unique_ptr<Model> Train(const LearningTask& task,
  27. const TrainingData& data) {
  28. std::unique_ptr<Model> model;
  29. trainer_.Train(
  30. task_, data,
  31. base::BindOnce(
  32. [](std::unique_ptr<Model>* model_out,
  33. std::unique_ptr<Model> model) { *model_out = std::move(model); },
  34. &model));
  35. task_environment_.RunUntilIdle();
  36. return model;
  37. }
  38. base::test::TaskEnvironment task_environment_;
  39. TestRandomNumberGenerator rng_;
  40. RandomTreeTrainer trainer_;
  41. LearningTask task_;
  42. // Feature ordering.
  43. LearningTask::Ordering ordering_;
  44. };
  45. TEST_P(RandomTreeTest, EmptyTrainingDataWorks) {
  46. TrainingData empty;
  47. std::unique_ptr<Model> model = Train(task_, empty);
  48. EXPECT_NE(model.get(), nullptr);
  49. EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
  50. }
  51. TEST_P(RandomTreeTest, UniformTrainingDataWorks) {
  52. SetupFeatures(2);
  53. LabelledExample example({FeatureValue(123), FeatureValue(456)},
  54. TargetValue(789));
  55. TrainingData training_data;
  56. const size_t n_examples = 10;
  57. for (size_t i = 0; i < n_examples; i++)
  58. training_data.push_back(example);
  59. std::unique_ptr<Model> model = Train(task_, training_data);
  60. // The tree should produce a distribution for one value (our target), which
  61. // has one count.
  62. TargetHistogram distribution = model->PredictDistribution(example.features);
  63. EXPECT_EQ(distribution.size(), 1u);
  64. EXPECT_EQ(distribution[example.target_value], 1.0);
  65. }
  66. TEST_P(RandomTreeTest, SimpleSeparableTrainingData) {
  67. SetupFeatures(1);
  68. TrainingData training_data;
  69. LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
  70. LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
  71. training_data.push_back(example_1);
  72. training_data.push_back(example_2);
  73. std::unique_ptr<Model> model = Train(task_, training_data);
  74. // Each value should have a distribution with one target value with one count.
  75. TargetHistogram distribution = model->PredictDistribution(example_1.features);
  76. EXPECT_NE(model.get(), nullptr);
  77. EXPECT_EQ(distribution.size(), 1u);
  78. EXPECT_EQ(distribution[example_1.target_value], 1u);
  79. distribution = model->PredictDistribution(example_2.features);
  80. EXPECT_EQ(distribution.size(), 1u);
  81. EXPECT_EQ(distribution[example_2.target_value], 1u);
  82. }
  83. TEST_P(RandomTreeTest, ComplexSeparableTrainingData) {
  84. // Building a random tree with numeric splits isn't terribly likely to work,
  85. // so just skip it. Entirely randomized splits are just too random. The
  86. // RandomForest unittests will test them as part of an ensemble.
  87. if (ordering_ == LearningTask::Ordering::kNumeric)
  88. return;
  89. SetupFeatures(4);
  90. // Build a four-feature training set that's completely separable, but one
  91. // needs all four features to do it.
  92. TrainingData training_data;
  93. for (int f1 = 0; f1 < 2; f1++) {
  94. for (int f2 = 0; f2 < 2; f2++) {
  95. for (int f3 = 0; f3 < 2; f3++) {
  96. for (int f4 = 0; f4 < 2; f4++) {
  97. LabelledExample example(
  98. {FeatureValue(f1), FeatureValue(f2), FeatureValue(f3),
  99. FeatureValue(f4)},
  100. TargetValue(f1 * 1 + f2 * 2 + f3 * 4 + f4 * 8));
  101. // Add two copies of each example.
  102. training_data.push_back(example);
  103. training_data.push_back(example);
  104. }
  105. }
  106. }
  107. }
  108. std::unique_ptr<Model> model = Train(task_, training_data);
  109. EXPECT_NE(model.get(), nullptr);
  110. // Each example should have a distribution that selects the right value.
  111. for (const LabelledExample& example : training_data) {
  112. TargetHistogram distribution = model->PredictDistribution(example.features);
  113. TargetValue singular_max;
  114. EXPECT_TRUE(distribution.FindSingularMax(&singular_max));
  115. EXPECT_EQ(singular_max, example.target_value);
  116. }
  117. }
  118. TEST_P(RandomTreeTest, UnseparableTrainingData) {
  119. SetupFeatures(1);
  120. TrainingData training_data;
  121. LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
  122. LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
  123. training_data.push_back(example_1);
  124. training_data.push_back(example_2);
  125. std::unique_ptr<Model> model = Train(task_, training_data);
  126. EXPECT_NE(model.get(), nullptr);
  127. // Each value should have a distribution with two targets with equal counts.
  128. TargetHistogram distribution = model->PredictDistribution(example_1.features);
  129. EXPECT_EQ(distribution.size(), 2u);
  130. EXPECT_EQ(distribution[example_1.target_value], 0.5);
  131. EXPECT_EQ(distribution[example_2.target_value], 0.5);
  132. distribution = model->PredictDistribution(example_2.features);
  133. EXPECT_EQ(distribution.size(), 2u);
  134. EXPECT_EQ(distribution[example_1.target_value], 0.5);
  135. EXPECT_EQ(distribution[example_2.target_value], 0.5);
  136. }
  137. TEST_P(RandomTreeTest, UnknownFeatureValueHandling) {
  138. // Verify how a previously unseen feature value is handled.
  139. SetupFeatures(1);
  140. TrainingData training_data;
  141. LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
  142. LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
  143. training_data.push_back(example_1);
  144. training_data.push_back(example_2);
  145. auto model = Train(task_, training_data);
  146. auto distribution =
  147. model->PredictDistribution(FeatureVector({FeatureValue(789)}));
  148. if (ordering_ == LearningTask::Ordering::kUnordered) {
  149. // OOV data could be split on either feature first, so we don't really know
  150. // which to expect. We assert that there should be exactly one example, but
  151. // whether it's |example_1| or |example_2| isn't clear.
  152. EXPECT_EQ(distribution.size(), 1u);
  153. EXPECT_EQ(distribution[example_1.target_value] +
  154. distribution[example_2.target_value],
  155. 1u);
  156. } else {
  157. // The unknown feature is numerically higher than |example_2|, so we
  158. // expect it to fall into that bucket.
  159. EXPECT_EQ(distribution.size(), 1u);
  160. EXPECT_EQ(distribution[example_2.target_value], 1u);
  161. }
  162. }
  163. TEST_P(RandomTreeTest, NumericFeaturesSplitMultipleTimes) {
  164. // Verify that numeric features can be split more than once in the tree.
  165. // This should also pass for nominal features, though it's less interesting.
  166. SetupFeatures(1);
  167. TrainingData training_data;
  168. const int feature_mult = 10;
  169. for (size_t i = 0; i < 4; i++) {
  170. LabelledExample example({FeatureValue(i * feature_mult)}, TargetValue(i));
  171. training_data.push_back(example);
  172. }
  173. std::unique_ptr<Model> model = Train(task_, training_data);
  174. for (size_t i = 0; i < 4; i++) {
  175. // Get a prediction for the |i|-th feature value.
  176. TargetHistogram distribution = model->PredictDistribution(
  177. FeatureVector({FeatureValue(i * feature_mult)}));
  178. // The distribution should have one count that should be correct. If
  179. // the feature isn't split four times, then some feature value will have too
  180. // many or too few counts.
  181. EXPECT_EQ(distribution.total_counts(), 1u);
  182. EXPECT_EQ(distribution[TargetValue(i)], 1u);
  183. }
  184. }
  185. INSTANTIATE_TEST_SUITE_P(RandomTreeTest,
  186. RandomTreeTest,
  187. testing::ValuesIn({LearningTask::Ordering::kUnordered,
  188. LearningTask::Ordering::kNumeric}));
  189. } // namespace learning
  190. } // namespace media