websocket_frame_parser.cc 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/websockets/websocket_frame_parser.h"
  5. #include <algorithm>
  6. #include <limits>
  7. #include <utility>
  8. #include <vector>
  9. #include "base/big_endian.h"
  10. #include "base/logging.h"
  11. #include "base/memory/scoped_refptr.h"
  12. #include "net/base/io_buffer.h"
  13. #include "net/websockets/websocket_frame.h"
  14. namespace {
  15. const uint8_t kFinalBit = 0x80;
  16. const uint8_t kReserved1Bit = 0x40;
  17. const uint8_t kReserved2Bit = 0x20;
  18. const uint8_t kReserved3Bit = 0x10;
  19. const uint8_t kOpCodeMask = 0xF;
  20. const uint8_t kMaskBit = 0x80;
  21. const uint8_t kPayloadLengthMask = 0x7F;
  22. const uint64_t kMaxPayloadLengthWithoutExtendedLengthField = 125;
  23. const uint64_t kPayloadLengthWithTwoByteExtendedLengthField = 126;
  24. const uint64_t kPayloadLengthWithEightByteExtendedLengthField = 127;
  25. const size_t kMaximumFrameHeaderSize =
  26. net::WebSocketFrameHeader::kBaseHeaderSize +
  27. net::WebSocketFrameHeader::kMaximumExtendedLengthSize +
  28. net::WebSocketFrameHeader::kMaskingKeyLength;
  29. } // namespace.
  30. namespace net {
  31. WebSocketFrameParser::WebSocketFrameParser() = default;
  32. WebSocketFrameParser::~WebSocketFrameParser() = default;
  33. bool WebSocketFrameParser::Decode(
  34. const char* data,
  35. size_t length,
  36. std::vector<std::unique_ptr<WebSocketFrameChunk>>* frame_chunks) {
  37. if (websocket_error_ != kWebSocketNormalClosure)
  38. return false;
  39. if (!length)
  40. return true;
  41. base::span<const char> data_span = base::make_span(data, length);
  42. // If we have incomplete frame header, try to decode a header combining with
  43. // |data|.
  44. bool first_chunk = false;
  45. if (incomplete_header_buffer_.size() > 0) {
  46. DCHECK(!current_frame_header_.get());
  47. const size_t original_size = incomplete_header_buffer_.size();
  48. DCHECK_LE(original_size, kMaximumFrameHeaderSize);
  49. incomplete_header_buffer_.insert(
  50. incomplete_header_buffer_.end(), data,
  51. data + std::min(length, kMaximumFrameHeaderSize - original_size));
  52. const size_t consumed = DecodeFrameHeader(incomplete_header_buffer_);
  53. if (websocket_error_ != kWebSocketNormalClosure)
  54. return false;
  55. if (!current_frame_header_.get())
  56. return true;
  57. DCHECK_GE(consumed, original_size);
  58. data_span = data_span.subspan(consumed - original_size);
  59. incomplete_header_buffer_.clear();
  60. first_chunk = true;
  61. }
  62. DCHECK(incomplete_header_buffer_.empty());
  63. while (data_span.size() > 0 || first_chunk) {
  64. if (!current_frame_header_.get()) {
  65. const size_t consumed = DecodeFrameHeader(data_span);
  66. if (websocket_error_ != kWebSocketNormalClosure)
  67. return false;
  68. // If frame header is incomplete, then carry over the remaining
  69. // data to the next round of Decode().
  70. if (!current_frame_header_.get()) {
  71. DCHECK(!consumed);
  72. incomplete_header_buffer_.insert(incomplete_header_buffer_.end(),
  73. data_span.data(),
  74. data_span.data() + data_span.size());
  75. // Sanity check: the size of carried-over data should not exceed
  76. // the maximum possible length of a frame header.
  77. DCHECK_LT(incomplete_header_buffer_.size(), kMaximumFrameHeaderSize);
  78. return true;
  79. }
  80. DCHECK_GE(data_span.size(), consumed);
  81. data_span = data_span.subspan(consumed);
  82. first_chunk = true;
  83. }
  84. DCHECK(incomplete_header_buffer_.empty());
  85. std::unique_ptr<WebSocketFrameChunk> frame_chunk =
  86. DecodeFramePayload(first_chunk, &data_span);
  87. first_chunk = false;
  88. DCHECK(frame_chunk.get());
  89. frame_chunks->push_back(std::move(frame_chunk));
  90. }
  91. return true;
  92. }
  93. size_t WebSocketFrameParser::DecodeFrameHeader(base::span<const char> data) {
  94. DVLOG(3) << "DecodeFrameHeader buffer size:"
  95. << ", data size:" << data.size();
  96. typedef WebSocketFrameHeader::OpCode OpCode;
  97. DCHECK(!current_frame_header_.get());
  98. // Header needs 2 bytes at minimum.
  99. if (data.size() < 2)
  100. return 0;
  101. size_t current = 0;
  102. const uint8_t first_byte = data[current++];
  103. const uint8_t second_byte = data[current++];
  104. const bool final = (first_byte & kFinalBit) != 0;
  105. const bool reserved1 = (first_byte & kReserved1Bit) != 0;
  106. const bool reserved2 = (first_byte & kReserved2Bit) != 0;
  107. const bool reserved3 = (first_byte & kReserved3Bit) != 0;
  108. const OpCode opcode = first_byte & kOpCodeMask;
  109. uint64_t payload_length = second_byte & kPayloadLengthMask;
  110. if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) {
  111. if (data.size() < current + 2)
  112. return 0;
  113. uint16_t payload_length_16;
  114. base::ReadBigEndian(reinterpret_cast<const uint8_t*>(&data[current]),
  115. &payload_length_16);
  116. current += 2;
  117. payload_length = payload_length_16;
  118. if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) {
  119. websocket_error_ = kWebSocketErrorProtocolError;
  120. return 0;
  121. }
  122. } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) {
  123. if (data.size() < current + 8)
  124. return 0;
  125. base::ReadBigEndian(reinterpret_cast<const uint8_t*>(&data[current]),
  126. &payload_length);
  127. current += 8;
  128. if (payload_length <= UINT16_MAX ||
  129. payload_length > static_cast<uint64_t>(INT64_MAX)) {
  130. websocket_error_ = kWebSocketErrorProtocolError;
  131. return 0;
  132. }
  133. if (payload_length > static_cast<uint64_t>(INT32_MAX)) {
  134. websocket_error_ = kWebSocketErrorMessageTooBig;
  135. return 0;
  136. }
  137. }
  138. DCHECK_EQ(websocket_error_, kWebSocketNormalClosure);
  139. WebSocketMaskingKey masking_key = {};
  140. const bool masked = (second_byte & kMaskBit) != 0;
  141. static const int kMaskingKeyLength = WebSocketFrameHeader::kMaskingKeyLength;
  142. if (masked) {
  143. if (data.size() < current + kMaskingKeyLength)
  144. return 0;
  145. std::copy(&data[current], &data[current] + kMaskingKeyLength,
  146. masking_key.key);
  147. current += kMaskingKeyLength;
  148. }
  149. current_frame_header_ = std::make_unique<WebSocketFrameHeader>(opcode);
  150. current_frame_header_->final = final;
  151. current_frame_header_->reserved1 = reserved1;
  152. current_frame_header_->reserved2 = reserved2;
  153. current_frame_header_->reserved3 = reserved3;
  154. current_frame_header_->masked = masked;
  155. current_frame_header_->masking_key = masking_key;
  156. current_frame_header_->payload_length = payload_length;
  157. DCHECK_EQ(0u, frame_offset_);
  158. return current;
  159. }
  160. std::unique_ptr<WebSocketFrameChunk> WebSocketFrameParser::DecodeFramePayload(
  161. bool first_chunk,
  162. base::span<const char>* data) {
  163. // The cast here is safe because |payload_length| is already checked to be
  164. // less than std::numeric_limits<int>::max() when the header is parsed.
  165. const int chunk_data_size = static_cast<int>(
  166. std::min(static_cast<uint64_t>(data->size()),
  167. current_frame_header_->payload_length - frame_offset_));
  168. auto frame_chunk = std::make_unique<WebSocketFrameChunk>();
  169. if (first_chunk) {
  170. frame_chunk->header = current_frame_header_->Clone();
  171. }
  172. frame_chunk->final_chunk = false;
  173. if (chunk_data_size > 0) {
  174. frame_chunk->payload = data->subspan(0, chunk_data_size);
  175. *data = data->subspan(chunk_data_size);
  176. frame_offset_ += chunk_data_size;
  177. }
  178. DCHECK_LE(frame_offset_, current_frame_header_->payload_length);
  179. if (frame_offset_ == current_frame_header_->payload_length) {
  180. frame_chunk->final_chunk = true;
  181. current_frame_header_.reset();
  182. frame_offset_ = 0;
  183. }
  184. return frame_chunk;
  185. }
  186. } // namespace net