learning_task_controller_helper.cc 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/learning_task_controller_helper.h"
  5. #include <memory>
  6. #include <utility>
  7. #include "base/bind.h"
  8. #include "base/threading/sequenced_task_runner_handle.h"
  9. namespace media {
  10. namespace learning {
  11. LearningTaskControllerHelper::LearningTaskControllerHelper(
  12. const LearningTask& task,
  13. AddExampleCB add_example_cb,
  14. SequenceBoundFeatureProvider feature_provider)
  15. : task_(task),
  16. feature_provider_(std::move(feature_provider)),
  17. task_runner_(base::SequencedTaskRunnerHandle::Get()),
  18. add_example_cb_(std::move(add_example_cb)) {}
  19. LearningTaskControllerHelper::~LearningTaskControllerHelper() = default;
  20. void LearningTaskControllerHelper::BeginObservation(
  21. base::UnguessableToken id,
  22. FeatureVector features,
  23. absl::optional<ukm::SourceId> source_id) {
  24. auto& pending_example = pending_examples_[id];
  25. if (source_id)
  26. pending_example.source_id = *source_id;
  27. // Start feature prediction, so that we capture the current values.
  28. if (!feature_provider_.is_null()) {
  29. // TODO(dcheng): Convert this to use Then() helper.
  30. feature_provider_.AsyncCall(&FeatureProvider::AddFeatures)
  31. .WithArgs(std::move(features),
  32. base::BindOnce(
  33. &LearningTaskControllerHelper::OnFeaturesReadyTrampoline,
  34. task_runner_, AsWeakPtr(), id));
  35. } else {
  36. pending_example.example.features = std::move(features);
  37. pending_example.features_done = true;
  38. }
  39. }
  40. void LearningTaskControllerHelper::CompleteObservation(
  41. base::UnguessableToken id,
  42. const ObservationCompletion& completion) {
  43. auto iter = pending_examples_.find(id);
  44. if (iter == pending_examples_.end())
  45. return;
  46. iter->second.example.target_value = completion.target_value;
  47. iter->second.example.weight = completion.weight;
  48. iter->second.target_done = true;
  49. ProcessExampleIfFinished(std::move(iter));
  50. }
  51. void LearningTaskControllerHelper::CancelObservation(
  52. base::UnguessableToken id) {
  53. auto iter = pending_examples_.find(id);
  54. if (iter == pending_examples_.end())
  55. return;
  56. // This would have to check for pending predictions, if we supported them, and
  57. // defer destruction until the features arrive.
  58. pending_examples_.erase(iter);
  59. }
  60. // static
  61. void LearningTaskControllerHelper::OnFeaturesReadyTrampoline(
  62. scoped_refptr<base::SequencedTaskRunner> task_runner,
  63. base::WeakPtr<LearningTaskControllerHelper> weak_this,
  64. base::UnguessableToken id,
  65. FeatureVector features) {
  66. // TODO(liberato): this would benefit from promises / deferred data.
  67. auto cb = base::BindOnce(&LearningTaskControllerHelper::OnFeaturesReady,
  68. std::move(weak_this), id, std::move(features));
  69. if (!task_runner->RunsTasksInCurrentSequence()) {
  70. task_runner->PostTask(FROM_HERE, std::move(cb));
  71. } else {
  72. std::move(cb).Run();
  73. }
  74. }
  75. void LearningTaskControllerHelper::OnFeaturesReady(base::UnguessableToken id,
  76. FeatureVector features) {
  77. PendingExampleMap::iterator iter = pending_examples_.find(id);
  78. // It's possible that OnLabelCallbackDestroyed has already run. That's okay
  79. // since we don't support prediction right now.
  80. if (iter == pending_examples_.end())
  81. return;
  82. iter->second.example.features = std::move(features);
  83. iter->second.features_done = true;
  84. ProcessExampleIfFinished(std::move(iter));
  85. }
  86. void LearningTaskControllerHelper::ProcessExampleIfFinished(
  87. PendingExampleMap::iterator iter) {
  88. if (!iter->second.features_done || !iter->second.target_done)
  89. return;
  90. add_example_cb_.Run(std::move(iter->second.example), iter->second.source_id);
  91. pending_examples_.erase(iter);
  92. // TODO(liberato): If we receive FeatureVector f1 then f2, and start filling
  93. // in features for a prediction, and if features become available in the order
  94. // f2, f1, and we receive a target value for f2 before f1's features are
  95. // complete, should we insist on deferring training with f2 until we start
  96. // prediction on f1? I suppose that we could just insist that features are
  97. // provided in the same order they're received, and it's automatic.
  98. }
  99. } // namespace learning
  100. } // namespace media