pseudotcp_adapter_unittest.cc 13 KB


  1. // Copyright 2015 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "remoting/protocol/pseudotcp_adapter.h"
  5. #include <memory>
  6. #include <utility>
  7. #include <vector>
  8. #include "base/bind.h"
  9. #include "base/callback_helpers.h"
  10. #include "base/compiler_specific.h"
  11. #include "base/containers/circular_deque.h"
  12. #include "base/location.h"
  13. #include "base/memory/ptr_util.h"
  14. #include "base/memory/raw_ptr.h"
  15. #include "base/run_loop.h"
  16. #include "base/task/single_thread_task_runner.h"
  17. #include "base/test/task_environment.h"
  18. #include "base/threading/thread_task_runner_handle.h"
  19. #include "base/time/time.h"
  20. #include "components/webrtc/thread_wrapper.h"
  21. #include "net/base/io_buffer.h"
  22. #include "net/base/net_errors.h"
  23. #include "net/base/test_completion_callback.h"
  24. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  25. #include "remoting/protocol/p2p_datagram_socket.h"
  26. #include "remoting/protocol/p2p_stream_socket.h"
  27. #include "testing/gmock/include/gmock/gmock.h"
  28. #include "testing/gtest/include/gtest/gtest.h"
  29. namespace remoting {
  30. namespace protocol {
  31. namespace {
  32. const int kMessageSize = 1024;
  33. const int kMessages = 100;
  34. const int kTestDataSize = kMessages * kMessageSize;
  35. class RateLimiter {
  36. public:
  37. virtual ~RateLimiter() = default;
  38. // Returns true if the new packet needs to be dropped, false otherwise.
  39. virtual bool DropNextPacket() = 0;
  40. };
  41. class LeakyBucket : public RateLimiter {
  42. public:
  43. // |rate| is in drops per second.
  44. LeakyBucket(double volume, double rate)
  45. : volume_(volume),
  46. rate_(rate),
  47. level_(0.0),
  48. last_update_(base::TimeTicks::Now()) {
  49. }
  50. ~LeakyBucket() override = default;
  51. bool DropNextPacket() override {
  52. base::TimeTicks now = base::TimeTicks::Now();
  53. double interval = (now - last_update_).InSecondsF();
  54. last_update_ = now;
  55. level_ = level_ + 1.0 - interval * rate_;
  56. if (level_ > volume_) {
  57. level_ = volume_;
  58. return true;
  59. } else if (level_ < 0.0) {
  60. level_ = 0.0;
  61. }
  62. return false;
  63. }
  64. private:
  65. double volume_;
  66. double rate_;
  67. double level_;
  68. base::TimeTicks last_update_;
  69. };
  70. class FakeSocket : public P2PDatagramSocket {
  71. public:
  72. FakeSocket() : rate_limiter_(nullptr), latency_ms_(0) {}
  73. ~FakeSocket() override = default;
  74. void AppendInputPacket(const std::vector<char>& data) {
  75. if (rate_limiter_ && rate_limiter_->DropNextPacket())
  76. return; // Lose the packet.
  77. if (!read_callback_.is_null()) {
  78. int size = std::min(read_buffer_size_, static_cast<int>(data.size()));
  79. memcpy(read_buffer_->data(), &data[0], data.size());
  80. net::CompletionRepeatingCallback cb = read_callback_;
  81. read_callback_.Reset();
  82. read_buffer_.reset();
  83. cb.Run(size);
  84. } else {
  85. incoming_packets_.push_back(data);
  86. }
  87. }
  88. void Connect(FakeSocket* peer_socket) {
  89. peer_socket_ = peer_socket;
  90. }
  91. void set_rate_limiter(RateLimiter* rate_limiter) {
  92. rate_limiter_ = rate_limiter;
  93. }
  94. void set_latency(int latency_ms) { latency_ms_ = latency_ms; }
  95. // P2PDatagramSocket interface.
  96. int Recv(const scoped_refptr<net::IOBuffer>& buf,
  97. int buf_len,
  98. const net::CompletionRepeatingCallback& callback) override {
  99. CHECK(read_callback_.is_null());
  100. CHECK(buf);
  101. if (incoming_packets_.size() > 0) {
  102. scoped_refptr<net::IOBuffer> buffer(buf);
  103. int size = std::min(
  104. static_cast<int>(incoming_packets_.front().size()), buf_len);
  105. memcpy(buffer->data(), &*incoming_packets_.front().begin(), size);
  106. incoming_packets_.pop_front();
  107. return size;
  108. } else {
  109. read_callback_ = callback;
  110. read_buffer_ = buf;
  111. read_buffer_size_ = buf_len;
  112. return net::ERR_IO_PENDING;
  113. }
  114. }
  115. int Send(const scoped_refptr<net::IOBuffer>& buf,
  116. int buf_len,
  117. const net::CompletionRepeatingCallback& callback) override {
  118. DCHECK(buf);
  119. if (peer_socket_) {
  120. base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
  121. FROM_HERE,
  122. base::BindOnce(&FakeSocket::AppendInputPacket,
  123. base::Unretained(peer_socket_),
  124. std::vector<char>(buf->data(), buf->data() + buf_len)),
  125. base::Milliseconds(latency_ms_));
  126. }
  127. return buf_len;
  128. }
  129. private:
  130. scoped_refptr<net::IOBuffer> read_buffer_;
  131. int read_buffer_size_;
  132. net::CompletionRepeatingCallback read_callback_;
  133. base::circular_deque<std::vector<char>> incoming_packets_;
  134. raw_ptr<FakeSocket> peer_socket_;
  135. raw_ptr<RateLimiter> rate_limiter_;
  136. int latency_ms_;
  137. };
  138. class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> {
  139. public:
  140. TCPChannelTester(scoped_refptr<base::SingleThreadTaskRunner> task_runner,
  141. P2PStreamSocket* client_socket,
  142. P2PStreamSocket* host_socket)
  143. : task_runner_(std::move(task_runner)),
  144. host_socket_(host_socket),
  145. client_socket_(client_socket),
  146. done_(false),
  147. write_errors_(0),
  148. read_errors_(0) {}
  149. void Start() {
  150. task_runner_->PostTask(FROM_HERE,
  151. base::BindOnce(&TCPChannelTester::DoStart, this));
  152. }
  153. void CheckResults() {
  154. EXPECT_EQ(0, write_errors_);
  155. EXPECT_EQ(0, read_errors_);
  156. ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity());
  157. output_buffer_->SetOffset(0);
  158. ASSERT_EQ(kTestDataSize, output_buffer_->size());
  159. EXPECT_EQ(0, memcmp(output_buffer_->data(),
  160. input_buffer_->StartOfBuffer(), kTestDataSize));
  161. }
  162. protected:
  163. virtual ~TCPChannelTester() = default;
  164. void Done() {
  165. done_ = true;
  166. task_runner_->PostTask(
  167. FROM_HERE, base::RunLoop::QuitCurrentWhenIdleClosureDeprecated());
  168. }
  169. void DoStart() {
  170. InitBuffers();
  171. DoRead();
  172. DoWrite();
  173. }
  174. void InitBuffers() {
  175. output_buffer_ = base::MakeRefCounted<net::DrainableIOBuffer>(
  176. base::MakeRefCounted<net::IOBuffer>(kTestDataSize), kTestDataSize);
  177. memset(output_buffer_->data(), 123, kTestDataSize);
  178. input_buffer_ = base::MakeRefCounted<net::GrowableIOBuffer>();
  179. // Always keep kMessageSize bytes available at the end of the input buffer.
  180. input_buffer_->SetCapacity(kMessageSize);
  181. }
  182. void DoWrite() {
  183. int result = 1;
  184. while (result > 0) {
  185. if (output_buffer_->BytesRemaining() == 0)
  186. break;
  187. int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
  188. kMessageSize);
  189. result = client_socket_->Write(
  190. output_buffer_.get(), bytes_to_write,
  191. base::BindOnce(&TCPChannelTester::OnWritten, base::Unretained(this)),
  192. TRAFFIC_ANNOTATION_FOR_TESTS);
  193. HandleWriteResult(result);
  194. }
  195. }
  196. void OnWritten(int result) {
  197. HandleWriteResult(result);
  198. DoWrite();
  199. }
  200. void HandleWriteResult(int result) {
  201. if (result <= 0 && result != net::ERR_IO_PENDING) {
  202. LOG(ERROR) << "Received error " << result << " when trying to write";
  203. write_errors_++;
  204. Done();
  205. } else if (result > 0) {
  206. output_buffer_->DidConsume(result);
  207. }
  208. }
  209. void DoRead() {
  210. int result = 1;
  211. while (result > 0) {
  212. input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize);
  213. result = host_socket_->Read(
  214. input_buffer_.get(), kMessageSize,
  215. base::BindOnce(&TCPChannelTester::OnRead, base::Unretained(this)));
  216. HandleReadResult(result);
  217. };
  218. }
  219. void OnRead(int result) {
  220. HandleReadResult(result);
  221. DoRead();
  222. }
  223. void HandleReadResult(int result) {
  224. if (result <= 0 && result != net::ERR_IO_PENDING) {
  225. if (!done_) {
  226. LOG(ERROR) << "Received error " << result << " when trying to read";
  227. read_errors_++;
  228. Done();
  229. }
  230. } else if (result > 0) {
  231. // Allocate memory for the next read.
  232. input_buffer_->SetCapacity(input_buffer_->capacity() + result);
  233. if (input_buffer_->capacity() == kTestDataSize + kMessageSize)
  234. Done();
  235. }
  236. }
  237. private:
  238. friend class base::RefCountedThreadSafe<TCPChannelTester>;
  239. scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
  240. raw_ptr<P2PStreamSocket> host_socket_;
  241. raw_ptr<P2PStreamSocket> client_socket_;
  242. bool done_;
  243. scoped_refptr<net::DrainableIOBuffer> output_buffer_;
  244. scoped_refptr<net::GrowableIOBuffer> input_buffer_;
  245. int write_errors_;
  246. int read_errors_;
  247. };
  248. class PseudoTcpAdapterTest : public testing::Test {
  249. protected:
  250. void SetUp() override {
  251. webrtc::ThreadWrapper::EnsureForCurrentMessageLoop();
  252. host_socket_ = new FakeSocket();
  253. client_socket_ = new FakeSocket();
  254. host_socket_->Connect(client_socket_);
  255. client_socket_->Connect(host_socket_);
  256. host_pseudotcp_ = std::make_unique<PseudoTcpAdapter>(
  257. base::WrapUnique(host_socket_.get()));
  258. client_pseudotcp_ = std::make_unique<PseudoTcpAdapter>(
  259. base::WrapUnique(client_socket_.get()));
  260. }
  261. raw_ptr<FakeSocket> host_socket_;
  262. raw_ptr<FakeSocket> client_socket_;
  263. std::unique_ptr<PseudoTcpAdapter> host_pseudotcp_;
  264. std::unique_ptr<PseudoTcpAdapter> client_pseudotcp_;
  265. base::test::SingleThreadTaskEnvironment task_environment_;
  266. };
  267. TEST_F(PseudoTcpAdapterTest, DataTransfer) {
  268. net::TestCompletionCallback host_connect_cb;
  269. net::TestCompletionCallback client_connect_cb;
  270. net::CompletionOnceCallback rv1 =
  271. host_pseudotcp_->Connect(host_connect_cb.callback());
  272. ASSERT_FALSE(rv1);
  273. net::CompletionOnceCallback rv2 =
  274. client_pseudotcp_->Connect(client_connect_cb.callback());
  275. ASSERT_FALSE(rv2);
  276. EXPECT_EQ(net::OK, host_connect_cb.WaitForResult());
  277. EXPECT_EQ(net::OK, client_connect_cb.WaitForResult());
  278. scoped_refptr<TCPChannelTester> tester =
  279. new TCPChannelTester(base::ThreadTaskRunnerHandle::Get(),
  280. host_pseudotcp_.get(), client_pseudotcp_.get());
  281. tester->Start();
  282. base::RunLoop().Run();
  283. tester->CheckResults();
  284. }
  285. TEST_F(PseudoTcpAdapterTest, LimitedChannel) {
  286. const int kLatencyMs = 20;
  287. const int kPacketsPerSecond = 400;
  288. const int kBurstPackets = 10;
  289. LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond);
  290. host_socket_->set_latency(kLatencyMs);
  291. host_socket_->set_rate_limiter(&host_limiter);
  292. LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond);
  293. host_socket_->set_latency(kLatencyMs);
  294. client_socket_->set_rate_limiter(&client_limiter);
  295. net::TestCompletionCallback host_connect_cb;
  296. net::TestCompletionCallback client_connect_cb;
  297. net::CompletionOnceCallback rv1 =
  298. host_pseudotcp_->Connect(host_connect_cb.callback());
  299. ASSERT_FALSE(rv1);
  300. net::CompletionOnceCallback rv2 =
  301. client_pseudotcp_->Connect(client_connect_cb.callback());
  302. ASSERT_FALSE(rv2);
  303. EXPECT_EQ(net::OK, host_connect_cb.WaitForResult());
  304. EXPECT_EQ(net::OK, client_connect_cb.WaitForResult());
  305. scoped_refptr<TCPChannelTester> tester =
  306. new TCPChannelTester(base::ThreadTaskRunnerHandle::Get(),
  307. host_pseudotcp_.get(), client_pseudotcp_.get());
  308. tester->Start();
  309. base::RunLoop().Run();
  310. tester->CheckResults();
  311. }
  312. class DeleteOnConnected {
  313. public:
  314. DeleteOnConnected(scoped_refptr<base::SingleThreadTaskRunner> task_runner,
  315. std::unique_ptr<PseudoTcpAdapter>* adapter)
  316. : task_runner_(std::move(task_runner)), adapter_(adapter) {}
  317. void OnConnected(int error) {
  318. adapter_->reset();
  319. task_runner_->PostTask(
  320. FROM_HERE, base::RunLoop::QuitCurrentWhenIdleClosureDeprecated());
  321. }
  322. scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
  323. raw_ptr<std::unique_ptr<PseudoTcpAdapter>> adapter_;
  324. };
  325. TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) {
  326. // This test verifies that deleting the adapter mid-callback doesn't lead
  327. // to deleted structures being touched as the stack unrolls, so the failure
  328. // mode is a crash rather than a normal test failure.
  329. net::TestCompletionCallback client_connect_cb;
  330. DeleteOnConnected host_delete(base::ThreadTaskRunnerHandle::Get(),
  331. &host_pseudotcp_);
  332. host_pseudotcp_->Connect(base::BindOnce(&DeleteOnConnected::OnConnected,
  333. base::Unretained(&host_delete)));
  334. client_pseudotcp_->Connect(client_connect_cb.callback());
  335. base::RunLoop().Run();
  336. ASSERT_EQ(NULL, host_pseudotcp_.get());
  337. }
  338. // Verify that we can send/receive data with the write-waits-for-send
  339. // flag set.
  340. TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) {
  341. net::TestCompletionCallback host_connect_cb;
  342. net::TestCompletionCallback client_connect_cb;
  343. host_pseudotcp_->SetWriteWaitsForSend(true);
  344. client_pseudotcp_->SetWriteWaitsForSend(true);
  345. // Disable Nagle's algorithm because the test is slow when it is
  346. // enabled.
  347. host_pseudotcp_->SetNoDelay(true);
  348. net::CompletionOnceCallback rv1 =
  349. host_pseudotcp_->Connect(host_connect_cb.callback());
  350. ASSERT_FALSE(rv1);
  351. net::CompletionOnceCallback rv2 =
  352. client_pseudotcp_->Connect(client_connect_cb.callback());
  353. ASSERT_FALSE(rv2);
  354. EXPECT_EQ(net::OK, host_connect_cb.WaitForResult());
  355. EXPECT_EQ(net::OK, client_connect_cb.WaitForResult());
  356. scoped_refptr<TCPChannelTester> tester =
  357. new TCPChannelTester(base::ThreadTaskRunnerHandle::Get(),
  358. host_pseudotcp_.get(), client_pseudotcp_.get());
  359. tester->Start();
  360. base::RunLoop().Run();
  361. tester->CheckResults();
  362. }
  363. } // namespace
  364. } // namespace protocol
  365. } // namespace remoting