random_tree_trainer.cc 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/random_tree_trainer.h"
  5. #include <math.h>
  6. #include "base/bind.h"
  7. #include "base/check_op.h"
  8. #include "base/threading/sequenced_task_runner_handle.h"
  9. #include "third_party/abseil-cpp/absl/types/optional.h"
  10. namespace media {
  11. namespace learning {
  12. RandomTreeTrainer::Split::Split() = default;
  13. RandomTreeTrainer::Split::Split(int index) : split_index(index) {}
  14. RandomTreeTrainer::Split::Split(Split&& rhs) = default;
  15. RandomTreeTrainer::Split::~Split() = default;
  16. RandomTreeTrainer::Split& RandomTreeTrainer::Split::operator=(Split&& rhs) =
  17. default;
  18. RandomTreeTrainer::Split::BranchInfo::BranchInfo() = default;
  19. RandomTreeTrainer::Split::BranchInfo::BranchInfo(BranchInfo&& rhs) = default;
  20. RandomTreeTrainer::Split::BranchInfo::~BranchInfo() = default;
  21. struct InteriorNode : public Model {
  22. InteriorNode(const LearningTask& task,
  23. int split_index,
  24. FeatureValue split_point)
  25. : split_index_(split_index),
  26. ordering_(task.feature_descriptions[split_index].ordering),
  27. split_point_(split_point) {}
  28. // Model
  29. TargetHistogram PredictDistribution(const FeatureVector& features) override {
  30. // Figure out what feature value we should use for the split.
  31. FeatureValue f;
  32. switch (ordering_) {
  33. case LearningTask::Ordering::kUnordered:
  34. // Use 0 for "!=" and 1 for "==".
  35. f = FeatureValue(features[split_index_] == split_point_);
  36. break;
  37. case LearningTask::Ordering::kNumeric:
  38. // Use 0 for "<=" and 1 for ">".
  39. f = FeatureValue(features[split_index_] > split_point_);
  40. break;
  41. }
  42. auto iter = children_.find(f);
  43. // If we've never seen this feature value, then return nothing.
  44. if (iter == children_.end())
  45. return TargetHistogram();
  46. return iter->second->PredictDistribution(features);
  47. }
  48. TargetHistogram PredictDistributionWithMissingValues(
  49. const FeatureVector& features) {
  50. TargetHistogram total;
  51. for (auto& child_pair : children_) {
  52. TargetHistogram predicted =
  53. child_pair.second->PredictDistribution(features);
  54. // TODO(liberato): Normalize? Weight?
  55. total += predicted;
  56. }
  57. return total;
  58. }
  59. // Add |child| has the node for feature value |v|.
  60. void AddChild(FeatureValue v, std::unique_ptr<Model> child) {
  61. DCHECK(!children_.contains(v));
  62. children_.emplace(v, std::move(child));
  63. }
  64. private:
  65. // Feature value that we split on.
  66. int split_index_ = -1;
  67. base::flat_map<FeatureValue, std::unique_ptr<Model>> children_;
  68. // How is our feature value ordered?
  69. LearningTask::Ordering ordering_;
  70. // For kNumeric features, this is the split point.
  71. FeatureValue split_point_;
  72. };
  73. struct LeafNode : public Model {
  74. LeafNode(const TrainingData& training_data,
  75. const std::vector<size_t> training_idx,
  76. LearningTask::Ordering ordering) {
  77. for (size_t idx : training_idx)
  78. distribution_ += training_data[idx];
  79. // Each leaf gets one vote.
  80. // See https://en.wikipedia.org/wiki/Bootstrap_aggregating . TL;DR: the
  81. // individual trees should average (regression) or vote (classification).
  82. //
  83. // TODO(liberato): It's unclear that a leaf should get to vote with an
  84. // entire distribution; we might want to take the max for kUnordered here.
  85. // If so, then we might also want to Average() for kNumeric targets, though
  86. // in that case, the results would be the same anyway. That's not, of
  87. // course, guaranteed for all methods of converting |distribution_| into a
  88. // numeric prediction. In general, we should provide a single estimate.
  89. distribution_.Normalize();
  90. }
  91. // TreeNode
  92. TargetHistogram PredictDistribution(const FeatureVector&) override {
  93. return distribution_;
  94. }
  95. private:
  96. TargetHistogram distribution_;
  97. };
  98. RandomTreeTrainer::RandomTreeTrainer(RandomNumberGenerator* rng)
  99. : HasRandomNumberGenerator(rng) {}
  100. RandomTreeTrainer::~RandomTreeTrainer() {}
  101. void RandomTreeTrainer::Train(const LearningTask& task,
  102. const TrainingData& training_data,
  103. TrainedModelCB model_cb) {
  104. // Start with all the training data.
  105. std::vector<size_t> training_idx;
  106. training_idx.reserve(training_data.size());
  107. for (size_t idx = 0; idx < training_data.size(); idx++)
  108. training_idx.push_back(idx);
  109. // It's a little odd that we don't post training. Perhaps we should.
  110. auto model = Train(task, training_data, training_idx);
  111. base::SequencedTaskRunnerHandle::Get()->PostTask(
  112. FROM_HERE, base::BindOnce(std::move(model_cb), std::move(model)));
  113. }
  114. std::unique_ptr<Model> RandomTreeTrainer::Train(
  115. const LearningTask& task,
  116. const TrainingData& training_data,
  117. const std::vector<size_t>& training_idx) {
  118. if (training_data.empty()) {
  119. return std::make_unique<LeafNode>(training_data, std::vector<size_t>(),
  120. LearningTask::Ordering::kUnordered);
  121. }
  122. DCHECK_EQ(task.feature_descriptions.size(), training_data[0].features.size());
  123. // Start with all features unused.
  124. FeatureSet unused_set;
  125. for (size_t idx = 0; idx < task.feature_descriptions.size(); idx++)
  126. unused_set.insert(idx);
  127. return Build(task, training_data, training_idx, unused_set);
  128. }
  129. std::unique_ptr<Model> RandomTreeTrainer::Build(
  130. const LearningTask& task,
  131. const TrainingData& training_data,
  132. const std::vector<size_t>& training_idx,
  133. const FeatureSet& unused_set) {
  134. DCHECK_GT(training_idx.size(), 0u);
  135. // TODO: enforce a minimum number of samples. ExtraTrees uses 2 for
  136. // classification, and 5 for regression.
  137. // Remove any constant attributes in |training_data| from |unused_set|. Also
  138. // check if our training data has a constant target value. For both features
  139. // and the target value, if the Optional has a value then it's the singular
  140. // value that we've found so far. If we find a second one, then we'll clear
  141. // the Optional.
  142. absl::optional<TargetValue> target_value(
  143. training_data[training_idx[0]].target_value);
  144. std::vector<absl::optional<FeatureValue>> feature_values;
  145. feature_values.resize(training_data[0].features.size());
  146. for (size_t feature_idx : unused_set) {
  147. feature_values[feature_idx] =
  148. training_data[training_idx[0]].features[feature_idx];
  149. }
  150. for (size_t idx : training_idx) {
  151. const LabelledExample& example = training_data[idx];
  152. // Record this target value to see if there is more than one. We skip the
  153. // insertion if we've already determined that it's not constant.
  154. if (target_value && target_value != example.target_value)
  155. target_value.reset();
  156. // For all features in |unused_set|, see if it's a constant in our subset of
  157. // the training data.
  158. for (size_t feature_idx : unused_set) {
  159. auto& value = feature_values[feature_idx];
  160. if (value && *value != example.features[feature_idx])
  161. value.reset();
  162. }
  163. }
  164. // Is the output constant in |training_data|? If so, then generate a leaf.
  165. // If we're not normalizing leaves, then this matters since this training data
  166. // might be split across multiple leaves.
  167. if (target_value) {
  168. return std::make_unique<LeafNode>(training_data, training_idx,
  169. task.target_description.ordering);
  170. }
  171. // Remove any constant features from the unused set, so that we don't try to
  172. // split on them. It would work, but it would be trivially useless. We also
  173. // don't want to use one of our potential splits on it.
  174. FeatureSet new_unused_set = unused_set;
  175. for (size_t feature_idx : unused_set) {
  176. auto& value = feature_values[feature_idx];
  177. if (value)
  178. new_unused_set.erase(feature_idx);
  179. }
  180. // Select the feature subset to consider at this leaf.
  181. // TODO(liberato): For nominals, with one-hot encoding, we'd give an equal
  182. // chance to each feature's value. For example, if F1 has {A, B} and F2 has
  183. // {C,D,E,F}, then we would pick uniformly over {A,B,C,D,E,F}. However, now
  184. // we pick between {F1, F2} then pick between either {A,B} or {C,D,E,F}. We
  185. // do this because it's simpler and doesn't seem to hurt anything.
  186. FeatureSet feature_candidates = new_unused_set;
  187. // TODO(liberato): Let our caller override this.
  188. const size_t features_per_split =
  189. std::max(static_cast<int>(sqrt(feature_candidates.size())), 3);
  190. // Note that it's okay if there are fewer features left; we'll select all of
  191. // them instead.
  192. while (feature_candidates.size() > features_per_split) {
  193. // Remove a random feature.
  194. size_t which = rng()->Generate(feature_candidates.size());
  195. auto iter = feature_candidates.begin();
  196. for (; which; which--, iter++)
  197. ;
  198. feature_candidates.erase(iter);
  199. }
  200. // TODO(liberato): Does it help if we refuse to split without an info gain?
  201. Split best_potential_split;
  202. // Find the best split among the candidates that we have.
  203. for (int i : feature_candidates) {
  204. Split potential_split =
  205. ConstructSplit(task, training_data, training_idx, i);
  206. if (potential_split.nats_remaining < best_potential_split.nats_remaining) {
  207. best_potential_split = std::move(potential_split);
  208. }
  209. }
  210. // Note that we can have a split with no index (i.e., no features left, or no
  211. // feature was an improvement in nats), or with a single index (had features,
  212. // but all had the same value). Either way, we should end up with a leaf.
  213. if (best_potential_split.branch_infos.size() < 2) {
  214. // Stop when there is no more tree.
  215. return std::make_unique<LeafNode>(training_data, training_idx,
  216. task.target_description.ordering);
  217. }
  218. // Build an interior node
  219. std::unique_ptr<InteriorNode> node = std::make_unique<InteriorNode>(
  220. task, best_potential_split.split_index, best_potential_split.split_point);
  221. for (auto& branch_iter : best_potential_split.branch_infos) {
  222. node->AddChild(branch_iter.first,
  223. Build(task, training_data, branch_iter.second.training_idx,
  224. new_unused_set));
  225. }
  226. return node;
  227. }
  228. RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
  229. const LearningTask& task,
  230. const TrainingData& training_data,
  231. const std::vector<size_t>& training_idx,
  232. int split_index) {
  233. // We should not be given a training set of size 0, since there's no need to
  234. // check an empty split.
  235. DCHECK_GT(training_idx.size(), 0u);
  236. Split split(split_index);
  237. bool is_numeric = task.feature_descriptions[split_index].ordering ==
  238. LearningTask::Ordering::kNumeric;
  239. // TODO(liberato): Consider removing nominal feature support and RF. That
  240. // would make this code somewhat simpler.
  241. // For a numeric split, find the split point. Otherwise, we'll split on every
  242. // nominal value that this feature has in |training_data|.
  243. if (is_numeric) {
  244. split.split_point =
  245. FindSplitPoint_Numeric(split.split_index, training_data, training_idx);
  246. } else {
  247. split.split_point =
  248. FindSplitPoint_Nominal(split.split_index, training_data, training_idx);
  249. }
  250. // Find the split's feature values and construct the training set for each.
  251. // I think we want to iterate on the underlying vector, and look up the int in
  252. // the training data directly. |total_weight| will hold the total weight of
  253. // all examples that come into this node.
  254. double total_weight = 0.;
  255. for (size_t idx : training_idx) {
  256. const LabelledExample& example = training_data[idx];
  257. total_weight += example.weight;
  258. // Get the value of the |index|-th feature for |example|.
  259. FeatureValue v_i = example.features[split.split_index];
  260. // Figure out what value this example would use for splitting. For nominal,
  261. // it's 1 or 0, based on whether |v_i| is equal to the split or not. For
  262. // numeric, it's whether |v_i| is <= the split point or not (0 for <=, and 1
  263. // for >).
  264. FeatureValue split_feature;
  265. if (is_numeric)
  266. split_feature = FeatureValue(v_i > split.split_point);
  267. else
  268. split_feature = FeatureValue(v_i == split.split_point);
  269. // Add |v_i| to the right training set. Remember that emplace will do
  270. // nothing if the key already exists.
  271. auto result =
  272. split.branch_infos.emplace(split_feature, Split::BranchInfo());
  273. auto iter = result.first;
  274. Split::BranchInfo& branch_info = iter->second;
  275. branch_info.training_idx.push_back(idx);
  276. branch_info.target_histogram += example;
  277. }
  278. // Figure out how good / bad this split is.
  279. switch (task.target_description.ordering) {
  280. case LearningTask::Ordering::kUnordered:
  281. ComputeSplitScore_Nominal(&split, total_weight);
  282. break;
  283. case LearningTask::Ordering::kNumeric:
  284. ComputeSplitScore_Numeric(&split, total_weight);
  285. break;
  286. }
  287. return split;
  288. }
  289. void RandomTreeTrainer::ComputeSplitScore_Nominal(
  290. Split* split,
  291. double total_incoming_weight) {
  292. // Compute the nats given that we're at this node.
  293. split->nats_remaining = 0;
  294. for (auto& info_iter : split->branch_infos) {
  295. Split::BranchInfo& branch_info = info_iter.second;
  296. // |weight_along_branch| is the total weight of examples that would follow
  297. // this branch in the tree.
  298. const double weight_along_branch =
  299. branch_info.target_histogram.total_counts();
  300. // |p_branch| is the probability of following this branch.
  301. const double p_branch = weight_along_branch / total_incoming_weight;
  302. for (auto& iter : branch_info.target_histogram) {
  303. double p = iter.second / total_incoming_weight;
  304. // p*log(p) is the expected nats if the answer is |iter|. We multiply
  305. // that by the probability of being in this bucket at all.
  306. split->nats_remaining -= (p * log(p)) * p_branch;
  307. }
  308. }
  309. }
  310. void RandomTreeTrainer::ComputeSplitScore_Numeric(
  311. Split* split,
  312. double total_incoming_weight) {
  313. // Compute the nats given that we're at this node.
  314. split->nats_remaining = 0;
  315. for (auto& info_iter : split->branch_infos) {
  316. Split::BranchInfo& branch_info = info_iter.second;
  317. // |weight_along_branch| is the total weight of examples that would follow
  318. // this branch in the tree.
  319. const double weight_along_branch =
  320. branch_info.target_histogram.total_counts();
  321. // |p_branch| is the probability of following this branch.
  322. const double p_branch = weight_along_branch / total_incoming_weight;
  323. // Compute the average at this node. Note that we have no idea if the leaf
  324. // node would actually use an average, but really it should match. It would
  325. // be really nice if we could compute the value (or TargetHistogram) as
  326. // part of computing the split, and have somebody just hand that target
  327. // distribution to the leaf if it ends up as one.
  328. double average = branch_info.target_histogram.Average();
  329. for (auto& iter : branch_info.target_histogram) {
  330. // Compute the squared error for all |iter.second| counts that each have a
  331. // value of |iter.first|, when this leaf approximates them as |average|.
  332. double sq_err = (iter.first.value() - average) *
  333. (iter.first.value() - average) * iter.second;
  334. split->nats_remaining += sq_err * p_branch;
  335. }
  336. }
  337. }
  338. FeatureValue RandomTreeTrainer::FindSplitPoint_Numeric(
  339. size_t split_index,
  340. const TrainingData& training_data,
  341. const std::vector<size_t>& training_idx) {
  342. // We should not be given a training set of size 0, since there's no need to
  343. // check an empty split.
  344. DCHECK_GT(training_idx.size(), 0u);
  345. // We should either (a) choose the single best split point given all our
  346. // training data (i.e., choosing between the splits that are equally between
  347. // adjacent feature values), or (b) choose the best split point by drawing
  348. // uniformly over the range that contains our feature values. (a) is
  349. // appropriate with RandomForest, while (b) is appropriate with ExtraTrees.
  350. FeatureValue v_min = training_data[training_idx[0]].features[split_index];
  351. FeatureValue v_max = training_data[training_idx[0]].features[split_index];
  352. for (size_t idx : training_idx) {
  353. const LabelledExample& example = training_data[idx];
  354. // Get the value of the |split_index|-th feature for
  355. FeatureValue v_i = example.features[split_index];
  356. if (v_i < v_min)
  357. v_min = v_i;
  358. if (v_i > v_max)
  359. v_max = v_i;
  360. }
  361. FeatureValue v_split;
  362. if (v_max == v_min) {
  363. // Pick |v_split| to return a trivial split, so that this ends up as a
  364. // leaf node anyway.
  365. v_split = v_max;
  366. } else {
  367. // Choose a random split point. Note that we want to end up with two
  368. // buckets, so we don't have a trivial split. By picking [v_min, v_max),
  369. // |v_min| will always be in one bucket and |v_max| will always not be.
  370. v_split = FeatureValue(
  371. rng()->GenerateDouble(v_max.value() - v_min.value()) + v_min.value());
  372. }
  373. return v_split;
  374. }
  375. FeatureValue RandomTreeTrainer::FindSplitPoint_Nominal(
  376. size_t split_index,
  377. const TrainingData& training_data,
  378. const std::vector<size_t>& training_idx) {
  379. // We should not be given a training set of size 0, since there's no need to
  380. // check an empty split.
  381. DCHECK_GT(training_idx.size(), 0u);
  382. // Construct a set of all values for |training_idx|. We don't care about
  383. // their relative frequency, since one-hot encoding doesn't.
  384. // For example, if a feature has 10 "yes" instances and 1 "no" instance, then
  385. // there's a 50% chance for each to be chosen here. This is because one-hot
  386. // encoding would do roughly the same thing: when choosing features, the
  387. // "is_yes" and "is_no" features that come out of one-hot encoding would be
  388. // equally likely to be chosen.
  389. //
  390. // Important but subtle note: we can't choose a value that's been chosen
  391. // before for this feature, since that would be like splitting on the same
  392. // one-hot feature more than once. Luckily, we won't be asked to do that. If
  393. // we choose "Yes" at some level in the tree, then the "==" branch will have
  394. // trivial features which will be removed from consideration early (we never
  395. // consider features with only one value), and the != branch won't have any
  396. // "Yes" values for us to pick at a lower level.
  397. std::set<FeatureValue> values;
  398. for (size_t idx : training_idx) {
  399. const LabelledExample& example = training_data[idx];
  400. values.insert(example.features[split_index]);
  401. }
  402. // Select one uniformly at random.
  403. size_t which = rng()->Generate(values.size());
  404. auto it = values.begin();
  405. for (; which > 0; it++, which--)
  406. ;
  407. return *it;
  408. }
  409. } // namespace learning
  410. } // namespace media