fake_stream_socket.cc 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. // Copyright 2014 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "remoting/protocol/fake_stream_socket.h"
  5. #include <utility>
  6. #include "base/bind.h"
  7. #include "base/callback_helpers.h"
  8. #include "base/location.h"
  9. #include "base/task/single_thread_task_runner.h"
  10. #include "base/threading/thread_task_runner_handle.h"
  11. #include "net/base/address_list.h"
  12. #include "net/base/io_buffer.h"
  13. #include "net/base/net_errors.h"
  14. #include "net/traffic_annotation/network_traffic_annotation.h"
  15. #include "testing/gtest/include/gtest/gtest.h"
  16. namespace remoting {
  17. namespace protocol {
  18. FakeStreamSocket::FakeStreamSocket()
  19. : task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
  20. FakeStreamSocket::~FakeStreamSocket() {
  21. EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
  22. if (peer_socket_) {
  23. task_runner_->PostTask(
  24. FROM_HERE, base::BindOnce(&FakeStreamSocket::SetReadError, peer_socket_,
  25. net::ERR_CONNECTION_CLOSED));
  26. }
  27. }
  28. void FakeStreamSocket::AppendInputData(const std::string& data) {
  29. EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
  30. input_data_.insert(input_data_.end(), data.begin(), data.end());
  31. // Complete pending read if any.
  32. if (!read_callback_.is_null()) {
  33. int result = std::min(read_buffer_size_,
  34. static_cast<int>(input_data_.size() - input_pos_));
  35. EXPECT_GT(result, 0);
  36. memcpy(read_buffer_->data(),
  37. &(*input_data_.begin()) + input_pos_, result);
  38. input_pos_ += result;
  39. read_buffer_ = nullptr;
  40. std::move(read_callback_).Run(result);
  41. }
  42. }
  43. void FakeStreamSocket::SetReadError(int error) {
  44. EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
  45. // Complete pending read if any.
  46. if (!read_callback_.is_null()) {
  47. std::move(read_callback_).Run(error);
  48. } else {
  49. next_read_error_ = error;
  50. }
  51. }
  52. void FakeStreamSocket::PairWith(FakeStreamSocket* peer_socket) {
  53. EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
  54. peer_socket_ = peer_socket->GetWeakPtr();
  55. peer_socket->peer_socket_ = GetWeakPtr();
  56. }
  57. base::WeakPtr<FakeStreamSocket> FakeStreamSocket::GetWeakPtr() {
  58. return weak_factory_.GetWeakPtr();
  59. }
  60. int FakeStreamSocket::Read(const scoped_refptr<net::IOBuffer>& buf,
  61. int buf_len,
  62. net::CompletionOnceCallback callback) {
  63. EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
  64. if (input_pos_ < static_cast<int>(input_data_.size())) {
  65. int result = std::min(buf_len,
  66. static_cast<int>(input_data_.size()) - input_pos_);
  67. memcpy(buf->data(), &(*input_data_.begin()) + input_pos_, result);
  68. input_pos_ += result;
  69. return result;
  70. } else if (next_read_error_.has_value()) {
  71. int r = next_read_error_.value();
  72. next_read_error_.reset();
  73. return r;
  74. } else {
  75. read_buffer_ = buf;
  76. read_buffer_size_ = buf_len;
  77. read_callback_ = std::move(callback);
  78. return net::ERR_IO_PENDING;
  79. }
  80. }
  81. int FakeStreamSocket::Write(
  82. const scoped_refptr<net::IOBuffer>& buf,
  83. int buf_len,
  84. net::CompletionOnceCallback callback,
  85. const net::NetworkTrafficAnnotationTag& traffic_annotation) {
  86. EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
  87. EXPECT_FALSE(write_pending_);
  88. if (write_limit_ > 0)
  89. buf_len = std::min(write_limit_, buf_len);
  90. if (async_write_) {
  91. task_runner_->PostTask(
  92. FROM_HERE, base::BindOnce(&FakeStreamSocket::DoAsyncWrite,
  93. weak_factory_.GetWeakPtr(),
  94. scoped_refptr<net::IOBuffer>(buf), buf_len,
  95. std::move(callback)));
  96. write_pending_ = true;
  97. return net::ERR_IO_PENDING;
  98. } else {
  99. if (next_write_error_ != net::OK) {
  100. int r = next_write_error_;
  101. next_write_error_ = net::OK;
  102. return r;
  103. }
  104. DoWrite(buf, buf_len);
  105. return buf_len;
  106. }
  107. }
  108. void FakeStreamSocket::DoAsyncWrite(const scoped_refptr<net::IOBuffer>& buf,
  109. int buf_len,
  110. net::CompletionOnceCallback callback) {
  111. write_pending_ = false;
  112. if (next_write_error_ != net::OK) {
  113. int r = next_write_error_;
  114. next_write_error_ = net::OK;
  115. std::move(callback).Run(r);
  116. return;
  117. }
  118. DoWrite(buf.get(), buf_len);
  119. std::move(callback).Run(buf_len);
  120. }
  121. void FakeStreamSocket::DoWrite(const scoped_refptr<net::IOBuffer>& buf,
  122. int buf_len) {
  123. written_data_.insert(written_data_.end(),
  124. buf->data(), buf->data() + buf_len);
  125. if (peer_socket_) {
  126. task_runner_->PostTask(
  127. FROM_HERE,
  128. base::BindOnce(&FakeStreamSocket::AppendInputData, peer_socket_,
  129. std::string(buf->data(), buf->data() + buf_len)));
  130. }
  131. }
  132. FakeStreamChannelFactory::FakeStreamChannelFactory()
  133. : task_runner_(base::ThreadTaskRunnerHandle::Get()) {}
  134. FakeStreamChannelFactory::~FakeStreamChannelFactory() = default;
  135. FakeStreamSocket* FakeStreamChannelFactory::GetFakeChannel(
  136. const std::string& name) {
  137. return channels_[name].get();
  138. }
  139. void FakeStreamChannelFactory::PairWith(
  140. FakeStreamChannelFactory* peer_factory) {
  141. peer_factory_ = peer_factory->weak_factory_.GetWeakPtr();
  142. peer_factory->peer_factory_ = weak_factory_.GetWeakPtr();
  143. }
  144. void FakeStreamChannelFactory::CreateChannel(const std::string& name,
  145. ChannelCreatedCallback callback) {
  146. std::unique_ptr<FakeStreamSocket> channel(new FakeStreamSocket());
  147. channels_[name] = channel->GetWeakPtr();
  148. channel->set_async_write(async_write_);
  149. if (peer_factory_) {
  150. FakeStreamSocket* peer_channel = peer_factory_->GetFakeChannel(name);
  151. if (peer_channel)
  152. channel->PairWith(peer_channel);
  153. }
  154. if (fail_create_)
  155. channel.reset();
  156. if (asynchronous_create_) {
  157. task_runner_->PostTask(
  158. FROM_HERE,
  159. base::BindOnce(&FakeStreamChannelFactory::NotifyChannelCreated,
  160. weak_factory_.GetWeakPtr(), std::move(channel), name,
  161. std::move(callback)));
  162. } else {
  163. NotifyChannelCreated(std::move(channel), name, std::move(callback));
  164. }
  165. }
  166. void FakeStreamChannelFactory::NotifyChannelCreated(
  167. std::unique_ptr<FakeStreamSocket> owned_channel,
  168. const std::string& name,
  169. ChannelCreatedCallback callback) {
  170. if (channels_.find(name) != channels_.end())
  171. std::move(callback).Run(std::move(owned_channel));
  172. }
  173. void FakeStreamChannelFactory::CancelChannelCreation(const std::string& name) {
  174. channels_.erase(name);
  175. }
  176. } // namespace protocol
  177. } // namespace remoting