learning_task_controller_impl.cc 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/learning_task_controller_impl.h"
  5. #include <memory>
  6. #include <utility>
  7. #include <vector>
  8. #include "base/bind.h"
  9. #include "base/check_op.h"
  10. #include "base/notreached.h"
  11. #include "media/learning/impl/distribution_reporter.h"
  12. #include "media/learning/impl/extra_trees_trainer.h"
  13. #include "media/learning/impl/lookup_table_trainer.h"
  14. namespace media {
  15. namespace learning {
  16. LearningTaskControllerImpl::LearningTaskControllerImpl(
  17. const LearningTask& task,
  18. std::unique_ptr<DistributionReporter> reporter,
  19. SequenceBoundFeatureProvider feature_provider)
  20. : task_(task),
  21. training_data_(std::make_unique<TrainingData>()),
  22. reporter_(std::move(reporter)),
  23. helper_(std::make_unique<LearningTaskControllerHelper>(
  24. task,
  25. base::BindRepeating(&LearningTaskControllerImpl::AddFinishedExample,
  26. AsWeakPtr()),
  27. std::move(feature_provider))),
  28. expected_feature_count_(task_.feature_descriptions.size()) {
  29. // Note that |helper_| uses the full set of features.
  30. // TODO(liberato): Make this compositional. FeatureSubsetTaskController?
  31. if (task_.feature_subset_size)
  32. DoFeatureSubsetSelection();
  33. switch (task_.model) {
  34. case LearningTask::Model::kExtraTrees:
  35. trainer_ = std::make_unique<ExtraTreesTrainer>();
  36. break;
  37. case LearningTask::Model::kLookupTable:
  38. trainer_ = std::make_unique<LookupTableTrainer>();
  39. break;
  40. }
  41. }
  42. LearningTaskControllerImpl::~LearningTaskControllerImpl() = default;
  43. void LearningTaskControllerImpl::BeginObservation(
  44. base::UnguessableToken id,
  45. const FeatureVector& features,
  46. const absl::optional<TargetValue>& default_target,
  47. const absl::optional<ukm::SourceId>& source_id) {
  48. // TODO(liberato): Should we enforce that the right number of features are
  49. // present here? Right now, we allow it to be shorter, so that features from
  50. // a FeatureProvider may be omitted. Of course, they have to be at the end in
  51. // that case. If we start enforcing it here, make sure that LearningHelper
  52. // starts adding the placeholder features.
  53. if (!trainer_)
  54. return;
  55. // We don't support default targets, since we're the base learner and can't
  56. // easily do that. However, defaults are handled by (weak) controllers
  57. // handed out by LearningSessionImpl. So, we don't bother since they never
  58. // get here anyway.
  59. DCHECK(!default_target);
  60. helper_->BeginObservation(id, features, source_id);
  61. }
  62. void LearningTaskControllerImpl::CompleteObservation(
  63. base::UnguessableToken id,
  64. const ObservationCompletion& completion) {
  65. if (!trainer_)
  66. return;
  67. helper_->CompleteObservation(id, completion);
  68. }
  69. void LearningTaskControllerImpl::CancelObservation(base::UnguessableToken id) {
  70. if (!trainer_)
  71. return;
  72. helper_->CancelObservation(id);
  73. }
  74. void LearningTaskControllerImpl::UpdateDefaultTarget(
  75. base::UnguessableToken id,
  76. const absl::optional<TargetValue>& default_target) {
  77. NOTREACHED();
  78. }
  79. const LearningTask& LearningTaskControllerImpl::GetLearningTask() {
  80. return task_;
  81. }
  82. void LearningTaskControllerImpl::PredictDistribution(
  83. const FeatureVector& features,
  84. PredictionCB callback) {
  85. if (model_)
  86. std::move(callback).Run(model_->PredictDistribution(features));
  87. else
  88. std::move(callback).Run(absl::nullopt);
  89. }
  90. void LearningTaskControllerImpl::AddFinishedExample(LabelledExample example,
  91. ukm::SourceId source_id) {
  92. // Verify that we have a trainer and that we got the right number of features.
  93. // We don't compare to |task_.feature_descriptions.size()| since that has been
  94. // adjusted to the subset size already. We expect the original count.
  95. if (!trainer_ || example.features.size() != expected_feature_count_)
  96. return;
  97. // Now that we have the whole set of features, select the subset we want.
  98. FeatureVector new_features;
  99. if (task_.feature_subset_size) {
  100. for (auto& iter : feature_indices_)
  101. new_features.push_back(example.features[iter]);
  102. example.features = std::move(new_features);
  103. } // else use them all.
  104. // The features should now match the task.
  105. DCHECK_EQ(example.features.size(), task_.feature_descriptions.size());
  106. if (training_data_->size() >= task_.max_data_set_size) {
  107. // Replace a random example. We don't necessarily want to replace the
  108. // oldest, since we don't necessarily want to enforce an ad-hoc recency
  109. // constraint here. That's a different issue.
  110. (*training_data_)[rng()->Generate(training_data_->size())] = example;
  111. } else {
  112. training_data_->push_back(example);
  113. }
  114. // Either way, we have one more example that we haven't used for training yet.
  115. num_untrained_examples_++;
  116. // Once we have a model, see if we'd get |example| correct.
  117. if (model_ && reporter_) {
  118. TargetHistogram predicted = model_->PredictDistribution(example.features);
  119. DistributionReporter::PredictionInfo info;
  120. info.observed = example.target_value;
  121. info.source_id = source_id;
  122. info.total_training_weight = last_training_weight_;
  123. info.total_training_examples = last_training_size_;
  124. reporter_->GetPredictionCallback(info).Run(predicted);
  125. }
  126. // Can't train more than one model concurrently.
  127. if (training_is_in_progress_)
  128. return;
  129. // Train every time we get enough new examples. Note that this works even if
  130. // we are replacing old examples rather than adding new ones.
  131. double frac = ((double)num_untrained_examples_) / training_data_->size();
  132. if (frac < task_.min_new_data_fraction)
  133. return;
  134. num_untrained_examples_ = 0;
  135. // Record these for metrics.
  136. last_training_weight_ = training_data_->total_weight();
  137. last_training_size_ = training_data_->size();
  138. TrainedModelCB model_cb =
  139. base::BindOnce(&LearningTaskControllerImpl::OnModelTrained, AsWeakPtr(),
  140. training_data_->total_weight(), training_data_->size());
  141. training_is_in_progress_ = true;
  142. // Note that this copies the training data, so it's okay if we add more
  143. // examples to our copy before this returns.
  144. // TODO(liberato): Post to a background task runner, and bind |model_cb| to
  145. // the current one. Be careful about ownership if we invalidate |trainer_|
  146. // on this thread. Be sure to post destruction to that sequence.
  147. trainer_->Train(task_, *training_data_, std::move(model_cb));
  148. }
  149. void LearningTaskControllerImpl::OnModelTrained(double training_weight,
  150. int training_size,
  151. std::unique_ptr<Model> model) {
  152. DCHECK(training_is_in_progress_);
  153. training_is_in_progress_ = false;
  154. model_ = std::move(model);
  155. // Record these for metrics.
  156. last_training_weight_ = training_weight;
  157. last_training_size_ = training_size;
  158. }
  159. void LearningTaskControllerImpl::SetTrainerForTesting(
  160. std::unique_ptr<TrainingAlgorithm> trainer) {
  161. trainer_ = std::move(trainer);
  162. }
  163. void LearningTaskControllerImpl::DoFeatureSubsetSelection() {
  164. // Choose a random feature, and trim the descriptions to match.
  165. std::vector<size_t> features;
  166. for (size_t i = 0; i < task_.feature_descriptions.size(); i++)
  167. features.push_back(i);
  168. for (int i = 0; i < *task_.feature_subset_size; i++) {
  169. // Pick an element from |i| to the end of the list, inclusive.
  170. // TODO(liberato): For tests, this will happen before any rng is provided
  171. // by the test; we'll use an actual rng.
  172. int r = rng()->Generate(features.size() - i) + i;
  173. // Swap them.
  174. std::swap(features[i], features[r]);
  175. }
  176. // Construct the feature subset from the first few elements. Also adjust the
  177. // task's descriptions to match. We do this in two steps so that the
  178. // descriptions are added via iterating over |feature_indices_|, so that the
  179. // enumeration order is the same as when we adjust the feature values of
  180. // incoming examples. In both cases, we iterate over |feature_indicies_|,
  181. // which might (will) re-order them with respect to |features|.
  182. for (int i = 0; i < *task_.feature_subset_size; i++)
  183. feature_indices_.insert(features[i]);
  184. std::vector<LearningTask::ValueDescription> adjusted_descriptions;
  185. for (auto& iter : feature_indices_)
  186. adjusted_descriptions.push_back(task_.feature_descriptions[iter]);
  187. task_.feature_descriptions = adjusted_descriptions;
  188. if (reporter_)
  189. reporter_->SetFeatureSubset(feature_indices_);
  190. }
  191. } // namespace learning
  192. } // namespace media