websocket_http2_handshake_stream.cc 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/websockets/websocket_http2_handshake_stream.h"
  5. #include <cstddef>
  6. #include <set>
  7. #include <utility>
  8. #include "base/bind.h"
  9. #include "base/check_op.h"
  10. #include "base/notreached.h"
  11. #include "base/strings/stringprintf.h"
  12. #include "base/time/time.h"
  13. #include "net/base/ip_endpoint.h"
  14. #include "net/http/http_request_headers.h"
  15. #include "net/http/http_request_info.h"
  16. #include "net/http/http_response_headers.h"
  17. #include "net/http/http_status_code.h"
  18. #include "net/spdy/spdy_http_utils.h"
  19. #include "net/spdy/spdy_session.h"
  20. #include "net/traffic_annotation/network_traffic_annotation.h"
  21. #include "net/websockets/websocket_basic_stream.h"
  22. #include "net/websockets/websocket_deflate_parameters.h"
  23. #include "net/websockets/websocket_deflate_predictor_impl.h"
  24. #include "net/websockets/websocket_deflate_stream.h"
  25. #include "net/websockets/websocket_deflater.h"
  26. #include "net/websockets/websocket_handshake_constants.h"
  27. #include "net/websockets/websocket_handshake_request_info.h"
  28. namespace net {
  29. namespace {
  30. bool ValidateStatus(const HttpResponseHeaders* headers) {
  31. return headers->GetStatusLine() == "HTTP/1.1 200";
  32. }
  33. } // namespace
  34. WebSocketHttp2HandshakeStream::WebSocketHttp2HandshakeStream(
  35. base::WeakPtr<SpdySession> session,
  36. WebSocketStream::ConnectDelegate* connect_delegate,
  37. std::vector<std::string> requested_sub_protocols,
  38. std::vector<std::string> requested_extensions,
  39. WebSocketStreamRequestAPI* request,
  40. std::set<std::string> dns_aliases)
  41. : session_(session),
  42. connect_delegate_(connect_delegate),
  43. requested_sub_protocols_(requested_sub_protocols),
  44. requested_extensions_(requested_extensions),
  45. stream_request_(request),
  46. dns_aliases_(std::move(dns_aliases)) {
  47. DCHECK(connect_delegate);
  48. DCHECK(request);
  49. }
  50. WebSocketHttp2HandshakeStream::~WebSocketHttp2HandshakeStream() {
  51. spdy_stream_request_.reset();
  52. RecordHandshakeResult(result_);
  53. }
  54. void WebSocketHttp2HandshakeStream::RegisterRequest(
  55. const HttpRequestInfo* request_info) {
  56. DCHECK(request_info);
  57. DCHECK(request_info->traffic_annotation.is_valid());
  58. request_info_ = request_info;
  59. }
  60. int WebSocketHttp2HandshakeStream::InitializeStream(
  61. bool can_send_early,
  62. RequestPriority priority,
  63. const NetLogWithSource& net_log,
  64. CompletionOnceCallback callback) {
  65. priority_ = priority;
  66. net_log_ = net_log;
  67. return OK;
  68. }
  69. int WebSocketHttp2HandshakeStream::SendRequest(
  70. const HttpRequestHeaders& headers,
  71. HttpResponseInfo* response,
  72. CompletionOnceCallback callback) {
  73. DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
  74. DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
  75. DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
  76. DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
  77. DCHECK(headers.HasHeader(websockets::kUpgrade));
  78. DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
  79. DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
  80. if (!session_) {
  81. const int rv = ERR_CONNECTION_CLOSED;
  82. OnFailure("Connection closed before sending request.", rv, absl::nullopt);
  83. return rv;
  84. }
  85. http_response_info_ = response;
  86. IPEndPoint address;
  87. int result = session_->GetPeerAddress(&address);
  88. if (result != OK) {
  89. OnFailure("Error getting IP address.", result, absl::nullopt);
  90. return result;
  91. }
  92. http_response_info_->remote_endpoint = address;
  93. auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
  94. request_info_->url, base::Time::Now());
  95. request->headers.CopyFrom(headers);
  96. AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
  97. requested_extensions_, &request->headers);
  98. AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
  99. requested_sub_protocols_, &request->headers);
  100. CreateSpdyHeadersFromHttpRequestForWebSocket(
  101. request_info_->url, request->headers, &http2_request_headers_);
  102. connect_delegate_->OnStartOpeningHandshake(std::move(request));
  103. callback_ = std::move(callback);
  104. spdy_stream_request_ = std::make_unique<SpdyStreamRequest>();
  105. // The initial request for the WebSocket is a CONNECT, so there is no need to
  106. // call ConfirmHandshake().
  107. int rv = spdy_stream_request_->StartRequest(
  108. SPDY_BIDIRECTIONAL_STREAM, session_, request_info_->url, true, priority_,
  109. request_info_->socket_tag, net_log_,
  110. base::BindOnce(&WebSocketHttp2HandshakeStream::StartRequestCallback,
  111. base::Unretained(this)),
  112. NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
  113. if (rv == OK) {
  114. StartRequestCallback(rv);
  115. return ERR_IO_PENDING;
  116. }
  117. return rv;
  118. }
  119. int WebSocketHttp2HandshakeStream::ReadResponseHeaders(
  120. CompletionOnceCallback callback) {
  121. if (stream_closed_)
  122. return stream_error_;
  123. if (response_headers_complete_)
  124. return ValidateResponse();
  125. callback_ = std::move(callback);
  126. return ERR_IO_PENDING;
  127. }
  128. int WebSocketHttp2HandshakeStream::ReadResponseBody(
  129. IOBuffer* buf,
  130. int buf_len,
  131. CompletionOnceCallback callback) {
  132. // Callers should instead call Upgrade() to get a WebSocketStream
  133. // and call ReadFrames() on that.
  134. NOTREACHED();
  135. return OK;
  136. }
  137. void WebSocketHttp2HandshakeStream::Close(bool not_reusable) {
  138. spdy_stream_request_.reset();
  139. if (stream_) {
  140. stream_ = nullptr;
  141. stream_closed_ = true;
  142. stream_error_ = ERR_CONNECTION_CLOSED;
  143. }
  144. stream_adapter_.reset();
  145. }
  146. bool WebSocketHttp2HandshakeStream::IsResponseBodyComplete() const {
  147. return false;
  148. }
  149. bool WebSocketHttp2HandshakeStream::IsConnectionReused() const {
  150. return true;
  151. }
  152. void WebSocketHttp2HandshakeStream::SetConnectionReused() {}
  153. bool WebSocketHttp2HandshakeStream::CanReuseConnection() const {
  154. return false;
  155. }
  156. int64_t WebSocketHttp2HandshakeStream::GetTotalReceivedBytes() const {
  157. return stream_ ? stream_->raw_received_bytes() : 0;
  158. }
  159. int64_t WebSocketHttp2HandshakeStream::GetTotalSentBytes() const {
  160. return stream_ ? stream_->raw_sent_bytes() : 0;
  161. }
  162. bool WebSocketHttp2HandshakeStream::GetAlternativeService(
  163. AlternativeService* alternative_service) const {
  164. return false;
  165. }
  166. bool WebSocketHttp2HandshakeStream::GetLoadTimingInfo(
  167. LoadTimingInfo* load_timing_info) const {
  168. return stream_ && stream_->GetLoadTimingInfo(load_timing_info);
  169. }
  170. void WebSocketHttp2HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
  171. if (stream_)
  172. stream_->GetSSLInfo(ssl_info);
  173. }
  174. void WebSocketHttp2HandshakeStream::GetSSLCertRequestInfo(
  175. SSLCertRequestInfo* cert_request_info) {
  176. // A multiplexed stream cannot request client certificates. Client
  177. // authentication may only occur during the initial SSL handshake.
  178. NOTREACHED();
  179. }
  180. int WebSocketHttp2HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
  181. if (!session_)
  182. return ERR_SOCKET_NOT_CONNECTED;
  183. return session_->GetRemoteEndpoint(endpoint);
  184. }
  185. void WebSocketHttp2HandshakeStream::PopulateNetErrorDetails(
  186. NetErrorDetails* /*details*/) {
  187. return;
  188. }
  189. void WebSocketHttp2HandshakeStream::Drain(HttpNetworkSession* session) {
  190. Close(true /* not_reusable */);
  191. }
  192. void WebSocketHttp2HandshakeStream::SetPriority(RequestPriority priority) {
  193. priority_ = priority;
  194. if (stream_)
  195. stream_->SetPriority(priority_);
  196. }
  197. std::unique_ptr<HttpStream>
  198. WebSocketHttp2HandshakeStream::RenewStreamForAuth() {
  199. // Renewing the stream is not supported.
  200. return nullptr;
  201. }
  202. const std::set<std::string>& WebSocketHttp2HandshakeStream::GetDnsAliases()
  203. const {
  204. return dns_aliases_;
  205. }
  206. base::StringPiece WebSocketHttp2HandshakeStream::GetAcceptChViaAlps() const {
  207. return {};
  208. }
  209. std::unique_ptr<WebSocketStream> WebSocketHttp2HandshakeStream::Upgrade() {
  210. DCHECK(extension_params_.get());
  211. stream_adapter_->DetachDelegate();
  212. std::unique_ptr<WebSocketStream> basic_stream =
  213. std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
  214. nullptr, sub_protocol_,
  215. extensions_, net_log_);
  216. if (!extension_params_->deflate_enabled)
  217. return basic_stream;
  218. return std::make_unique<WebSocketDeflateStream>(
  219. std::move(basic_stream), extension_params_->deflate_parameters,
  220. std::make_unique<WebSocketDeflatePredictorImpl>());
  221. }
  222. base::WeakPtr<WebSocketHandshakeStreamBase>
  223. WebSocketHttp2HandshakeStream::GetWeakPtr() {
  224. return weak_ptr_factory_.GetWeakPtr();
  225. }
  226. void WebSocketHttp2HandshakeStream::OnHeadersSent() {
  227. std::move(callback_).Run(OK);
  228. }
  229. void WebSocketHttp2HandshakeStream::OnHeadersReceived(
  230. const spdy::Http2HeaderBlock& response_headers) {
  231. DCHECK(!response_headers_complete_);
  232. DCHECK(http_response_info_);
  233. response_headers_complete_ = true;
  234. const int rv =
  235. SpdyHeadersToHttpResponse(response_headers, http_response_info_);
  236. DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
  237. http_response_info_->response_time = stream_->response_time();
  238. // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
  239. // care of that part.
  240. http_response_info_->was_alpn_negotiated = true;
  241. http_response_info_->request_time = stream_->GetRequestTime();
  242. http_response_info_->connection_info =
  243. HttpResponseInfo::CONNECTION_INFO_HTTP2;
  244. http_response_info_->alpn_negotiated_protocol =
  245. HttpResponseInfo::ConnectionInfoToString(
  246. http_response_info_->connection_info);
  247. if (callback_)
  248. std::move(callback_).Run(ValidateResponse());
  249. }
  250. void WebSocketHttp2HandshakeStream::OnClose(int status) {
  251. DCHECK(stream_adapter_);
  252. DCHECK_GT(ERR_IO_PENDING, status);
  253. stream_closed_ = true;
  254. stream_error_ = status;
  255. stream_ = nullptr;
  256. stream_adapter_.reset();
  257. // If response headers have already been received,
  258. // then ValidateResponse() sets |result_|.
  259. if (!response_headers_complete_)
  260. result_ = HandshakeResult::HTTP2_FAILED;
  261. OnFailure(std::string("Stream closed with error: ") + ErrorToString(status),
  262. status, absl::nullopt);
  263. if (callback_)
  264. std::move(callback_).Run(status);
  265. }
  266. void WebSocketHttp2HandshakeStream::StartRequestCallback(int rv) {
  267. DCHECK(callback_);
  268. if (rv != OK) {
  269. spdy_stream_request_.reset();
  270. std::move(callback_).Run(rv);
  271. return;
  272. }
  273. stream_ = spdy_stream_request_->ReleaseStream();
  274. spdy_stream_request_.reset();
  275. stream_adapter_ =
  276. std::make_unique<WebSocketSpdyStreamAdapter>(stream_, this, net_log_);
  277. rv = stream_->SendRequestHeaders(std::move(http2_request_headers_),
  278. MORE_DATA_TO_SEND);
  279. // SendRequestHeaders() always returns asynchronously,
  280. // and instead of taking a callback, it calls OnHeadersSent().
  281. DCHECK_EQ(ERR_IO_PENDING, rv);
  282. }
  283. int WebSocketHttp2HandshakeStream::ValidateResponse() {
  284. DCHECK(http_response_info_);
  285. const HttpResponseHeaders* headers = http_response_info_->headers.get();
  286. const int response_code = headers->response_code();
  287. switch (response_code) {
  288. case HTTP_OK:
  289. return ValidateUpgradeResponse(headers);
  290. // We need to pass these through for authentication to work.
  291. case HTTP_UNAUTHORIZED:
  292. case HTTP_PROXY_AUTHENTICATION_REQUIRED:
  293. return OK;
  294. // Other status codes are potentially risky (see the warnings in the
  295. // WHATWG WebSocket API spec) and so are dropped by default.
  296. default:
  297. OnFailure(
  298. base::StringPrintf(
  299. "Error during WebSocket handshake: Unexpected response code: %d",
  300. headers->response_code()),
  301. ERR_FAILED, headers->response_code());
  302. result_ = HandshakeResult::HTTP2_INVALID_STATUS;
  303. return ERR_INVALID_RESPONSE;
  304. }
  305. }
  306. int WebSocketHttp2HandshakeStream::ValidateUpgradeResponse(
  307. const HttpResponseHeaders* headers) {
  308. extension_params_ = std::make_unique<WebSocketExtensionParams>();
  309. std::string failure_message;
  310. if (!ValidateStatus(headers)) {
  311. result_ = HandshakeResult::HTTP2_INVALID_STATUS;
  312. } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
  313. &sub_protocol_, &failure_message)) {
  314. result_ = HandshakeResult::HTTP2_FAILED_SUBPROTO;
  315. } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
  316. extension_params_.get())) {
  317. result_ = HandshakeResult::HTTP2_FAILED_EXTENSIONS;
  318. } else {
  319. result_ = HandshakeResult::HTTP2_CONNECTED;
  320. return OK;
  321. }
  322. const int rv = ERR_INVALID_RESPONSE;
  323. OnFailure("Error during WebSocket handshake: " + failure_message, rv,
  324. absl::nullopt);
  325. return rv;
  326. }
  327. void WebSocketHttp2HandshakeStream::OnFailure(
  328. const std::string& message,
  329. int net_error,
  330. absl::optional<int> response_code) {
  331. stream_request_->OnFailure(message, net_error, response_code);
  332. }
  333. } // namespace net