service_proxy_impl.cc 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. // Copyright 2021 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/segmentation_platform/internal/service_proxy_impl.h"
  5. #include <inttypes.h>
  6. #include <sstream>
  7. #include "base/observer_list.h"
  8. #include "base/strings/string_number_conversions.h"
  9. #include "base/strings/stringprintf.h"
  10. #include "base/time/time.h"
  11. #include "components/segmentation_platform/internal/database/signal_storage_config.h"
  12. #include "components/segmentation_platform/internal/metadata/metadata_utils.h"
  13. #include "components/segmentation_platform/internal/scheduler/execution_service.h"
  14. #include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
  15. #include "components/segmentation_platform/internal/selection/segment_selector_impl.h"
  16. #include "components/segmentation_platform/public/config.h"
  17. #include "base/logging.h"
  18. namespace segmentation_platform {
  19. namespace {
  20. std::string SegmentMetadataToString(const proto::SegmentInfo& segment_info) {
  21. if (!segment_info.has_model_metadata())
  22. return std::string();
  23. return "model_metadata: { " +
  24. metadata_utils::SegmetationModelMetadataToString(
  25. segment_info.model_metadata()) +
  26. " }";
  27. }
  28. std::string PredictionResultToString(const proto::SegmentInfo& segment_info) {
  29. if (!segment_info.has_prediction_result())
  30. return std::string();
  31. const auto prediction_result = segment_info.prediction_result();
  32. base::Time time;
  33. if (prediction_result.has_timestamp_us()) {
  34. time = base::Time::FromDeltaSinceWindowsEpoch(
  35. base::Microseconds(prediction_result.timestamp_us()));
  36. }
  37. std::ostringstream time_string;
  38. time_string << time;
  39. return base::StringPrintf(
  40. "result: %f, time: %s",
  41. prediction_result.has_result() ? prediction_result.result() : 0,
  42. time_string.str().c_str());
  43. }
  44. } // namespace
  45. ServiceProxyImpl::ServiceProxyImpl(
  46. SegmentInfoDatabase* segment_db,
  47. SignalStorageConfig* signal_storage_config,
  48. std::vector<std::unique_ptr<Config>>* configs,
  49. base::flat_map<std::string, std::unique_ptr<SegmentSelectorImpl>>*
  50. segment_selectors)
  51. : segment_db_(segment_db),
  52. signal_storage_config_(signal_storage_config),
  53. configs_(configs),
  54. segment_selectors_(segment_selectors) {}
  55. ServiceProxyImpl::~ServiceProxyImpl() = default;
  56. void ServiceProxyImpl::AddObserver(ServiceProxy::Observer* observer) {
  57. observers_.AddObserver(observer);
  58. }
  59. void ServiceProxyImpl::RemoveObserver(ServiceProxy::Observer* observer) {
  60. observers_.RemoveObserver(observer);
  61. }
  62. void ServiceProxyImpl::OnServiceStatusChanged(bool is_initialized,
  63. int status_flag) {
  64. bool changed = (is_service_initialized_ != is_initialized) ||
  65. (service_status_flag_ != status_flag);
  66. is_service_initialized_ = is_initialized;
  67. service_status_flag_ = status_flag;
  68. UpdateObservers(changed);
  69. }
  70. void ServiceProxyImpl::UpdateObservers(bool update_service_status) {
  71. if (observers_.empty())
  72. return;
  73. if (update_service_status) {
  74. for (auto& obs : observers_)
  75. obs.OnServiceStatusChanged(is_service_initialized_, service_status_flag_);
  76. }
  77. if (segment_db_ &&
  78. (static_cast<int>(ServiceStatus::kSegmentationInfoDbInitialized) &
  79. service_status_flag_)) {
  80. segment_db_->GetAllSegmentInfo(
  81. base::BindOnce(&ServiceProxyImpl::OnGetAllSegmentationInfo,
  82. weak_ptr_factory_.GetWeakPtr()));
  83. }
  84. }
  85. void ServiceProxyImpl::SetExecutionService(
  86. ExecutionService* model_execution_scheduler) {
  87. execution_service = model_execution_scheduler;
  88. }
  89. void ServiceProxyImpl::GetServiceStatus() {
  90. UpdateObservers(true /* update_service_status */);
  91. }
  92. void ServiceProxyImpl::ExecuteModel(SegmentId segment_id) {
  93. if (!execution_service ||
  94. segment_id == SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
  95. return;
  96. }
  97. segment_db_->GetSegmentInfo(
  98. segment_id,
  99. base::BindOnce(&ServiceProxyImpl::OnSegmentInfoFetchedForExecution,
  100. weak_ptr_factory_.GetWeakPtr()));
  101. }
  102. void ServiceProxyImpl::OnSegmentInfoFetchedForExecution(
  103. absl::optional<proto::SegmentInfo> segment_info) {
  104. if (!segment_info)
  105. return;
  106. auto request = std::make_unique<ExecutionRequest>();
  107. request->record_metrics_for_default = false;
  108. request->save_result_to_db = true;
  109. request->segment_info = &segment_info.value();
  110. execution_service->RequestModelExecution(std::move(request));
  111. }
  112. void ServiceProxyImpl::OverwriteResult(SegmentId segment_id, float result) {
  113. if (!execution_service)
  114. return;
  115. if (result < 0 || result > 1)
  116. return;
  117. if (segment_id != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
  118. execution_service->OverwriteModelExecutionResult(
  119. segment_id, std::make_pair(result, ModelExecutionStatus::kSuccess));
  120. }
  121. }
  122. void ServiceProxyImpl::SetSelectedSegment(const std::string& segmentation_key,
  123. SegmentId segment_id) {
  124. if (!segment_selectors_ ||
  125. segment_selectors_->find(segmentation_key) == segment_selectors_->end()) {
  126. return;
  127. }
  128. if (segment_id != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
  129. auto& selector = segment_selectors_->at(segmentation_key);
  130. selector->UpdateSelectedSegment(segment_id);
  131. }
  132. }
  133. void ServiceProxyImpl::OnGetAllSegmentationInfo(
  134. std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_info) {
  135. if (!configs_)
  136. return;
  137. // Convert the |segment_info| vector to a map for quick lookup.
  138. base::flat_map<SegmentId, proto::SegmentInfo> segment_ids;
  139. for (const auto& info : *segment_info) {
  140. segment_ids[info.first] = info.second;
  141. }
  142. std::vector<ServiceProxy::ClientInfo> result;
  143. for (const auto& config : *configs_) {
  144. SegmentId selected = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
  145. if (segment_selectors_ &&
  146. segment_selectors_->find(config->segmentation_key) !=
  147. segment_selectors_->end()) {
  148. absl::optional<proto::SegmentId> target =
  149. segment_selectors_->at(config->segmentation_key)
  150. ->GetCachedSegmentResult()
  151. .segment;
  152. if (target.has_value()) {
  153. selected = *target;
  154. }
  155. }
  156. result.emplace_back(config->segmentation_key, selected);
  157. for (const auto& segment_id : config->segments) {
  158. if (!segment_ids.contains(segment_id.first))
  159. continue;
  160. const auto& info = segment_ids[segment_id.first];
  161. result.back().segment_status.emplace_back(
  162. segment_id.first, SegmentMetadataToString(info),
  163. PredictionResultToString(info),
  164. signal_storage_config_
  165. ? signal_storage_config_->MeetsSignalCollectionRequirement(
  166. info.model_metadata())
  167. : false);
  168. }
  169. }
  170. for (auto& obs : observers_)
  171. obs.OnClientInfoAvailable(result);
  172. }
  173. void ServiceProxyImpl::OnModelExecutionCompleted(SegmentId segment_id) {
  174. // Update the observers with the new execution results.
  175. UpdateObservers(false);
  176. }
  177. } // namespace segmentation_platform