oauth2_access_token_manager.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef GOOGLE_APIS_GAIA_OAUTH2_ACCESS_TOKEN_MANAGER_H_
  5. #define GOOGLE_APIS_GAIA_OAUTH2_ACCESS_TOKEN_MANAGER_H_
  6. #include <map>
  7. #include <set>
  8. #include "base/gtest_prod_util.h"
  9. #include "base/memory/raw_ptr.h"
  10. #include "base/observer_list.h"
  11. #include "base/sequence_checker.h"
  12. #include "base/time/time.h"
  13. #include "google_apis/gaia/core_account_id.h"
  14. #include "google_apis/gaia/google_service_auth_error.h"
  15. #include "google_apis/gaia/oauth2_access_token_consumer.h"
  16. namespace network {
  17. class SharedURLLoaderFactory;
  18. }
  19. class OAuth2AccessTokenFetcher;
  20. // Class that manages requests for OAuth2 access tokens.
  21. class OAuth2AccessTokenManager {
  22. public:
  23. // A set of scopes in OAuth2 authentication.
  24. typedef std::set<std::string> ScopeSet;
  25. class RequestImpl;
  26. class Delegate {
  27. public:
  28. Delegate();
  29. virtual ~Delegate();
  30. // Creates and returns an OAuth2AccessTokenFetcher.
  31. [[nodiscard]] virtual std::unique_ptr<OAuth2AccessTokenFetcher>
  32. CreateAccessTokenFetcher(
  33. const CoreAccountId& account_id,
  34. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  35. OAuth2AccessTokenConsumer* consumer) = 0;
  36. // Returns |true| if a refresh token is available for |account_id|, and
  37. // |false| otherwise.
  38. virtual bool HasRefreshToken(const CoreAccountId& account_id) const = 0;
  39. // Attempts to fix the error if possible. Returns true if the error was
  40. // fixed and false otherwise. Default implementation returns false.
  41. virtual bool FixRequestErrorIfPossible();
  42. // Returns a SharedURLLoaderFactory object that will be used as part of
  43. // fetching access tokens. Default implementation returns nullptr.
  44. virtual scoped_refptr<network::SharedURLLoaderFactory> GetURLLoaderFactory()
  45. const;
  46. // Gives the delegate a chance to handle the access token request before
  47. // the manager sends the request over the network. Returns true if the
  48. // request was handled by the delegate (in which case the manager will not
  49. // send the request) and false otherwise.
  50. virtual bool HandleAccessTokenFetch(
  51. RequestImpl* request,
  52. const CoreAccountId& account_id,
  53. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  54. const std::string& client_id,
  55. const std::string& client_secret,
  56. const ScopeSet& scopes);
  57. // Called when an access token is invalidated.
  58. virtual void OnAccessTokenInvalidated(const CoreAccountId& account_id,
  59. const std::string& client_id,
  60. const ScopeSet& scopes,
  61. const std::string& access_token) {}
  62. // Called when an access token is fetched.
  63. virtual void OnAccessTokenFetched(const CoreAccountId& account_id,
  64. const GoogleServiceAuthError& error) {}
  65. };
  66. // Class representing a request that fetches an OAuth2 access token.
  67. class Request {
  68. public:
  69. virtual ~Request();
  70. virtual CoreAccountId GetAccountId() const = 0;
  71. protected:
  72. Request();
  73. };
  74. // Class representing the consumer of a Request passed to |StartRequest|,
  75. // which will be called back when the request completes.
  76. class Consumer {
  77. public:
  78. explicit Consumer(const std::string& id);
  79. virtual ~Consumer();
  80. std::string id() const { return id_; }
  81. // |request| is a Request that is started by this consumer and has
  82. // completed.
  83. virtual void OnGetTokenSuccess(
  84. const Request* request,
  85. const OAuth2AccessTokenConsumer::TokenResponse& token_response) = 0;
  86. virtual void OnGetTokenFailure(const Request* request,
  87. const GoogleServiceAuthError& error) = 0;
  88. private:
  89. std::string id_;
  90. };
  91. // Implements a cancelable |OAuth2AccessTokenManager::Request|, which should
  92. // be operated on the UI thread.
  93. // TODO(davidroche): move this out of header file.
  94. class RequestImpl : public base::SupportsWeakPtr<RequestImpl>,
  95. public Request {
  96. public:
  97. // |consumer| is required to outlive this.
  98. RequestImpl(const CoreAccountId& account_id, Consumer* consumer);
  99. ~RequestImpl() override;
  100. // Overridden from Request:
  101. CoreAccountId GetAccountId() const override;
  102. std::string GetConsumerId() const;
  103. // Informs |consumer_| that this request is completed.
  104. void InformConsumer(
  105. const GoogleServiceAuthError& error,
  106. const OAuth2AccessTokenConsumer::TokenResponse& token_response);
  107. private:
  108. const CoreAccountId account_id_;
  109. // |consumer_| to call back when this request completes.
  110. const raw_ptr<Consumer> consumer_;
  111. SEQUENCE_CHECKER(sequence_checker_);
  112. };
  113. // Classes that want to monitor status of access token and access token
  114. // request should implement this interface and register with the
  115. // AddDiagnosticsObserver() call.
  116. class DiagnosticsObserver {
  117. public:
  118. // Called when receiving request for access token.
  119. virtual void OnAccessTokenRequested(const CoreAccountId& account_id,
  120. const std::string& consumer_id,
  121. const ScopeSet& scopes) {}
  122. // Called when access token fetching finished successfully or
  123. // unsuccessfully. |expiration_time| are only valid with
  124. // successful completion.
  125. virtual void OnFetchAccessTokenComplete(const CoreAccountId& account_id,
  126. const std::string& consumer_id,
  127. const ScopeSet& scopes,
  128. GoogleServiceAuthError error,
  129. base::Time expiration_time) {}
  130. // Called when an access token was removed.
  131. virtual void OnAccessTokenRemoved(const CoreAccountId& account_id,
  132. const ScopeSet& scopes) {}
  133. };
  134. // The parameters used to fetch an OAuth2 access token.
  135. struct RequestParameters {
  136. RequestParameters(const std::string& client_id,
  137. const CoreAccountId& account_id,
  138. const ScopeSet& scopes);
  139. RequestParameters(const RequestParameters& other);
  140. ~RequestParameters();
  141. bool operator<(const RequestParameters& params) const;
  142. // OAuth2 client id.
  143. std::string client_id;
  144. // Account id for which the request is made.
  145. CoreAccountId account_id;
  146. // URL scopes for the requested access token.
  147. ScopeSet scopes;
  148. };
  149. typedef std::map<RequestParameters, OAuth2AccessTokenConsumer::TokenResponse>
  150. TokenCache;
  151. explicit OAuth2AccessTokenManager(
  152. OAuth2AccessTokenManager::Delegate* delegate);
  153. OAuth2AccessTokenManager(const OAuth2AccessTokenManager&) = delete;
  154. OAuth2AccessTokenManager& operator=(const OAuth2AccessTokenManager&) = delete;
  155. virtual ~OAuth2AccessTokenManager();
  156. OAuth2AccessTokenManager::Delegate* GetDelegate();
  157. const OAuth2AccessTokenManager::Delegate* GetDelegate() const;
  158. // Add or remove observers of this token manager.
  159. void AddDiagnosticsObserver(DiagnosticsObserver* observer);
  160. void RemoveDiagnosticsObserver(DiagnosticsObserver* observer);
  161. // Checks in the cache for a valid access token for a specified |account_id|
  162. // and |scopes|, and if not found starts a request for an OAuth2 access token
  163. // using the OAuth2 refresh token maintained by this instance for that
  164. // |account_id|. The caller owns the returned Request.
  165. // |scopes| is the set of scopes to get an access token for, |consumer| is
  166. // the object that will be called back with results if the returned request
  167. // is not deleted.
  168. std::unique_ptr<Request> StartRequest(const CoreAccountId& account_id,
  169. const ScopeSet& scopes,
  170. Consumer* consumer);
  171. // This method does the same as |StartRequest| except it uses |client_id| and
  172. // |client_secret| to identify OAuth client app instead of using
  173. // Chrome's default values.
  174. std::unique_ptr<Request> StartRequestForClient(
  175. const CoreAccountId& account_id,
  176. const std::string& client_id,
  177. const std::string& client_secret,
  178. const ScopeSet& scopes,
  179. Consumer* consumer);
  180. // This method does the same as |StartRequest| except it uses the
  181. // URLLoaderfactory given by |url_loader_factory| instead of using the one
  182. // returned by |GetURLLoaderFactory| implemented by the delegate.
  183. std::unique_ptr<Request> StartRequestWithContext(
  184. const CoreAccountId& account_id,
  185. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  186. const ScopeSet& scopes,
  187. Consumer* consumer);
  188. // Fetches an OAuth token for the specified client/scopes. Virtual so it can
  189. // be overridden for tests.
  190. virtual void FetchOAuth2Token(
  191. RequestImpl* request,
  192. const CoreAccountId& account_id,
  193. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  194. const std::string& client_id,
  195. const std::string& client_secret,
  196. const std::string& consumer_name,
  197. const ScopeSet& scopes);
  198. // Returns a currently valid OAuth2 access token for the given set of scopes,
  199. // or NULL if none have been cached. Note the user of this method should
  200. // ensure no entry with the same |client_scopes| is added before the usage of
  201. // the returned entry is done.
  202. const OAuth2AccessTokenConsumer::TokenResponse* GetCachedTokenResponse(
  203. const RequestParameters& client_scopes);
  204. // Clears the internal token cache.
  205. void ClearCache();
  206. // Clears all of the tokens belonging to |account_id| from the internal token
  207. // cache. It does not matter what other parameters, like |client_id| were
  208. // used to request the tokens.
  209. void ClearCacheForAccount(const CoreAccountId& account_id);
  210. // Cancels all requests that are currently in progress. Virtual so it can be
  211. // overridden for tests.
  212. virtual void CancelAllRequests();
  213. // Cancels all requests related to a given |account_id|. Virtual so it can be
  214. // overridden for tests.
  215. virtual void CancelRequestsForAccount(const CoreAccountId& account_id);
  216. // Mark an OAuth2 |access_token| issued for |account_id| and |scopes| as
  217. // invalid. This should be done if the token was received from this class,
  218. // but was not accepted by the server (e.g., the server returned
  219. // 401 Unauthorized). The token will be removed from the cache for the given
  220. // scopes.
  221. void InvalidateAccessToken(const CoreAccountId& account_id,
  222. const ScopeSet& scopes,
  223. const std::string& access_token);
  224. void set_max_authorization_token_fetch_retries_for_testing(int max_retries);
  225. // Returns the current number of pending fetchers matching given params.
  226. size_t GetNumPendingRequestsForTesting(const std::string& client_id,
  227. const CoreAccountId& account_id,
  228. const ScopeSet& scopes) const;
  229. // Returns a list of DiagnosticsObservers.
  230. const base::ObserverList<DiagnosticsObserver, true>::Unchecked&
  231. GetDiagnosticsObserversForTesting();
  232. protected:
  233. // Invalidates the |access_token| issued for |account_id|, |client_id| and
  234. // |scopes|. Virtual so it can be overridden for tests.
  235. virtual void InvalidateAccessTokenImpl(const CoreAccountId& account_id,
  236. const std::string& client_id,
  237. const ScopeSet& scopes,
  238. const std::string& access_token);
  239. private:
  240. class Fetcher;
  241. friend class Fetcher;
  242. TokenCache& token_cache() { return token_cache_; }
  243. // Create an access token fetcher for the given account id.
  244. std::unique_ptr<OAuth2AccessTokenFetcher> CreateAccessTokenFetcher(
  245. const CoreAccountId& account_id,
  246. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  247. OAuth2AccessTokenConsumer* consumer);
  248. // This method does the same as |StartRequestWithContext| except it
  249. // uses |client_id| and |client_secret| to identify OAuth
  250. // client app instead of using Chrome's default values.
  251. std::unique_ptr<Request> StartRequestForClientWithContext(
  252. const CoreAccountId& account_id,
  253. scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
  254. const std::string& client_id,
  255. const std::string& client_secret,
  256. const ScopeSet& scopes,
  257. Consumer* consumer);
  258. // Posts a task to fire the Consumer callback with the cached token response.
  259. void InformConsumerWithCachedTokenResponse(
  260. const OAuth2AccessTokenConsumer::TokenResponse* token_response,
  261. RequestImpl* request,
  262. const RequestParameters& client_scopes);
  263. // Add a new entry to the cache.
  264. void RegisterTokenResponse(
  265. const std::string& client_id,
  266. const CoreAccountId& account_id,
  267. const ScopeSet& scopes,
  268. const OAuth2AccessTokenConsumer::TokenResponse& token_response);
  269. // Removes an access token for the given set of scopes from the cache.
  270. // Returns true if the entry was removed, otherwise false.
  271. bool RemoveCachedTokenResponse(const RequestParameters& client_scopes,
  272. const std::string& token_to_remove);
  273. // Called when |fetcher| finishes fetching.
  274. void OnFetchComplete(Fetcher* fetcher);
  275. // Called when a number of fetchers need to be canceled.
  276. void CancelFetchers(std::vector<Fetcher*> fetchers_to_cancel);
  277. // The cache of currently valid tokens.
  278. TokenCache token_cache_;
  279. // List of observers to notify when access token status changes.
  280. base::ObserverList<DiagnosticsObserver, true>::Unchecked
  281. diagnostics_observer_list_;
  282. raw_ptr<Delegate> delegate_;
  283. // A map from fetch parameters to a fetcher that is fetching an OAuth2 access
  284. // token using these parameters.
  285. std::map<RequestParameters, std::unique_ptr<Fetcher>> pending_fetchers_;
  286. // Maximum number of retries in fetching an OAuth2 access token.
  287. static int max_fetch_retry_num_;
  288. SEQUENCE_CHECKER(sequence_checker_);
  289. FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenManagerTest, ClearCache);
  290. FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenManagerTest, ClearCacheForAccount);
  291. FRIEND_TEST_ALL_PREFIXES(OAuth2AccessTokenManagerTest, OnAccessTokenRemoved);
  292. };
  293. #endif // GOOGLE_APIS_GAIA_OAUTH2_ACCESS_TOKEN_MANAGER_H_