learning_session_impl.cc 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/learning/impl/learning_session_impl.h"
  5. #include <set>
  6. #include <utility>
  7. #include "base/bind.h"
  8. #include "base/check.h"
  9. #include "base/memory/raw_ptr.h"
  10. #include "media/learning/impl/distribution_reporter.h"
  11. #include "media/learning/impl/learning_task_controller_impl.h"
  12. namespace media {
  13. namespace learning {
  14. // Allow multiple clients to own an LTC that points to the same underlying LTC.
  15. // Since we don't own the LTC, we also keep track of in-flight observations and
  16. // explicitly cancel them on destruction, since dropping an LTC implies that.
  17. class WeakLearningTaskController : public LearningTaskController {
  18. public:
  19. WeakLearningTaskController(
  20. base::WeakPtr<LearningSessionImpl> weak_session,
  21. base::SequenceBound<LearningTaskController>* controller,
  22. const LearningTask& task)
  23. : weak_session_(std::move(weak_session)),
  24. controller_(controller),
  25. task_(task) {}
  26. ~WeakLearningTaskController() override {
  27. if (!weak_session_)
  28. return;
  29. // Cancel any outstanding observation, unless they have a default value. In
  30. // that case, complete them.
  31. for (auto& id : outstanding_observations_) {
  32. const absl::optional<TargetValue>& default_value = id.second;
  33. if (default_value) {
  34. controller_->AsyncCall(&LearningTaskController::CompleteObservation)
  35. .WithArgs(id.first, *default_value);
  36. } else {
  37. controller_->AsyncCall(&LearningTaskController::CancelObservation)
  38. .WithArgs(id.first);
  39. }
  40. }
  41. }
  42. void BeginObservation(
  43. base::UnguessableToken id,
  44. const FeatureVector& features,
  45. const absl::optional<TargetValue>& default_target,
  46. const absl::optional<ukm::SourceId>& source_id) override {
  47. if (!weak_session_)
  48. return;
  49. outstanding_observations_[id] = default_target;
  50. // We don't send along the default value because LearningTaskControllerImpl
  51. // doesn't support it. Since all client calls eventually come through us
  52. // anyway, it seems okay to handle it here.
  53. controller_->AsyncCall(&LearningTaskController::BeginObservation)
  54. .WithArgs(id, features, absl::nullopt, source_id);
  55. }
  56. void CompleteObservation(base::UnguessableToken id,
  57. const ObservationCompletion& completion) override {
  58. if (!weak_session_)
  59. return;
  60. outstanding_observations_.erase(id);
  61. controller_->AsyncCall(&LearningTaskController::CompleteObservation)
  62. .WithArgs(id, completion);
  63. }
  64. void CancelObservation(base::UnguessableToken id) override {
  65. if (!weak_session_)
  66. return;
  67. outstanding_observations_.erase(id);
  68. controller_->AsyncCall(&LearningTaskController::CancelObservation)
  69. .WithArgs(id);
  70. }
  71. void UpdateDefaultTarget(
  72. base::UnguessableToken id,
  73. const absl::optional<TargetValue>& default_target) override {
  74. if (!weak_session_)
  75. return;
  76. outstanding_observations_[id] = default_target;
  77. }
  78. const LearningTask& GetLearningTask() override { return task_; }
  79. void PredictDistribution(const FeatureVector& features,
  80. PredictionCB callback) override {
  81. if (!weak_session_)
  82. return;
  83. controller_->AsyncCall(&LearningTaskController::PredictDistribution)
  84. .WithArgs(features, std::move(callback));
  85. }
  86. base::WeakPtr<LearningSessionImpl> weak_session_;
  87. raw_ptr<base::SequenceBound<LearningTaskController>> controller_;
  88. LearningTask task_;
  89. // Set of ids that have been started but not completed / cancelled yet, and
  90. // any default target value.
  91. std::map<base::UnguessableToken, absl::optional<TargetValue>>
  92. outstanding_observations_;
  93. };
  94. LearningSessionImpl::LearningSessionImpl(
  95. scoped_refptr<base::SequencedTaskRunner> task_runner)
  96. : task_runner_(std::move(task_runner)),
  97. controller_factory_(base::BindRepeating(
  98. [](scoped_refptr<base::SequencedTaskRunner> task_runner,
  99. const LearningTask& task,
  100. SequenceBoundFeatureProvider feature_provider)
  101. -> base::SequenceBound<LearningTaskController> {
  102. return base::SequenceBound<LearningTaskControllerImpl>(
  103. task_runner, task, DistributionReporter::Create(task),
  104. std::move(feature_provider));
  105. })) {}
  106. LearningSessionImpl::~LearningSessionImpl() = default;
  107. void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
  108. CreateTaskControllerCB cb) {
  109. controller_factory_ = std::move(cb);
  110. }
  111. std::unique_ptr<LearningTaskController> LearningSessionImpl::GetController(
  112. const std::string& task_name) {
  113. auto iter = controller_map_.find(task_name);
  114. if (iter == controller_map_.end())
  115. return nullptr;
  116. // If there were any way to replace / destroy a controller other than when we
  117. // destroy |this|, then this wouldn't be such a good idea.
  118. return std::make_unique<WeakLearningTaskController>(
  119. weak_factory_.GetWeakPtr(), &iter->second, task_map_[task_name]);
  120. }
  121. void LearningSessionImpl::RegisterTask(
  122. const LearningTask& task,
  123. SequenceBoundFeatureProvider feature_provider) {
  124. DCHECK(controller_map_.count(task.name) == 0);
  125. controller_map_.emplace(
  126. task.name,
  127. controller_factory_.Run(task_runner_, task, std::move(feature_provider)));
  128. task_map_.emplace(task.name, task);
  129. }
  130. } // namespace learning
  131. } // namespace media