oauth_token_getter_impl.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Copyright 2014 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "remoting/base/oauth_token_getter_impl.h"
  5. #include <memory>
  6. #include <utility>
  7. #include "base/bind.h"
  8. #include "base/callback.h"
  9. #include "base/containers/queue.h"
  10. #include "base/strings/string_util.h"
  11. #include "google_apis/google_api_keys.h"
  12. #include "remoting/base/logging.h"
  13. #include "services/network/public/cpp/shared_url_loader_factory.h"
  14. namespace remoting {
  15. namespace {
  16. // Maximum number of retries on network/500 errors.
  17. const int kMaxRetries = 3;
  18. // Time when we we try to update OAuth token before its expiration.
  19. const int kTokenUpdateTimeBeforeExpirySeconds = 120;
  20. // Max time we wait for the response before giving up.
  21. constexpr base::TimeDelta kResponseTimeoutDuration = base::Seconds(30);
  22. } // namespace
  23. OAuthTokenGetterImpl::OAuthTokenGetterImpl(
  24. std::unique_ptr<OAuthIntermediateCredentials> intermediate_credentials,
  25. const OAuthTokenGetter::CredentialsUpdatedCallback& on_credentials_update,
  26. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  27. bool auto_refresh)
  28. : url_loader_factory_(url_loader_factory),
  29. intermediate_credentials_(std::move(intermediate_credentials)),
  30. gaia_oauth_client_(new gaia::GaiaOAuthClient(url_loader_factory)),
  31. credentials_updated_callback_(on_credentials_update) {
  32. if (auto_refresh) {
  33. refresh_timer_ = std::make_unique<base::OneShotTimer>();
  34. }
  35. }
  36. OAuthTokenGetterImpl::OAuthTokenGetterImpl(
  37. std::unique_ptr<OAuthAuthorizationCredentials> authorization_credentials,
  38. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  39. bool auto_refresh)
  40. : url_loader_factory_(url_loader_factory),
  41. authorization_credentials_(std::move(authorization_credentials)),
  42. gaia_oauth_client_(new gaia::GaiaOAuthClient(url_loader_factory)) {
  43. if (auto_refresh) {
  44. refresh_timer_ = std::make_unique<base::OneShotTimer>();
  45. }
  46. }
  47. OAuthTokenGetterImpl::~OAuthTokenGetterImpl() {
  48. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  49. }
  50. void OAuthTokenGetterImpl::OnGetTokensResponse(const std::string& refresh_token,
  51. const std::string& access_token,
  52. int expires_seconds) {
  53. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  54. DCHECK(intermediate_credentials_);
  55. VLOG(1) << "Received OAuth tokens.";
  56. // Update the access token and any other auto-update timers.
  57. UpdateAccessToken(access_token, expires_seconds);
  58. // Keep the refresh token in the authorization_credentials.
  59. authorization_credentials_ =
  60. std::make_unique<OAuthTokenGetter::OAuthAuthorizationCredentials>(
  61. std::string(), refresh_token,
  62. intermediate_credentials_->is_service_account);
  63. // Clear out the one time use token.
  64. intermediate_credentials_.reset();
  65. // At this point we don't know the email address so we need to fetch it.
  66. email_discovery_ = true;
  67. gaia_oauth_client_->GetUserEmail(access_token, kMaxRetries, this);
  68. }
  69. void OAuthTokenGetterImpl::OnRefreshTokenResponse(
  70. const std::string& access_token,
  71. int expires_seconds) {
  72. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  73. DCHECK(authorization_credentials_);
  74. VLOG(1) << "Received OAuth token.";
  75. // Update the access token and any other auto-update timers.
  76. UpdateAccessToken(access_token, expires_seconds);
  77. if (!authorization_credentials_->is_service_account && !email_verified_) {
  78. gaia_oauth_client_->GetUserEmail(access_token, kMaxRetries, this);
  79. } else {
  80. NotifyTokenCallbacks(OAuthTokenGetterImpl::SUCCESS,
  81. authorization_credentials_->login,
  82. oauth_access_token_);
  83. }
  84. }
  85. void OAuthTokenGetterImpl::OnGetUserEmailResponse(
  86. const std::string& user_email) {
  87. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  88. DCHECK(authorization_credentials_);
  89. VLOG(1) << "Received user info.";
  90. if (email_discovery_) {
  91. authorization_credentials_->login = user_email;
  92. email_discovery_ = false;
  93. NotifyUpdatedCallbacks(authorization_credentials_->login,
  94. authorization_credentials_->refresh_token);
  95. } else if (user_email != authorization_credentials_->login) {
  96. LOG(ERROR) << "OAuth token and email address do not refer to "
  97. "the same account.";
  98. OnOAuthError();
  99. return;
  100. }
  101. email_verified_ = true;
  102. NotifyTokenCallbacks(OAuthTokenGetterImpl::SUCCESS,
  103. authorization_credentials_->login, oauth_access_token_);
  104. }
  105. void OAuthTokenGetterImpl::UpdateAccessToken(const std::string& access_token,
  106. int expires_seconds) {
  107. oauth_access_token_ = access_token;
  108. base::TimeDelta token_expiration =
  109. base::Seconds(expires_seconds) -
  110. base::Seconds(kTokenUpdateTimeBeforeExpirySeconds);
  111. access_token_expiry_time_ = base::Time::Now() + token_expiration;
  112. if (refresh_timer_) {
  113. refresh_timer_->Stop();
  114. refresh_timer_->Start(FROM_HERE, token_expiration, this,
  115. &OAuthTokenGetterImpl::RefreshAccessToken);
  116. }
  117. }
  118. void OAuthTokenGetterImpl::NotifyTokenCallbacks(
  119. Status status,
  120. const std::string& user_email,
  121. const std::string& access_token) {
  122. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  123. SetResponsePending(false);
  124. base::queue<TokenCallback> callbacks;
  125. callbacks.swap(pending_callbacks_);
  126. while (!callbacks.empty()) {
  127. std::move(callbacks.front()).Run(status, user_email, access_token);
  128. callbacks.pop();
  129. }
  130. }
  131. void OAuthTokenGetterImpl::NotifyUpdatedCallbacks(
  132. const std::string& user_email,
  133. const std::string& refresh_token) {
  134. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  135. if (credentials_updated_callback_) {
  136. credentials_updated_callback_.Run(user_email, refresh_token);
  137. }
  138. }
  139. void OAuthTokenGetterImpl::OnOAuthError() {
  140. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  141. LOG(ERROR) << "OAuth: invalid credentials.";
  142. // Throw away invalid credentials and force a refresh.
  143. oauth_access_token_.clear();
  144. access_token_expiry_time_ = base::Time();
  145. email_verified_ = false;
  146. NotifyTokenCallbacks(OAuthTokenGetterImpl::AUTH_ERROR, std::string(),
  147. std::string());
  148. }
  149. void OAuthTokenGetterImpl::OnNetworkError(int response_code) {
  150. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  151. LOG(ERROR) << "Network error when trying to update OAuth token: "
  152. << response_code;
  153. NotifyTokenCallbacks(OAuthTokenGetterImpl::NETWORK_ERROR, std::string(),
  154. std::string());
  155. }
  156. void OAuthTokenGetterImpl::CallWithToken(TokenCallback on_access_token) {
  157. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  158. pending_callbacks_.push(std::move(on_access_token));
  159. if (intermediate_credentials_) {
  160. if (!IsResponsePending()) {
  161. GetOauthTokensFromAuthCode();
  162. }
  163. } else {
  164. bool need_new_auth_token =
  165. access_token_expiry_time_.is_null() ||
  166. base::Time::Now() >= access_token_expiry_time_ ||
  167. (!authorization_credentials_->is_service_account && !email_verified_);
  168. if (need_new_auth_token) {
  169. if (!IsResponsePending()) {
  170. RefreshAccessToken();
  171. }
  172. } else {
  173. // If IsResponsePending() is true here, |on_access_token| will be called
  174. // when the response is received.
  175. if (!IsResponsePending()) {
  176. NotifyTokenCallbacks(OAuthTokenGetterImpl::SUCCESS,
  177. authorization_credentials_->login,
  178. oauth_access_token_);
  179. }
  180. }
  181. }
  182. }
  183. void OAuthTokenGetterImpl::InvalidateCache() {
  184. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  185. access_token_expiry_time_ = base::Time();
  186. }
  187. base::WeakPtr<OAuthTokenGetterImpl> OAuthTokenGetterImpl::GetWeakPtr() {
  188. return weak_factory_.GetWeakPtr();
  189. }
  190. void OAuthTokenGetterImpl::GetOauthTokensFromAuthCode() {
  191. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  192. VLOG(1) << "Fetching OAuth token from Auth Code.";
  193. DCHECK(!IsResponsePending());
  194. // Service accounts use different API keys, as they use the client app flow.
  195. google_apis::OAuth2Client oauth2_client =
  196. intermediate_credentials_->is_service_account
  197. ? google_apis::CLIENT_REMOTING_HOST
  198. : google_apis::CLIENT_REMOTING;
  199. // For the case of fetching an OAuth token from a one-time-use code, the
  200. // caller should provide a redirect URI.
  201. std::string redirect_uri = intermediate_credentials_->oauth_redirect_uri;
  202. DCHECK(!redirect_uri.empty());
  203. gaia::OAuthClientInfo client_info = {
  204. google_apis::GetOAuth2ClientID(oauth2_client),
  205. google_apis::GetOAuth2ClientSecret(oauth2_client), redirect_uri};
  206. SetResponsePending(true);
  207. gaia_oauth_client_->GetTokensFromAuthCode(
  208. client_info, intermediate_credentials_->authorization_code, kMaxRetries,
  209. this);
  210. }
  211. void OAuthTokenGetterImpl::RefreshAccessToken() {
  212. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  213. VLOG(1) << "Refreshing OAuth Access token.";
  214. DCHECK(!IsResponsePending());
  215. // Service accounts use different API keys, as they use the client app flow.
  216. google_apis::OAuth2Client oauth2_client =
  217. authorization_credentials_->is_service_account
  218. ? google_apis::CLIENT_REMOTING_HOST
  219. : google_apis::CLIENT_REMOTING;
  220. gaia::OAuthClientInfo client_info = {
  221. google_apis::GetOAuth2ClientID(oauth2_client),
  222. google_apis::GetOAuth2ClientSecret(oauth2_client),
  223. // Redirect URL is only used when getting tokens from auth code. It
  224. // is not required when getting access tokens from refresh tokens.
  225. ""};
  226. SetResponsePending(true);
  227. std::vector<std::string> empty_scope_list; // Use scope from refresh token.
  228. gaia_oauth_client_->RefreshToken(client_info,
  229. authorization_credentials_->refresh_token,
  230. empty_scope_list, kMaxRetries, this);
  231. }
  232. bool OAuthTokenGetterImpl::IsResponsePending() const {
  233. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  234. return response_timeout_timer_.IsRunning();
  235. }
  236. void OAuthTokenGetterImpl::SetResponsePending(bool is_pending) {
  237. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  238. if (is_pending) {
  239. if (IsResponsePending()) {
  240. LOG(DFATAL) << "The response is already pending.";
  241. return;
  242. }
  243. response_timeout_timer_.Start(FROM_HERE, kResponseTimeoutDuration, this,
  244. &OAuthTokenGetterImpl::OnResponseTimeout);
  245. } else {
  246. response_timeout_timer_.Stop();
  247. }
  248. }
  249. void OAuthTokenGetterImpl::OnResponseTimeout() {
  250. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  251. LOG(ERROR) << "GaiaOAuthClient response timeout";
  252. gaia_oauth_client_ =
  253. std::make_unique<gaia::GaiaOAuthClient>(url_loader_factory_);
  254. NotifyTokenCallbacks(OAuthTokenGetterImpl::NETWORK_ERROR, {}, {});
  255. }
  256. } // namespace remoting