learning_task_controller_helper_unittest.cc 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include <memory>
  5. #include <utility>
  6. #include <vector>
  7. #include "base/bind.h"
  8. #include "base/memory/raw_ptr.h"
  9. #include "base/test/task_environment.h"
  10. #include "base/threading/sequenced_task_runner_handle.h"
  11. #include "media/learning/impl/learning_task_controller_helper.h"
  12. #include "testing/gtest/include/gtest/gtest.h"
  13. namespace media {
  14. namespace learning {
  15. class LearningTaskControllerHelperTest : public testing::Test {
  16. public:
  17. class FakeFeatureProvider : public FeatureProvider {
  18. public:
  19. FakeFeatureProvider(FeatureVector* features_out,
  20. FeatureProvider::FeatureVectorCB* cb_out)
  21. : features_out_(features_out), cb_out_(cb_out) {}
  22. // Do nothing, except note that we were called.
  23. void AddFeatures(FeatureVector features,
  24. FeatureProvider::FeatureVectorCB cb) override {
  25. *features_out_ = std::move(features);
  26. *cb_out_ = std::move(cb);
  27. }
  28. raw_ptr<FeatureVector> features_out_;
  29. raw_ptr<FeatureProvider::FeatureVectorCB> cb_out_;
  30. };
  31. LearningTaskControllerHelperTest() {
  32. task_runner_ = base::SequencedTaskRunnerHandle::Get();
  33. task_.name = "example_task";
  34. example_.features.push_back(FeatureValue(1));
  35. example_.features.push_back(FeatureValue(2));
  36. example_.features.push_back(FeatureValue(3));
  37. example_.target_value = TargetValue(123);
  38. example_.weight = 100u;
  39. id_ = base::UnguessableToken::Create();
  40. }
  41. ~LearningTaskControllerHelperTest() override {
  42. // To prevent a memory leak, reset the helper. This will post destruction
  43. // of other objects, so RunUntilIdle().
  44. helper_.reset();
  45. task_environment_.RunUntilIdle();
  46. }
  47. void CreateClient(bool include_fp) {
  48. // Create the fake feature provider, and get a pointer to it.
  49. base::SequenceBound<FakeFeatureProvider> sb_fp;
  50. if (include_fp) {
  51. sb_fp = base::SequenceBound<FakeFeatureProvider>(task_runner_,
  52. &fp_features_, &fp_cb_);
  53. task_environment_.RunUntilIdle();
  54. }
  55. // TODO(liberato): make sure this works without a fp.
  56. helper_ = std::make_unique<LearningTaskControllerHelper>(
  57. task_,
  58. base::BindRepeating(
  59. &LearningTaskControllerHelperTest::OnLabelledExample,
  60. base::Unretained(this)),
  61. std::move(sb_fp));
  62. }
  63. void OnLabelledExample(LabelledExample example, ukm::SourceId source_id) {
  64. most_recent_example_ = std::move(example);
  65. most_recent_source_id_ = source_id;
  66. }
  67. // Since we're friends but the tests aren't.
  68. size_t pending_example_count() const {
  69. return helper_->pending_example_count_for_testing();
  70. }
  71. base::test::TaskEnvironment task_environment_;
  72. scoped_refptr<base::SequencedTaskRunner> task_runner_;
  73. std::unique_ptr<LearningTaskControllerHelper> helper_;
  74. // Most recent features / cb given to our FakeFeatureProvider.
  75. FeatureVector fp_features_;
  76. FeatureProvider::FeatureVectorCB fp_cb_;
  77. // Most recently added example via OnLabelledExample, if any.
  78. absl::optional<LabelledExample> most_recent_example_;
  79. ukm::SourceId most_recent_source_id_;
  80. LearningTask task_;
  81. base::UnguessableToken id_;
  82. LabelledExample example_;
  83. };
  84. TEST_F(LearningTaskControllerHelperTest, AddingAnExampleWithoutFPWorks) {
  85. // A helper that doesn't use a FeatureProvider should forward examples as soon
  86. // as they're done.
  87. CreateClient(false);
  88. ukm::SourceId source_id = 2;
  89. helper_->BeginObservation(id_, example_.features, source_id);
  90. EXPECT_EQ(pending_example_count(), 1u);
  91. helper_->CompleteObservation(
  92. id_, ObservationCompletion(example_.target_value, example_.weight));
  93. EXPECT_TRUE(most_recent_example_);
  94. EXPECT_EQ(*most_recent_example_, example_);
  95. EXPECT_EQ(most_recent_example_->weight, example_.weight);
  96. EXPECT_EQ(most_recent_source_id_, source_id);
  97. EXPECT_EQ(pending_example_count(), 0u);
  98. }
  99. TEST_F(LearningTaskControllerHelperTest, DropTargetValueWithoutFPWorks) {
  100. // Verify that we can drop an example without labelling it.
  101. CreateClient(false);
  102. helper_->BeginObservation(id_, example_.features, absl::nullopt);
  103. EXPECT_EQ(pending_example_count(), 1u);
  104. helper_->CancelObservation(id_);
  105. task_environment_.RunUntilIdle();
  106. EXPECT_FALSE(most_recent_example_);
  107. EXPECT_EQ(pending_example_count(), 0u);
  108. }
  109. TEST_F(LearningTaskControllerHelperTest, AddTargetValueBeforeFP) {
  110. // Verify that an example is added if the target value arrives first.
  111. CreateClient(true);
  112. helper_->BeginObservation(id_, example_.features, absl::nullopt);
  113. EXPECT_EQ(pending_example_count(), 1u);
  114. task_environment_.RunUntilIdle();
  115. // The feature provider should know about the example.
  116. EXPECT_EQ(fp_features_, example_.features);
  117. // Add the targe value and verify that the example wasn't added yet.
  118. helper_->CompleteObservation(
  119. id_, ObservationCompletion(example_.target_value, example_.weight));
  120. EXPECT_FALSE(most_recent_example_);
  121. EXPECT_EQ(pending_example_count(), 1u);
  122. // Add the features, and verify that they arrive at the AddExampleCB.
  123. example_.features[0] = FeatureValue(456);
  124. std::move(fp_cb_).Run(example_.features);
  125. task_environment_.RunUntilIdle();
  126. EXPECT_EQ(pending_example_count(), 0u);
  127. EXPECT_TRUE(most_recent_example_);
  128. EXPECT_EQ(*most_recent_example_, example_);
  129. EXPECT_EQ(most_recent_example_->weight, example_.weight);
  130. }
  131. TEST_F(LearningTaskControllerHelperTest, DropTargetValueBeforeFP) {
  132. // Verify that an example is correctly dropped before the FP adds features.
  133. CreateClient(true);
  134. helper_->BeginObservation(id_, example_.features, absl::nullopt);
  135. EXPECT_EQ(pending_example_count(), 1u);
  136. task_environment_.RunUntilIdle();
  137. // The feature provider should know about the example.
  138. EXPECT_EQ(fp_features_, example_.features);
  139. // Cancel the observation.
  140. helper_->CancelObservation(id_);
  141. // We don't care if the example is still queued or not, only that we can
  142. // add features and have it be zero by then.
  143. // Add the features, and verify that the pending example is removed and no
  144. // example was sent to us.
  145. example_.features[0] = FeatureValue(456);
  146. std::move(fp_cb_).Run(example_.features);
  147. task_environment_.RunUntilIdle();
  148. EXPECT_EQ(pending_example_count(), 0u);
  149. EXPECT_FALSE(most_recent_example_);
  150. }
  151. TEST_F(LearningTaskControllerHelperTest, AddTargetValueAfterFP) {
  152. // Verify that an example is added if the target value arrives second.
  153. CreateClient(true);
  154. helper_->BeginObservation(id_, example_.features, absl::nullopt);
  155. EXPECT_EQ(pending_example_count(), 1u);
  156. task_environment_.RunUntilIdle();
  157. // The feature provider should know about the example.
  158. EXPECT_EQ(fp_features_, example_.features);
  159. EXPECT_EQ(pending_example_count(), 1u);
  160. // Add the features, and verify that the example isn't sent yet.
  161. example_.features[0] = FeatureValue(456);
  162. std::move(fp_cb_).Run(example_.features);
  163. task_environment_.RunUntilIdle();
  164. EXPECT_FALSE(most_recent_example_);
  165. EXPECT_EQ(pending_example_count(), 1u);
  166. // Add the targe value and verify that the example is added.
  167. helper_->CompleteObservation(
  168. id_, ObservationCompletion(example_.target_value, example_.weight));
  169. EXPECT_TRUE(most_recent_example_);
  170. EXPECT_EQ(*most_recent_example_, example_);
  171. EXPECT_EQ(most_recent_example_->weight, example_.weight);
  172. EXPECT_EQ(pending_example_count(), 0u);
  173. }
  174. TEST_F(LearningTaskControllerHelperTest, DropTargetValueAfterFP) {
  175. // Verify that we can cancel the observationc after sending features.
  176. CreateClient(true);
  177. helper_->BeginObservation(id_, example_.features, absl::nullopt);
  178. EXPECT_EQ(pending_example_count(), 1u);
  179. task_environment_.RunUntilIdle();
  180. // The feature provider should know about the example.
  181. EXPECT_EQ(fp_features_, example_.features);
  182. EXPECT_EQ(pending_example_count(), 1u);
  183. // Add the features, and verify that the example isn't sent yet. We do care
  184. // that the example is still pending, since we haven't actually dropped the
  185. // callback yet; we might send a TargetValue.
  186. example_.features[0] = FeatureValue(456);
  187. std::move(fp_cb_).Run(example_.features);
  188. task_environment_.RunUntilIdle();
  189. EXPECT_FALSE(most_recent_example_);
  190. EXPECT_EQ(pending_example_count(), 1u);
  191. // Cancel the observation, and verify that the pending example has been
  192. // removed, and no example was sent to us.
  193. helper_->CancelObservation(id_);
  194. task_environment_.RunUntilIdle();
  195. EXPECT_FALSE(most_recent_example_);
  196. EXPECT_EQ(pending_example_count(), 0u);
  197. }
  198. } // namespace learning
  199. } // namespace media