socks5_client_socket_unittest.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/socket/socks5_client_socket.h"
  5. #include <algorithm>
  6. #include <iterator>
  7. #include <map>
  8. #include <memory>
  9. #include <utility>
  10. #include "base/containers/span.h"
  11. #include "base/memory/ptr_util.h"
  12. #include "base/memory/raw_ptr.h"
  13. #include "base/sys_byteorder.h"
  14. #include "build/build_config.h"
  15. #include "net/base/address_list.h"
  16. #include "net/base/test_completion_callback.h"
  17. #include "net/base/winsock_init.h"
  18. #include "net/log/net_log_event_type.h"
  19. #include "net/log/test_net_log.h"
  20. #include "net/log/test_net_log_util.h"
  21. #include "net/socket/client_socket_factory.h"
  22. #include "net/socket/socket_test_util.h"
  23. #include "net/socket/tcp_client_socket.h"
  24. #include "net/test/gtest_util.h"
  25. #include "net/test/test_with_task_environment.h"
  26. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  27. #include "testing/gmock/include/gmock/gmock.h"
  28. #include "testing/gtest/include/gtest/gtest.h"
  29. #include "testing/platform_test.h"
  30. using net::test::IsError;
  31. using net::test::IsOk;
  32. //-----------------------------------------------------------------------------
  33. namespace net {
  34. class NetLog;
  35. namespace {
  36. // Base class to test SOCKS5ClientSocket
  37. class SOCKS5ClientSocketTest : public PlatformTest, public WithTaskEnvironment {
  38. public:
  39. SOCKS5ClientSocketTest();
  40. SOCKS5ClientSocketTest(const SOCKS5ClientSocketTest&) = delete;
  41. SOCKS5ClientSocketTest& operator=(const SOCKS5ClientSocketTest&) = delete;
  42. // Create a SOCKSClientSocket on top of a MockSocket.
  43. std::unique_ptr<SOCKS5ClientSocket> BuildMockSocket(
  44. base::span<const MockRead> reads,
  45. base::span<const MockWrite> writes,
  46. const std::string& hostname,
  47. int port,
  48. NetLog* net_log);
  49. void SetUp() override;
  50. protected:
  51. const uint16_t kNwPort;
  52. RecordingNetLogObserver net_log_observer_;
  53. std::unique_ptr<SOCKS5ClientSocket> user_sock_;
  54. AddressList address_list_;
  55. // Filled in by BuildMockSocket() and owned by its return value
  56. // (which |user_sock| is set to).
  57. raw_ptr<StreamSocket> tcp_sock_;
  58. TestCompletionCallback callback_;
  59. std::unique_ptr<SocketDataProvider> data_;
  60. };
  61. SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
  62. : kNwPort(base::HostToNet16(80)) {}
  63. // Set up platform before every test case
  64. void SOCKS5ClientSocketTest::SetUp() {
  65. PlatformTest::SetUp();
  66. // Create the "localhost" AddressList used by the TCP connection to connect.
  67. address_list_ =
  68. AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), 1080);
  69. }
  70. std::unique_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket(
  71. base::span<const MockRead> reads,
  72. base::span<const MockWrite> writes,
  73. const std::string& hostname,
  74. int port,
  75. NetLog* net_log) {
  76. TestCompletionCallback callback;
  77. data_ = std::make_unique<StaticSocketDataProvider>(reads, writes);
  78. auto tcp_sock = std::make_unique<MockTCPClientSocket>(address_list_, net_log,
  79. data_.get());
  80. tcp_sock_ = tcp_sock.get();
  81. int rv = tcp_sock_->Connect(callback.callback());
  82. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  83. rv = callback.WaitForResult();
  84. EXPECT_THAT(rv, IsOk());
  85. EXPECT_TRUE(tcp_sock_->IsConnected());
  86. // The SOCKS5ClientSocket takes ownership of |tcp_sock_|, but keep a
  87. // non-owning pointer to it.
  88. return std::make_unique<SOCKS5ClientSocket>(std::move(tcp_sock),
  89. HostPortPair(hostname, port),
  90. TRAFFIC_ANNOTATION_FOR_TESTS);
  91. }
  92. // Tests a complete SOCKS5 handshake and the disconnection.
  93. TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
  94. const std::string payload_write = "random data";
  95. const std::string payload_read = "moar random data";
  96. const char kOkRequest[] = {
  97. 0x05, // Version
  98. 0x01, // Command (CONNECT)
  99. 0x00, // Reserved.
  100. 0x03, // Address type (DOMAINNAME).
  101. 0x09, // Length of domain (9)
  102. // Domain string:
  103. 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
  104. 0x00, 0x50, // 16-bit port (80)
  105. };
  106. MockWrite data_writes[] = {
  107. MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
  108. MockWrite(ASYNC, kOkRequest, std::size(kOkRequest)),
  109. MockWrite(ASYNC, payload_write.data(), payload_write.size())};
  110. MockRead data_reads[] = {
  111. MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
  112. MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
  113. MockRead(ASYNC, payload_read.data(), payload_read.size()) };
  114. user_sock_ =
  115. BuildMockSocket(data_reads, data_writes, "localhost", 80, NetLog::Get());
  116. // At this state the TCP connection is completed but not the SOCKS handshake.
  117. EXPECT_TRUE(tcp_sock_->IsConnected());
  118. EXPECT_FALSE(user_sock_->IsConnected());
  119. int rv = user_sock_->Connect(callback_.callback());
  120. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  121. EXPECT_FALSE(user_sock_->IsConnected());
  122. auto net_log_entries = net_log_observer_.GetEntries();
  123. EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
  124. NetLogEventType::SOCKS5_CONNECT));
  125. rv = callback_.WaitForResult();
  126. EXPECT_THAT(rv, IsOk());
  127. EXPECT_TRUE(user_sock_->IsConnected());
  128. net_log_entries = net_log_observer_.GetEntries();
  129. EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
  130. NetLogEventType::SOCKS5_CONNECT));
  131. scoped_refptr<IOBuffer> buffer =
  132. base::MakeRefCounted<IOBuffer>(payload_write.size());
  133. memcpy(buffer->data(), payload_write.data(), payload_write.size());
  134. rv = user_sock_->Write(buffer.get(), payload_write.size(),
  135. callback_.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
  136. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  137. rv = callback_.WaitForResult();
  138. EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
  139. buffer = base::MakeRefCounted<IOBuffer>(payload_read.size());
  140. rv =
  141. user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
  142. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  143. rv = callback_.WaitForResult();
  144. EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
  145. EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
  146. user_sock_->Disconnect();
  147. EXPECT_FALSE(tcp_sock_->IsConnected());
  148. EXPECT_FALSE(user_sock_->IsConnected());
  149. }
  150. // Test that you can call Connect() again after having called Disconnect().
  151. TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
  152. const std::string hostname = "my-host-name";
  153. const char kSOCKS5DomainRequest[] = {
  154. 0x05, // VER
  155. 0x01, // CMD
  156. 0x00, // RSV
  157. 0x03, // ATYPE
  158. };
  159. std::string request(kSOCKS5DomainRequest, std::size(kSOCKS5DomainRequest));
  160. request.push_back(static_cast<char>(hostname.size()));
  161. request.append(hostname);
  162. request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
  163. for (int i = 0; i < 2; ++i) {
  164. MockWrite data_writes[] = {
  165. MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
  166. MockWrite(SYNCHRONOUS, request.data(), request.size())
  167. };
  168. MockRead data_reads[] = {
  169. MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
  170. MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
  171. };
  172. user_sock_ =
  173. BuildMockSocket(data_reads, data_writes, hostname, 80, nullptr);
  174. int rv = user_sock_->Connect(callback_.callback());
  175. EXPECT_THAT(rv, IsOk());
  176. EXPECT_TRUE(user_sock_->IsConnected());
  177. user_sock_->Disconnect();
  178. EXPECT_FALSE(user_sock_->IsConnected());
  179. }
  180. }
  181. // Test that we fail trying to connect to a hostname longer than 255 bytes.
  182. TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
  183. // Create a string of length 256, where each character is 'x'.
  184. std::string large_host_name;
  185. std::fill_n(std::back_inserter(large_host_name), 256, 'x');
  186. // Create a SOCKS socket, with mock transport socket.
  187. MockWrite data_writes[] = {MockWrite()};
  188. MockRead data_reads[] = {MockRead()};
  189. user_sock_ =
  190. BuildMockSocket(data_reads, data_writes, large_host_name, 80, nullptr);
  191. // Try to connect -- should fail (without having read/written anything to
  192. // the transport socket first) because the hostname is too long.
  193. TestCompletionCallback callback;
  194. int rv = user_sock_->Connect(callback.callback());
  195. EXPECT_THAT(rv, IsError(ERR_SOCKS_CONNECTION_FAILED));
  196. }
  197. TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
  198. const std::string hostname = "www.google.com";
  199. const char kOkRequest[] = {
  200. 0x05, // Version
  201. 0x01, // Command (CONNECT)
  202. 0x00, // Reserved.
  203. 0x03, // Address type (DOMAINNAME).
  204. 0x0E, // Length of domain (14)
  205. // Domain string:
  206. 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
  207. 0x00, 0x50, // 16-bit port (80)
  208. };
  209. // Test for partial greet request write
  210. {
  211. const char partial1[] = { 0x05, 0x01 };
  212. const char partial2[] = { 0x00 };
  213. MockWrite data_writes[] = {
  214. MockWrite(ASYNC, partial1, std::size(partial1)),
  215. MockWrite(ASYNC, partial2, std::size(partial2)),
  216. MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
  217. MockRead data_reads[] = {
  218. MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
  219. MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
  220. user_sock_ =
  221. BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
  222. int rv = user_sock_->Connect(callback_.callback());
  223. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  224. auto net_log_entries = net_log_observer_.GetEntries();
  225. EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
  226. NetLogEventType::SOCKS5_CONNECT));
  227. rv = callback_.WaitForResult();
  228. EXPECT_THAT(rv, IsOk());
  229. EXPECT_TRUE(user_sock_->IsConnected());
  230. net_log_entries = net_log_observer_.GetEntries();
  231. EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
  232. NetLogEventType::SOCKS5_CONNECT));
  233. }
  234. // Test for partial greet response read
  235. {
  236. const char partial1[] = { 0x05 };
  237. const char partial2[] = { 0x00 };
  238. MockWrite data_writes[] = {
  239. MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
  240. MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
  241. MockRead data_reads[] = {
  242. MockRead(ASYNC, partial1, std::size(partial1)),
  243. MockRead(ASYNC, partial2, std::size(partial2)),
  244. MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength)};
  245. user_sock_ =
  246. BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
  247. int rv = user_sock_->Connect(callback_.callback());
  248. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  249. auto net_log_entries = net_log_observer_.GetEntries();
  250. EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
  251. NetLogEventType::SOCKS5_CONNECT));
  252. rv = callback_.WaitForResult();
  253. EXPECT_THAT(rv, IsOk());
  254. EXPECT_TRUE(user_sock_->IsConnected());
  255. net_log_entries = net_log_observer_.GetEntries();
  256. EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
  257. NetLogEventType::SOCKS5_CONNECT));
  258. }
  259. // Test for partial handshake request write.
  260. {
  261. const int kSplitPoint = 3; // Break handshake write into two parts.
  262. MockWrite data_writes[] = {
  263. MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
  264. MockWrite(ASYNC, kOkRequest, kSplitPoint),
  265. MockWrite(ASYNC, kOkRequest + kSplitPoint,
  266. std::size(kOkRequest) - kSplitPoint)};
  267. MockRead data_reads[] = {
  268. MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
  269. MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
  270. user_sock_ =
  271. BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
  272. int rv = user_sock_->Connect(callback_.callback());
  273. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  274. auto net_log_entries = net_log_observer_.GetEntries();
  275. EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
  276. NetLogEventType::SOCKS5_CONNECT));
  277. rv = callback_.WaitForResult();
  278. EXPECT_THAT(rv, IsOk());
  279. EXPECT_TRUE(user_sock_->IsConnected());
  280. net_log_entries = net_log_observer_.GetEntries();
  281. EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
  282. NetLogEventType::SOCKS5_CONNECT));
  283. }
  284. // Test for partial handshake response read
  285. {
  286. const int kSplitPoint = 6; // Break the handshake read into two parts.
  287. MockWrite data_writes[] = {
  288. MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
  289. MockWrite(ASYNC, kOkRequest, std::size(kOkRequest))};
  290. MockRead data_reads[] = {
  291. MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
  292. MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
  293. MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
  294. kSOCKS5OkResponseLength - kSplitPoint)
  295. };
  296. user_sock_ =
  297. BuildMockSocket(data_reads, data_writes, hostname, 80, NetLog::Get());
  298. int rv = user_sock_->Connect(callback_.callback());
  299. EXPECT_THAT(rv, IsError(ERR_IO_PENDING));
  300. auto net_log_entries = net_log_observer_.GetEntries();
  301. EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
  302. NetLogEventType::SOCKS5_CONNECT));
  303. rv = callback_.WaitForResult();
  304. EXPECT_THAT(rv, IsOk());
  305. EXPECT_TRUE(user_sock_->IsConnected());
  306. net_log_entries = net_log_observer_.GetEntries();
  307. EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
  308. NetLogEventType::SOCKS5_CONNECT));
  309. }
  310. }
  311. TEST_F(SOCKS5ClientSocketTest, Tag) {
  312. StaticSocketDataProvider data;
  313. auto tagging_sock = std::make_unique<MockTaggingStreamSocket>(
  314. std::make_unique<MockTCPClientSocket>(address_list_, NetLog::Get(),
  315. &data));
  316. auto* tagging_sock_ptr = tagging_sock.get();
  317. // |socket| takes ownership of |tagging_sock|, but keep a non-owning pointer
  318. // to it.
  319. SOCKS5ClientSocket socket(std::move(tagging_sock),
  320. HostPortPair("localhost", 80),
  321. TRAFFIC_ANNOTATION_FOR_TESTS);
  322. EXPECT_EQ(tagging_sock_ptr->tag(), SocketTag());
  323. #if BUILDFLAG(IS_ANDROID)
  324. SocketTag tag(0x12345678, 0x87654321);
  325. socket.ApplySocketTag(tag);
  326. EXPECT_EQ(tagging_sock_ptr->tag(), tag);
  327. #endif // BUILDFLAG(IS_ANDROID)
  328. }
  329. } // namespace
  330. } // namespace net