channel_multiplexer.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "remoting/protocol/channel_multiplexer.h"
  5. #include <stddef.h>
  6. #include <string.h>
  7. #include <utility>
  8. #include "base/bind.h"
  9. #include "base/callback.h"
  10. #include "base/callback_helpers.h"
  11. #include "base/location.h"
  12. #include "base/memory/raw_ptr.h"
  13. #include "base/sequence_checker.h"
  14. #include "base/task/single_thread_task_runner.h"
  15. #include "base/threading/thread_task_runner_handle.h"
  16. #include "net/base/net_errors.h"
  17. #include "remoting/protocol/message_serialization.h"
  18. #include "remoting/protocol/p2p_stream_socket.h"
  19. namespace remoting {
  20. namespace protocol {
  21. namespace {
  22. const int kChannelIdUnknown = -1;
  23. const int kMaxPacketSize = 1024;
  24. class PendingPacket {
  25. public:
  26. PendingPacket(std::unique_ptr<MultiplexPacket> packet)
  27. : packet(std::move(packet)) {}
  28. PendingPacket(const PendingPacket&) = delete;
  29. PendingPacket& operator=(const PendingPacket&) = delete;
  30. ~PendingPacket() = default;
  31. bool is_empty() { return pos >= packet->data().size(); }
  32. int Read(char* buffer, size_t size) {
  33. size = std::min(size, packet->data().size() - pos);
  34. memcpy(buffer, packet->data().data() + pos, size);
  35. pos += size;
  36. return size;
  37. }
  38. private:
  39. std::unique_ptr<MultiplexPacket> packet;
  40. size_t pos = 0U;
  41. };
  42. } // namespace
  43. const char ChannelMultiplexer::kMuxChannelName[] = "mux";
  44. struct ChannelMultiplexer::PendingChannel {
  45. PendingChannel(const std::string& name, ChannelCreatedCallback callback)
  46. : name(name), callback(std::move(callback)) {}
  47. std::string name;
  48. ChannelCreatedCallback callback;
  49. };
  50. class ChannelMultiplexer::MuxChannel {
  51. public:
  52. MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name,
  53. int send_id);
  54. MuxChannel(const MuxChannel&) = delete;
  55. MuxChannel& operator=(const MuxChannel&) = delete;
  56. ~MuxChannel();
  57. const std::string& name() { return name_; }
  58. int receive_id() { return receive_id_; }
  59. void set_receive_id(int id) { receive_id_ = id; }
  60. // Called by ChannelMultiplexer.
  61. std::unique_ptr<P2PStreamSocket> CreateSocket();
  62. void OnIncomingPacket(std::unique_ptr<MultiplexPacket> packet);
  63. void OnBaseChannelError(int error);
  64. // Called by MuxSocket.
  65. void OnSocketDestroyed();
  66. void DoWrite(std::unique_ptr<MultiplexPacket> packet,
  67. base::OnceClosure done_task,
  68. const net::NetworkTrafficAnnotationTag& traffic_annotation);
  69. int DoRead(const scoped_refptr<net::IOBuffer>& buffer, int buffer_len);
  70. private:
  71. raw_ptr<ChannelMultiplexer> multiplexer_;
  72. std::string name_;
  73. int send_id_;
  74. bool id_sent_;
  75. int receive_id_;
  76. raw_ptr<MuxSocket> socket_;
  77. std::list<std::unique_ptr<PendingPacket>> pending_packets_;
  78. };
  79. class ChannelMultiplexer::MuxSocket : public P2PStreamSocket {
  80. public:
  81. MuxSocket(MuxChannel* channel);
  82. MuxSocket(const MuxSocket&) = delete;
  83. MuxSocket& operator=(const MuxSocket&) = delete;
  84. ~MuxSocket() override;
  85. void OnWriteComplete();
  86. void OnBaseChannelError(int error);
  87. void OnPacketReceived();
  88. // P2PStreamSocket interface.
  89. int Read(const scoped_refptr<net::IOBuffer>& buffer,
  90. int buffer_len,
  91. net::CompletionOnceCallback callback) override;
  92. int Write(
  93. const scoped_refptr<net::IOBuffer>& buffer,
  94. int buffer_len,
  95. net::CompletionOnceCallback callback,
  96. const net::NetworkTrafficAnnotationTag& traffic_annotation) override;
  97. private:
  98. raw_ptr<MuxChannel> channel_;
  99. int base_channel_error_ = net::OK;
  100. net::CompletionOnceCallback read_callback_;
  101. scoped_refptr<net::IOBuffer> read_buffer_;
  102. int read_buffer_size_;
  103. bool write_pending_;
  104. int write_result_;
  105. net::CompletionOnceCallback write_callback_;
  106. SEQUENCE_CHECKER(sequence_checker_);
  107. base::WeakPtrFactory<MuxSocket> weak_factory_{this};
  108. };
  109. ChannelMultiplexer::MuxChannel::MuxChannel(ChannelMultiplexer* multiplexer,
  110. const std::string& name,
  111. int send_id)
  112. : multiplexer_(multiplexer),
  113. name_(name),
  114. send_id_(send_id),
  115. id_sent_(false),
  116. receive_id_(kChannelIdUnknown),
  117. socket_(nullptr) {}
  118. ChannelMultiplexer::MuxChannel::~MuxChannel() {
  119. // Socket must be destroyed before the channel.
  120. DCHECK(!socket_);
  121. }
  122. std::unique_ptr<P2PStreamSocket>
  123. ChannelMultiplexer::MuxChannel::CreateSocket() {
  124. DCHECK(!socket_); // Can't create more than one socket per channel.
  125. std::unique_ptr<MuxSocket> result(new MuxSocket(this));
  126. socket_ = result.get();
  127. return std::move(result);
  128. }
  129. void ChannelMultiplexer::MuxChannel::OnIncomingPacket(
  130. std::unique_ptr<MultiplexPacket> packet) {
  131. DCHECK_EQ(packet->channel_id(), receive_id_);
  132. if (packet->data().size() > 0) {
  133. pending_packets_.push_back(
  134. std::make_unique<PendingPacket>(std::move(packet)));
  135. if (socket_) {
  136. // Notify the socket that we have more data.
  137. socket_->OnPacketReceived();
  138. }
  139. }
  140. }
  141. void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
  142. if (socket_)
  143. socket_->OnBaseChannelError(error);
  144. }
  145. void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
  146. DCHECK(socket_);
  147. socket_ = nullptr;
  148. }
  149. void ChannelMultiplexer::MuxChannel::DoWrite(
  150. std::unique_ptr<MultiplexPacket> packet,
  151. base::OnceClosure done_task,
  152. const net::NetworkTrafficAnnotationTag& traffic_annotation) {
  153. packet->set_channel_id(send_id_);
  154. if (!id_sent_) {
  155. packet->set_channel_name(name_);
  156. id_sent_ = true;
  157. }
  158. multiplexer_->DoWrite(std::move(packet), std::move(done_task),
  159. traffic_annotation);
  160. }
  161. int ChannelMultiplexer::MuxChannel::DoRead(
  162. const scoped_refptr<net::IOBuffer>& buffer,
  163. int buffer_len) {
  164. int pos = 0;
  165. while (buffer_len > 0 && !pending_packets_.empty()) {
  166. DCHECK(!pending_packets_.front()->is_empty());
  167. int result = pending_packets_.front()->Read(
  168. buffer->data() + pos, buffer_len);
  169. DCHECK_LE(result, buffer_len);
  170. pos += result;
  171. buffer_len -= pos;
  172. if (pending_packets_.front()->is_empty())
  173. pending_packets_.pop_front();
  174. }
  175. return pos;
  176. }
  177. ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel)
  178. : channel_(channel),
  179. read_buffer_size_(0),
  180. write_pending_(false),
  181. write_result_(0) {}
  182. ChannelMultiplexer::MuxSocket::~MuxSocket() {
  183. channel_->OnSocketDestroyed();
  184. }
  185. int ChannelMultiplexer::MuxSocket::Read(
  186. const scoped_refptr<net::IOBuffer>& buffer,
  187. int buffer_len,
  188. net::CompletionOnceCallback callback) {
  189. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  190. DCHECK(read_callback_.is_null());
  191. if (base_channel_error_ != net::OK)
  192. return base_channel_error_;
  193. int result = channel_->DoRead(buffer, buffer_len);
  194. if (result == 0) {
  195. read_buffer_ = buffer;
  196. read_buffer_size_ = buffer_len;
  197. read_callback_ = std::move(callback);
  198. return net::ERR_IO_PENDING;
  199. }
  200. return result;
  201. }
  202. int ChannelMultiplexer::MuxSocket::Write(
  203. const scoped_refptr<net::IOBuffer>& buffer,
  204. int buffer_len,
  205. net::CompletionOnceCallback callback,
  206. const net::NetworkTrafficAnnotationTag& traffic_annotation) {
  207. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  208. DCHECK(write_callback_.is_null());
  209. if (base_channel_error_ != net::OK)
  210. return base_channel_error_;
  211. std::unique_ptr<MultiplexPacket> packet(new MultiplexPacket());
  212. size_t size = std::min(kMaxPacketSize, buffer_len);
  213. packet->mutable_data()->assign(buffer->data(), size);
  214. write_pending_ = true;
  215. channel_->DoWrite(
  216. std::move(packet),
  217. base::BindOnce(&ChannelMultiplexer::MuxSocket::OnWriteComplete,
  218. weak_factory_.GetWeakPtr()),
  219. traffic_annotation);
  220. // OnWriteComplete() might be called above synchronously.
  221. if (write_pending_) {
  222. DCHECK(write_callback_.is_null());
  223. write_callback_ = std::move(callback);
  224. write_result_ = size;
  225. return net::ERR_IO_PENDING;
  226. }
  227. return size;
  228. }
  229. void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
  230. write_pending_ = false;
  231. if (!write_callback_.is_null())
  232. std::move(write_callback_).Run(write_result_);
  233. }
  234. void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
  235. base_channel_error_ = error;
  236. // Here only one of the read and write callbacks is called if both of them are
  237. // pending. Ideally both of them should be called in that case, but that would
  238. // require the second one to be called asynchronously which would complicate
  239. // this code. Channels handle read and write errors the same way (see
  240. // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
  241. // callbacks is enough.
  242. if (!read_callback_.is_null()) {
  243. std::move(read_callback_).Run(error);
  244. return;
  245. }
  246. if (!write_callback_.is_null())
  247. std::move(write_callback_).Run(error);
  248. }
  249. void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
  250. if (!read_callback_.is_null()) {
  251. int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
  252. read_buffer_ = nullptr;
  253. DCHECK_GT(result, 0);
  254. std::move(read_callback_).Run(result);
  255. }
  256. }
  257. ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
  258. const std::string& base_channel_name)
  259. : base_channel_factory_(factory),
  260. base_channel_name_(base_channel_name),
  261. next_channel_id_(0) {}
  262. ChannelMultiplexer::~ChannelMultiplexer() {
  263. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  264. DCHECK(pending_channels_.empty());
  265. // Cancel creation of the base channel if it hasn't finished.
  266. if (base_channel_factory_)
  267. base_channel_factory_->CancelChannelCreation(base_channel_name_);
  268. }
  269. void ChannelMultiplexer::CreateChannel(const std::string& name,
  270. ChannelCreatedCallback callback) {
  271. if (base_channel_.get()) {
  272. // Already have |base_channel_|. Create new multiplexed channel
  273. // synchronously.
  274. std::move(callback).Run(GetOrCreateChannel(name)->CreateSocket());
  275. } else if (!base_channel_.get() && !base_channel_factory_) {
  276. // Fail synchronously if we failed to create |base_channel_|.
  277. std::move(callback).Run(nullptr);
  278. } else {
  279. // Still waiting for the |base_channel_|.
  280. pending_channels_.emplace_back(name, std::move(callback));
  281. // If this is the first multiplexed channel then create the base channel.
  282. if (pending_channels_.size() == 1U) {
  283. base_channel_factory_->CreateChannel(
  284. base_channel_name_,
  285. base::BindOnce(&ChannelMultiplexer::OnBaseChannelReady,
  286. base::Unretained(this)));
  287. }
  288. }
  289. }
  290. void ChannelMultiplexer::CancelChannelCreation(const std::string& name) {
  291. for (auto it = pending_channels_.begin(); it != pending_channels_.end();
  292. ++it) {
  293. if (it->name == name) {
  294. pending_channels_.erase(it);
  295. return;
  296. }
  297. }
  298. }
  299. void ChannelMultiplexer::OnBaseChannelReady(
  300. std::unique_ptr<P2PStreamSocket> socket) {
  301. base_channel_factory_ = nullptr;
  302. base_channel_ = std::move(socket);
  303. if (base_channel_.get()) {
  304. // Initialize reader and writer.
  305. reader_.StartReading(
  306. base_channel_.get(),
  307. base::BindRepeating(&ChannelMultiplexer::OnIncomingPacket,
  308. base::Unretained(this)),
  309. base::BindOnce(&ChannelMultiplexer::OnBaseChannelError,
  310. base::Unretained(this)));
  311. writer_.Start(base::BindRepeating(&P2PStreamSocket::Write,
  312. base::Unretained(base_channel_.get())),
  313. base::BindOnce(&ChannelMultiplexer::OnBaseChannelError,
  314. base::Unretained(this)));
  315. }
  316. DoCreatePendingChannels();
  317. }
  318. void ChannelMultiplexer::DoCreatePendingChannels() {
  319. if (pending_channels_.empty())
  320. return;
  321. // Every time this function is called it connects a single channel and posts a
  322. // separate task to connect other channels. This is necessary because the
  323. // callback may destroy the multiplexer or somehow else modify
  324. // |pending_channels_| list (e.g. call CancelChannelCreation()).
  325. base::ThreadTaskRunnerHandle::Get()->PostTask(
  326. FROM_HERE, base::BindOnce(&ChannelMultiplexer::DoCreatePendingChannels,
  327. weak_factory_.GetWeakPtr()));
  328. PendingChannel c = std::move(pending_channels_.front());
  329. pending_channels_.erase(pending_channels_.begin());
  330. std::unique_ptr<P2PStreamSocket> socket;
  331. if (base_channel_.get())
  332. socket = GetOrCreateChannel(c.name)->CreateSocket();
  333. std::move(c.callback).Run(std::move(socket));
  334. }
  335. ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel(
  336. const std::string& name) {
  337. std::unique_ptr<MuxChannel>& channel = channels_[name];
  338. if (!channel) {
  339. // Create a new channel if we haven't found existing one.
  340. channel = std::make_unique<MuxChannel>(this, name, next_channel_id_);
  341. ++next_channel_id_;
  342. }
  343. return channel.get();
  344. }
  345. void ChannelMultiplexer::OnBaseChannelError(int error) {
  346. for (auto it = channels_.begin(); it != channels_.end(); ++it) {
  347. base::ThreadTaskRunnerHandle::Get()->PostTask(
  348. FROM_HERE,
  349. base::BindOnce(&ChannelMultiplexer::NotifyBaseChannelError,
  350. weak_factory_.GetWeakPtr(), it->second->name(), error));
  351. }
  352. }
  353. void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
  354. int error) {
  355. auto it = channels_.find(name);
  356. if (it != channels_.end())
  357. it->second->OnBaseChannelError(error);
  358. }
  359. void ChannelMultiplexer::OnIncomingPacket(
  360. std::unique_ptr<CompoundBuffer> buffer) {
  361. std::unique_ptr<MultiplexPacket> packet =
  362. ParseMessage<MultiplexPacket>(buffer.get());
  363. if (!packet)
  364. return;
  365. DCHECK(packet->has_channel_id());
  366. if (!packet->has_channel_id()) {
  367. LOG(ERROR) << "Received packet without channel_id.";
  368. return;
  369. }
  370. int receive_id = packet->channel_id();
  371. MuxChannel* channel = nullptr;
  372. auto it = channels_by_receive_id_.find(receive_id);
  373. if (it != channels_by_receive_id_.end()) {
  374. channel = it->second;
  375. } else {
  376. // This is a new |channel_id| we haven't seen before. Look it up by name.
  377. if (!packet->has_channel_name()) {
  378. LOG(ERROR) << "Received packet with unknown channel_id and "
  379. "without channel_name.";
  380. return;
  381. }
  382. channel = GetOrCreateChannel(packet->channel_name());
  383. channel->set_receive_id(receive_id);
  384. channels_by_receive_id_[receive_id] = channel;
  385. }
  386. channel->OnIncomingPacket(std::move(packet));
  387. }
  388. void ChannelMultiplexer::DoWrite(
  389. std::unique_ptr<MultiplexPacket> packet,
  390. base::OnceClosure done_task,
  391. const net::NetworkTrafficAnnotationTag& traffic_annotation) {
  392. writer_.Write(SerializeAndFrameMessage(*packet), std::move(done_task),
  393. traffic_annotation);
  394. }
  395. } // namespace protocol
  396. } // namespace remoting