segmentation_ukm_helper_unittest.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
  5. #include <cmath>
  6. #include "base/bit_cast.h"
  7. #include "base/test/metrics/histogram_tester.h"
  8. #include "base/test/scoped_feature_list.h"
  9. #include "base/test/simple_test_clock.h"
  10. #include "base/test/task_environment.h"
  11. #include "components/prefs/testing_pref_service.h"
  12. #include "components/segmentation_platform/internal/constants.h"
  13. #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
  14. #include "components/segmentation_platform/public/config.h"
  15. #include "components/segmentation_platform/public/features.h"
  16. #include "components/segmentation_platform/public/local_state_helper.h"
  17. #include "components/segmentation_platform/public/segmentation_platform_service.h"
  18. #include "components/ukm/test_ukm_recorder.h"
  19. #include "services/metrics/public/cpp/ukm_builders.h"
  20. #include "testing/gtest/include/gtest/gtest.h"
  21. using Segmentation_ModelExecution = ukm::builders::Segmentation_ModelExecution;
  22. namespace segmentation_platform {
  23. namespace {
  24. // Round errors allowed during conversion.
  25. static const double kRoundingError = 1E-5;
  26. float Int64ToFloat(int64_t encoded) {
  27. return static_cast<float>(base::bit_cast<double>(encoded));
  28. }
  29. void CompareEncodeDecodeDifference(float tensor) {
  30. ASSERT_LT(
  31. std::abs(tensor -
  32. Int64ToFloat(
  33. segmentation_platform::SegmentationUkmHelper::FloatToInt64(
  34. tensor))),
  35. kRoundingError);
  36. }
  37. absl::optional<proto::PredictionResult> GetPredictionResult() {
  38. proto::PredictionResult result;
  39. result.set_result(0.5);
  40. return result;
  41. }
  42. } // namespace
  43. class SegmentationUkmHelperTest : public testing::Test {
  44. public:
  45. SegmentationUkmHelperTest() = default;
  46. SegmentationUkmHelperTest(const SegmentationUkmHelperTest&) = delete;
  47. SegmentationUkmHelperTest& operator=(const SegmentationUkmHelperTest&) =
  48. delete;
  49. ~SegmentationUkmHelperTest() override = default;
  50. void SetUp() override { test_recorder_.Purge(); }
  51. void ExpectUkmMetrics(const base::StringPiece entry_name,
  52. const std::vector<base::StringPiece>& keys,
  53. const std::vector<int64_t>& values) {
  54. const auto& entries = test_recorder_.GetEntriesByName(entry_name);
  55. EXPECT_EQ(1u, entries.size());
  56. for (const auto* entry : entries) {
  57. const size_t keys_size = keys.size();
  58. EXPECT_EQ(keys_size, values.size());
  59. for (size_t i = 0; i < keys_size; ++i) {
  60. test_recorder_.ExpectEntryMetric(entry, keys[i], values[i]);
  61. }
  62. }
  63. }
  64. void ExpectEmptyUkmMetrics(const base::StringPiece entry_name) {
  65. EXPECT_EQ(0u, test_recorder_.GetEntriesByName(entry_name).size());
  66. }
  67. void InitializeAllowedSegmentIds(const std::string& allowed_ids) {
  68. std::map<std::string, std::string> params = {
  69. {kSegmentIdsAllowedForReportingKey, allowed_ids}};
  70. feature_list_.InitAndEnableFeatureWithParameters(
  71. features::kSegmentationStructuredMetricsFeature, params);
  72. SegmentationUkmHelper::GetInstance()->Initialize();
  73. }
  74. void DisableStructureMetrics() {
  75. feature_list_.InitAndDisableFeature(
  76. features::kSegmentationStructuredMetricsFeature);
  77. SegmentationUkmHelper::GetInstance()->Initialize();
  78. }
  79. protected:
  80. base::test::TaskEnvironment task_environment_;
  81. ukm::TestAutoSetUkmRecorder test_recorder_;
  82. base::test::ScopedFeatureList feature_list_;
  83. };
  84. // Tests that basic execution results recording works properly.
  85. TEST_F(SegmentationUkmHelperTest, TestExecutionResultReporting) {
  86. // Allow results for OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB to be recorded.
  87. InitializeAllowedSegmentIds("4");
  88. std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
  89. SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
  90. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
  91. ExpectUkmMetrics(Segmentation_ModelExecution::kEntryName,
  92. {Segmentation_ModelExecution::kOptimizationTargetName,
  93. Segmentation_ModelExecution::kModelVersionName,
  94. Segmentation_ModelExecution::kInput0Name,
  95. Segmentation_ModelExecution::kInput1Name,
  96. Segmentation_ModelExecution::kInput2Name,
  97. Segmentation_ModelExecution::kInput3Name,
  98. Segmentation_ModelExecution::kPredictionResultName},
  99. {
  100. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
  101. 101,
  102. SegmentationUkmHelper::FloatToInt64(0.1),
  103. SegmentationUkmHelper::FloatToInt64(0.7),
  104. SegmentationUkmHelper::FloatToInt64(0.8),
  105. SegmentationUkmHelper::FloatToInt64(0.5),
  106. SegmentationUkmHelper::FloatToInt64(0.6),
  107. });
  108. }
  109. // Tests that the training data collection recording works properly.
  110. TEST_F(SegmentationUkmHelperTest, TestTrainingDataCollectionReporting) {
  111. // Allow results for OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB to be recorded.
  112. InitializeAllowedSegmentIds("4");
  113. std::vector<float> input_tensors = {0.1};
  114. std::vector<float> outputs = {1.0, 0.0};
  115. std::vector<int> output_indexes = {2, 3};
  116. SelectedSegment selected_segment(
  117. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
  118. selected_segment.selection_time = base::Time::Now() - base::Seconds(10);
  119. SegmentationUkmHelper::GetInstance()->RecordTrainingData(
  120. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
  121. outputs, output_indexes, GetPredictionResult(), selected_segment);
  122. ExpectUkmMetrics(Segmentation_ModelExecution::kEntryName,
  123. {Segmentation_ModelExecution::kOptimizationTargetName,
  124. Segmentation_ModelExecution::kModelVersionName,
  125. Segmentation_ModelExecution::kInput0Name,
  126. Segmentation_ModelExecution::kActualResult3Name,
  127. Segmentation_ModelExecution::kActualResult4Name,
  128. Segmentation_ModelExecution::kPredictionResultName,
  129. Segmentation_ModelExecution::kSelectionResultName,
  130. Segmentation_ModelExecution::kOutputDelaySecName},
  131. {
  132. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
  133. 101,
  134. SegmentationUkmHelper::FloatToInt64(0.1),
  135. SegmentationUkmHelper::FloatToInt64(1.0),
  136. SegmentationUkmHelper::FloatToInt64(0.0),
  137. SegmentationUkmHelper::FloatToInt64(0.5),
  138. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB,
  139. 10,
  140. });
  141. }
  142. // Tests that recording is disabled if kSegmentationStructuredMetricsFeature
  143. // is disabled.
  144. TEST_F(SegmentationUkmHelperTest, TestDisabledStructuredMetrics) {
  145. DisableStructureMetrics();
  146. std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
  147. SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
  148. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
  149. ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
  150. }
  151. // Tests that recording is disabled for segment IDs that are not in the allowed
  152. // list.
  153. TEST_F(SegmentationUkmHelperTest, TestNotAllowedSegmentId) {
  154. InitializeAllowedSegmentIds("7, 8");
  155. std::vector<float> input_tensors = {0.1, 0.7, 0.8, 0.5};
  156. SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
  157. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors, 0.6);
  158. ExpectEmptyUkmMetrics(Segmentation_ModelExecution::kEntryName);
  159. }
  160. // Tests that float encoding works properly.
  161. TEST_F(SegmentationUkmHelperTest, TestFloatEncoding) {
  162. // Compare the numbers with their IEEE754 binary representations in double.
  163. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(0.5), 0x3FE0000000000000);
  164. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(0.25), 0x3FD0000000000000);
  165. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(0.125), 0x3FC0000000000000);
  166. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(0.75), 0x3FE8000000000000);
  167. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(1), 0x3FF0000000000000);
  168. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(0), 0);
  169. ASSERT_EQ(SegmentationUkmHelper::FloatToInt64(10), 0x4024000000000000);
  170. }
  171. // Tests that floats encoded can be properly decoded later.
  172. TEST_F(SegmentationUkmHelperTest, FloatEncodeDeocode) {
  173. CompareEncodeDecodeDifference(0.1);
  174. CompareEncodeDecodeDifference(0.5);
  175. CompareEncodeDecodeDifference(0.88);
  176. CompareEncodeDecodeDifference(0.01);
  177. ASSERT_EQ(0, Int64ToFloat(SegmentationUkmHelper::FloatToInt64(0)));
  178. ASSERT_EQ(1, Int64ToFloat(SegmentationUkmHelper::FloatToInt64(1)));
  179. }
  180. // Tests that there are too many input tensors to record.
  181. TEST_F(SegmentationUkmHelperTest, TooManyInputTensors) {
  182. base::HistogramTester tester;
  183. std::string histogram_name(
  184. "SegmentationPlatform.StructuredMetrics.TooManyTensors.Count");
  185. InitializeAllowedSegmentIds("4");
  186. std::vector<float> input_tensors(100, 0.1);
  187. ukm::SourceId source_id =
  188. SegmentationUkmHelper::GetInstance()->RecordModelExecutionResult(
  189. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
  190. 0.6);
  191. ASSERT_EQ(source_id, ukm::kInvalidSourceId);
  192. tester.ExpectTotalCount(histogram_name, 1);
  193. ASSERT_EQ(tester.GetTotalSum(histogram_name), 100);
  194. }
  195. // Tests output validation for |RecordTrainingData|.
  196. TEST_F(SegmentationUkmHelperTest, OutputsValidation) {
  197. InitializeAllowedSegmentIds("4");
  198. std::vector<float> input_tensors{0.1};
  199. // outputs, output_indexes size doesn't match.
  200. std::vector<float> outputs{1.0, 0.0};
  201. std::vector<int> output_indexes{0};
  202. ukm::SourceId source_id =
  203. SegmentationUkmHelper::GetInstance()->RecordTrainingData(
  204. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
  205. outputs, output_indexes, GetPredictionResult(), absl::nullopt);
  206. ASSERT_EQ(source_id, ukm::kInvalidSourceId);
  207. // output_indexes value too large.
  208. output_indexes = {100, 1000};
  209. source_id = SegmentationUkmHelper::GetInstance()->RecordTrainingData(
  210. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
  211. outputs, output_indexes, GetPredictionResult(), absl::nullopt);
  212. ASSERT_EQ(source_id, ukm::kInvalidSourceId);
  213. // Valid outputs.
  214. output_indexes = {3, 0};
  215. source_id = SegmentationUkmHelper::GetInstance()->RecordTrainingData(
  216. proto::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB, 101, input_tensors,
  217. outputs, output_indexes, GetPredictionResult(), absl::nullopt);
  218. ASSERT_NE(source_id, ukm::kInvalidSourceId);
  219. }
  220. TEST_F(SegmentationUkmHelperTest, AllowedToUploadData) {
  221. TestingPrefServiceSimple prefs;
  222. SegmentationPlatformService::RegisterLocalStatePrefs(prefs.registry());
  223. LocalStateHelper::GetInstance().Initialize(&prefs);
  224. base::SimpleTestClock clock;
  225. clock.SetNow(base::Time::Now());
  226. // If pref is not initialized, AllowedToUploadData() always return false.
  227. ASSERT_FALSE(
  228. SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
  229. LocalStateHelper::GetInstance().SetPrefTime(
  230. kSegmentationUkmMostRecentAllowedTimeKey, clock.Now());
  231. ASSERT_FALSE(
  232. SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
  233. clock.Advance(base::Seconds(10));
  234. ASSERT_TRUE(
  235. SegmentationUkmHelper::AllowedToUploadData(base::Seconds(1), &clock));
  236. ASSERT_FALSE(
  237. SegmentationUkmHelper::AllowedToUploadData(base::Seconds(11), &clock));
  238. }
  239. } // namespace segmentation_platform