extra_trees_trainer_unittest.cc 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/extra_trees_trainer.h"
  5. #include "base/bind.h"
  6. #include "base/memory/ref_counted.h"
  7. #include "base/test/task_environment.h"
  8. #include "media/learning/impl/fisher_iris_dataset.h"
  9. #include "media/learning/impl/test_random_number_generator.h"
  10. #include "testing/gtest/include/gtest/gtest.h"
  11. namespace media {
  12. namespace learning {
  13. class ExtraTreesTest : public testing::TestWithParam<LearningTask::Ordering> {
  14. public:
  15. ExtraTreesTest() : rng_(0), ordering_(GetParam()) {
  16. trainer_.SetRandomNumberGeneratorForTesting(&rng_);
  17. }
  18. // Set up |task_| to have |n| features with the given ordering.
  19. void SetupFeatures(size_t n) {
  20. for (size_t i = 0; i < n; i++) {
  21. LearningTask::ValueDescription desc;
  22. desc.ordering = ordering_;
  23. task_.feature_descriptions.push_back(desc);
  24. }
  25. }
  26. std::unique_ptr<Model> Train(const LearningTask& task,
  27. const TrainingData& data) {
  28. std::unique_ptr<Model> model;
  29. trainer_.Train(
  30. task_, data,
  31. base::BindOnce(
  32. [](std::unique_ptr<Model>* model_out,
  33. std::unique_ptr<Model> model) { *model_out = std::move(model); },
  34. &model));
  35. task_environment_.RunUntilIdle();
  36. return model;
  37. }
  38. base::test::TaskEnvironment task_environment_;
  39. TestRandomNumberGenerator rng_;
  40. ExtraTreesTrainer trainer_;
  41. LearningTask task_;
  42. // Feature ordering.
  43. LearningTask::Ordering ordering_;
  44. };
  45. TEST_P(ExtraTreesTest, EmptyTrainingDataWorks) {
  46. TrainingData empty;
  47. auto model = Train(task_, empty);
  48. EXPECT_NE(model.get(), nullptr);
  49. EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
  50. }
  51. TEST_P(ExtraTreesTest, FisherIrisDataset) {
  52. SetupFeatures(4);
  53. FisherIrisDataset iris;
  54. TrainingData training_data = iris.GetTrainingData();
  55. auto model = Train(task_, training_data);
  56. // Verify predictions on the training set, just for sanity.
  57. size_t num_correct = 0;
  58. for (const LabelledExample& example : training_data) {
  59. TargetHistogram distribution = model->PredictDistribution(example.features);
  60. TargetValue predicted_value;
  61. if (distribution.FindSingularMax(&predicted_value) &&
  62. predicted_value == example.target_value) {
  63. num_correct += example.weight;
  64. }
  65. }
  66. // Expect very high accuracy. We should get ~100%.
  67. double train_accuracy = ((double)num_correct) / training_data.total_weight();
  68. EXPECT_GT(train_accuracy, 0.95);
  69. }
  70. TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) {
  71. // Create a training set with unseparable data, but give one of them a large
  72. // weight. See if that one wins.
  73. SetupFeatures(1);
  74. LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
  75. LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
  76. const size_t weight = 100;
  77. TrainingData training_data;
  78. example_1.weight = weight;
  79. training_data.push_back(example_1);
  80. // Push many |example_2|'s, which will win without the weights.
  81. training_data.push_back(example_2);
  82. training_data.push_back(example_2);
  83. training_data.push_back(example_2);
  84. training_data.push_back(example_2);
  85. // Create a weighed set with |weight| for each example's weight.
  86. EXPECT_FALSE(training_data.is_unweighted());
  87. auto model = Train(task_, training_data);
  88. // The singular max should be example_1.
  89. TargetHistogram distribution = model->PredictDistribution(example_1.features);
  90. TargetValue predicted_value;
  91. EXPECT_TRUE(distribution.FindSingularMax(&predicted_value));
  92. EXPECT_EQ(predicted_value, example_1.target_value);
  93. }
  94. TEST_P(ExtraTreesTest, RegressionWorks) {
  95. // Create a training set with unseparable data, but give one of them a large
  96. // weight. See if that one wins.
  97. SetupFeatures(2);
  98. LabelledExample example_1({FeatureValue(1), FeatureValue(123)},
  99. TargetValue(1));
  100. LabelledExample example_1_a({FeatureValue(1), FeatureValue(123)},
  101. TargetValue(5));
  102. LabelledExample example_2({FeatureValue(1), FeatureValue(456)},
  103. TargetValue(20));
  104. LabelledExample example_2_a({FeatureValue(1), FeatureValue(456)},
  105. TargetValue(25));
  106. TrainingData training_data;
  107. example_1.weight = 100;
  108. training_data.push_back(example_1);
  109. training_data.push_back(example_1_a);
  110. example_2.weight = 100;
  111. training_data.push_back(example_2);
  112. training_data.push_back(example_2_a);
  113. task_.target_description.ordering = LearningTask::Ordering::kNumeric;
  114. // Create a weighed set with |weight| for each example's weight.
  115. auto model = Train(task_, training_data);
  116. // Make sure that the results are in the right range.
  117. TargetHistogram distribution = model->PredictDistribution(example_1.features);
  118. EXPECT_GT(distribution.Average(), example_1.target_value.value() * 0.95);
  119. EXPECT_LT(distribution.Average(), example_1.target_value.value() * 1.05);
  120. distribution = model->PredictDistribution(example_2.features);
  121. EXPECT_GT(distribution.Average(), example_2.target_value.value() * 0.95);
  122. EXPECT_LT(distribution.Average(), example_2.target_value.value() * 1.05);
  123. }
  124. TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) {
  125. // Create a binary classification task and a regression task that are roughly
  126. // the same. Verify that the results are the same, too. In particular, for
  127. // each set of features, we choose a regression target |pct| between 0 and
  128. // 100. For the corresponding binary classification problem, we add |pct|
  129. // true instances, and 100-|pct| false instances. The predicted averages
  130. // should be roughly the same.
  131. SetupFeatures(3);
  132. TrainingData c_data, r_data;
  133. std::set<LabelledExample> r_examples;
  134. for (size_t i = 0; i < 4 * 4 * 4; i++) {
  135. FeatureValue f1(i & 3);
  136. FeatureValue f2((i >> 2) & 3);
  137. FeatureValue f3((i >> 4) & 3);
  138. int frac = (1.0 * (f1.value() + f2.value() + f3.value())) / 9;
  139. LabelledExample e({f1, f2, f3}, TargetValue(0));
  140. // TODO(liberato): Consider adding noise, and verifying that the model
  141. // predictions are roughly the same as each other, rather than the same as
  142. // the currently noise-free target.
  143. // Push some number of false and some number of true instances that is in
  144. // the right ratio for |frac|.
  145. const int total_examples = 100;
  146. const int positive_examples = total_examples * frac;
  147. e.weight = total_examples - positive_examples;
  148. if (e.weight > 0)
  149. c_data.push_back(e);
  150. e.target_value = TargetValue(1.0);
  151. e.weight = positive_examples;
  152. if (e.weight > 0)
  153. c_data.push_back(e);
  154. // For the regression data, add an example with |frac| directly. Also save
  155. // it so that we can look up the right answer below.
  156. LabelledExample r_example(LabelledExample({f1, f2, f3}, TargetValue(frac)));
  157. r_examples.insert(r_example);
  158. r_data.push_back(r_example);
  159. }
  160. // Train a model on the binary classification task and the regression task.
  161. auto c_model = Train(task_, c_data);
  162. task_.target_description.ordering = LearningTask::Ordering::kNumeric;
  163. auto r_model = Train(task_, r_data);
  164. // Verify that, for all feature combinations, the models roughly agree. Since
  165. // the data is separable, it probably should be exact.
  166. for (auto& r_example : r_examples) {
  167. const FeatureVector& fv = r_example.features;
  168. TargetHistogram c_dist = c_model->PredictDistribution(fv);
  169. EXPECT_LE(c_dist.Average(), r_example.target_value.value() * 1.05);
  170. EXPECT_GE(c_dist.Average(), r_example.target_value.value() * 0.95);
  171. TargetHistogram r_dist = r_model->PredictDistribution(fv);
  172. EXPECT_LE(r_dist.Average(), r_example.target_value.value() * 1.05);
  173. EXPECT_GE(r_dist.Average(), r_example.target_value.value() * 0.95);
  174. }
  175. }
  176. INSTANTIATE_TEST_SUITE_P(ExtraTreesTest,
  177. ExtraTreesTest,
  178. testing::ValuesIn({LearningTask::Ordering::kUnordered,
  179. LearningTask::Ordering::kNumeric}));
  180. } // namespace learning
  181. } // namespace media