mojo_learning_task_controller_service_unittest.cc 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include <memory>
  5. #include <utility>
  6. #include "base/bind.h"
  7. #include "base/memory/ptr_util.h"
  8. #include "base/memory/raw_ptr.h"
  9. #include "base/test/task_environment.h"
  10. #include "base/threading/thread.h"
  11. #include "media/learning/mojo/mojo_learning_task_controller_service.h"
  12. #include "testing/gtest/include/gtest/gtest.h"
  13. namespace {
  14. // Meaningless, but non-empty, source id.
  15. ukm::SourceId kSourceId{123};
  16. } // namespace
  17. namespace media {
  18. namespace learning {
  19. class MojoLearningTaskControllerServiceTest : public ::testing::Test {
  20. public:
  21. class FakeLearningTaskController : public LearningTaskController {
  22. public:
  23. void BeginObservation(
  24. base::UnguessableToken id,
  25. const FeatureVector& features,
  26. const absl::optional<TargetValue>& default_target,
  27. const absl::optional<ukm::SourceId>& source_id) override {
  28. begin_args_.id_ = id;
  29. begin_args_.features_ = features;
  30. begin_args_.default_target_ = default_target;
  31. begin_args_.source_id_ = source_id;
  32. }
  33. void CompleteObservation(base::UnguessableToken id,
  34. const ObservationCompletion& completion) override {
  35. complete_args_.id_ = id;
  36. complete_args_.completion_ = completion;
  37. }
  38. void CancelObservation(base::UnguessableToken id) override {
  39. cancel_args_.id_ = id;
  40. }
  41. void UpdateDefaultTarget(
  42. base::UnguessableToken id,
  43. const absl::optional<TargetValue>& default_target) override {
  44. update_default_args_.id_ = id;
  45. update_default_args_.default_target_ = default_target;
  46. }
  47. const LearningTask& GetLearningTask() override {
  48. return LearningTask::Empty();
  49. }
  50. void PredictDistribution(const FeatureVector& features,
  51. PredictionCB callback) override {
  52. predict_distribution_args_.features_ = features;
  53. predict_distribution_args_.callback_ = std::move(callback);
  54. }
  55. struct {
  56. base::UnguessableToken id_;
  57. FeatureVector features_;
  58. absl::optional<TargetValue> default_target_;
  59. absl::optional<ukm::SourceId> source_id_;
  60. } begin_args_;
  61. struct {
  62. base::UnguessableToken id_;
  63. ObservationCompletion completion_;
  64. } complete_args_;
  65. struct {
  66. base::UnguessableToken id_;
  67. } cancel_args_;
  68. struct {
  69. base::UnguessableToken id_;
  70. absl::optional<TargetValue> default_target_;
  71. } update_default_args_;
  72. struct {
  73. FeatureVector features_;
  74. PredictionCB callback_;
  75. } predict_distribution_args_;
  76. };
  77. public:
  78. MojoLearningTaskControllerServiceTest() = default;
  79. ~MojoLearningTaskControllerServiceTest() override = default;
  80. void SetUp() override {
  81. std::unique_ptr<FakeLearningTaskController> controller =
  82. std::make_unique<FakeLearningTaskController>();
  83. controller_raw_ = controller.get();
  84. // Add two features.
  85. task_.feature_descriptions.push_back({});
  86. task_.feature_descriptions.push_back({});
  87. // Tell |learning_controller_| to forward to the fake learner impl.
  88. service_ = std::make_unique<MojoLearningTaskControllerService>(
  89. task_, kSourceId, std::move(controller));
  90. }
  91. LearningTask task_;
  92. // Mojo stuff.
  93. base::test::TaskEnvironment task_environment_;
  94. raw_ptr<FakeLearningTaskController> controller_raw_ = nullptr;
  95. // The learner under test.
  96. std::unique_ptr<MojoLearningTaskControllerService> service_;
  97. };
  98. TEST_F(MojoLearningTaskControllerServiceTest, BeginComplete) {
  99. base::UnguessableToken id = base::UnguessableToken::Create();
  100. FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  101. service_->BeginObservation(id, features, absl::nullopt);
  102. EXPECT_EQ(id, controller_raw_->begin_args_.id_);
  103. EXPECT_EQ(features, controller_raw_->begin_args_.features_);
  104. EXPECT_FALSE(controller_raw_->begin_args_.default_target_);
  105. EXPECT_TRUE(controller_raw_->begin_args_.source_id_);
  106. EXPECT_EQ(*controller_raw_->begin_args_.source_id_, kSourceId);
  107. ObservationCompletion completion(TargetValue(1234));
  108. service_->CompleteObservation(id, completion);
  109. EXPECT_EQ(id, controller_raw_->complete_args_.id_);
  110. EXPECT_EQ(completion.target_value,
  111. controller_raw_->complete_args_.completion_.target_value);
  112. }
  113. TEST_F(MojoLearningTaskControllerServiceTest, BeginCancel) {
  114. base::UnguessableToken id = base::UnguessableToken::Create();
  115. FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  116. service_->BeginObservation(id, features, absl::nullopt);
  117. EXPECT_EQ(id, controller_raw_->begin_args_.id_);
  118. EXPECT_EQ(features, controller_raw_->begin_args_.features_);
  119. EXPECT_FALSE(controller_raw_->begin_args_.default_target_);
  120. service_->CancelObservation(id);
  121. EXPECT_EQ(id, controller_raw_->cancel_args_.id_);
  122. }
  123. TEST_F(MojoLearningTaskControllerServiceTest, BeginWithDefaultTarget) {
  124. base::UnguessableToken id = base::UnguessableToken::Create();
  125. FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  126. TargetValue default_target(987);
  127. service_->BeginObservation(id, features, default_target);
  128. EXPECT_EQ(id, controller_raw_->begin_args_.id_);
  129. EXPECT_EQ(features, controller_raw_->begin_args_.features_);
  130. EXPECT_EQ(default_target, controller_raw_->begin_args_.default_target_);
  131. EXPECT_TRUE(controller_raw_->begin_args_.source_id_);
  132. EXPECT_EQ(*controller_raw_->begin_args_.source_id_, kSourceId);
  133. }
  134. TEST_F(MojoLearningTaskControllerServiceTest, TooFewFeaturesIsIgnored) {
  135. // A FeatureVector with too few elements should be ignored.
  136. base::UnguessableToken id = base::UnguessableToken::Create();
  137. FeatureVector short_features = {FeatureValue(123)};
  138. service_->BeginObservation(id, short_features, absl::nullopt);
  139. EXPECT_NE(id, controller_raw_->begin_args_.id_);
  140. EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
  141. }
  142. TEST_F(MojoLearningTaskControllerServiceTest, TooManyFeaturesIsIgnored) {
  143. // A FeatureVector with too many elements should be ignored.
  144. base::UnguessableToken id = base::UnguessableToken::Create();
  145. FeatureVector long_features = {FeatureValue(123), FeatureValue(456),
  146. FeatureValue(789)};
  147. service_->BeginObservation(id, long_features, absl::nullopt);
  148. EXPECT_NE(id, controller_raw_->begin_args_.id_);
  149. EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
  150. }
  151. TEST_F(MojoLearningTaskControllerServiceTest, CompleteWithoutBeginFails) {
  152. base::UnguessableToken id = base::UnguessableToken::Create();
  153. ObservationCompletion completion(TargetValue(1234));
  154. service_->CompleteObservation(id, completion);
  155. EXPECT_NE(id, controller_raw_->complete_args_.id_);
  156. }
  157. TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) {
  158. base::UnguessableToken id = base::UnguessableToken::Create();
  159. service_->CancelObservation(id);
  160. EXPECT_NE(id, controller_raw_->cancel_args_.id_);
  161. }
  162. TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToValue) {
  163. base::UnguessableToken id = base::UnguessableToken::Create();
  164. FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  165. service_->BeginObservation(id, features, absl::nullopt);
  166. TargetValue default_target(987);
  167. service_->UpdateDefaultTarget(id, default_target);
  168. EXPECT_EQ(id, controller_raw_->update_default_args_.id_);
  169. EXPECT_EQ(default_target,
  170. controller_raw_->update_default_args_.default_target_);
  171. }
  172. TEST_F(MojoLearningTaskControllerServiceTest, UpdateDefaultTargetToNoValue) {
  173. base::UnguessableToken id = base::UnguessableToken::Create();
  174. FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  175. TargetValue default_target(987);
  176. service_->BeginObservation(id, features, default_target);
  177. service_->UpdateDefaultTarget(id, absl::nullopt);
  178. EXPECT_EQ(id, controller_raw_->update_default_args_.id_);
  179. EXPECT_EQ(absl::nullopt,
  180. controller_raw_->update_default_args_.default_target_);
  181. }
  182. TEST_F(MojoLearningTaskControllerServiceTest, PredictDistribution) {
  183. FeatureVector features = {FeatureValue(123), FeatureValue(456)};
  184. TargetHistogram observed_prediction;
  185. service_->PredictDistribution(
  186. features, base::BindOnce(
  187. [](TargetHistogram* test_storage,
  188. const absl::optional<TargetHistogram>& predicted) {
  189. *test_storage = *predicted;
  190. },
  191. &observed_prediction));
  192. EXPECT_EQ(features, controller_raw_->predict_distribution_args_.features_);
  193. EXPECT_FALSE(controller_raw_->predict_distribution_args_.callback_.is_null());
  194. TargetHistogram expected_prediction;
  195. expected_prediction[TargetValue(1)] = 1.0;
  196. expected_prediction[TargetValue(2)] = 2.0;
  197. expected_prediction[TargetValue(3)] = 3.0;
  198. std::move(controller_raw_->predict_distribution_args_.callback_)
  199. .Run(expected_prediction);
  200. EXPECT_EQ(expected_prediction, observed_prediction);
  201. }
  202. } // namespace learning
  203. } // namespace media