base_requests.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "google_apis/common/base_requests.h"
  5. #include <stddef.h>
  6. #include <algorithm>
  7. #include <memory>
  8. #include <utility>
  9. #include "base/bind.h"
  10. #include "base/json/json_reader.h"
  11. #include "base/strings/string_piece.h"
  12. #include "base/strings/stringprintf.h"
  13. #include "base/task/task_runner_util.h"
  14. #include "base/threading/thread_task_runner_handle.h"
  15. #include "base/values.h"
  16. #include "google_apis/common/request_sender.h"
  17. #include "google_apis/common/task_util.h"
  18. #include "net/base/load_flags.h"
  19. #include "net/http/http_util.h"
  20. #include "services/network/public/cpp/resource_request.h"
  21. #include "services/network/public/mojom/url_response_head.mojom.h"
  22. namespace {
  23. // Template for optional OAuth2 authorization HTTP header.
  24. const char kAuthorizationHeaderFormat[] = "Authorization: Bearer %s";
  25. // Template for GData API version HTTP header.
  26. const char kGDataVersionHeader[] = "GData-Version: 3.0";
  27. // Maximum number of attempts for re-authentication per request.
  28. const int kMaxReAuthenticateAttemptsPerRequest = 1;
  29. // Returns response headers as a string. Returns a warning message if
  30. // |response_head| does not contain a valid response. Used only for debugging.
  31. std::string GetResponseHeadersAsString(
  32. const network::mojom::URLResponseHead& response_head) {
  33. // Check that response code indicates response headers are valid (i.e. not
  34. // malformed) before we retrieve the headers.
  35. if (response_head.headers->response_code() == -1)
  36. return "Response headers are malformed!!";
  37. return response_head.headers->raw_headers();
  38. }
  39. } // namespace
  40. namespace google_apis {
  41. absl::optional<std::string> MapJsonErrorToReason(
  42. const std::string& error_body) {
  43. DVLOG(1) << error_body;
  44. const char kErrorKey[] = "error";
  45. const char kErrorErrorsKey[] = "errors";
  46. const char kErrorReasonKey[] = "reason";
  47. const char kErrorMessageKey[] = "message";
  48. const char kErrorCodeKey[] = "code";
  49. std::unique_ptr<const base::Value> value(google_apis::ParseJson(error_body));
  50. const base::Value::Dict* dictionary = value ? value->GetIfDict() : nullptr;
  51. const base::Value::Dict* error =
  52. dictionary ? dictionary->FindDict(kErrorKey) : nullptr;
  53. if (error) {
  54. // Get error message and code.
  55. const std::string* message = error->FindString(kErrorMessageKey);
  56. absl::optional<int> code = error->FindInt(kErrorCodeKey);
  57. DLOG(ERROR) << "code: " << (code ? code.value() : OTHER_ERROR)
  58. << ", message: " << (message ? *message : "");
  59. // Returns the reason of the first error.
  60. if (const base::Value::List* errors = error->FindList(kErrorErrorsKey)) {
  61. const base::Value& first_error = (*errors)[0];
  62. if (first_error.is_dict()) {
  63. const std::string* reason = first_error.FindStringKey(kErrorReasonKey);
  64. if (reason)
  65. return *reason;
  66. }
  67. }
  68. }
  69. return absl::nullopt;
  70. }
  71. std::unique_ptr<base::Value> ParseJson(const std::string& json) {
  72. auto parsed_json = base::JSONReader::ReadAndReturnValueWithError(json);
  73. if (!parsed_json.has_value()) {
  74. std::string trimmed_json;
  75. if (json.size() < 80) {
  76. trimmed_json = json;
  77. } else {
  78. // Take the first 50 and the last 10 bytes.
  79. trimmed_json =
  80. base::StringPrintf("%s [%s bytes] %s", json.substr(0, 50).c_str(),
  81. base::NumberToString(json.size() - 60).c_str(),
  82. json.substr(json.size() - 10).c_str());
  83. }
  84. LOG(WARNING) << "Error while parsing entry response: "
  85. << parsed_json.error().message << ", json:\n"
  86. << trimmed_json;
  87. return nullptr;
  88. }
  89. return base::Value::ToUniquePtrValue(std::move(*parsed_json));
  90. }
  91. UrlFetchRequestBase::UrlFetchRequestBase(
  92. RequestSender* sender,
  93. ProgressCallback upload_progress_callback,
  94. ProgressCallback download_progress_callback)
  95. : re_authenticate_count_(0),
  96. sender_(sender),
  97. upload_progress_callback_(upload_progress_callback),
  98. download_progress_callback_(download_progress_callback),
  99. response_content_length_(-1) {}
  100. UrlFetchRequestBase::~UrlFetchRequestBase() = default;
  101. void UrlFetchRequestBase::Start(const std::string& access_token,
  102. const std::string& custom_user_agent,
  103. ReAuthenticateCallback callback) {
  104. DCHECK(CalledOnValidThread());
  105. DCHECK(!access_token.empty());
  106. DCHECK(callback);
  107. DCHECK(re_authenticate_callback_.is_null());
  108. Prepare(base::BindOnce(&UrlFetchRequestBase::StartAfterPrepare,
  109. weak_ptr_factory_.GetWeakPtr(), access_token,
  110. custom_user_agent, std::move(callback)));
  111. }
  112. void UrlFetchRequestBase::Prepare(PrepareCallback callback) {
  113. DCHECK(CalledOnValidThread());
  114. DCHECK(!callback.is_null());
  115. std::move(callback).Run(HTTP_SUCCESS);
  116. }
  117. void UrlFetchRequestBase::StartAfterPrepare(
  118. const std::string& access_token,
  119. const std::string& custom_user_agent,
  120. ReAuthenticateCallback callback,
  121. ApiErrorCode code) {
  122. DCHECK(CalledOnValidThread());
  123. DCHECK(!access_token.empty());
  124. DCHECK(callback);
  125. DCHECK(re_authenticate_callback_.is_null());
  126. const GURL url = GetURL();
  127. ApiErrorCode error_code;
  128. if (IsSuccessfulErrorCode(code))
  129. error_code = code;
  130. else if (url.is_empty())
  131. error_code = OTHER_ERROR;
  132. else
  133. error_code = HTTP_SUCCESS;
  134. if (error_code != HTTP_SUCCESS) {
  135. // Error is found on generating the url or preparing the request. Send the
  136. // error message to the callback, and then return immediately without trying
  137. // to connect to the server. We need to call CompleteRequestWithError
  138. // asynchronously because client code does not assume result callback is
  139. // called synchronously.
  140. base::ThreadTaskRunnerHandle::Get()->PostTask(
  141. FROM_HERE,
  142. base::BindOnce(&UrlFetchRequestBase::CompleteRequestWithError,
  143. weak_ptr_factory_.GetWeakPtr(), error_code));
  144. return;
  145. }
  146. re_authenticate_callback_ = callback;
  147. DVLOG(1) << "URL: " << url.spec();
  148. auto request = std::make_unique<network::ResourceRequest>();
  149. request->url = url;
  150. request->method = GetRequestType();
  151. request->load_flags = net::LOAD_DISABLE_CACHE;
  152. request->credentials_mode = network::mojom::CredentialsMode::kOmit;
  153. // Add request headers.
  154. // Note that SetHeader clears the current headers and sets it to the passed-in
  155. // headers, so calling it for each header will result in only the last header
  156. // being set in request headers.
  157. if (!custom_user_agent.empty())
  158. request->headers.SetHeader("User-Agent", custom_user_agent);
  159. request->headers.AddHeaderFromString(kGDataVersionHeader);
  160. request->headers.AddHeaderFromString(
  161. base::StringPrintf(kAuthorizationHeaderFormat, access_token.data()));
  162. for (const auto& header : GetExtraRequestHeaders()) {
  163. request->headers.AddHeaderFromString(header);
  164. DVLOG(1) << "Extra header: " << header;
  165. }
  166. url_loader_ = network::SimpleURLLoader::Create(
  167. std::move(request), sender_->get_traffic_annotation_tag());
  168. url_loader_->SetAllowHttpErrorResults(true /* allow */);
  169. download_data_ = std::make_unique<DownloadData>(blocking_task_runner());
  170. GetOutputFilePath(&download_data_->output_file_path,
  171. &download_data_->get_content_callback);
  172. if (!download_data_->get_content_callback.is_null()) {
  173. download_data_->get_content_callback =
  174. CreateRelayCallback(download_data_->get_content_callback);
  175. }
  176. // Set upload data if available.
  177. std::string upload_content_type;
  178. std::string upload_content;
  179. if (GetContentData(&upload_content_type, &upload_content)) {
  180. url_loader_->AttachStringForUpload(upload_content, upload_content_type);
  181. } else {
  182. base::FilePath local_file_path;
  183. int64_t range_offset = 0;
  184. int64_t range_length = 0;
  185. if (GetContentFile(&local_file_path, &range_offset, &range_length,
  186. &upload_content_type)) {
  187. url_loader_->AttachFileForUpload(local_file_path, upload_content_type,
  188. range_offset, range_length);
  189. }
  190. }
  191. if (!upload_progress_callback_.is_null()) {
  192. url_loader_->SetOnUploadProgressCallback(base::BindRepeating(
  193. &UrlFetchRequestBase::OnUploadProgress, weak_ptr_factory_.GetWeakPtr(),
  194. upload_progress_callback_));
  195. }
  196. if (!download_progress_callback_.is_null()) {
  197. url_loader_->SetOnDownloadProgressCallback(base::BindRepeating(
  198. &UrlFetchRequestBase::OnDownloadProgress,
  199. weak_ptr_factory_.GetWeakPtr(), download_progress_callback_));
  200. }
  201. url_loader_->SetOnResponseStartedCallback(base::BindOnce(
  202. &UrlFetchRequestBase::OnResponseStarted, weak_ptr_factory_.GetWeakPtr()));
  203. url_loader_->DownloadAsStream(sender_->url_loader_factory(), this);
  204. }
  205. void UrlFetchRequestBase::OnDownloadProgress(ProgressCallback progress_callback,
  206. uint64_t current) {
  207. progress_callback.Run(static_cast<int64_t>(current),
  208. response_content_length_);
  209. }
  210. void UrlFetchRequestBase::OnUploadProgress(ProgressCallback progress_callback,
  211. uint64_t position,
  212. uint64_t total) {
  213. progress_callback.Run(static_cast<int64_t>(position),
  214. static_cast<int64_t>(total));
  215. }
  216. void UrlFetchRequestBase::OnResponseStarted(
  217. const GURL& final_url,
  218. const network::mojom::URLResponseHead& response_head) {
  219. DVLOG(1) << "Response headers:\n"
  220. << GetResponseHeadersAsString(response_head);
  221. response_content_length_ = response_head.content_length;
  222. }
  223. UrlFetchRequestBase::DownloadData::DownloadData(
  224. scoped_refptr<base::SequencedTaskRunner> blocking_task_runner)
  225. : blocking_task_runner_(blocking_task_runner) {}
  226. UrlFetchRequestBase::DownloadData::~DownloadData() {
  227. if (output_file.IsValid()) {
  228. blocking_task_runner_->PostTask(
  229. FROM_HERE,
  230. base::BindOnce([](base::File file) {}, std::move(output_file)));
  231. }
  232. }
  233. // static
  234. bool UrlFetchRequestBase::WriteFileData(std::string file_data,
  235. DownloadData* download_data) {
  236. if (!download_data->output_file.IsValid()) {
  237. download_data->output_file.Initialize(
  238. download_data->output_file_path,
  239. base::File::FLAG_CREATE_ALWAYS | base::File::FLAG_WRITE);
  240. if (!download_data->output_file.IsValid())
  241. return false;
  242. }
  243. if (download_data->output_file.WriteAtCurrentPos(file_data.data(),
  244. file_data.size()) == -1) {
  245. download_data->output_file.Close();
  246. return false;
  247. }
  248. // Even when writing response to a file save the first 1 MiB of the response
  249. // body so that it can be used to get error information in case of server side
  250. // errors. The size limit is to avoid consuming too much redundant memory.
  251. const size_t kMaxStringSize = 1024 * 1024;
  252. if (download_data->response_body.size() < kMaxStringSize) {
  253. size_t bytes_to_copy = std::min(
  254. file_data.size(), kMaxStringSize - download_data->response_body.size());
  255. download_data->response_body.append(file_data.data(), bytes_to_copy);
  256. }
  257. return true;
  258. }
  259. void UrlFetchRequestBase::OnWriteComplete(
  260. std::unique_ptr<DownloadData> download_data,
  261. base::OnceClosure resume,
  262. bool write_success) {
  263. download_data_ = std::move(download_data);
  264. if (!write_success) {
  265. error_code_ = OTHER_ERROR;
  266. url_loader_.reset(); // Cancel the request
  267. // No SimpleURLLoader to call OnComplete() so call it directly.
  268. OnComplete(false);
  269. return;
  270. }
  271. std::move(resume).Run();
  272. }
  273. void UrlFetchRequestBase::OnDataReceived(base::StringPiece string_piece,
  274. base::OnceClosure resume) {
  275. if (!download_data_->get_content_callback.is_null()) {
  276. download_data_->get_content_callback.Run(
  277. HTTP_SUCCESS, std::make_unique<std::string>(string_piece),
  278. download_data_->response_body.empty());
  279. }
  280. if (!download_data_->output_file_path.empty()) {
  281. DownloadData* download_data_ptr = download_data_.get();
  282. base::PostTaskAndReplyWithResult(
  283. blocking_task_runner(), FROM_HERE,
  284. base::BindOnce(&UrlFetchRequestBase::WriteFileData,
  285. std::string(string_piece), download_data_ptr),
  286. base::BindOnce(&UrlFetchRequestBase::OnWriteComplete,
  287. weak_ptr_factory_.GetWeakPtr(),
  288. std::move(download_data_), std::move(resume)));
  289. return;
  290. }
  291. download_data_->response_body.append(string_piece.data(),
  292. string_piece.size());
  293. std::move(resume).Run();
  294. }
  295. void UrlFetchRequestBase::OnComplete(bool success) {
  296. DCHECK(download_data_);
  297. blocking_task_runner()->PostTaskAndReply(
  298. FROM_HERE,
  299. base::BindOnce([](base::File file) {},
  300. std::move(download_data_->output_file)),
  301. base::BindOnce(&UrlFetchRequestBase::OnOutputFileClosed,
  302. weak_ptr_factory_.GetWeakPtr(), success));
  303. }
  304. void UrlFetchRequestBase::OnOutputFileClosed(bool success) {
  305. DCHECK(download_data_);
  306. const network::mojom::URLResponseHead* response_info;
  307. if (url_loader_) {
  308. response_info = url_loader_->ResponseInfo();
  309. if (response_info) {
  310. error_code_ =
  311. static_cast<ApiErrorCode>(response_info->headers->response_code());
  312. } else {
  313. error_code_ =
  314. NetError() == net::ERR_NETWORK_CHANGED ? NO_CONNECTION : OTHER_ERROR;
  315. }
  316. if (!download_data_->response_body.empty()) {
  317. if (!IsSuccessfulErrorCode(error_code_.value())) {
  318. absl::optional<std::string> reason =
  319. MapJsonErrorToReason(download_data_->response_body);
  320. if (reason.has_value())
  321. error_code_ = MapReasonToError(error_code_.value(), reason.value());
  322. }
  323. }
  324. } else {
  325. // If the request is cancelled then error_code_ must be set.
  326. DCHECK(error_code_.has_value());
  327. response_info = nullptr;
  328. }
  329. if (error_code_.value() == HTTP_UNAUTHORIZED) {
  330. if (++re_authenticate_count_ <= kMaxReAuthenticateAttemptsPerRequest) {
  331. // Reset re_authenticate_callback_ so Start() can be called again.
  332. std::move(re_authenticate_callback_).Run(this);
  333. return;
  334. }
  335. OnAuthFailed(GetErrorCode());
  336. return;
  337. }
  338. // Overridden by each specialization
  339. ProcessURLFetchResults(response_info,
  340. std::move(download_data_->output_file_path),
  341. std::move(download_data_->response_body));
  342. }
  343. void UrlFetchRequestBase::OnRetry(base::OnceClosure start_retry) {
  344. NOTREACHED();
  345. }
  346. std::string UrlFetchRequestBase::GetRequestType() const {
  347. return "GET";
  348. }
  349. std::vector<std::string> UrlFetchRequestBase::GetExtraRequestHeaders() const {
  350. return std::vector<std::string>();
  351. }
  352. bool UrlFetchRequestBase::GetContentData(std::string* upload_content_type,
  353. std::string* upload_content) {
  354. return false;
  355. }
  356. bool UrlFetchRequestBase::GetContentFile(base::FilePath* local_file_path,
  357. int64_t* range_offset,
  358. int64_t* range_length,
  359. std::string* upload_content_type) {
  360. return false;
  361. }
  362. void UrlFetchRequestBase::GetOutputFilePath(
  363. base::FilePath* local_file_path,
  364. GetContentCallback* get_content_callback) {}
  365. void UrlFetchRequestBase::Cancel() {
  366. url_loader_.reset();
  367. CompleteRequestWithError(CANCELLED);
  368. }
  369. ApiErrorCode UrlFetchRequestBase::GetErrorCode() const {
  370. DCHECK(error_code_.has_value()) << "GetErrorCode only valid after "
  371. "resource load complete.";
  372. return error_code_.value();
  373. }
  374. int UrlFetchRequestBase::NetError() const {
  375. if (!url_loader_) // If resource load cancelled?
  376. return net::ERR_FAILED;
  377. return url_loader_->NetError();
  378. }
  379. bool UrlFetchRequestBase::CalledOnValidThread() {
  380. return thread_checker_.CalledOnValidThread();
  381. }
  382. base::SequencedTaskRunner* UrlFetchRequestBase::blocking_task_runner() const {
  383. return sender_->blocking_task_runner();
  384. }
  385. void UrlFetchRequestBase::OnProcessURLFetchResultsComplete() {
  386. sender_->RequestFinished(this);
  387. }
  388. void UrlFetchRequestBase::CompleteRequestWithError(ApiErrorCode code) {
  389. RunCallbackOnPrematureFailure(code);
  390. sender_->RequestFinished(this);
  391. }
  392. void UrlFetchRequestBase::OnAuthFailed(ApiErrorCode code) {
  393. CompleteRequestWithError(code);
  394. }
  395. base::WeakPtr<AuthenticatedRequestInterface> UrlFetchRequestBase::GetWeakPtr() {
  396. return weak_ptr_factory_.GetWeakPtr();
  397. }
  398. } // namespace google_apis