segmentation_platform_service_impl.cc 11 KB


  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
  5. #include <string>
  6. #include "base/bind.h"
  7. #include "base/callback_helpers.h"
  8. #include "base/command_line.h"
  9. #include "base/files/file_path.h"
  10. #include "base/memory/scoped_refptr.h"
  11. #include "base/metrics/histogram_functions.h"
  12. #include "base/system/sys_info.h"
  13. #include "base/task/sequenced_task_runner.h"
  14. #include "base/threading/sequenced_task_runner_handle.h"
  15. #include "base/threading/thread_task_runner_handle.h"
  16. #include "base/time/clock.h"
  17. #include "components/prefs/pref_registry_simple.h"
  18. #include "components/segmentation_platform/internal/constants.h"
  19. #include "components/segmentation_platform/internal/database/storage_service.h"
  20. #include "components/segmentation_platform/internal/platform_options.h"
  21. #include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
  22. #include "components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h"
  23. #include "components/segmentation_platform/internal/selection/segment_score_provider.h"
  24. #include "components/segmentation_platform/internal/selection/segment_selector_impl.h"
  25. #include "components/segmentation_platform/internal/selection/segmentation_result_prefs.h"
  26. #include "components/segmentation_platform/internal/stats.h"
  27. #include "components/segmentation_platform/public/config.h"
  28. #include "components/segmentation_platform/public/field_trial_register.h"
  29. #include "components/segmentation_platform/public/input_context.h"
  30. #include "components/segmentation_platform/public/input_delegate.h"
  31. #include "components/segmentation_platform/public/model_provider.h"
  32. namespace segmentation_platform {
  33. namespace {
  34. using proto::SegmentId;
  35. base::flat_set<SegmentId> GetAllSegmentIds(
  36. const std::vector<std::unique_ptr<Config>>& configs) {
  37. base::flat_set<SegmentId> all_segment_ids;
  38. for (const auto& config : configs) {
  39. for (const auto& segment_id : config->segments)
  40. all_segment_ids.insert(segment_id.first);
  41. }
  42. return all_segment_ids;
  43. }
  44. } // namespace
  45. SegmentationPlatformServiceImpl::InitParams::InitParams() = default;
  46. SegmentationPlatformServiceImpl::InitParams::~InitParams() = default;
  47. SegmentationPlatformServiceImpl::SegmentationPlatformServiceImpl(
  48. std::unique_ptr<InitParams> init_params)
  49. : model_provider_factory_(std::move(init_params->model_provider)),
  50. task_runner_(init_params->task_runner),
  51. clock_(init_params->clock.get()),
  52. platform_options_(PlatformOptions::CreateDefault()),
  53. input_delegate_holder_(std::move(init_params->input_delegate_holder)),
  54. configs_(std::move(init_params->configs)),
  55. all_segment_ids_(GetAllSegmentIds(configs_)),
  56. field_trial_register_(std::move(init_params->field_trial_register)),
  57. profile_prefs_(init_params->profile_prefs.get()),
  58. creation_time_(clock_->Now()) {
  59. base::UmaHistogramMediumTimes(
  60. "SegmentationPlatform.Init.ProcessCreationToServiceCreationLatency",
  61. base::SysInfo::Uptime());
  62. DCHECK(task_runner_);
  63. DCHECK(clock);
  64. DCHECK(init_params->profile_prefs);
  65. if (init_params->storage_service) {
  66. // Test only:
  67. storage_service_ = std::move(init_params->storage_service);
  68. } else {
  69. DCHECK(model_provider_factory_ && init_params->db_provider);
  70. DCHECK(!init_params->storage_dir.empty() && init_params->ukm_data_manager);
  71. storage_service_ = std::make_unique<StorageService>(
  72. init_params->storage_dir, init_params->db_provider,
  73. init_params->task_runner, init_params->clock,
  74. init_params->ukm_data_manager, all_segment_ids_,
  75. model_provider_factory_.get());
  76. }
  77. // Construct signal processors.
  78. signal_handler_.Initialize(
  79. storage_service_.get(), init_params->history_service, all_segment_ids_,
  80. base::BindRepeating(
  81. &SegmentationPlatformServiceImpl::OnModelRefreshNeeded,
  82. weak_ptr_factory_.GetWeakPtr()));
  83. for (const auto& config : configs_) {
  84. segment_selectors_[config->segmentation_key] =
  85. std::make_unique<SegmentSelectorImpl>(
  86. storage_service_->segment_info_database(),
  87. storage_service_->signal_storage_config(),
  88. init_params->profile_prefs, config.get(),
  89. field_trial_register_.get(), init_params->clock, platform_options_,
  90. storage_service_->default_model_manager());
  91. }
  92. proxy_ = std::make_unique<ServiceProxyImpl>(
  93. storage_service_->segment_info_database(),
  94. storage_service_->signal_storage_config(), &configs_,
  95. &segment_selectors_);
  96. segment_score_provider_ =
  97. SegmentScoreProvider::Create(storage_service_->segment_info_database());
  98. // Kick off initialization of all databases. Internal operations will be
  99. // delayed until they are all complete.
  100. storage_service_->Initialize(
  101. base::BindOnce(&SegmentationPlatformServiceImpl::OnDatabaseInitialized,
  102. weak_ptr_factory_.GetWeakPtr()));
  103. }
  104. SegmentationPlatformServiceImpl::~SegmentationPlatformServiceImpl() {
  105. signal_handler_.TearDown();
  106. }
  107. void SegmentationPlatformServiceImpl::GetSelectedSegment(
  108. const std::string& segmentation_key,
  109. SegmentSelectionCallback callback) {
  110. CHECK(segment_selectors_.find(segmentation_key) != segment_selectors_.end());
  111. auto& selector = segment_selectors_.at(segmentation_key);
  112. selector->GetSelectedSegment(std::move(callback));
  113. }
  114. SegmentSelectionResult SegmentationPlatformServiceImpl::GetCachedSegmentResult(
  115. const std::string& segmentation_key) {
  116. CHECK(segment_selectors_.find(segmentation_key) != segment_selectors_.end());
  117. auto& selector = segment_selectors_.at(segmentation_key);
  118. return selector->GetCachedSegmentResult();
  119. }
  120. void SegmentationPlatformServiceImpl::GetSelectedSegmentOnDemand(
  121. const std::string& segmentation_key,
  122. scoped_refptr<InputContext> input_context,
  123. SegmentSelectionCallback callback) {
  124. if (!storage_initialized_) {
  125. // If the platform isn't fully initialized, cache the input arguments to run
  126. // later.
  127. pending_actions_.push_back(base::BindOnce(
  128. &SegmentationPlatformServiceImpl::GetSelectedSegmentOnDemand,
  129. weak_ptr_factory_.GetWeakPtr(), segmentation_key,
  130. std::move(input_context), std::move(callback)));
  131. return;
  132. }
  133. CHECK(segment_selectors_.find(segmentation_key) != segment_selectors_.end());
  134. auto& selector = segment_selectors_.at(segmentation_key);
  135. // Wrap callback to record metrics.
  136. auto wrapped_callback = base::BindOnce(
  137. [](const std::string& segmentation_key, base::Time start_time,
  138. SegmentSelectionCallback callback,
  139. const SegmentSelectionResult& result) -> void {
  140. stats::RecordOnDemandSegmentSelectionDuration(
  141. segmentation_key, result, base::Time::Now() - start_time);
  142. std::move(callback).Run(result);
  143. },
  144. segmentation_key, base::Time::Now(), std::move(callback));
  145. selector->GetSelectedSegmentOnDemand(input_context,
  146. std::move(wrapped_callback));
  147. }
  148. void SegmentationPlatformServiceImpl::EnableMetrics(
  149. bool signal_collection_allowed) {
  150. signal_handler_.EnableMetrics(signal_collection_allowed);
  151. }
  152. ServiceProxy* SegmentationPlatformServiceImpl::GetServiceProxy() {
  153. return proxy_.get();
  154. }
  155. bool SegmentationPlatformServiceImpl::IsPlatformInitialized() {
  156. return storage_initialized_;
  157. }
  158. void SegmentationPlatformServiceImpl::OnDatabaseInitialized(bool success) {
  159. storage_initialized_ = true;
  160. OnServiceStatusChanged();
  161. if (!success) {
  162. for (const auto& config : configs_) {
  163. stats::RecordSegmentSelectionFailure(
  164. config->segmentation_key,
  165. stats::SegmentationSelectionFailureReason::kDBInitFailure);
  166. }
  167. return;
  168. }
  169. segment_score_provider_->Initialize(base::DoNothing());
  170. signal_handler_.OnSignalListUpdated();
  171. std::vector<ModelExecutionSchedulerImpl::Observer*> observers;
  172. for (auto& key_and_selector : segment_selectors_)
  173. observers.push_back(key_and_selector.second.get());
  174. observers.push_back(proxy_.get());
  175. execution_service_.Initialize(
  176. storage_service_.get(), &signal_handler_, clock_,
  177. base::BindRepeating(
  178. &SegmentationPlatformServiceImpl::OnSegmentationModelUpdated,
  179. weak_ptr_factory_.GetWeakPtr()),
  180. task_runner_, all_segment_ids_, model_provider_factory_.get(),
  181. std::move(observers), platform_options_,
  182. std::move(input_delegate_holder_), &configs_, profile_prefs_);
  183. proxy_->SetExecutionService(&execution_service_);
  184. for (auto& selector : segment_selectors_) {
  185. selector.second->OnPlatformInitialized(&execution_service_);
  186. }
  187. // Run any method calls that were received during initialization.
  188. while (!pending_actions_.empty()) {
  189. auto callback = std::move(pending_actions_.front());
  190. pending_actions_.pop_front();
  191. base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE,
  192. std::move(callback));
  193. }
  194. // Run any daily maintenance tasks.
  195. RunDailyTasks(/*is_startup=*/true);
  196. init_time_ = clock_->Now();
  197. base::UmaHistogramMediumTimes(
  198. "SegmentationPlatform.Init.CreationToInitializationLatency",
  199. init_time_ - creation_time_);
  200. }
  201. void SegmentationPlatformServiceImpl::OnSegmentationModelUpdated(
  202. proto::SegmentInfo segment_info) {
  203. DCHECK(metadata_utils::ValidateSegmentInfoMetadataAndFeatures(segment_info) ==
  204. metadata_utils::ValidationResult::kValidationSuccess);
  205. signal_handler_.OnSignalListUpdated();
  206. execution_service_.OnNewModelInfoReady(segment_info);
  207. // Update the service status for proxy.
  208. base::ThreadTaskRunnerHandle::Get()->PostTask(
  209. FROM_HERE,
  210. base::BindOnce(&SegmentationPlatformServiceImpl::OnServiceStatusChanged,
  211. weak_ptr_factory_.GetWeakPtr()));
  212. }
  213. void SegmentationPlatformServiceImpl::OnModelRefreshNeeded() {
  214. execution_service_.RefreshModelResults();
  215. }
  216. void SegmentationPlatformServiceImpl::OnServiceStatusChanged() {
  217. proxy_->OnServiceStatusChanged(storage_initialized_,
  218. storage_service_->GetServiceStatus());
  219. }
  220. void SegmentationPlatformServiceImpl::RunDailyTasks(bool is_startup) {
  221. execution_service_.RunDailyTasks(is_startup);
  222. storage_service_->ExecuteDatabaseMaintenanceTasks(is_startup);
  223. base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
  224. FROM_HERE,
  225. base::BindOnce(&SegmentationPlatformServiceImpl::RunDailyTasks,
  226. weak_ptr_factory_.GetWeakPtr(), /*is_startup=*/false),
  227. base::Days(1));
  228. }
  229. // static
  230. void SegmentationPlatformService::RegisterProfilePrefs(
  231. PrefRegistrySimple* registry) {
  232. registry->RegisterDictionaryPref(kSegmentationResultPref);
  233. }
  234. // static
  235. void SegmentationPlatformService::RegisterLocalStatePrefs(
  236. PrefRegistrySimple* registry) {
  237. registry->RegisterTimePref(kSegmentationUkmMostRecentAllowedTimeKey,
  238. base::Time());
  239. registry->RegisterTimePref(kSegmentationLastCollectionTimePref, base::Time());
  240. }
  241. } // namespace segmentation_platform