websocket_stream_create_test_base.cc 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. // Copyright 2015 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/websockets/websocket_stream_create_test_base.h"
  5. #include "base/memory/raw_ptr.h"
  6. #include <utility>
  7. #include "base/callback.h"
  8. #include "net/base/ip_endpoint.h"
  9. #include "net/http/http_request_headers.h"
  10. #include "net/http/http_response_headers.h"
  11. #include "net/log/net_log_with_source.h"
  12. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  13. #include "net/websockets/websocket_basic_handshake_stream.h"
  14. #include "net/websockets/websocket_handshake_request_info.h"
  15. #include "net/websockets/websocket_handshake_response_info.h"
  16. #include "net/websockets/websocket_stream.h"
  17. #include "url/gurl.h"
  18. #include "url/origin.h"
  19. namespace net {
  20. using HeaderKeyValuePair = WebSocketStreamCreateTestBase::HeaderKeyValuePair;
  21. class WebSocketStreamCreateTestBase::TestConnectDelegate
  22. : public WebSocketStream::ConnectDelegate {
  23. public:
  24. TestConnectDelegate(WebSocketStreamCreateTestBase* owner,
  25. base::OnceClosure done_callback)
  26. : owner_(owner), done_callback_(std::move(done_callback)) {}
  27. TestConnectDelegate(const TestConnectDelegate&) = delete;
  28. TestConnectDelegate& operator=(const TestConnectDelegate&) = delete;
  29. void OnCreateRequest(URLRequest* request) override {
  30. owner_->url_request_ = request;
  31. }
  32. void OnSuccess(
  33. std::unique_ptr<WebSocketStream> stream,
  34. std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
  35. if (owner_->response_info_)
  36. ADD_FAILURE();
  37. owner_->response_info_ = std::move(response);
  38. stream.swap(owner_->stream_);
  39. std::move(done_callback_).Run();
  40. }
  41. void OnFailure(const std::string& message,
  42. int net_error,
  43. absl::optional<int> response_code) override {
  44. owner_->has_failed_ = true;
  45. owner_->failure_message_ = message;
  46. owner_->failure_response_code_ = response_code.value_or(-1);
  47. std::move(done_callback_).Run();
  48. }
  49. void OnStartOpeningHandshake(
  50. std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
  51. // Can be called multiple times (in the case of HTTP auth). Last call
  52. // wins.
  53. owner_->request_info_ = std::move(request);
  54. }
  55. void OnSSLCertificateError(
  56. std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
  57. ssl_error_callbacks,
  58. int net_error,
  59. const SSLInfo& ssl_info,
  60. bool fatal) override {
  61. owner_->ssl_error_callbacks_ = std::move(ssl_error_callbacks);
  62. owner_->ssl_info_ = ssl_info;
  63. owner_->ssl_fatal_ = fatal;
  64. }
  65. int OnAuthRequired(const AuthChallengeInfo& auth_info,
  66. scoped_refptr<HttpResponseHeaders> response_headers,
  67. const IPEndPoint& remote_endpoint,
  68. base::OnceCallback<void(const AuthCredentials*)> callback,
  69. absl::optional<AuthCredentials>* credentials) override {
  70. owner_->run_loop_waiting_for_on_auth_required_.Quit();
  71. owner_->auth_challenge_info_ = auth_info;
  72. *credentials = owner_->auth_credentials_;
  73. owner_->on_auth_required_callback_ = std::move(callback);
  74. return owner_->on_auth_required_rv_;
  75. }
  76. private:
  77. raw_ptr<WebSocketStreamCreateTestBase> owner_;
  78. base::OnceClosure done_callback_;
  79. };
  80. WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase() = default;
  81. WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() = default;
  82. void WebSocketStreamCreateTestBase::CreateAndConnectStream(
  83. const GURL& socket_url,
  84. const std::vector<std::string>& sub_protocols,
  85. const url::Origin& origin,
  86. const SiteForCookies& site_for_cookies,
  87. const IsolationInfo& isolation_info,
  88. const HttpRequestHeaders& additional_headers,
  89. std::unique_ptr<base::OneShotTimer> timer) {
  90. auto connect_delegate = std::make_unique<TestConnectDelegate>(
  91. this, connect_run_loop_.QuitClosure());
  92. auto api_delegate = std::make_unique<TestWebSocketStreamRequestAPI>();
  93. stream_request_ = WebSocketStream::CreateAndConnectStreamForTesting(
  94. socket_url, sub_protocols, origin, site_for_cookies, isolation_info,
  95. additional_headers, url_request_context_host_.GetURLRequestContext(),
  96. NetLogWithSource(), TRAFFIC_ANNOTATION_FOR_TESTS,
  97. std::move(connect_delegate),
  98. timer ? std::move(timer) : std::make_unique<base::OneShotTimer>(),
  99. std::move(api_delegate));
  100. }
  101. std::vector<HeaderKeyValuePair>
  102. WebSocketStreamCreateTestBase::RequestHeadersToVector(
  103. const HttpRequestHeaders& headers) {
  104. HttpRequestHeaders::Iterator it(headers);
  105. std::vector<HeaderKeyValuePair> result;
  106. while (it.GetNext())
  107. result.emplace_back(it.name(), it.value());
  108. return result;
  109. }
  110. std::vector<HeaderKeyValuePair>
  111. WebSocketStreamCreateTestBase::ResponseHeadersToVector(
  112. const HttpResponseHeaders& headers) {
  113. size_t iter = 0;
  114. std::string name, value;
  115. std::vector<HeaderKeyValuePair> result;
  116. while (headers.EnumerateHeaderLines(&iter, &name, &value))
  117. result.emplace_back(name, value);
  118. return result;
  119. }
  120. void WebSocketStreamCreateTestBase::WaitUntilConnectDone() {
  121. connect_run_loop_.Run();
  122. }
  123. void WebSocketStreamCreateTestBase::WaitUntilOnAuthRequired() {
  124. run_loop_waiting_for_on_auth_required_.Run();
  125. }
  126. std::vector<std::string> WebSocketStreamCreateTestBase::NoSubProtocols() {
  127. return std::vector<std::string>();
  128. }
  129. } // namespace net