websocket_deflate_stream.cc 14 KB


  1. // Copyright 2013 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/websockets/websocket_deflate_stream.h"
  5. #include <stdint.h>
  6. #include <algorithm>
  7. #include <string>
  8. #include <utility>
  9. #include <vector>
  10. #include "base/bind.h"
  11. #include "base/logging.h"
  12. #include "base/memory/scoped_refptr.h"
  13. #include "net/base/io_buffer.h"
  14. #include "net/base/net_errors.h"
  15. #include "net/websockets/websocket_deflate_parameters.h"
  16. #include "net/websockets/websocket_deflate_predictor.h"
  17. #include "net/websockets/websocket_deflater.h"
  18. #include "net/websockets/websocket_errors.h"
  19. #include "net/websockets/websocket_frame.h"
  20. #include "net/websockets/websocket_inflater.h"
  21. #include "net/websockets/websocket_stream.h"
  22. class GURL;
  23. namespace net {
  24. namespace {
  25. const int kWindowBits = 15;
  26. const size_t kChunkSize = 4 * 1024;
  27. } // namespace
  28. WebSocketDeflateStream::WebSocketDeflateStream(
  29. std::unique_ptr<WebSocketStream> stream,
  30. const WebSocketDeflateParameters& params,
  31. std::unique_ptr<WebSocketDeflatePredictor> predictor)
  32. : stream_(std::move(stream)),
  33. deflater_(params.client_context_take_over_mode()),
  34. inflater_(kChunkSize, kChunkSize),
  35. predictor_(std::move(predictor)) {
  36. DCHECK(stream_);
  37. DCHECK(params.IsValidAsResponse());
  38. int client_max_window_bits = 15;
  39. if (params.is_client_max_window_bits_specified()) {
  40. DCHECK(params.has_client_max_window_bits_value());
  41. client_max_window_bits = params.client_max_window_bits();
  42. }
  43. deflater_.Initialize(client_max_window_bits);
  44. inflater_.Initialize(kWindowBits);
  45. }
  46. WebSocketDeflateStream::~WebSocketDeflateStream() = default;
  47. int WebSocketDeflateStream::ReadFrames(
  48. std::vector<std::unique_ptr<WebSocketFrame>>* frames,
  49. CompletionOnceCallback callback) {
  50. read_callback_ = std::move(callback);
  51. inflater_outputs_.clear();
  52. int result = stream_->ReadFrames(
  53. frames, base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
  54. base::Unretained(this), base::Unretained(frames)));
  55. if (result < 0)
  56. return result;
  57. DCHECK_EQ(OK, result);
  58. DCHECK(!frames->empty());
  59. return InflateAndReadIfNecessary(frames);
  60. }
  61. int WebSocketDeflateStream::WriteFrames(
  62. std::vector<std::unique_ptr<WebSocketFrame>>* frames,
  63. CompletionOnceCallback callback) {
  64. deflater_outputs_.clear();
  65. int result = Deflate(frames);
  66. if (result != OK)
  67. return result;
  68. if (frames->empty())
  69. return OK;
  70. return stream_->WriteFrames(frames, std::move(callback));
  71. }
  72. void WebSocketDeflateStream::Close() { stream_->Close(); }
  73. std::string WebSocketDeflateStream::GetSubProtocol() const {
  74. return stream_->GetSubProtocol();
  75. }
  76. std::string WebSocketDeflateStream::GetExtensions() const {
  77. return stream_->GetExtensions();
  78. }
  79. const NetLogWithSource& WebSocketDeflateStream::GetNetLogWithSource() const {
  80. return stream_->GetNetLogWithSource();
  81. }
  82. void WebSocketDeflateStream::OnReadComplete(
  83. std::vector<std::unique_ptr<WebSocketFrame>>* frames,
  84. int result) {
  85. if (result != OK) {
  86. frames->clear();
  87. std::move(read_callback_).Run(result);
  88. return;
  89. }
  90. int r = InflateAndReadIfNecessary(frames);
  91. if (r != ERR_IO_PENDING)
  92. std::move(read_callback_).Run(r);
  93. }
  94. int WebSocketDeflateStream::Deflate(
  95. std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
  96. std::vector<std::unique_ptr<WebSocketFrame>> frames_to_write;
  97. // Store frames of the currently processed message if writing_state_ equals to
  98. // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
  99. std::vector<std::unique_ptr<WebSocketFrame>> frames_of_message;
  100. for (size_t i = 0; i < frames->size(); ++i) {
  101. DCHECK(!(*frames)[i]->header.reserved1);
  102. if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
  103. frames_to_write.push_back(std::move((*frames)[i]));
  104. continue;
  105. }
  106. if (writing_state_ == NOT_WRITING)
  107. OnMessageStart(*frames, i);
  108. std::unique_ptr<WebSocketFrame> frame(std::move((*frames)[i]));
  109. predictor_->RecordInputDataFrame(frame.get());
  110. if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
  111. if (frame->header.final)
  112. writing_state_ = NOT_WRITING;
  113. predictor_->RecordWrittenDataFrame(frame.get());
  114. frames_to_write.push_back(std::move(frame));
  115. current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
  116. } else {
  117. if (frame->payload &&
  118. !deflater_.AddBytes(
  119. frame->payload,
  120. static_cast<size_t>(frame->header.payload_length))) {
  121. DVLOG(1) << "WebSocket protocol error. "
  122. << "deflater_.AddBytes() returns an error.";
  123. return ERR_WS_PROTOCOL_ERROR;
  124. }
  125. if (frame->header.final && !deflater_.Finish()) {
  126. DVLOG(1) << "WebSocket protocol error. "
  127. << "deflater_.Finish() returns an error.";
  128. return ERR_WS_PROTOCOL_ERROR;
  129. }
  130. if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
  131. if (deflater_.CurrentOutputSize() >= kChunkSize ||
  132. frame->header.final) {
  133. int result = AppendCompressedFrame(frame->header, &frames_to_write);
  134. if (result != OK)
  135. return result;
  136. }
  137. if (frame->header.final)
  138. writing_state_ = NOT_WRITING;
  139. } else {
  140. DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
  141. bool final = frame->header.final;
  142. frames_of_message.push_back(std::move(frame));
  143. if (final) {
  144. int result = AppendPossiblyCompressedMessage(&frames_of_message,
  145. &frames_to_write);
  146. if (result != OK)
  147. return result;
  148. frames_of_message.clear();
  149. writing_state_ = NOT_WRITING;
  150. }
  151. }
  152. }
  153. }
  154. DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
  155. frames->swap(frames_to_write);
  156. return OK;
  157. }
  158. void WebSocketDeflateStream::OnMessageStart(
  159. const std::vector<std::unique_ptr<WebSocketFrame>>& frames,
  160. size_t index) {
  161. WebSocketFrame* frame = frames[index].get();
  162. current_writing_opcode_ = frame->header.opcode;
  163. DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
  164. current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
  165. WebSocketDeflatePredictor::Result prediction =
  166. predictor_->Predict(frames, index);
  167. switch (prediction) {
  168. case WebSocketDeflatePredictor::DEFLATE:
  169. writing_state_ = WRITING_COMPRESSED_MESSAGE;
  170. return;
  171. case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
  172. writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
  173. return;
  174. case WebSocketDeflatePredictor::TRY_DEFLATE:
  175. writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
  176. return;
  177. }
  178. NOTREACHED();
  179. }
  180. int WebSocketDeflateStream::AppendCompressedFrame(
  181. const WebSocketFrameHeader& header,
  182. std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
  183. const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
  184. scoped_refptr<IOBufferWithSize> compressed_payload =
  185. deflater_.GetOutput(deflater_.CurrentOutputSize());
  186. if (!compressed_payload.get()) {
  187. DVLOG(1) << "WebSocket protocol error. "
  188. << "deflater_.GetOutput() returns an error.";
  189. return ERR_WS_PROTOCOL_ERROR;
  190. }
  191. deflater_outputs_.push_back(compressed_payload);
  192. auto compressed = std::make_unique<WebSocketFrame>(opcode);
  193. compressed->header.CopyFrom(header);
  194. compressed->header.opcode = opcode;
  195. compressed->header.final = header.final;
  196. compressed->header.reserved1 =
  197. (opcode != WebSocketFrameHeader::kOpCodeContinuation);
  198. compressed->payload = compressed_payload->data();
  199. compressed->header.payload_length = compressed_payload->size();
  200. current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
  201. predictor_->RecordWrittenDataFrame(compressed.get());
  202. frames_to_write->push_back(std::move(compressed));
  203. return OK;
  204. }
  205. int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
  206. std::vector<std::unique_ptr<WebSocketFrame>>* frames,
  207. std::vector<std::unique_ptr<WebSocketFrame>>* frames_to_write) {
  208. DCHECK(!frames->empty());
  209. const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
  210. scoped_refptr<IOBufferWithSize> compressed_payload =
  211. deflater_.GetOutput(deflater_.CurrentOutputSize());
  212. if (!compressed_payload.get()) {
  213. DVLOG(1) << "WebSocket protocol error. "
  214. << "deflater_.GetOutput() returns an error.";
  215. return ERR_WS_PROTOCOL_ERROR;
  216. }
  217. deflater_outputs_.push_back(compressed_payload);
  218. uint64_t original_payload_length = 0;
  219. for (size_t i = 0; i < frames->size(); ++i) {
  220. WebSocketFrame* frame = (*frames)[i].get();
  221. // Asserts checking that frames represent one whole data message.
  222. DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
  223. DCHECK_EQ(i == 0,
  224. WebSocketFrameHeader::kOpCodeContinuation !=
  225. frame->header.opcode);
  226. DCHECK_EQ(i == frames->size() - 1, frame->header.final);
  227. original_payload_length += frame->header.payload_length;
  228. }
  229. if (original_payload_length <=
  230. static_cast<uint64_t>(compressed_payload->size())) {
  231. // Compression is not effective. Use the original frames.
  232. for (auto& frame : *frames) {
  233. predictor_->RecordWrittenDataFrame(frame.get());
  234. frames_to_write->push_back(std::move(frame));
  235. }
  236. frames->clear();
  237. return OK;
  238. }
  239. auto compressed = std::make_unique<WebSocketFrame>(opcode);
  240. compressed->header.CopyFrom((*frames)[0]->header);
  241. compressed->header.opcode = opcode;
  242. compressed->header.final = true;
  243. compressed->header.reserved1 = true;
  244. compressed->payload = compressed_payload->data();
  245. compressed->header.payload_length = compressed_payload->size();
  246. predictor_->RecordWrittenDataFrame(compressed.get());
  247. frames_to_write->push_back(std::move(compressed));
  248. return OK;
  249. }
  250. int WebSocketDeflateStream::Inflate(
  251. std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
  252. std::vector<std::unique_ptr<WebSocketFrame>> frames_to_output;
  253. std::vector<std::unique_ptr<WebSocketFrame>> frames_passed;
  254. frames->swap(frames_passed);
  255. for (auto& frame_passed : frames_passed) {
  256. std::unique_ptr<WebSocketFrame> frame(std::move(frame_passed));
  257. frame_passed = nullptr;
  258. DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
  259. << " final=" << frame->header.final
  260. << " reserved1=" << frame->header.reserved1
  261. << " payload_length=" << frame->header.payload_length;
  262. if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
  263. frames_to_output.push_back(std::move(frame));
  264. continue;
  265. }
  266. if (reading_state_ == NOT_READING) {
  267. if (frame->header.reserved1)
  268. reading_state_ = READING_COMPRESSED_MESSAGE;
  269. else
  270. reading_state_ = READING_UNCOMPRESSED_MESSAGE;
  271. current_reading_opcode_ = frame->header.opcode;
  272. } else {
  273. if (frame->header.reserved1) {
  274. DVLOG(1) << "WebSocket protocol error. "
  275. << "Receiving a non-first frame with RSV1 flag set.";
  276. return ERR_WS_PROTOCOL_ERROR;
  277. }
  278. }
  279. if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
  280. if (frame->header.final)
  281. reading_state_ = NOT_READING;
  282. current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
  283. frames_to_output.push_back(std::move(frame));
  284. } else {
  285. DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
  286. if (frame->payload &&
  287. !inflater_.AddBytes(
  288. frame->payload,
  289. static_cast<size_t>(frame->header.payload_length))) {
  290. DVLOG(1) << "WebSocket protocol error. "
  291. << "inflater_.AddBytes() returns an error.";
  292. return ERR_WS_PROTOCOL_ERROR;
  293. }
  294. if (frame->header.final) {
  295. if (!inflater_.Finish()) {
  296. DVLOG(1) << "WebSocket protocol error. "
  297. << "inflater_.Finish() returns an error.";
  298. return ERR_WS_PROTOCOL_ERROR;
  299. }
  300. }
  301. // TODO(yhirano): Many frames can be generated by the inflater and
  302. // memory consumption can grow.
  303. // We could avoid it, but avoiding it makes this class much more
  304. // complicated.
  305. while (inflater_.CurrentOutputSize() >= kChunkSize ||
  306. frame->header.final) {
  307. size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
  308. auto inflated =
  309. std::make_unique<WebSocketFrame>(WebSocketFrameHeader::kOpCodeText);
  310. scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
  311. inflater_outputs_.push_back(data);
  312. bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
  313. if (!data.get()) {
  314. DVLOG(1) << "WebSocket protocol error. "
  315. << "inflater_.GetOutput() returns an error.";
  316. return ERR_WS_PROTOCOL_ERROR;
  317. }
  318. inflated->header.CopyFrom(frame->header);
  319. inflated->header.opcode = current_reading_opcode_;
  320. inflated->header.final = is_final;
  321. inflated->header.reserved1 = false;
  322. inflated->payload = data->data();
  323. inflated->header.payload_length = data->size();
  324. DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
  325. << " final=" << inflated->header.final
  326. << " reserved1=" << inflated->header.reserved1
  327. << " payload_length=" << inflated->header.payload_length;
  328. frames_to_output.push_back(std::move(inflated));
  329. current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
  330. if (is_final)
  331. break;
  332. }
  333. if (frame->header.final)
  334. reading_state_ = NOT_READING;
  335. }
  336. }
  337. frames->swap(frames_to_output);
  338. return frames->empty() ? ERR_IO_PENDING : OK;
  339. }
  340. int WebSocketDeflateStream::InflateAndReadIfNecessary(
  341. std::vector<std::unique_ptr<WebSocketFrame>>* frames) {
  342. int result = Inflate(frames);
  343. while (result == ERR_IO_PENDING) {
  344. DCHECK(frames->empty());
  345. result = stream_->ReadFrames(
  346. frames,
  347. base::BindOnce(&WebSocketDeflateStream::OnReadComplete,
  348. base::Unretained(this), base::Unretained(frames)));
  349. if (result < 0)
  350. break;
  351. DCHECK_EQ(OK, result);
  352. DCHECK(!frames->empty());
  353. result = Inflate(frames);
  354. }
  355. if (result < 0)
  356. frames->clear();
  357. return result;
  358. }
  359. } // namespace net