web_socket.cc 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/server/web_socket.h"
  5. #include <vector>
  6. #include "base/base64.h"
  7. #include "base/check.h"
  8. #include "base/hash/sha1.h"
  9. #include "base/strings/string_number_conversions.h"
  10. #include "base/strings/stringprintf.h"
  11. #include "base/sys_byteorder.h"
  12. #include "net/server/http_connection.h"
  13. #include "net/server/http_server.h"
  14. #include "net/server/http_server_request_info.h"
  15. #include "net/server/http_server_response_info.h"
  16. #include "net/server/web_socket_encoder.h"
  17. #include "net/websockets/websocket_deflate_parameters.h"
  18. #include "net/websockets/websocket_extension.h"
  19. #include "net/websockets/websocket_handshake_constants.h"
  20. namespace net {
  21. namespace {
  22. std::string ExtensionsHeaderString(
  23. const std::vector<WebSocketExtension>& extensions) {
  24. if (extensions.empty())
  25. return std::string();
  26. std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString();
  27. for (size_t i = 1; i < extensions.size(); ++i)
  28. result += ", " + extensions[i].ToString();
  29. return result + "\r\n";
  30. }
  31. std::string ValidResponseString(
  32. const std::string& accept_hash,
  33. const std::vector<WebSocketExtension> extensions) {
  34. return base::StringPrintf(
  35. "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
  36. "Upgrade: WebSocket\r\n"
  37. "Connection: Upgrade\r\n"
  38. "Sec-WebSocket-Accept: %s\r\n"
  39. "%s"
  40. "\r\n",
  41. accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str());
  42. }
  43. } // namespace
  44. WebSocket::WebSocket(HttpServer* server, HttpConnection* connection)
  45. : server_(server), connection_(connection) {}
  46. WebSocket::~WebSocket() = default;
  47. void WebSocket::Accept(const HttpServerRequestInfo& request,
  48. const NetworkTrafficAnnotationTag traffic_annotation) {
  49. std::string version = request.GetHeaderValue("sec-websocket-version");
  50. if (version != "8" && version != "13") {
  51. SendErrorResponse("Invalid request format. The version is not valid.",
  52. traffic_annotation);
  53. return;
  54. }
  55. std::string key = request.GetHeaderValue("sec-websocket-key");
  56. if (key.empty()) {
  57. SendErrorResponse(
  58. "Invalid request format. Sec-WebSocket-Key is empty or isn't "
  59. "specified.",
  60. traffic_annotation);
  61. return;
  62. }
  63. std::string encoded_hash;
  64. base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid),
  65. &encoded_hash);
  66. std::vector<WebSocketExtension> response_extensions;
  67. auto i = request.headers.find("sec-websocket-extensions");
  68. if (i == request.headers.end()) {
  69. encoder_ = WebSocketEncoder::CreateServer();
  70. } else {
  71. WebSocketDeflateParameters params;
  72. encoder_ = WebSocketEncoder::CreateServer(i->second, &params);
  73. if (!encoder_) {
  74. Fail();
  75. return;
  76. }
  77. if (encoder_->deflate_enabled()) {
  78. DCHECK(params.IsValidAsResponse());
  79. response_extensions.push_back(params.AsExtension());
  80. }
  81. }
  82. server_->SendRaw(connection_->id(),
  83. ValidResponseString(encoded_hash, response_extensions),
  84. traffic_annotation);
  85. traffic_annotation_ = std::make_unique<NetworkTrafficAnnotationTag>(
  86. NetworkTrafficAnnotationTag(traffic_annotation));
  87. }
  88. WebSocket::ParseResult WebSocket::Read(std::string* message) {
  89. if (closed_)
  90. return FRAME_CLOSE;
  91. if (!encoder_) {
  92. // RFC6455, section 4.1 says "Once the client's opening handshake has been
  93. // sent, the client MUST wait for a response from the server before sending
  94. // any further data". If |encoder_| is null here, ::Accept either has not
  95. // been called at all, or has rejected a request rather than producing
  96. // a server handshake. Either way, the client clearly couldn't have gotten
  97. // a proper server handshake, so error out, especially since this method
  98. // can't proceed without an |encoder_|.
  99. return FRAME_ERROR;
  100. }
  101. ParseResult result = FRAME_OK_MIDDLE;
  102. HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
  103. base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize());
  104. int bytes_consumed = 0;
  105. result = encoder_->DecodeFrame(frame, &bytes_consumed, message);
  106. read_buf->DidConsume(bytes_consumed);
  107. if (result == FRAME_CLOSE)
  108. closed_ = true;
  109. if (result == FRAME_PING) {
  110. if (!traffic_annotation_)
  111. return FRAME_ERROR;
  112. Send(*message, WebSocketFrameHeader::kOpCodePong, *traffic_annotation_);
  113. }
  114. return result;
  115. }
  116. void WebSocket::Send(base::StringPiece message,
  117. WebSocketFrameHeader::OpCodeEnum op_code,
  118. const NetworkTrafficAnnotationTag traffic_annotation) {
  119. if (closed_)
  120. return;
  121. std::string encoded;
  122. switch (op_code) {
  123. case WebSocketFrameHeader::kOpCodeText:
  124. encoder_->EncodeTextFrame(message, 0, &encoded);
  125. break;
  126. case WebSocketFrameHeader::kOpCodePong:
  127. encoder_->EncodePongFrame(message, 0, &encoded);
  128. break;
  129. default:
  130. // Only Pong and Text frame types are supported.
  131. NOTREACHED();
  132. }
  133. server_->SendRaw(connection_->id(), encoded, traffic_annotation);
  134. }
  135. void WebSocket::Fail() {
  136. closed_ = true;
  137. // TODO(yhirano): The server SHOULD log the problem.
  138. server_->Close(connection_->id());
  139. }
  140. void WebSocket::SendErrorResponse(
  141. const std::string& message,
  142. const NetworkTrafficAnnotationTag traffic_annotation) {
  143. if (closed_)
  144. return;
  145. closed_ = true;
  146. server_->Send500(connection_->id(), message, traffic_annotation);
  147. }
  148. } // namespace net