learning_helper.cc 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/capabilities/learning_helper.h"
  5. #include "base/task/thread_pool.h"
  6. #include "media/learning/common/feature_library.h"
  7. #include "media/learning/common/learning_task.h"
  8. namespace media {
  9. using learning::FeatureLibrary;
  10. using learning::FeatureProviderFactoryCB;
  11. using learning::FeatureValue;
  12. using learning::LabelledExample;
  13. using learning::LearningSessionImpl;
  14. using learning::LearningTask;
  15. using learning::LearningTaskController;
  16. using learning::ObservationCompletion;
  17. using learning::SequenceBoundFeatureProvider;
  18. using learning::TargetValue;
  19. // Remember that these are used to construct UMA histogram names! Be sure to
  20. // update histograms.xml if you change them!
  21. // Dropped frame ratio, default features, unweighted regression tree.
  22. const char* const kDroppedFrameRatioBaseUnweightedTreeTaskName =
  23. "BaseUnweightedTree";
  24. // Dropped frame ratio, default features, unweighted examples, lookup table.
  25. const char* const kDroppedFrameRatioBaseUnweightedTableTaskName =
  26. "BaseUnweightedTable";
  27. // Same as BaseUnweightedTree, but with 200 training examples max.
  28. const char* const kDroppedFrameRatioBaseUnweightedTree200TaskName =
  29. "BaseUnweightedTree200";
  30. // Dropped frame ratio, default+FeatureLibrary features, regression tree with
  31. // unweighted examples and 200 training examples max.
  32. const char* const kDroppedFrameRatioEnhancedUnweightedTree200TaskName =
  33. "EnhancedUnweightedTree200";
  34. // Threshold for the dropped frame to total frame ratio, at which we'll decide
  35. // that the playback was not smooth.
  36. constexpr double kSmoothnessThreshold = 0.1;
  37. LearningHelper::LearningHelper(FeatureProviderFactoryCB feature_factory) {
  38. // Create the LearningSession on a background task runner. In the future,
  39. // it's likely that the session will live on the main thread, and handle
  40. // delegation of LearningTaskControllers to other threads. However, for now,
  41. // do it here.
  42. learning_session_ = std::make_unique<LearningSessionImpl>(
  43. base::ThreadPool::CreateSequencedTaskRunner(
  44. {base::TaskPriority::BEST_EFFORT,
  45. base::TaskShutdownBehavior::SKIP_ON_SHUTDOWN}));
  46. // Register a few learning tasks.
  47. //
  48. // We only do this here since we own the session. Normally, whatever creates
  49. // the session would register all the learning tasks.
  50. LearningTask dropped_frame_task(
  51. "no name", LearningTask::Model::kLookupTable,
  52. {
  53. {"codec_profile",
  54. ::media::learning::LearningTask::Ordering::kUnordered},
  55. {"width", ::media::learning::LearningTask::Ordering::kNumeric},
  56. {"height", ::media::learning::LearningTask::Ordering::kNumeric},
  57. {"frame_rate", ::media::learning::LearningTask::Ordering::kNumeric},
  58. },
  59. LearningTask::ValueDescription(
  60. {"dropped_ratio", LearningTask::Ordering::kNumeric}));
  61. // Report results hackily both in aggregate and by training data weight.
  62. dropped_frame_task.smoothness_threshold = kSmoothnessThreshold;
  63. dropped_frame_task.uma_hacky_aggregate_confusion_matrix = true;
  64. dropped_frame_task.uma_hacky_by_training_weight_confusion_matrix = true;
  65. // Buckets will have 10 examples each, or 20 for the 200-set tasks.
  66. const double data_set_size = 100;
  67. const double big_data_set_size = 200;
  68. // Unweighted table
  69. dropped_frame_task.name = kDroppedFrameRatioBaseUnweightedTableTaskName;
  70. dropped_frame_task.max_data_set_size = data_set_size;
  71. learning_session_->RegisterTask(dropped_frame_task,
  72. SequenceBoundFeatureProvider());
  73. base_unweighted_table_controller_ =
  74. learning_session_->GetController(dropped_frame_task.name);
  75. // Unweighted base tree.
  76. dropped_frame_task.name = kDroppedFrameRatioBaseUnweightedTreeTaskName;
  77. dropped_frame_task.model = LearningTask::Model::kExtraTrees;
  78. dropped_frame_task.max_data_set_size = data_set_size;
  79. learning_session_->RegisterTask(dropped_frame_task,
  80. SequenceBoundFeatureProvider());
  81. base_unweighted_tree_controller_ =
  82. learning_session_->GetController(dropped_frame_task.name);
  83. // Unweighted tree with a larger training set.
  84. dropped_frame_task.name = kDroppedFrameRatioBaseUnweightedTree200TaskName;
  85. dropped_frame_task.max_data_set_size = big_data_set_size;
  86. learning_session_->RegisterTask(dropped_frame_task,
  87. SequenceBoundFeatureProvider());
  88. base_unweighted_tree_200_controller_ =
  89. learning_session_->GetController(dropped_frame_task.name);
  90. // Add common features, if we have a factory.
  91. if (feature_factory) {
  92. dropped_frame_task.name =
  93. kDroppedFrameRatioEnhancedUnweightedTree200TaskName;
  94. dropped_frame_task.max_data_set_size = big_data_set_size;
  95. dropped_frame_task.feature_descriptions.push_back(
  96. {"origin", ::media::learning::LearningTask::Ordering::kUnordered});
  97. dropped_frame_task.feature_descriptions.push_back(
  98. FeatureLibrary::NetworkType());
  99. dropped_frame_task.feature_descriptions.push_back(
  100. FeatureLibrary::BatteryPower());
  101. learning_session_->RegisterTask(dropped_frame_task,
  102. feature_factory.Run(dropped_frame_task));
  103. enhanced_unweighted_tree_200_controller_ =
  104. learning_session_->GetController(dropped_frame_task.name);
  105. }
  106. }
  107. LearningHelper::~LearningHelper() = default;
  108. void LearningHelper::AppendStats(
  109. const VideoDecodeStatsDB::VideoDescKey& video_key,
  110. learning::FeatureValue origin,
  111. const VideoDecodeStatsDB::DecodeStatsEntry& new_stats) {
  112. // If no frames were recorded, then do nothing.
  113. if (new_stats.frames_decoded == 0)
  114. return;
  115. // Sanity.
  116. if (new_stats.frames_dropped > new_stats.frames_decoded)
  117. return;
  118. // Add a training example for |new_stats|.
  119. LabelledExample example;
  120. // Extract features from |video_key|.
  121. example.features.push_back(FeatureValue(video_key.codec_profile));
  122. example.features.push_back(FeatureValue(video_key.size.width()));
  123. example.features.push_back(FeatureValue(video_key.size.height()));
  124. example.features.push_back(FeatureValue(video_key.frame_rate));
  125. // Record the ratio of dropped frames to non-dropped frames. Weight this
  126. // example by the total number of frames, since we want to predict the
  127. // aggregate dropped frames ratio. That lets us compare with the current
  128. // implementation directly.
  129. //
  130. // It's also not clear that we want to do this; we might want to weight each
  131. // playback equally and predict the dropped frame ratio. For example, if
  132. // there is a dependence on video length, then it's unclear that weighting
  133. // the examples is the right thing to do.
  134. example.target_value = TargetValue(
  135. static_cast<double>(new_stats.frames_dropped) / new_stats.frames_decoded);
  136. example.weight = 1u;
  137. // Add this example to all tasks.
  138. AddExample(base_unweighted_table_controller_.get(), example);
  139. AddExample(base_unweighted_tree_controller_.get(), example);
  140. AddExample(base_unweighted_tree_200_controller_.get(), example);
  141. if (enhanced_unweighted_tree_200_controller_) {
  142. example.features.push_back(origin);
  143. AddExample(enhanced_unweighted_tree_200_controller_.get(), example);
  144. }
  145. }
  146. void LearningHelper::AddExample(LearningTaskController* controller,
  147. const LabelledExample& example) {
  148. base::UnguessableToken id = base::UnguessableToken::Create();
  149. controller->BeginObservation(id, example.features);
  150. controller->CompleteObservation(
  151. id, ObservationCompletion(example.target_value, example.weight));
  152. }
  153. } // namespace media