websocket_handshake_stream_create_helper_test.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. // Copyright 2013 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/websockets/websocket_handshake_stream_create_helper.h"
  5. #include <string>
  6. #include <utility>
  7. #include <vector>
  8. #include "base/memory/scoped_refptr.h"
  9. #include "net/base/completion_once_callback.h"
  10. #include "net/base/host_port_pair.h"
  11. #include "net/base/ip_endpoint.h"
  12. #include "net/base/load_flags.h"
  13. #include "net/base/net_errors.h"
  14. #include "net/base/privacy_mode.h"
  15. #include "net/base/proxy_server.h"
  16. #include "net/dns/public/secure_dns_policy.h"
  17. #include "net/http/http_network_session.h"
  18. #include "net/http/http_request_headers.h"
  19. #include "net/http/http_request_info.h"
  20. #include "net/http/http_response_headers.h"
  21. #include "net/http/http_response_info.h"
  22. #include "net/log/net_log_with_source.h"
  23. #include "net/socket/client_socket_handle.h"
  24. #include "net/socket/connect_job.h"
  25. #include "net/socket/socket_tag.h"
  26. #include "net/socket/socket_test_util.h"
  27. #include "net/socket/ssl_client_socket.h"
  28. #include "net/socket/websocket_endpoint_lock_manager.h"
  29. #include "net/spdy/spdy_session.h"
  30. #include "net/spdy/spdy_session_key.h"
  31. #include "net/spdy/spdy_test_util_common.h"
  32. #include "net/ssl/ssl_config.h"
  33. #include "net/ssl/ssl_info.h"
  34. #include "net/test/cert_test_util.h"
  35. #include "net/test/gtest_util.h"
  36. #include "net/test/test_data_directory.h"
  37. #include "net/test/test_with_task_environment.h"
  38. #include "net/traffic_annotation/network_traffic_annotation.h"
  39. #include "net/websockets/websocket_basic_handshake_stream.h"
  40. #include "net/websockets/websocket_stream.h"
  41. #include "net/websockets/websocket_test_util.h"
  42. #include "testing/gmock/include/gmock/gmock.h"
  43. #include "testing/gtest/include/gtest/gtest.h"
  44. #include "third_party/abseil-cpp/absl/types/optional.h"
  45. #include "url/gurl.h"
  46. #include "url/origin.h"
  47. #include "url/scheme_host_port.h"
  48. #include "url/url_constants.h"
  49. using ::net::test::IsError;
  50. using ::net::test::IsOk;
  51. using ::testing::StrictMock;
  52. using ::testing::TestWithParam;
  53. using ::testing::Values;
  54. using ::testing::_;
  55. namespace net {
  56. namespace {
  57. enum HandshakeStreamType { BASIC_HANDSHAKE_STREAM, HTTP2_HANDSHAKE_STREAM };
  58. // This class encapsulates the details of creating a mock ClientSocketHandle.
  59. class MockClientSocketHandleFactory {
  60. public:
  61. MockClientSocketHandleFactory()
  62. : common_connect_job_params_(
  63. socket_factory_maker_.factory(),
  64. nullptr /* host_resolver */,
  65. nullptr /* http_auth_cache */,
  66. nullptr /* http_auth_handler_factory */,
  67. nullptr /* spdy_session_pool */,
  68. nullptr /* quic_supported_versions */,
  69. nullptr /* quic_stream_factory */,
  70. nullptr /* proxy_delegate */,
  71. nullptr /* http_user_agent_settings */,
  72. nullptr /* ssl_client_context */,
  73. nullptr /* socket_performance_watcher_factory */,
  74. nullptr /* network_quality_estimator */,
  75. nullptr /* net_log */,
  76. nullptr /* websocket_endpoint_lock_manager */),
  77. pool_(1, 1, &common_connect_job_params_) {}
  78. MockClientSocketHandleFactory(const MockClientSocketHandleFactory&) = delete;
  79. MockClientSocketHandleFactory& operator=(
  80. const MockClientSocketHandleFactory&) = delete;
  81. // The created socket expects |expect_written| to be written to the socket,
  82. // and will respond with |return_to_read|. The test will fail if the expected
  83. // text is not written, or if all the bytes are not read.
  84. std::unique_ptr<ClientSocketHandle> CreateClientSocketHandle(
  85. const std::string& expect_written,
  86. const std::string& return_to_read) {
  87. socket_factory_maker_.SetExpectations(expect_written, return_to_read);
  88. auto socket_handle = std::make_unique<ClientSocketHandle>();
  89. socket_handle->Init(
  90. ClientSocketPool::GroupId(
  91. url::SchemeHostPort(url::kHttpScheme, "a", 80),
  92. PrivacyMode::PRIVACY_MODE_DISABLED, NetworkIsolationKey(),
  93. SecureDnsPolicy::kAllow),
  94. scoped_refptr<ClientSocketPool::SocketParams>(),
  95. absl::nullopt /* proxy_annotation_tag */, MEDIUM, SocketTag(),
  96. ClientSocketPool::RespectLimits::ENABLED, CompletionOnceCallback(),
  97. ClientSocketPool::ProxyAuthCallback(), &pool_, NetLogWithSource());
  98. return socket_handle;
  99. }
  100. private:
  101. WebSocketMockClientSocketFactoryMaker socket_factory_maker_;
  102. const CommonConnectJobParams common_connect_job_params_;
  103. MockTransportClientSocketPool pool_;
  104. };
  105. class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
  106. public:
  107. ~TestConnectDelegate() override = default;
  108. void OnCreateRequest(URLRequest* request) override {}
  109. void OnSuccess(
  110. std::unique_ptr<WebSocketStream> stream,
  111. std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {}
  112. void OnFailure(const std::string& failure_message,
  113. int net_error,
  114. absl::optional<int> response_code) override {}
  115. void OnStartOpeningHandshake(
  116. std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {}
  117. void OnSSLCertificateError(
  118. std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
  119. ssl_error_callbacks,
  120. int net_error,
  121. const SSLInfo& ssl_info,
  122. bool fatal) override {}
  123. int OnAuthRequired(const AuthChallengeInfo& auth_info,
  124. scoped_refptr<HttpResponseHeaders> response_headers,
  125. const IPEndPoint& host_port_pair,
  126. base::OnceCallback<void(const AuthCredentials*)> callback,
  127. absl::optional<AuthCredentials>* credentials) override {
  128. *credentials = absl::nullopt;
  129. return OK;
  130. }
  131. };
  132. class MockWebSocketStreamRequestAPI : public WebSocketStreamRequestAPI {
  133. public:
  134. ~MockWebSocketStreamRequestAPI() override = default;
  135. MOCK_METHOD1(OnBasicHandshakeStreamCreated,
  136. void(WebSocketBasicHandshakeStream* handshake_stream));
  137. MOCK_METHOD1(OnHttp2HandshakeStreamCreated,
  138. void(WebSocketHttp2HandshakeStream* handshake_stream));
  139. MOCK_METHOD3(OnFailure,
  140. void(const std::string& message,
  141. int net_error,
  142. absl::optional<int> response_code));
  143. };
  144. class WebSocketHandshakeStreamCreateHelperTest
  145. : public TestWithParam<HandshakeStreamType>,
  146. public WithTaskEnvironment {
  147. protected:
  148. std::unique_ptr<WebSocketStream> CreateAndInitializeStream(
  149. const std::vector<std::string>& sub_protocols,
  150. const WebSocketExtraHeaders& extra_request_headers,
  151. const WebSocketExtraHeaders& extra_response_headers) {
  152. const char kPath[] = "/";
  153. const char kOrigin[] = "http://origin.example.org";
  154. const GURL url("wss://www.example.org/");
  155. NetLogWithSource net_log;
  156. WebSocketHandshakeStreamCreateHelper create_helper(
  157. &connect_delegate_, sub_protocols, &stream_request_);
  158. switch (GetParam()) {
  159. case BASIC_HANDSHAKE_STREAM:
  160. EXPECT_CALL(stream_request_, OnBasicHandshakeStreamCreated(_)).Times(1);
  161. break;
  162. case HTTP2_HANDSHAKE_STREAM:
  163. EXPECT_CALL(stream_request_, OnHttp2HandshakeStreamCreated(_)).Times(1);
  164. break;
  165. default:
  166. NOTREACHED();
  167. }
  168. EXPECT_CALL(stream_request_, OnFailure(_, _, _)).Times(0);
  169. HttpRequestInfo request_info;
  170. request_info.url = url;
  171. request_info.method = "GET";
  172. request_info.load_flags = LOAD_DISABLE_CACHE;
  173. request_info.traffic_annotation =
  174. MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
  175. auto headers = WebSocketCommonTestHeaders();
  176. switch (GetParam()) {
  177. case BASIC_HANDSHAKE_STREAM: {
  178. std::unique_ptr<ClientSocketHandle> socket_handle =
  179. socket_handle_factory_.CreateClientSocketHandle(
  180. WebSocketStandardRequest(kPath, "www.example.org",
  181. url::Origin::Create(GURL(kOrigin)),
  182. /*send_additional_request_headers=*/{},
  183. extra_request_headers),
  184. WebSocketStandardResponse(
  185. WebSocketExtraHeadersToString(extra_response_headers)));
  186. std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
  187. create_helper.CreateBasicStream(std::move(socket_handle), false,
  188. &websocket_endpoint_lock_manager_);
  189. // If in future the implementation type returned by CreateBasicStream()
  190. // changes, this static_cast will be wrong. However, in that case the
  191. // test will fail and AddressSanitizer should identify the issue.
  192. static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
  193. ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
  194. handshake->RegisterRequest(&request_info);
  195. int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY, net_log,
  196. CompletionOnceCallback());
  197. EXPECT_THAT(rv, IsOk());
  198. HttpResponseInfo response;
  199. TestCompletionCallback request_callback;
  200. rv = handshake->SendRequest(headers, &response,
  201. request_callback.callback());
  202. EXPECT_THAT(rv, IsOk());
  203. TestCompletionCallback response_callback;
  204. rv = handshake->ReadResponseHeaders(response_callback.callback());
  205. EXPECT_THAT(rv, IsOk());
  206. EXPECT_EQ(101, response.headers->response_code());
  207. EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
  208. EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
  209. return handshake->Upgrade();
  210. }
  211. case HTTP2_HANDSHAKE_STREAM: {
  212. SpdyTestUtil spdy_util;
  213. spdy::Http2HeaderBlock request_header_block = WebSocketHttp2Request(
  214. kPath, "www.example.org", kOrigin, extra_request_headers);
  215. spdy::SpdySerializedFrame request_headers(
  216. spdy_util.ConstructSpdyHeaders(1, std::move(request_header_block),
  217. DEFAULT_PRIORITY, false));
  218. MockWrite writes[] = {CreateMockWrite(request_headers, 0)};
  219. spdy::Http2HeaderBlock response_header_block =
  220. WebSocketHttp2Response(extra_response_headers);
  221. spdy::SpdySerializedFrame response_headers(
  222. spdy_util.ConstructSpdyResponseHeaders(
  223. 1, std::move(response_header_block), false));
  224. MockRead reads[] = {CreateMockRead(response_headers, 1),
  225. MockRead(ASYNC, 0, 2)};
  226. SequencedSocketData data(reads, writes);
  227. SSLSocketDataProvider ssl(ASYNC, OK);
  228. ssl.ssl_info.cert =
  229. ImportCertFromFile(GetTestCertsDirectory(), "wildcard.pem");
  230. SpdySessionDependencies session_deps;
  231. session_deps.socket_factory->AddSocketDataProvider(&data);
  232. session_deps.socket_factory->AddSSLSocketDataProvider(&ssl);
  233. std::unique_ptr<HttpNetworkSession> http_network_session =
  234. SpdySessionDependencies::SpdyCreateSession(&session_deps);
  235. const SpdySessionKey key(
  236. HostPortPair::FromURL(url), ProxyServer::Direct(),
  237. PRIVACY_MODE_DISABLED, SpdySessionKey::IsProxySession::kFalse,
  238. SocketTag(), NetworkIsolationKey(), SecureDnsPolicy::kAllow);
  239. base::WeakPtr<SpdySession> spdy_session =
  240. CreateSpdySession(http_network_session.get(), key, net_log);
  241. std::unique_ptr<WebSocketHandshakeStreamBase> handshake =
  242. create_helper.CreateHttp2Stream(spdy_session, {} /* dns_aliases */);
  243. handshake->RegisterRequest(&request_info);
  244. int rv = handshake->InitializeStream(true, DEFAULT_PRIORITY,
  245. NetLogWithSource(),
  246. CompletionOnceCallback());
  247. EXPECT_THAT(rv, IsOk());
  248. HttpResponseInfo response;
  249. TestCompletionCallback request_callback;
  250. rv = handshake->SendRequest(headers, &response,
  251. request_callback.callback());
  252. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  253. rv = request_callback.WaitForResult();
  254. EXPECT_THAT(rv, IsOk());
  255. TestCompletionCallback response_callback;
  256. rv = handshake->ReadResponseHeaders(response_callback.callback());
  257. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  258. rv = response_callback.WaitForResult();
  259. EXPECT_THAT(rv, IsOk());
  260. EXPECT_EQ(200, response.headers->response_code());
  261. return handshake->Upgrade();
  262. }
  263. default:
  264. NOTREACHED();
  265. return nullptr;
  266. }
  267. }
  268. private:
  269. MockClientSocketHandleFactory socket_handle_factory_;
  270. TestConnectDelegate connect_delegate_;
  271. StrictMock<MockWebSocketStreamRequestAPI> stream_request_;
  272. WebSocketEndpointLockManager websocket_endpoint_lock_manager_;
  273. };
  274. INSTANTIATE_TEST_SUITE_P(All,
  275. WebSocketHandshakeStreamCreateHelperTest,
  276. Values(BASIC_HANDSHAKE_STREAM,
  277. HTTP2_HANDSHAKE_STREAM));
  278. // Confirm that the basic case works as expected.
  279. TEST_P(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
  280. std::unique_ptr<WebSocketStream> stream =
  281. CreateAndInitializeStream({}, {}, {});
  282. EXPECT_EQ("", stream->GetExtensions());
  283. EXPECT_EQ("", stream->GetSubProtocol());
  284. }
  285. // Verify that the sub-protocols are passed through.
  286. TEST_P(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
  287. std::vector<std::string> sub_protocols;
  288. sub_protocols.push_back("chat");
  289. sub_protocols.push_back("superchat");
  290. std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
  291. sub_protocols, {{"Sec-WebSocket-Protocol", "chat, superchat"}},
  292. {{"Sec-WebSocket-Protocol", "superchat"}});
  293. EXPECT_EQ("superchat", stream->GetSubProtocol());
  294. }
  295. // Verify that extension name is available. Bad extension names are tested in
  296. // websocket_stream_test.cc.
  297. TEST_P(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
  298. std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
  299. {}, {}, {{"Sec-WebSocket-Extensions", "permessage-deflate"}});
  300. EXPECT_EQ("permessage-deflate", stream->GetExtensions());
  301. }
  302. // Verify that extension parameters are available. Bad parameters are tested in
  303. // websocket_stream_test.cc.
  304. TEST_P(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
  305. std::unique_ptr<WebSocketStream> stream = CreateAndInitializeStream(
  306. {}, {},
  307. {{"Sec-WebSocket-Extensions",
  308. "permessage-deflate;"
  309. " client_max_window_bits=14; server_max_window_bits=14;"
  310. " server_no_context_takeover; client_no_context_takeover"}});
  311. EXPECT_EQ(
  312. "permessage-deflate;"
  313. " client_max_window_bits=14; server_max_window_bits=14;"
  314. " server_no_context_takeover; client_no_context_takeover",
  315. stream->GetExtensions());
  316. }
  317. } // namespace
  318. } // namespace net