distribution_reporter_unittest.cc 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include <memory>
  5. #include <vector>
  6. #include "base/bind.h"
  7. #include "base/test/task_environment.h"
  8. #include "components/ukm/test_ukm_recorder.h"
  9. #include "media/learning/common/learning_task.h"
  10. #include "media/learning/impl/distribution_reporter.h"
  11. #include "testing/gtest/include/gtest/gtest.h"
  12. namespace media {
  13. namespace learning {
  14. class DistributionReporterTest : public testing::Test {
  15. public:
  16. DistributionReporterTest()
  17. : ukm_recorder_(std::make_unique<ukm::TestAutoSetUkmRecorder>()),
  18. source_id_(123) {
  19. task_.name = "TaskName";
  20. // UMA reporting requires a numeric target.
  21. task_.target_description.ordering = LearningTask::Ordering::kNumeric;
  22. }
  23. base::test::TaskEnvironment task_environment_;
  24. std::unique_ptr<ukm::TestAutoSetUkmRecorder> ukm_recorder_;
  25. LearningTask task_;
  26. ukm::SourceId source_id_;
  27. std::unique_ptr<DistributionReporter> reporter_;
  28. TargetHistogram HistogramFor(double value) {
  29. TargetHistogram histogram;
  30. histogram += TargetValue(value);
  31. return histogram;
  32. }
  33. };
  34. TEST_F(DistributionReporterTest, DistributionReporterDoesNotCrash) {
  35. // Make sure that we request some sort of reporting.
  36. task_.uma_hacky_aggregate_confusion_matrix = true;
  37. reporter_ = DistributionReporter::Create(task_);
  38. EXPECT_NE(reporter_, nullptr);
  39. // Observe an average of 2 / 3.
  40. DistributionReporter::PredictionInfo info;
  41. info.observed = TargetValue(2.0 / 3.0);
  42. auto cb = reporter_->GetPredictionCallback(info);
  43. TargetHistogram predicted;
  44. const TargetValue Zero(0);
  45. const TargetValue One(1);
  46. // Predict an average of 5 / 9.
  47. predicted[Zero] = 40;
  48. predicted[One] = 50;
  49. std::move(cb).Run(predicted);
  50. }
  51. TEST_F(DistributionReporterTest, CallbackRecordsRegressionPredictions) {
  52. // Make sure that |reporter_| records everything correctly for regressions.
  53. task_.target_description.ordering = LearningTask::Ordering::kNumeric;
  54. // Scale 1-2 => 0->100.
  55. task_.ukm_min_input_value = 1.;
  56. task_.ukm_max_input_value = 2.;
  57. task_.report_via_ukm = true;
  58. reporter_ = DistributionReporter::Create(task_);
  59. EXPECT_NE(reporter_, nullptr);
  60. DistributionReporter::PredictionInfo info;
  61. info.observed = TargetValue(1.1); // => 10
  62. info.source_id = source_id_;
  63. auto cb = reporter_->GetPredictionCallback(info);
  64. TargetHistogram predicted;
  65. const TargetValue One(1);
  66. const TargetValue Five(5);
  67. // Predict an average of 1.5 => 50 in the 0-100 scale.
  68. predicted[One] = 70;
  69. predicted[Five] = 10;
  70. ASSERT_EQ(predicted.Average(), 1.5);
  71. std::move(cb).Run(predicted);
  72. // The record should show the correct averages, scaled by |fixed_point_scale|.
  73. std::vector<const ukm::mojom::UkmEntry*> entries =
  74. ukm_recorder_->GetEntriesByName("Media.Learning.PredictionRecord");
  75. EXPECT_EQ(entries.size(), 1u);
  76. ukm::TestUkmRecorder::ExpectEntryMetric(entries[0], "LearningTask",
  77. task_.GetId());
  78. ukm::TestUkmRecorder::ExpectEntryMetric(entries[0], "ObservedValue", 10);
  79. ukm::TestUkmRecorder::ExpectEntryMetric(entries[0], "PredictedValue", 50);
  80. }
  81. TEST_F(DistributionReporterTest, DistributionReporterNeedsUmaNameOrUkm) {
  82. // Make sure that we don't get a reporter if we don't request any reporting.
  83. task_.target_description.ordering = LearningTask::Ordering::kNumeric;
  84. task_.uma_hacky_aggregate_confusion_matrix = false;
  85. task_.uma_hacky_by_training_weight_confusion_matrix = false;
  86. task_.uma_hacky_by_feature_subset_confusion_matrix = false;
  87. task_.report_via_ukm = false;
  88. reporter_ = DistributionReporter::Create(task_);
  89. EXPECT_EQ(reporter_, nullptr);
  90. }
  91. TEST_F(DistributionReporterTest,
  92. DistributionReporterHackyConfusionMatrixNeedsRegression) {
  93. // Hacky confusion matrix reporting only works with regression.
  94. task_.target_description.ordering = LearningTask::Ordering::kUnordered;
  95. task_.uma_hacky_aggregate_confusion_matrix = true;
  96. reporter_ = DistributionReporter::Create(task_);
  97. EXPECT_EQ(reporter_, nullptr);
  98. }
  99. TEST_F(DistributionReporterTest, ProvidesAggregateReporter) {
  100. task_.uma_hacky_aggregate_confusion_matrix = true;
  101. reporter_ = DistributionReporter::Create(task_);
  102. EXPECT_NE(reporter_, nullptr);
  103. }
  104. TEST_F(DistributionReporterTest, ProvidesByTrainingWeightReporter) {
  105. task_.uma_hacky_by_training_weight_confusion_matrix = true;
  106. reporter_ = DistributionReporter::Create(task_);
  107. EXPECT_NE(reporter_, nullptr);
  108. }
  109. TEST_F(DistributionReporterTest, ProvidesByFeatureSubsetReporter) {
  110. task_.uma_hacky_by_feature_subset_confusion_matrix = true;
  111. reporter_ = DistributionReporter::Create(task_);
  112. EXPECT_NE(reporter_, nullptr);
  113. }
  114. TEST_F(DistributionReporterTest, UkmBucketizesProperly) {
  115. task_.target_description.ordering = LearningTask::Ordering::kNumeric;
  116. // Scale [1000, 2000] => [0, 100]
  117. task_.ukm_min_input_value = 1000;
  118. task_.ukm_max_input_value = 2000;
  119. task_.report_via_ukm = true;
  120. reporter_ = DistributionReporter::Create(task_);
  121. DistributionReporter::PredictionInfo info;
  122. info.source_id = source_id_;
  123. // Add a few predictions / observations. We rotate the predicted / observed
  124. // just to be sure they end up in the right UKM field.
  125. // Inputs less than min scale to 0.
  126. info.observed = TargetValue(900);
  127. reporter_->GetPredictionCallback(info).Run(HistogramFor(1500));
  128. // Inputs exactly at min scale to 0.
  129. info.observed = TargetValue(1000);
  130. reporter_->GetPredictionCallback(info).Run(HistogramFor(2000));
  131. // Inputs in the middle scale to 50.
  132. info.observed = TargetValue(1500);
  133. reporter_->GetPredictionCallback(info).Run(HistogramFor(2100));
  134. // Inputs at max scale to 100.
  135. info.observed = TargetValue(2000);
  136. reporter_->GetPredictionCallback(info).Run(HistogramFor(900));
  137. // Inputs greater than max scale to 100.
  138. info.observed = TargetValue(2100);
  139. reporter_->GetPredictionCallback(info).Run(HistogramFor(1000));
  140. std::vector<const ukm::mojom::UkmEntry*> entries =
  141. ukm_recorder_->GetEntriesByName("Media.Learning.PredictionRecord");
  142. EXPECT_EQ(entries.size(), 5u);
  143. ukm::TestUkmRecorder::ExpectEntryMetric(entries[0], "ObservedValue", 0);
  144. ukm::TestUkmRecorder::ExpectEntryMetric(entries[0], "PredictedValue", 50);
  145. ukm::TestUkmRecorder::ExpectEntryMetric(entries[1], "ObservedValue", 0);
  146. ukm::TestUkmRecorder::ExpectEntryMetric(entries[1], "PredictedValue", 100);
  147. ukm::TestUkmRecorder::ExpectEntryMetric(entries[2], "ObservedValue", 50);
  148. ukm::TestUkmRecorder::ExpectEntryMetric(entries[2], "PredictedValue", 100);
  149. ukm::TestUkmRecorder::ExpectEntryMetric(entries[3], "ObservedValue", 100);
  150. ukm::TestUkmRecorder::ExpectEntryMetric(entries[3], "PredictedValue", 0);
  151. ukm::TestUkmRecorder::ExpectEntryMetric(entries[4], "ObservedValue", 100);
  152. ukm::TestUkmRecorder::ExpectEntryMetric(entries[4], "PredictedValue", 0);
  153. }
  154. } // namespace learning
  155. } // namespace media