segmentation_platform_service_test_base.cc 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/segmentation_platform/internal/segmentation_platform_service_test_base.h"
  5. #include "base/test/test_simple_task_runner.h"
  6. #include "base/time/time.h"
  7. #include "base/values.h"
  8. #include "components/prefs/scoped_user_pref_update.h"
  9. #include "components/segmentation_platform/internal/constants.h"
  10. #include "components/segmentation_platform/internal/database/segment_info_database.h"
  11. #include "components/segmentation_platform/internal/execution/mock_model_provider.h"
  12. #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
  13. #include "components/segmentation_platform/internal/ukm_data_manager.h"
  14. #include "components/segmentation_platform/public/config.h"
  15. #include "components/segmentation_platform/public/field_trial_register.h"
  16. namespace segmentation_platform {
  17. namespace {
  18. class MockFieldTrialRegister : public FieldTrialRegister {
  19. public:
  20. MOCK_METHOD2(RegisterFieldTrial,
  21. void(base::StringPiece trial_name,
  22. base::StringPiece group_name));
  23. MOCK_METHOD3(RegisterSubsegmentFieldTrialIfNeeded,
  24. void(base::StringPiece trial_name,
  25. proto::SegmentId segment_id,
  26. int subsegment_rank));
  27. };
  28. std::vector<std::unique_ptr<Config>> CreateTestConfigs() {
  29. std::vector<std::unique_ptr<Config>> configs;
  30. {
  31. std::unique_ptr<Config> config = std::make_unique<Config>();
  32. config->segmentation_key = kTestSegmentationKey1;
  33. config->segment_selection_ttl = base::Days(28);
  34. config->AddSegmentId(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
  35. config->AddSegmentId(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
  36. configs.push_back(std::move(config));
  37. }
  38. {
  39. std::unique_ptr<Config> config = std::make_unique<Config>();
  40. config->segmentation_key = kTestSegmentationKey2;
  41. config->segment_selection_ttl = base::Days(10);
  42. config->AddSegmentId(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
  43. config->AddSegmentId(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_VOICE);
  44. configs.push_back(std::move(config));
  45. }
  46. {
  47. std::unique_ptr<Config> config = std::make_unique<Config>();
  48. config->segmentation_key = kTestSegmentationKey3;
  49. config->segment_selection_ttl = base::Days(14);
  50. config->AddSegmentId(SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB);
  51. configs.push_back(std::move(config));
  52. }
  53. {
  54. // Empty config.
  55. std::unique_ptr<Config> config = std::make_unique<Config>();
  56. config->segmentation_key = "test_key";
  57. configs.push_back(std::move(config));
  58. }
  59. return configs;
  60. }
  61. } // namespace
  62. constexpr char kTestSegmentationKey1[] = "test_key1";
  63. constexpr char kTestSegmentationKey2[] = "test_key2";
  64. constexpr char kTestSegmentationKey3[] = "test_key3";
  65. SegmentationPlatformServiceTestBase::SegmentationPlatformServiceTestBase() =
  66. default;
  67. SegmentationPlatformServiceTestBase::~SegmentationPlatformServiceTestBase() =
  68. default;
  69. void SegmentationPlatformServiceTestBase::InitPlatform(
  70. UkmDataManager* ukm_data_manager,
  71. history::HistoryService* history_service) {
  72. task_runner_ = base::MakeRefCounted<base::TestSimpleTaskRunner>();
  73. auto segment_db =
  74. std::make_unique<leveldb_proto::test::FakeDB<proto::SegmentInfo>>(
  75. &segment_db_entries_);
  76. auto signal_db =
  77. std::make_unique<leveldb_proto::test::FakeDB<proto::SignalData>>(
  78. &signal_db_entries_);
  79. auto segment_storage_config_db = std::make_unique<
  80. leveldb_proto::test::FakeDB<proto::SignalStorageConfigs>>(
  81. &segment_storage_config_db_entries_);
  82. segment_db_ = segment_db.get();
  83. signal_db_ = signal_db.get();
  84. segment_storage_config_db_ = segment_storage_config_db.get();
  85. auto model_provider_factory =
  86. std::make_unique<TestModelProviderFactory>(&model_provider_data_);
  87. SegmentationPlatformService::RegisterProfilePrefs(pref_service_.registry());
  88. SetUpPrefs();
  89. std::vector<std::unique_ptr<Config>> configs = CreateTestConfigs();
  90. base::flat_set<SegmentId> all_segment_ids;
  91. for (const auto& config : configs) {
  92. for (const auto& segment_id : config->segments)
  93. all_segment_ids.insert(segment_id.first);
  94. }
  95. auto storage_service = std::make_unique<StorageService>(
  96. std::move(segment_db), std::move(signal_db),
  97. std::move(segment_storage_config_db), &test_clock_, ukm_data_manager,
  98. all_segment_ids, model_provider_factory.get());
  99. auto params = std::make_unique<SegmentationPlatformServiceImpl::InitParams>();
  100. params->storage_service = std::move(storage_service);
  101. params->model_provider =
  102. std::make_unique<TestModelProviderFactory>(&model_provider_data_);
  103. params->profile_prefs = &pref_service_;
  104. params->history_service = history_service;
  105. params->task_runner = task_runner_;
  106. params->clock = &test_clock_;
  107. params->configs = std::move(configs);
  108. params->field_trial_register = std::make_unique<MockFieldTrialRegister>();
  109. segmentation_platform_service_impl_ =
  110. std::make_unique<SegmentationPlatformServiceImpl>(std::move(params));
  111. }
  112. void SegmentationPlatformServiceTestBase::DestroyPlatform() {
  113. segmentation_platform_service_impl_.reset();
  114. // Allow for the SegmentationModelExecutor owned by SegmentationModelHandler
  115. // to be destroyed.
  116. task_runner_->RunUntilIdle();
  117. }
  118. void SegmentationPlatformServiceTestBase::SetUpPrefs() {
  119. DictionaryPrefUpdate update(&pref_service_, kSegmentationResultPref);
  120. base::Value* dictionary = update.Get();
  121. base::Value segmentation_result(base::Value::Type::DICTIONARY);
  122. segmentation_result.SetIntKey(
  123. "segment_id", SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE);
  124. dictionary->SetKey(kTestSegmentationKey1, std::move(segmentation_result));
  125. }
  126. std::vector<std::unique_ptr<Config>>
  127. SegmentationPlatformServiceTestBase::CreateConfigs() {
  128. return CreateTestConfigs();
  129. }
  130. } // namespace segmentation_platform