buffered_socket_writer_unittest.cc 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. // Copyright 2015 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "remoting/base/buffered_socket_writer.h"
  5. #include <stddef.h>
  6. #include <stdlib.h>
  7. #include <memory>
  8. #include "base/bind.h"
  9. #include "base/callback.h"
  10. #include "base/run_loop.h"
  11. #include "base/test/task_environment.h"
  12. #include "net/base/io_buffer.h"
  13. #include "net/base/net_errors.h"
  14. #include "net/log/net_log.h"
  15. #include "net/socket/socket_test_util.h"
  16. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  17. #include "testing/gmock/include/gmock/gmock.h"
  18. #include "testing/gtest/include/gtest/gtest.h"
  19. namespace remoting {
  20. namespace {
  21. const int kTestBufferSize = 10000;
  22. const size_t kWriteChunkSize = 1024U;
  23. int WriteNetSocket(net::Socket* socket,
  24. const scoped_refptr<net::IOBuffer>& buf,
  25. int buf_len,
  26. net::CompletionOnceCallback callback,
  27. const net::NetworkTrafficAnnotationTag& traffic_annotation) {
  28. return socket->Write(buf.get(), buf_len, std::move(callback),
  29. traffic_annotation);
  30. }
  31. class SocketDataProvider: public net::SocketDataProvider {
  32. public:
  33. SocketDataProvider()
  34. : write_limit_(-1), async_write_(false), next_write_error_(net::OK) {}
  35. net::MockRead OnRead() override {
  36. return net::MockRead(net::ASYNC, net::ERR_IO_PENDING);
  37. }
  38. net::MockWriteResult OnWrite(const std::string& data) override {
  39. if (next_write_error_ != net::OK) {
  40. int r = next_write_error_;
  41. next_write_error_ = net::OK;
  42. return net::MockWriteResult(async_write_ ? net::ASYNC : net::SYNCHRONOUS,
  43. r);
  44. }
  45. int size = data.size();
  46. if (write_limit_ > 0)
  47. size = std::min(write_limit_, size);
  48. written_data_.append(data, 0, size);
  49. return net::MockWriteResult(async_write_ ? net::ASYNC : net::SYNCHRONOUS,
  50. size);
  51. }
  52. bool AllReadDataConsumed() const override {
  53. return true;
  54. }
  55. bool AllWriteDataConsumed() const override {
  56. return true;
  57. }
  58. void Reset() override {}
  59. std::string written_data() { return written_data_; }
  60. void set_write_limit(int limit) { write_limit_ = limit; }
  61. void set_async_write(bool async_write) { async_write_ = async_write; }
  62. void set_next_write_error(int error) { next_write_error_ = error; }
  63. private:
  64. std::string written_data_;
  65. int write_limit_;
  66. bool async_write_;
  67. int next_write_error_;
  68. };
  69. } // namespace
  70. class BufferedSocketWriterTest : public testing::Test {
  71. public:
  72. BufferedSocketWriterTest()
  73. : write_error_(0) {
  74. }
  75. void DestroyWriter() {
  76. writer_.reset();
  77. socket_.reset();
  78. }
  79. void Unexpected() {
  80. EXPECT_TRUE(false);
  81. }
  82. protected:
  83. void SetUp() override {
  84. socket_ = std::make_unique<net::MockTCPClientSocket>(
  85. net::AddressList(), net::NetLog::Get(), &socket_data_provider_);
  86. socket_data_provider_.set_connect_data(
  87. net::MockConnect(net::SYNCHRONOUS, net::OK));
  88. EXPECT_EQ(net::OK, socket_->Connect(net::CompletionOnceCallback()));
  89. writer_ = std::make_unique<BufferedSocketWriter>();
  90. test_buffer_ = base::MakeRefCounted<net::IOBufferWithSize>(kTestBufferSize);
  91. test_buffer_2_ =
  92. base::MakeRefCounted<net::IOBufferWithSize>(kTestBufferSize);
  93. for (int i = 0; i < kTestBufferSize; ++i) {
  94. test_buffer_->data()[i] = rand() % 256;
  95. test_buffer_2_->data()[i] = rand() % 256;
  96. }
  97. }
  98. void StartWriter() {
  99. writer_->Start(base::BindRepeating(&WriteNetSocket, socket_.get()),
  100. base::BindOnce(&BufferedSocketWriterTest::OnWriteFailed,
  101. base::Unretained(this)));
  102. }
  103. void OnWriteFailed(int error) {
  104. write_error_ = error;
  105. }
  106. void VerifyWrittenData() {
  107. ASSERT_EQ(static_cast<size_t>(test_buffer_->size() +
  108. test_buffer_2_->size()),
  109. socket_data_provider_.written_data().size());
  110. EXPECT_EQ(0, memcmp(test_buffer_->data(),
  111. socket_data_provider_.written_data().data(),
  112. test_buffer_->size()));
  113. EXPECT_EQ(0, memcmp(test_buffer_2_->data(),
  114. socket_data_provider_.written_data().data() +
  115. test_buffer_->size(),
  116. test_buffer_2_->size()));
  117. }
  118. void TestWrite() {
  119. writer_->Write(test_buffer_, base::OnceClosure(),
  120. TRAFFIC_ANNOTATION_FOR_TESTS);
  121. writer_->Write(test_buffer_2_, base::OnceClosure(),
  122. TRAFFIC_ANNOTATION_FOR_TESTS);
  123. base::RunLoop().RunUntilIdle();
  124. VerifyWrittenData();
  125. }
  126. void TestAppendInCallback() {
  127. writer_->Write(
  128. test_buffer_,
  129. base::BindOnce(base::IgnoreResult(&BufferedSocketWriter::Write),
  130. base::Unretained(writer_.get()), test_buffer_2_,
  131. base::OnceClosure(), TRAFFIC_ANNOTATION_FOR_TESTS),
  132. TRAFFIC_ANNOTATION_FOR_TESTS);
  133. base::RunLoop().RunUntilIdle();
  134. VerifyWrittenData();
  135. }
  136. base::test::SingleThreadTaskEnvironment task_environment_;
  137. SocketDataProvider socket_data_provider_;
  138. std::unique_ptr<net::StreamSocket> socket_;
  139. std::unique_ptr<BufferedSocketWriter> writer_;
  140. scoped_refptr<net::IOBufferWithSize> test_buffer_;
  141. scoped_refptr<net::IOBufferWithSize> test_buffer_2_;
  142. int write_error_;
  143. };
  144. // Test synchronous write.
  145. TEST_F(BufferedSocketWriterTest, WriteFull) {
  146. StartWriter();
  147. TestWrite();
  148. }
  149. // Test synchronous write in 1k chunks.
  150. TEST_F(BufferedSocketWriterTest, WriteChunks) {
  151. StartWriter();
  152. socket_data_provider_.set_write_limit(kWriteChunkSize);
  153. TestWrite();
  154. }
  155. // Test asynchronous write.
  156. TEST_F(BufferedSocketWriterTest, WriteAsync) {
  157. StartWriter();
  158. socket_data_provider_.set_async_write(true);
  159. socket_data_provider_.set_write_limit(kWriteChunkSize);
  160. TestWrite();
  161. }
  162. // Make sure we can call Write() from the done callback.
  163. TEST_F(BufferedSocketWriterTest, AppendInCallbackSync) {
  164. StartWriter();
  165. TestAppendInCallback();
  166. }
  167. // Make sure we can call Write() from the done callback.
  168. TEST_F(BufferedSocketWriterTest, AppendInCallbackAsync) {
  169. StartWriter();
  170. socket_data_provider_.set_async_write(true);
  171. socket_data_provider_.set_write_limit(kWriteChunkSize);
  172. TestAppendInCallback();
  173. }
  174. // Test that the writer can be destroyed from callback.
  175. TEST_F(BufferedSocketWriterTest, DestroyFromCallback) {
  176. StartWriter();
  177. socket_data_provider_.set_async_write(true);
  178. writer_->Write(test_buffer_,
  179. base::BindOnce(&BufferedSocketWriterTest::DestroyWriter,
  180. base::Unretained(this)),
  181. TRAFFIC_ANNOTATION_FOR_TESTS);
  182. writer_->Write(test_buffer_2_,
  183. base::BindOnce(&BufferedSocketWriterTest::Unexpected,
  184. base::Unretained(this)),
  185. TRAFFIC_ANNOTATION_FOR_TESTS);
  186. socket_data_provider_.set_async_write(false);
  187. base::RunLoop().RunUntilIdle();
  188. ASSERT_GE(socket_data_provider_.written_data().size(),
  189. static_cast<size_t>(test_buffer_->size()));
  190. EXPECT_EQ(0, memcmp(test_buffer_->data(),
  191. socket_data_provider_.written_data().data(),
  192. test_buffer_->size()));
  193. }
  194. // Verify that it stops writing after the first error.
  195. TEST_F(BufferedSocketWriterTest, TestWriteErrorSync) {
  196. StartWriter();
  197. socket_data_provider_.set_write_limit(kWriteChunkSize);
  198. writer_->Write(test_buffer_, base::OnceClosure(),
  199. TRAFFIC_ANNOTATION_FOR_TESTS);
  200. socket_data_provider_.set_async_write(true);
  201. socket_data_provider_.set_next_write_error(net::ERR_FAILED);
  202. writer_->Write(test_buffer_2_,
  203. base::BindOnce(&BufferedSocketWriterTest::Unexpected,
  204. base::Unretained(this)),
  205. TRAFFIC_ANNOTATION_FOR_TESTS);
  206. socket_data_provider_.set_async_write(false);
  207. base::RunLoop().RunUntilIdle();
  208. EXPECT_EQ(net::ERR_FAILED, write_error_);
  209. EXPECT_EQ(static_cast<size_t>(test_buffer_->size()),
  210. socket_data_provider_.written_data().size());
  211. }
  212. // Verify that it stops writing after the first error.
  213. TEST_F(BufferedSocketWriterTest, TestWriteErrorAsync) {
  214. StartWriter();
  215. socket_data_provider_.set_write_limit(kWriteChunkSize);
  216. writer_->Write(test_buffer_, base::OnceClosure(),
  217. TRAFFIC_ANNOTATION_FOR_TESTS);
  218. socket_data_provider_.set_async_write(true);
  219. socket_data_provider_.set_next_write_error(net::ERR_FAILED);
  220. writer_->Write(test_buffer_2_,
  221. base::BindOnce(&BufferedSocketWriterTest::Unexpected,
  222. base::Unretained(this)),
  223. TRAFFIC_ANNOTATION_FOR_TESTS);
  224. base::RunLoop().RunUntilIdle();
  225. EXPECT_EQ(net::ERR_FAILED, write_error_);
  226. EXPECT_EQ(static_cast<size_t>(test_buffer_->size()),
  227. socket_data_provider_.written_data().size());
  228. }
  229. TEST_F(BufferedSocketWriterTest, WriteBeforeStart) {
  230. writer_->Write(test_buffer_, base::OnceClosure(),
  231. TRAFFIC_ANNOTATION_FOR_TESTS);
  232. writer_->Write(test_buffer_2_, base::OnceClosure(),
  233. TRAFFIC_ANNOTATION_FOR_TESTS);
  234. StartWriter();
  235. base::RunLoop().RunUntilIdle();
  236. VerifyWrittenData();
  237. }
  238. } // namespace remoting