prediction_model_fetcher_impl.cc 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/optimization_guide/core/prediction_model_fetcher_impl.h"
  5. #include <memory>
  6. #include <string>
  7. #include <utility>
  8. #include <vector>
  9. #include "base/feature_list.h"
  10. #include "base/metrics/histogram_functions.h"
  11. #include "base/metrics/histogram_macros.h"
  12. #include "components/optimization_guide/core/model_util.h"
  13. #include "components/optimization_guide/core/optimization_guide_features.h"
  14. #include "components/optimization_guide/core/optimization_guide_util.h"
  15. #include "components/optimization_guide/proto/models.pb.h"
  16. #include "components/variations/net/variations_http_headers.h"
  17. #include "net/base/load_flags.h"
  18. #include "net/base/url_util.h"
  19. #include "net/http/http_request_headers.h"
  20. #include "net/http/http_response_headers.h"
  21. #include "net/http/http_status_code.h"
  22. #include "net/traffic_annotation/network_traffic_annotation.h"
  23. #include "services/network/public/cpp/resource_request.h"
  24. #include "services/network/public/cpp/shared_url_loader_factory.h"
  25. #include "services/network/public/cpp/simple_url_loader.h"
  26. #include "services/network/public/mojom/url_response_head.mojom.h"
  27. namespace optimization_guide {
  28. PredictionModelFetcherImpl::PredictionModelFetcherImpl(
  29. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  30. const GURL& optimization_guide_service_get_models_url)
  31. : optimization_guide_service_get_models_url_(
  32. net::AppendOrReplaceQueryParameter(
  33. optimization_guide_service_get_models_url,
  34. "key",
  35. optimization_guide::features::
  36. GetOptimizationGuideServiceAPIKey())),
  37. url_loader_factory_(url_loader_factory) {
  38. CHECK(optimization_guide_service_get_models_url_.SchemeIs(url::kHttpsScheme));
  39. }
  40. PredictionModelFetcherImpl::~PredictionModelFetcherImpl() = default;
  41. bool PredictionModelFetcherImpl::FetchOptimizationGuideServiceModels(
  42. const std::vector<proto::ModelInfo>& models_request_info,
  43. proto::RequestContext request_context,
  44. const std::string& locale,
  45. ModelsFetchedCallback models_fetched_callback) {
  46. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  47. if (url_loader_)
  48. return false;
  49. // If there are no models to request, do not make a GetModelsRequest.
  50. if (models_request_info.empty()) {
  51. std::move(models_fetched_callback).Run(absl::nullopt);
  52. return false;
  53. }
  54. pending_models_request_ =
  55. std::make_unique<optimization_guide::proto::GetModelsRequest>();
  56. pending_models_request_->set_request_context(request_context);
  57. pending_models_request_->set_locale(locale);
  58. *pending_models_request_->mutable_origin_info() =
  59. optimization_guide::GetClientOriginInfo();
  60. for (const auto& model_request_info : models_request_info)
  61. *pending_models_request_->add_requested_models() = model_request_info;
  62. std::string serialized_request;
  63. pending_models_request_->SerializeToString(&serialized_request);
  64. net::NetworkTrafficAnnotationTag traffic_annotation =
  65. net::DefineNetworkTrafficAnnotation("optimization_guide_model",
  66. R"(
  67. semantics {
  68. sender: "Optimization Guide"
  69. description:
  70. "Requests the updated set of machine learning models from the "
  71. "Optimization Guide Service that are applicable to the current "
  72. "client version."
  73. trigger:
  74. "Requested at the beginning of each session if there are features "
  75. "enabled by the current client version that require machine "
  76. "learning models."
  77. data: "A list of models supported by the client."
  78. destination: GOOGLE_OWNED_SERVICE
  79. }
  80. policy {
  81. cookies_allowed: NO
  82. setting: "This feature cannot be disabled."
  83. chrome_policy {
  84. ComponentUpdatesEnabled {
  85. policy_options {mode: MANDATORY}
  86. ComponentUpdatesEnabled: false
  87. }
  88. }
  89. })");
  90. auto resource_request = std::make_unique<network::ResourceRequest>();
  91. resource_request->url = optimization_guide_service_get_models_url_;
  92. // POST request for providing the GetModelsRequest proto to the remote
  93. // Optimization Guide Service.
  94. resource_request->method = "POST";
  95. resource_request->credentials_mode = network::mojom::CredentialsMode::kOmit;
  96. url_loader_ = variations::CreateSimpleURLLoaderWithVariationsHeader(
  97. std::move(resource_request),
  98. // This is always InIncognito::kNo as the OptimizationGuideKeyedService is
  99. // not enabled on incognito sessions and is rechecked before each fetch.
  100. variations::InIncognito::kNo, variations::SignedIn::kNo,
  101. traffic_annotation);
  102. url_loader_->AttachStringForUpload(serialized_request,
  103. "application/x-protobuf");
  104. url_loader_->DownloadToStringOfUnboundedSizeUntilCrashAndDie(
  105. url_loader_factory_.get(),
  106. base::BindOnce(&PredictionModelFetcherImpl::OnURLLoadComplete,
  107. base::Unretained(this)));
  108. models_fetched_callback_ = std::move(models_fetched_callback);
  109. return true;
  110. }
  111. void PredictionModelFetcherImpl::HandleResponse(
  112. const std::string& get_models_response_data,
  113. int net_status,
  114. int response_code) {
  115. std::unique_ptr<optimization_guide::proto::GetModelsResponse>
  116. get_models_response =
  117. std::make_unique<optimization_guide::proto::GetModelsResponse>();
  118. UMA_HISTOGRAM_ENUMERATION(
  119. "OptimizationGuide.PredictionModelFetcher."
  120. "GetModelsResponse.Status",
  121. static_cast<net::HttpStatusCode>(response_code),
  122. net::HTTP_VERSION_NOT_SUPPORTED);
  123. // Net error codes are negative but histogram enums must be positive.
  124. base::UmaHistogramSparse(
  125. "OptimizationGuide.PredictionModelFetcher."
  126. "GetModelsResponse.NetErrorCode",
  127. -net_status);
  128. for (const auto& model_info : pending_models_request_->requested_models()) {
  129. if (response_code >= 0 &&
  130. response_code <= net::HTTP_VERSION_NOT_SUPPORTED) {
  131. base::UmaHistogramEnumeration(
  132. "OptimizationGuide.PredictionModelFetcher."
  133. "GetModelsResponse.Status." +
  134. optimization_guide::GetStringNameForOptimizationTarget(
  135. model_info.optimization_target()),
  136. static_cast<net::HttpStatusCode>(response_code),
  137. net::HTTP_VERSION_NOT_SUPPORTED);
  138. }
  139. // Net error codes are negative but histogram enums must be positive.
  140. base::UmaHistogramSparse(
  141. "OptimizationGuide.PredictionModelFetcher."
  142. "GetModelsResponse.NetErrorCode." +
  143. optimization_guide::GetStringNameForOptimizationTarget(
  144. model_info.optimization_target()),
  145. -net_status);
  146. }
  147. if (net_status == net::OK && response_code == net::HTTP_OK &&
  148. get_models_response->ParseFromString(get_models_response_data)) {
  149. std::move(models_fetched_callback_).Run(std::move(get_models_response));
  150. } else {
  151. std::move(models_fetched_callback_).Run(absl::nullopt);
  152. }
  153. }
  154. void PredictionModelFetcherImpl::OnURLLoadComplete(
  155. std::unique_ptr<std::string> response_body) {
  156. int response_code = -1;
  157. if (url_loader_->ResponseInfo() && url_loader_->ResponseInfo()->headers) {
  158. response_code = url_loader_->ResponseInfo()->headers->response_code();
  159. }
  160. HandleResponse(response_body ? *response_body : "", url_loader_->NetError(),
  161. response_code);
  162. url_loader_.reset();
  163. }
  164. } // namespace optimization_guide