ukm_data_manager_impl_unittest.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/segmentation_platform/internal/ukm_data_manager_impl.h"
  5. #include "base/files/scoped_temp_dir.h"
  6. #include "base/metrics/metrics_hashes.h"
  7. #include "base/test/task_environment.h"
  8. #include "components/history/core/browser/history_service.h"
  9. #include "components/history/core/test/history_service_test_util.h"
  10. #include "components/prefs/testing_pref_service.h"
  11. #include "components/segmentation_platform/internal/database/mock_ukm_database.h"
  12. #include "components/segmentation_platform/internal/database/ukm_types.h"
  13. #include "components/segmentation_platform/internal/execution/model_execution_manager_impl.h"
  14. #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
  15. #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
  16. #include "components/segmentation_platform/internal/segmentation_platform_service_test_base.h"
  17. #include "components/segmentation_platform/internal/signals/ukm_observer.h"
  18. #include "components/segmentation_platform/public/local_state_helper.h"
  19. #include "components/ukm/test_ukm_recorder.h"
  20. #include "services/metrics/public/cpp/ukm_builders.h"
  21. #include "testing/gmock/include/gmock/gmock.h"
  22. #include "testing/gtest/include/gtest/gtest.h"
  23. namespace segmentation_platform {
  24. namespace {
  25. using testing::_;
  26. using ukm::builders::PageLoad;
  27. using ukm::builders::PaintPreviewCapture;
  28. constexpr ukm::SourceId kSourceId = 10;
  29. constexpr ukm::SourceId kSourceId2 = 12;
  30. ukm::mojom::UkmEntryPtr GetSamplePageLoadEntry(
  31. ukm::SourceId source_id = kSourceId) {
  32. ukm::mojom::UkmEntryPtr entry = ukm::mojom::UkmEntry::New();
  33. entry->source_id = source_id;
  34. entry->event_hash = PageLoad::kEntryNameHash;
  35. entry->metrics[PageLoad::kCpuTimeNameHash] = 10;
  36. entry->metrics[PageLoad::kIsNewBookmarkNameHash] = 20;
  37. entry->metrics[PageLoad::kIsNTPCustomLinkNameHash] = 30;
  38. return entry;
  39. }
  40. ukm::mojom::UkmEntryPtr GetSamplePaintPreviewEntry(
  41. ukm::SourceId source_id = kSourceId) {
  42. ukm::mojom::UkmEntryPtr entry = ukm::mojom::UkmEntry::New();
  43. entry->source_id = source_id;
  44. entry->event_hash = PaintPreviewCapture::kEntryNameHash;
  45. entry->metrics[PaintPreviewCapture::kBlinkCaptureTimeNameHash] = 5;
  46. entry->metrics[PaintPreviewCapture::kCompressedOnDiskSizeNameHash] = 15;
  47. return entry;
  48. }
  49. proto::SegmentationModelMetadata PageLoadModelMetadata() {
  50. proto::SegmentationModelMetadata metadata;
  51. metadata.set_time_unit(proto::TimeUnit::DAY);
  52. metadata.set_bucket_duration(42u);
  53. auto* feature = metadata.add_input_features();
  54. auto* sql_feature = feature->mutable_sql_feature();
  55. sql_feature->set_sql("SELECT COUNT(*) from metrics;");
  56. auto* ukm_event = sql_feature->mutable_signal_filter()->add_ukm_events();
  57. ukm_event->set_event_hash(PageLoad::kEntryNameHash);
  58. ukm_event->add_metric_hash_filter(PageLoad::kCpuTimeNameHash);
  59. ukm_event->add_metric_hash_filter(PageLoad::kIsNewBookmarkNameHash);
  60. return metadata;
  61. }
  62. proto::SegmentationModelMetadata PaintPreviewModelMetadata() {
  63. proto::SegmentationModelMetadata metadata;
  64. metadata.set_time_unit(proto::TimeUnit::DAY);
  65. metadata.set_bucket_duration(42u);
  66. auto* feature = metadata.add_input_features();
  67. auto* sql_feature = feature->mutable_sql_feature();
  68. sql_feature->set_sql("SELECT COUNT(*) from metrics;");
  69. auto* ukm_event2 = sql_feature->mutable_signal_filter()->add_ukm_events();
  70. ukm_event2->set_event_hash(PaintPreviewCapture::kEntryNameHash);
  71. ukm_event2->add_metric_hash_filter(
  72. PaintPreviewCapture::kBlinkCaptureTimeNameHash);
  73. return metadata;
  74. }
  75. } // namespace
  76. class TestServicesForPlatform : public SegmentationPlatformServiceTestBase {
  77. public:
  78. explicit TestServicesForPlatform(UkmDataManagerImpl* ukm_data_manager) {
  79. EXPECT_TRUE(profile_dir.CreateUniqueTempDir());
  80. history_service = history::CreateHistoryService(profile_dir.GetPath(),
  81. /*create_db=*/true);
  82. InitPlatform(ukm_data_manager, history_service.get());
  83. segment_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
  84. signal_db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
  85. segment_storage_config_db_->InitStatusCallback(
  86. leveldb_proto::Enums::InitStatus::kOK);
  87. segment_storage_config_db_->LoadCallback(true);
  88. // If initialization is succeeded, model execution scheduler should start
  89. // querying segment db.
  90. segment_db_->LoadCallback(true);
  91. }
  92. ~TestServicesForPlatform() override {
  93. DestroyPlatform();
  94. history_service.reset();
  95. }
  96. void AddModel(const proto::SegmentationModelMetadata& metadata) {
  97. auto& callback = model_provider_data_.model_providers_callbacks
  98. [SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE];
  99. callback.Run(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE, metadata,
  100. 0);
  101. segment_db_->GetCallback(true);
  102. segment_db_->UpdateCallback(true);
  103. segment_db_->LoadCallback(true);
  104. base::RunLoop().RunUntilIdle();
  105. }
  106. SegmentationPlatformServiceImpl& platform() {
  107. return *segmentation_platform_service_impl_;
  108. }
  109. void SaveSegmentResult(SegmentId segment_id,
  110. absl::optional<proto::PredictionResult> result) {
  111. const std::string key = base::NumberToString(static_cast<int>(segment_id));
  112. auto& segment_info = segment_db_entries_[key];
  113. // Assume that test already created the segment info, this method only
  114. // writes result.
  115. ASSERT_EQ(segment_info.segment_id(), segment_id);
  116. if (result) {
  117. *segment_info.mutable_prediction_result() = std::move(*result);
  118. } else {
  119. segment_info.clear_prediction_result();
  120. }
  121. }
  122. bool HasSegmentResult(SegmentId segment_id) {
  123. const std::string key = base::NumberToString(static_cast<int>(segment_id));
  124. const auto it = segment_db_entries_.find(key);
  125. if (it == segment_db_entries_.end())
  126. return false;
  127. return it->second.has_prediction_result();
  128. }
  129. base::ScopedTempDir profile_dir;
  130. std::unique_ptr<history::HistoryService> history_service;
  131. };
  132. class UkmDataManagerImplTest : public testing::Test {
  133. public:
  134. UkmDataManagerImplTest() = default;
  135. ~UkmDataManagerImplTest() override = default;
  136. void SetUp() override {
  137. SegmentationPlatformService::RegisterLocalStatePrefs(prefs_.registry());
  138. LocalStateHelper::GetInstance().Initialize(&prefs_);
  139. data_manager_ = std::make_unique<UkmDataManagerImpl>();
  140. ukm_recorder_ = std::make_unique<ukm::TestUkmRecorder>();
  141. auto ukm_db = std::make_unique<MockUkmDatabase>();
  142. ukm_database_ = ukm_db.get();
  143. ukm_observer_ = std::make_unique<UkmObserver>(ukm_recorder_.get());
  144. data_manager_->InitializeForTesting(std::move(ukm_db), ukm_observer_.get());
  145. }
  146. void TearDown() override {
  147. ukm_observer_.reset();
  148. ukm_recorder_.reset();
  149. ukm_database_ = nullptr;
  150. data_manager_.reset();
  151. }
  152. void RecordUkmAndWaitForDatabase(ukm::mojom::UkmEntryPtr entry) {}
  153. TestServicesForPlatform& CreatePlatform() {
  154. platform_services_.push_back(
  155. std::make_unique<TestServicesForPlatform>(data_manager_.get()));
  156. return *platform_services_.back();
  157. }
  158. void RemovePlatform(const TestServicesForPlatform* platform) {
  159. auto it = platform_services_.begin();
  160. while (it != platform_services_.end()) {
  161. if (it->get() == platform) {
  162. platform_services_.erase(it);
  163. return;
  164. }
  165. it++;
  166. }
  167. }
  168. protected:
  169. // Use system time to avoid history service expiration tasks to go into an
  170. // infinite loop.
  171. base::test::TaskEnvironment task_environment_{
  172. base::test::TaskEnvironment::TimeSource::SYSTEM_TIME};
  173. std::unique_ptr<UkmObserver> ukm_observer_;
  174. std::unique_ptr<ukm::TestUkmRecorder> ukm_recorder_;
  175. raw_ptr<MockUkmDatabase> ukm_database_;
  176. std::unique_ptr<UkmDataManagerImpl> data_manager_;
  177. std::vector<std::unique_ptr<TestServicesForPlatform>> platform_services_;
  178. std::vector<ukm::mojom::UkmEntryPtr> db_entries_;
  179. TestingPrefServiceSimple prefs_;
  180. };
  181. MATCHER_P(HasEventHash, event_hash, "") {
  182. return arg->event_hash == event_hash;
  183. }
  184. TEST_F(UkmDataManagerImplTest, HistoryNotification) {
  185. const GURL kUrl1 = GURL("https://www.url1.com/");
  186. const SegmentId kSegmentId =
  187. SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
  188. TestServicesForPlatform& platform1 = CreatePlatform();
  189. platform1.AddModel(PageLoadModelMetadata());
  190. proto::PredictionResult prediction_result;
  191. prediction_result.set_result(10);
  192. prediction_result.set_timestamp_us(1000);
  193. platform1.SaveSegmentResult(kSegmentId, prediction_result);
  194. EXPECT_TRUE(platform1.HasSegmentResult(kSegmentId));
  195. // Add a page to history and check that the notification is sent to
  196. // UkmDatabase. All notifications should be sent.
  197. base::RunLoop wait_for_add1;
  198. EXPECT_CALL(*ukm_database_, OnUrlValidated(kUrl1))
  199. .WillOnce([&wait_for_add1]() { wait_for_add1.QuitClosure().Run(); });
  200. platform1.history_service->AddPage(kUrl1, base::Time::Now(),
  201. history::VisitSource::SOURCE_BROWSED);
  202. wait_for_add1.Run();
  203. platform1.history_service->DeleteURLs({kUrl1});
  204. // Check that RemoveUrls() notification is sent to UkmDatabase.
  205. base::RunLoop wait_for_remove1;
  206. EXPECT_CALL(*ukm_database_, RemoveUrls(std::vector({kUrl1}), false))
  207. .WillOnce(
  208. [&wait_for_remove1]() { wait_for_remove1.QuitClosure().Run(); });
  209. wait_for_remove1.Run();
  210. // Run segment info callbacks that were posted to remove results.
  211. platform1.segment_db().GetCallback(true);
  212. platform1.segment_db().UpdateCallback(true);
  213. // History based segment results should be removed.
  214. EXPECT_FALSE(platform1.HasSegmentResult(
  215. SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE));
  216. RemovePlatform(&platform1);
  217. }
  218. TEST_F(UkmDataManagerImplTest, UkmSourceObservation) {
  219. const GURL kUrl1 = GURL("https://www.url1.com/");
  220. // Create a platform that observes PageLoad events.
  221. TestServicesForPlatform& platform1 = CreatePlatform();
  222. platform1.AddModel(PageLoadModelMetadata());
  223. // Source updates are notified to the database.
  224. base::RunLoop wait_for_source;
  225. EXPECT_CALL(*ukm_database_,
  226. UpdateUrlForUkmSource(kSourceId, kUrl1, /*is_validated=*/false))
  227. .WillOnce([&wait_for_source](ukm::SourceId source_id, const GURL& url,
  228. bool is_validated) {
  229. wait_for_source.QuitClosure().Run();
  230. });
  231. ukm_recorder_->UpdateSourceURL(kSourceId, kUrl1);
  232. wait_for_source.Run();
  233. RemovePlatform(&platform1);
  234. }
  235. TEST_F(UkmDataManagerImplTest, UkmEntryObservation) {
  236. const GURL kUrl1 = GURL("https://www.url1.com/");
  237. // UKM added before creating platform do not get recorded.
  238. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  239. ukm_recorder_->AddEntry(GetSamplePaintPreviewEntry());
  240. // Create a platform that observes PageLoad events.
  241. TestServicesForPlatform& platform1 = CreatePlatform();
  242. platform1.AddModel(PageLoadModelMetadata());
  243. // Not added since UkmDataManager is not notified for UKM observation.
  244. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  245. // Not added since it is not PageLoad event.
  246. ukm_recorder_->AddEntry(GetSamplePaintPreviewEntry());
  247. // PageLoad event gets recorded in UkmDatabase.
  248. base::RunLoop wait_for_record;
  249. EXPECT_CALL(*ukm_database_,
  250. StoreUkmEntry(HasEventHash(PageLoad::kEntryNameHash)))
  251. .WillOnce([&wait_for_record](ukm::mojom::UkmEntryPtr entry) {
  252. wait_for_record.QuitClosure().Run();
  253. });
  254. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  255. wait_for_record.Run();
  256. RemovePlatform(&platform1);
  257. }
  258. TEST_F(UkmDataManagerImplTest, UkmServiceCreatedBeforePlatform) {
  259. const GURL kUrl1 = GURL("https://www.url1.com/");
  260. TestServicesForPlatform& platform1 = CreatePlatform();
  261. platform1.AddModel(PageLoadModelMetadata());
  262. // Entry should be recorded, This step does not wait for the database record
  263. // here since it is waits for the next observation below.
  264. EXPECT_CALL(*ukm_database_,
  265. StoreUkmEntry(HasEventHash(PageLoad::kEntryNameHash)));
  266. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  267. // Source updates should be notified.
  268. base::RunLoop wait_for_source;
  269. EXPECT_CALL(*ukm_database_,
  270. UpdateUrlForUkmSource(kSourceId, kUrl1, /*is_validated=*/false))
  271. .WillOnce([&wait_for_source](ukm::SourceId source_id, const GURL& url,
  272. bool is_validated) {
  273. wait_for_source.QuitClosure().Run();
  274. });
  275. ukm_recorder_->UpdateSourceURL(kSourceId, kUrl1);
  276. wait_for_source.Run();
  277. RemovePlatform(&platform1);
  278. }
  279. TEST_F(UkmDataManagerImplTest, UrlValidationWithHistory) {
  280. const GURL kUrl1 = GURL("https://www.url1.com/");
  281. TestServicesForPlatform& platform1 = CreatePlatform();
  282. platform1.AddModel(PageLoadModelMetadata());
  283. // History page is added before source update.
  284. base::RunLoop wait_for_add1;
  285. EXPECT_CALL(*ukm_database_, OnUrlValidated(kUrl1))
  286. .WillOnce([&wait_for_add1]() { wait_for_add1.QuitClosure().Run(); });
  287. platform1.history_service->AddPage(kUrl1, base::Time::Now(),
  288. history::VisitSource::SOURCE_BROWSED);
  289. wait_for_add1.Run();
  290. // Source update should have a validated URL.
  291. base::RunLoop wait_for_source;
  292. EXPECT_CALL(*ukm_database_,
  293. UpdateUrlForUkmSource(kSourceId, kUrl1, /*is_validated=*/true))
  294. .WillOnce([&wait_for_source](ukm::SourceId source_id, const GURL& url,
  295. bool is_validated) {
  296. wait_for_source.QuitClosure().Run();
  297. });
  298. ukm_recorder_->UpdateSourceURL(kSourceId, kUrl1);
  299. wait_for_source.Run();
  300. RemovePlatform(&platform1);
  301. }
  302. TEST_F(UkmDataManagerImplTest, MultiplePlatforms) {
  303. const GURL kUrl1 = GURL("https://www.url1.com/");
  304. const GURL kUrl2 = GURL("https://www.url2.com/");
  305. // Create 2 platforms, and 1 of them observing UKM events.
  306. TestServicesForPlatform& platform1 = CreatePlatform();
  307. TestServicesForPlatform& platform3 = CreatePlatform();
  308. platform1.AddModel(PageLoadModelMetadata());
  309. // Only page load should be added to database.
  310. EXPECT_CALL(*ukm_database_,
  311. StoreUkmEntry(HasEventHash(PageLoad::kEntryNameHash)));
  312. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  313. ukm_recorder_->AddEntry(GetSamplePaintPreviewEntry());
  314. // Create another platform observing paint preview.
  315. TestServicesForPlatform& platform2 = CreatePlatform();
  316. platform2.AddModel(PaintPreviewModelMetadata());
  317. // Both should be added to database.
  318. EXPECT_CALL(*ukm_database_,
  319. StoreUkmEntry(HasEventHash(PageLoad::kEntryNameHash)));
  320. EXPECT_CALL(*ukm_database_,
  321. StoreUkmEntry(HasEventHash(PaintPreviewCapture::kEntryNameHash)));
  322. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  323. ukm_recorder_->AddEntry(GetSamplePaintPreviewEntry());
  324. // Sources should still be updated.
  325. base::RunLoop wait_for_source;
  326. EXPECT_CALL(*ukm_database_,
  327. UpdateUrlForUkmSource(kSourceId, kUrl1, /*is_validated=*/false))
  328. .WillOnce([&wait_for_source](ukm::SourceId source_id, const GURL& url,
  329. bool is_validated) {
  330. wait_for_source.QuitClosure().Run();
  331. });
  332. ukm_recorder_->UpdateSourceURL(kSourceId, kUrl1);
  333. wait_for_source.Run();
  334. // Removing platform1 does not stop observing metrics.
  335. RemovePlatform(&platform1);
  336. EXPECT_CALL(*ukm_database_,
  337. StoreUkmEntry(HasEventHash(PageLoad::kEntryNameHash)));
  338. EXPECT_CALL(*ukm_database_,
  339. StoreUkmEntry(HasEventHash(PaintPreviewCapture::kEntryNameHash)));
  340. ukm_recorder_->AddEntry(GetSamplePageLoadEntry());
  341. ukm_recorder_->AddEntry(GetSamplePaintPreviewEntry());
  342. // Update history service on one of the platforms, and the database should get
  343. // a validated URL.
  344. base::RunLoop wait_for_add1;
  345. EXPECT_CALL(*ukm_database_, OnUrlValidated(kUrl2))
  346. .WillOnce([&wait_for_add1]() { wait_for_add1.QuitClosure().Run(); });
  347. platform2.history_service->AddPage(kUrl2, base::Time::Now(),
  348. history::VisitSource::SOURCE_BROWSED);
  349. wait_for_add1.Run();
  350. base::RunLoop wait_for_source2;
  351. EXPECT_CALL(*ukm_database_,
  352. UpdateUrlForUkmSource(kSourceId2, kUrl2, /*is_validated=*/true))
  353. .WillOnce([&wait_for_source2](ukm::SourceId source_id, const GURL& url,
  354. bool is_validated) {
  355. wait_for_source2.QuitClosure().Run();
  356. });
  357. ukm_recorder_->UpdateSourceURL(kSourceId2, kUrl2);
  358. wait_for_source2.Run();
  359. RemovePlatform(&platform2);
  360. RemovePlatform(&platform3);
  361. }
  362. } // namespace segmentation_platform