cast_transport.cc 13 KB


  1. // Copyright 2014 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/cast_channel/cast_transport.h"
  5. #include <stddef.h>
  6. #include <stdint.h>
  7. #include <memory>
  8. #include <string>
  9. #include <utility>
  10. #include "base/bind.h"
  11. #include "base/format_macros.h"
  12. #include "base/location.h"
  13. #include "base/numerics/safe_conversions.h"
  14. #include "base/task/single_thread_task_runner.h"
  15. #include "base/threading/thread_task_runner_handle.h"
  16. #include "components/cast_channel/cast_framer.h"
  17. #include "components/cast_channel/cast_message_util.h"
  18. #include "components/cast_channel/logger.h"
  19. #include "net/base/net_errors.h"
  20. #include "third_party/openscreen/src/cast/common/channel/proto/cast_channel.pb.h"
  21. #define VLOG_WITH_CONNECTION(level) \
  22. VLOG(level) << "[" << ip_endpoint_.ToString() << ", auth=SSL_VERIFIED] "
  23. namespace cast_channel {
  24. namespace {
  25. #if DCHECK_IS_ON()
  26. // Used to filter out PING and PONG message from logs, since there are a lot of
  27. // them and they're not interesting.
  28. bool IsPingPong(const CastMessage& message) {
  29. return message.has_payload_utf8() &&
  30. (message.payload_utf8() == R"({"type":"PING"})" ||
  31. message.payload_utf8() == R"({"type":"PONG"})");
  32. }
  33. #endif // DCHECK_IS_ON()
  34. } // namespace
  35. CastTransportImpl::CastTransportImpl(Channel* channel,
  36. int channel_id,
  37. const net::IPEndPoint& ip_endpoint,
  38. scoped_refptr<Logger> logger)
  39. : started_(false),
  40. channel_(channel),
  41. write_state_(WriteState::IDLE),
  42. read_state_(ReadState::READ),
  43. error_state_(ChannelError::NONE),
  44. channel_id_(channel_id),
  45. ip_endpoint_(ip_endpoint),
  46. logger_(logger) {
  47. // Buffer is reused across messages to minimize unnecessary buffer
  48. // [re]allocations.
  49. read_buffer_ = base::MakeRefCounted<net::GrowableIOBuffer>();
  50. read_buffer_->SetCapacity(MessageFramer::MessageHeader::max_message_size());
  51. framer_ = std::make_unique<MessageFramer>(read_buffer_);
  52. }
  53. CastTransportImpl::~CastTransportImpl() {
  54. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  55. FlushWriteQueue();
  56. }
  57. bool CastTransportImpl::IsTerminalWriteState(WriteState write_state) {
  58. return write_state == WriteState::WRITE_ERROR ||
  59. write_state == WriteState::IDLE;
  60. }
  61. bool CastTransportImpl::IsTerminalReadState(ReadState read_state) {
  62. return read_state == ReadState::READ_ERROR;
  63. }
  64. void CastTransportImpl::SetReadDelegate(std::unique_ptr<Delegate> delegate) {
  65. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  66. DCHECK(delegate);
  67. delegate_ = std::move(delegate);
  68. if (started_) {
  69. delegate_->Start();
  70. }
  71. }
  72. void CastTransportImpl::FlushWriteQueue() {
  73. for (; !write_queue_.empty(); write_queue_.pop()) {
  74. base::ThreadTaskRunnerHandle::Get()->PostTask(
  75. FROM_HERE, base::BindOnce(std::move(write_queue_.front().callback),
  76. net::ERR_FAILED));
  77. }
  78. }
  79. void CastTransportImpl::SendMessage(const CastMessage& message,
  80. net::CompletionOnceCallback callback) {
  81. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  82. DCHECK(IsCastMessageValid(message));
  83. DVLOG_IF(1, !IsPingPong(message)) << "Sending: " << message;
  84. std::string serialized_message;
  85. if (!MessageFramer::Serialize(message, &serialized_message)) {
  86. base::ThreadTaskRunnerHandle::Get()->PostTask(
  87. FROM_HERE, base::BindOnce(std::move(callback), net::ERR_FAILED));
  88. return;
  89. }
  90. write_queue_.emplace(message.namespace_(), serialized_message,
  91. std::move(callback));
  92. if (write_state_ == WriteState::IDLE) {
  93. SetWriteState(WriteState::WRITE);
  94. OnWriteResult(net::OK);
  95. }
  96. }
  97. CastTransportImpl::WriteRequest::WriteRequest(
  98. const std::string& namespace_,
  99. const std::string& payload,
  100. net::CompletionOnceCallback callback)
  101. : message_namespace(namespace_), callback(std::move(callback)) {
  102. VLOG(2) << "WriteRequest size: " << payload.size();
  103. io_buffer = base::MakeRefCounted<net::DrainableIOBuffer>(
  104. base::MakeRefCounted<net::StringIOBuffer>(payload), payload.size());
  105. }
  106. CastTransportImpl::WriteRequest::WriteRequest(WriteRequest&& other) = default;
  107. CastTransportImpl::WriteRequest::~WriteRequest() {}
  108. void CastTransportImpl::SetReadState(ReadState read_state) {
  109. if (read_state_ != read_state)
  110. read_state_ = read_state;
  111. }
  112. void CastTransportImpl::SetWriteState(WriteState write_state) {
  113. if (write_state_ != write_state)
  114. write_state_ = write_state;
  115. }
  116. void CastTransportImpl::SetErrorState(ChannelError error_state) {
  117. VLOG_WITH_CONNECTION(2) << "SetErrorState: "
  118. << ::cast_channel::ChannelErrorToString(error_state);
  119. error_state_ = error_state;
  120. }
  121. void CastTransportImpl::OnWriteResult(int result) {
  122. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  123. DCHECK_NE(WriteState::IDLE, write_state_);
  124. if (write_queue_.empty()) {
  125. SetWriteState(WriteState::IDLE);
  126. return;
  127. }
  128. // Network operations can either finish synchronously or asynchronously.
  129. // This method executes the state machine transitions in a loop so that
  130. // write state transitions happen even when network operations finish
  131. // synchronously.
  132. int rv = result;
  133. do {
  134. VLOG_WITH_CONNECTION(2)
  135. << "OnWriteResult (state=" << AsInteger(write_state_) << ", "
  136. << "result=" << rv << ", "
  137. << "queue size=" << write_queue_.size() << ")";
  138. WriteState state = write_state_;
  139. write_state_ = WriteState::UNKNOWN;
  140. switch (state) {
  141. case WriteState::WRITE:
  142. rv = DoWrite();
  143. break;
  144. case WriteState::WRITE_COMPLETE:
  145. rv = DoWriteComplete(rv);
  146. break;
  147. case WriteState::DO_CALLBACK:
  148. rv = DoWriteCallback();
  149. break;
  150. case WriteState::HANDLE_ERROR:
  151. rv = DoWriteHandleError(rv);
  152. DCHECK_EQ(WriteState::WRITE_ERROR, write_state_);
  153. break;
  154. default:
  155. NOTREACHED() << "Unknown state in write state machine: "
  156. << AsInteger(state);
  157. SetWriteState(WriteState::WRITE_ERROR);
  158. SetErrorState(ChannelError::UNKNOWN);
  159. rv = net::ERR_FAILED;
  160. break;
  161. }
  162. } while (rv != net::ERR_IO_PENDING && !IsTerminalWriteState(write_state_));
  163. if (write_state_ == WriteState::WRITE_ERROR) {
  164. FlushWriteQueue();
  165. DCHECK_NE(ChannelError::NONE, error_state_);
  166. VLOG_WITH_CONNECTION(2) << "Sending OnError().";
  167. delegate_->OnError(error_state_);
  168. }
  169. }
  170. int CastTransportImpl::DoWrite() {
  171. DCHECK(!write_queue_.empty());
  172. net::DrainableIOBuffer* io_buffer = write_queue_.front().io_buffer.get();
  173. VLOG_WITH_CONNECTION(2) << "WriteData byte_count = " << io_buffer->size()
  174. << " bytes_written " << io_buffer->BytesConsumed();
  175. SetWriteState(WriteState::WRITE_COMPLETE);
  176. // TODO(mfoltz): Improve APIs for CastTransportImpl::Channel::{Read|Write} so
  177. // that they don't expect raw pointers but handle movable parameters instead.
  178. channel_->Write(io_buffer, io_buffer->BytesRemaining(),
  179. base::BindOnce(&CastTransportImpl::OnWriteResult,
  180. base::Unretained(this)));
  181. return net::ERR_IO_PENDING;
  182. }
  183. int CastTransportImpl::DoWriteComplete(int result) {
  184. VLOG_WITH_CONNECTION(2) << "DoWriteComplete result=" << result;
  185. DCHECK(!write_queue_.empty());
  186. if (result <= 0) { // NOTE that 0 also indicates an error
  187. logger_->LogSocketEventWithRv(channel_id_, ChannelEvent::SOCKET_WRITE,
  188. result);
  189. SetErrorState(ChannelError::CAST_SOCKET_ERROR);
  190. SetWriteState(WriteState::HANDLE_ERROR);
  191. return result == 0 ? net::ERR_FAILED : result;
  192. }
  193. // Some bytes were successfully written
  194. net::DrainableIOBuffer* io_buffer = write_queue_.front().io_buffer.get();
  195. io_buffer->DidConsume(result);
  196. if (io_buffer->BytesRemaining() == 0) { // Message fully sent
  197. SetWriteState(WriteState::DO_CALLBACK);
  198. } else {
  199. SetWriteState(WriteState::WRITE);
  200. }
  201. return net::OK;
  202. }
  203. int CastTransportImpl::DoWriteCallback() {
  204. VLOG_WITH_CONNECTION(2) << "DoWriteCallback";
  205. DCHECK(!write_queue_.empty());
  206. base::ThreadTaskRunnerHandle::Get()->PostTask(
  207. FROM_HERE,
  208. base::BindOnce(std::move(write_queue_.front().callback), net::OK));
  209. write_queue_.pop();
  210. if (write_queue_.empty()) {
  211. SetWriteState(WriteState::IDLE);
  212. } else {
  213. SetWriteState(WriteState::WRITE);
  214. }
  215. return net::OK;
  216. }
  217. int CastTransportImpl::DoWriteHandleError(int result) {
  218. VLOG_WITH_CONNECTION(2) << "DoWriteHandleError result=" << result;
  219. DCHECK_NE(ChannelError::NONE, error_state_);
  220. DCHECK_LT(result, 0);
  221. SetWriteState(WriteState::WRITE_ERROR);
  222. return net::ERR_FAILED;
  223. }
  224. void CastTransportImpl::Start() {
  225. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  226. DCHECK(!started_);
  227. DCHECK_EQ(ReadState::READ, read_state_);
  228. DCHECK(delegate_) << "Read delegate must be set prior to calling Start()";
  229. started_ = true;
  230. delegate_->Start();
  231. SetReadState(ReadState::READ);
  232. // Start the read state machine.
  233. OnReadResult(net::OK);
  234. }
  235. void CastTransportImpl::OnReadResult(int result) {
  236. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  237. // Network operations can either finish synchronously or asynchronously.
  238. // This method executes the state machine transitions in a loop so that
  239. // write state transitions happen even when network operations finish
  240. // synchronously.
  241. int rv = result;
  242. do {
  243. VLOG_WITH_CONNECTION(2) << "OnReadResult(state=" << AsInteger(read_state_)
  244. << ", result=" << rv << ")";
  245. ReadState state = read_state_;
  246. read_state_ = ReadState::UNKNOWN;
  247. switch (state) {
  248. case ReadState::READ:
  249. rv = DoRead();
  250. break;
  251. case ReadState::READ_COMPLETE:
  252. rv = DoReadComplete(rv);
  253. break;
  254. case ReadState::DO_CALLBACK:
  255. rv = DoReadCallback();
  256. break;
  257. case ReadState::HANDLE_ERROR:
  258. rv = DoReadHandleError(rv);
  259. DCHECK_EQ(read_state_, ReadState::READ_ERROR);
  260. break;
  261. default:
  262. NOTREACHED() << "Unknown state in read state machine: "
  263. << AsInteger(state);
  264. SetReadState(ReadState::READ_ERROR);
  265. SetErrorState(ChannelError::UNKNOWN);
  266. rv = net::ERR_FAILED;
  267. break;
  268. }
  269. } while (rv != net::ERR_IO_PENDING && !IsTerminalReadState(read_state_));
  270. if (IsTerminalReadState(read_state_)) {
  271. DCHECK_EQ(ReadState::READ_ERROR, read_state_);
  272. VLOG_WITH_CONNECTION(2) << "Sending OnError().";
  273. delegate_->OnError(error_state_);
  274. }
  275. }
  276. int CastTransportImpl::DoRead() {
  277. VLOG_WITH_CONNECTION(2) << "DoRead";
  278. SetReadState(ReadState::READ_COMPLETE);
  279. // Determine how many bytes need to be read.
  280. size_t num_bytes_to_read = framer_->BytesRequested();
  281. DCHECK_GT(num_bytes_to_read, 0u);
  282. // Read up to num_bytes_to_read into |current_read_buffer_|.
  283. channel_->Read(
  284. read_buffer_.get(), base::checked_cast<uint32_t>(num_bytes_to_read),
  285. base::BindOnce(&CastTransportImpl::OnReadResult, base::Unretained(this)));
  286. return net::ERR_IO_PENDING;
  287. }
  288. int CastTransportImpl::DoReadComplete(int result) {
  289. VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
  290. if (result <= 0) {
  291. logger_->LogSocketEventWithRv(channel_id_, ChannelEvent::SOCKET_READ,
  292. result);
  293. VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket.";
  294. SetErrorState(ChannelError::CAST_SOCKET_ERROR);
  295. SetReadState(ReadState::HANDLE_ERROR);
  296. return result == 0 ? net::ERR_FAILED : result;
  297. }
  298. size_t message_size;
  299. DCHECK(!current_message_);
  300. ChannelError framing_error;
  301. current_message_ = framer_->Ingest(result, &message_size, &framing_error);
  302. if (current_message_.get() && (framing_error == ChannelError::NONE)) {
  303. DCHECK_GT(message_size, static_cast<size_t>(0));
  304. SetReadState(ReadState::DO_CALLBACK);
  305. } else if (framing_error != ChannelError::NONE) {
  306. DCHECK(!current_message_);
  307. SetErrorState(ChannelError::INVALID_MESSAGE);
  308. SetReadState(ReadState::HANDLE_ERROR);
  309. } else {
  310. DCHECK(!current_message_);
  311. SetReadState(ReadState::READ);
  312. }
  313. return net::OK;
  314. }
  315. int CastTransportImpl::DoReadCallback() {
  316. VLOG_WITH_CONNECTION(2) << "DoReadCallback";
  317. if (!IsCastMessageValid(*current_message_)) {
  318. SetReadState(ReadState::HANDLE_ERROR);
  319. SetErrorState(ChannelError::INVALID_MESSAGE);
  320. return net::ERR_INVALID_RESPONSE;
  321. }
  322. SetReadState(ReadState::READ);
  323. DVLOG_IF(1, !IsPingPong(*current_message_))
  324. << "Received: " << *current_message_;
  325. delegate_->OnMessage(*current_message_);
  326. current_message_.reset();
  327. return net::OK;
  328. }
  329. int CastTransportImpl::DoReadHandleError(int result) {
  330. VLOG_WITH_CONNECTION(2) << "DoReadHandleError";
  331. DCHECK_NE(ChannelError::NONE, error_state_);
  332. DCHECK_LE(result, 0);
  333. SetReadState(ReadState::READ_ERROR);
  334. return net::ERR_FAILED;
  335. }
  336. } // namespace cast_channel