123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- // Copyright 2018 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "media/learning/impl/extra_trees_trainer.h"
- #include "base/bind.h"
- #include "base/memory/ref_counted.h"
- #include "base/test/task_environment.h"
- #include "media/learning/impl/fisher_iris_dataset.h"
- #include "media/learning/impl/test_random_number_generator.h"
- #include "testing/gtest/include/gtest/gtest.h"
- namespace media {
- namespace learning {
- class ExtraTreesTest : public testing::TestWithParam<LearningTask::Ordering> {
- public:
- ExtraTreesTest() : rng_(0), ordering_(GetParam()) {
- trainer_.SetRandomNumberGeneratorForTesting(&rng_);
- }
- // Set up |task_| to have |n| features with the given ordering.
- void SetupFeatures(size_t n) {
- for (size_t i = 0; i < n; i++) {
- LearningTask::ValueDescription desc;
- desc.ordering = ordering_;
- task_.feature_descriptions.push_back(desc);
- }
- }
- std::unique_ptr<Model> Train(const LearningTask& task,
- const TrainingData& data) {
- std::unique_ptr<Model> model;
- trainer_.Train(
- task_, data,
- base::BindOnce(
- [](std::unique_ptr<Model>* model_out,
- std::unique_ptr<Model> model) { *model_out = std::move(model); },
- &model));
- task_environment_.RunUntilIdle();
- return model;
- }
- base::test::TaskEnvironment task_environment_;
- TestRandomNumberGenerator rng_;
- ExtraTreesTrainer trainer_;
- LearningTask task_;
- // Feature ordering.
- LearningTask::Ordering ordering_;
- };
- TEST_P(ExtraTreesTest, EmptyTrainingDataWorks) {
- TrainingData empty;
- auto model = Train(task_, empty);
- EXPECT_NE(model.get(), nullptr);
- EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
- }
- TEST_P(ExtraTreesTest, FisherIrisDataset) {
- SetupFeatures(4);
- FisherIrisDataset iris;
- TrainingData training_data = iris.GetTrainingData();
- auto model = Train(task_, training_data);
- // Verify predictions on the training set, just for sanity.
- size_t num_correct = 0;
- for (const LabelledExample& example : training_data) {
- TargetHistogram distribution = model->PredictDistribution(example.features);
- TargetValue predicted_value;
- if (distribution.FindSingularMax(&predicted_value) &&
- predicted_value == example.target_value) {
- num_correct += example.weight;
- }
- }
- // Expect very high accuracy. We should get ~100%.
- double train_accuracy = ((double)num_correct) / training_data.total_weight();
- EXPECT_GT(train_accuracy, 0.95);
- }
- TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) {
- // Create a training set with unseparable data, but give one of them a large
- // weight. See if that one wins.
- SetupFeatures(1);
- LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
- LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
- const size_t weight = 100;
- TrainingData training_data;
- example_1.weight = weight;
- training_data.push_back(example_1);
- // Push many |example_2|'s, which will win without the weights.
- training_data.push_back(example_2);
- training_data.push_back(example_2);
- training_data.push_back(example_2);
- training_data.push_back(example_2);
- // Create a weighed set with |weight| for each example's weight.
- EXPECT_FALSE(training_data.is_unweighted());
- auto model = Train(task_, training_data);
- // The singular max should be example_1.
- TargetHistogram distribution = model->PredictDistribution(example_1.features);
- TargetValue predicted_value;
- EXPECT_TRUE(distribution.FindSingularMax(&predicted_value));
- EXPECT_EQ(predicted_value, example_1.target_value);
- }
- TEST_P(ExtraTreesTest, RegressionWorks) {
- // Create a training set with unseparable data, but give one of them a large
- // weight. See if that one wins.
- SetupFeatures(2);
- LabelledExample example_1({FeatureValue(1), FeatureValue(123)},
- TargetValue(1));
- LabelledExample example_1_a({FeatureValue(1), FeatureValue(123)},
- TargetValue(5));
- LabelledExample example_2({FeatureValue(1), FeatureValue(456)},
- TargetValue(20));
- LabelledExample example_2_a({FeatureValue(1), FeatureValue(456)},
- TargetValue(25));
- TrainingData training_data;
- example_1.weight = 100;
- training_data.push_back(example_1);
- training_data.push_back(example_1_a);
- example_2.weight = 100;
- training_data.push_back(example_2);
- training_data.push_back(example_2_a);
- task_.target_description.ordering = LearningTask::Ordering::kNumeric;
- // Create a weighed set with |weight| for each example's weight.
- auto model = Train(task_, training_data);
- // Make sure that the results are in the right range.
- TargetHistogram distribution = model->PredictDistribution(example_1.features);
- EXPECT_GT(distribution.Average(), example_1.target_value.value() * 0.95);
- EXPECT_LT(distribution.Average(), example_1.target_value.value() * 1.05);
- distribution = model->PredictDistribution(example_2.features);
- EXPECT_GT(distribution.Average(), example_2.target_value.value() * 0.95);
- EXPECT_LT(distribution.Average(), example_2.target_value.value() * 1.05);
- }
- TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) {
- // Create a binary classification task and a regression task that are roughly
- // the same. Verify that the results are the same, too. In particular, for
- // each set of features, we choose a regression target |pct| between 0 and
- // 100. For the corresponding binary classification problem, we add |pct|
- // true instances, and 100-|pct| false instances. The predicted averages
- // should be roughly the same.
- SetupFeatures(3);
- TrainingData c_data, r_data;
- std::set<LabelledExample> r_examples;
- for (size_t i = 0; i < 4 * 4 * 4; i++) {
- FeatureValue f1(i & 3);
- FeatureValue f2((i >> 2) & 3);
- FeatureValue f3((i >> 4) & 3);
- int frac = (1.0 * (f1.value() + f2.value() + f3.value())) / 9;
- LabelledExample e({f1, f2, f3}, TargetValue(0));
- // TODO(liberato): Consider adding noise, and verifying that the model
- // predictions are roughly the same as each other, rather than the same as
- // the currently noise-free target.
- // Push some number of false and some number of true instances that is in
- // the right ratio for |frac|.
- const int total_examples = 100;
- const int positive_examples = total_examples * frac;
- e.weight = total_examples - positive_examples;
- if (e.weight > 0)
- c_data.push_back(e);
- e.target_value = TargetValue(1.0);
- e.weight = positive_examples;
- if (e.weight > 0)
- c_data.push_back(e);
- // For the regression data, add an example with |frac| directly. Also save
- // it so that we can look up the right answer below.
- LabelledExample r_example(LabelledExample({f1, f2, f3}, TargetValue(frac)));
- r_examples.insert(r_example);
- r_data.push_back(r_example);
- }
- // Train a model on the binary classification task and the regression task.
- auto c_model = Train(task_, c_data);
- task_.target_description.ordering = LearningTask::Ordering::kNumeric;
- auto r_model = Train(task_, r_data);
- // Verify that, for all feature combinations, the models roughly agree. Since
- // the data is separable, it probably should be exact.
- for (auto& r_example : r_examples) {
- const FeatureVector& fv = r_example.features;
- TargetHistogram c_dist = c_model->PredictDistribution(fv);
- EXPECT_LE(c_dist.Average(), r_example.target_value.value() * 1.05);
- EXPECT_GE(c_dist.Average(), r_example.target_value.value() * 0.95);
- TargetHistogram r_dist = r_model->PredictDistribution(fv);
- EXPECT_LE(r_dist.Average(), r_example.target_value.value() * 1.05);
- EXPECT_GE(r_dist.Average(), r_example.target_value.value() * 0.95);
- }
- }
- INSTANTIATE_TEST_SUITE_P(ExtraTreesTest,
- ExtraTreesTest,
- testing::ValuesIn({LearningTask::Ordering::kUnordered,
- LearningTask::Ordering::kNumeric}));
- } // namespace learning
- } // namespace media
|