socket_test_util.h 49 KB


  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
  5. #define NET_SOCKET_SOCKET_TEST_UTIL_H_
  6. #include <stddef.h>
  7. #include <stdint.h>
  8. #include <cstring>
  9. #include <memory>
  10. #include <string>
  11. #include <utility>
  12. #include <vector>
  13. #include "base/bind.h"
  14. #include "base/callback.h"
  15. #include "base/check_op.h"
  16. #include "base/containers/span.h"
  17. #include "base/memory/ptr_util.h"
  18. #include "base/memory/raw_ptr.h"
  19. #include "base/memory/ref_counted.h"
  20. #include "base/memory/weak_ptr.h"
  21. #include "build/build_config.h"
  22. #include "net/base/address_list.h"
  23. #include "net/base/completion_once_callback.h"
  24. #include "net/base/io_buffer.h"
  25. #include "net/base/net_errors.h"
  26. #include "net/base/test_completion_callback.h"
  27. #include "net/http/http_auth_controller.h"
  28. #include "net/log/net_log_with_source.h"
  29. #include "net/socket/client_socket_factory.h"
  30. #include "net/socket/client_socket_handle.h"
  31. #include "net/socket/client_socket_pool.h"
  32. #include "net/socket/datagram_client_socket.h"
  33. #include "net/socket/socket_performance_watcher.h"
  34. #include "net/socket/socket_tag.h"
  35. #include "net/socket/ssl_client_socket.h"
  36. #include "net/socket/transport_client_socket.h"
  37. #include "net/socket/transport_client_socket_pool.h"
  38. #include "net/ssl/ssl_config_service.h"
  39. #include "net/ssl/ssl_info.h"
  40. #include "testing/gtest/include/gtest/gtest.h"
  41. #include "third_party/abseil-cpp/absl/types/optional.h"
  42. namespace base {
  43. class RunLoop;
  44. }
  45. namespace net {
  46. struct CommonConnectJobParams;
  47. class NetLog;
  48. struct NetworkTrafficAnnotationTag;
  49. class X509Certificate;
  50. const handles::NetworkHandle kDefaultNetworkForTests = 1;
  51. const handles::NetworkHandle kNewNetworkForTests = 2;
  52. enum {
  53. // A private network error code used by the socket test utility classes.
  54. // If the |result| member of a MockRead is
  55. // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
  56. // marker that indicates the peer will close the connection after the next
  57. // MockRead. The other members of that MockRead are ignored.
  58. ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
  59. };
  60. class AsyncSocket;
  61. class MockClientSocket;
  62. class SSLClientSocket;
  63. class StreamSocket;
  64. enum IoMode { ASYNC, SYNCHRONOUS };
  65. struct MockConnect {
  66. // Asynchronous connection success.
  67. // Creates a MockConnect with |mode| ASYC, |result| OK, and
  68. // |peer_addr| 192.0.2.33.
  69. MockConnect();
  70. // Creates a MockConnect with the specified mode and result, with
  71. // |peer_addr| 192.0.2.33.
  72. MockConnect(IoMode io_mode, int r);
  73. MockConnect(IoMode io_mode, int r, IPEndPoint addr);
  74. MockConnect(IoMode io_mode, int r, IPEndPoint addr, bool first_attempt_fails);
  75. ~MockConnect();
  76. IoMode mode;
  77. int result;
  78. IPEndPoint peer_addr;
  79. bool first_attempt_fails = false;
  80. };
  81. struct MockConfirm {
  82. // Asynchronous confirm success.
  83. // Creates a MockConfirm with |mode| ASYC and |result| OK.
  84. MockConfirm();
  85. // Creates a MockConfirm with the specified mode and result.
  86. MockConfirm(IoMode io_mode, int r);
  87. ~MockConfirm();
  88. IoMode mode;
  89. int result;
  90. };
  91. // MockRead and MockWrite shares the same interface and members, but we'd like
  92. // to have distinct types because we don't want to have them used
  93. // interchangably. To do this, a struct template is defined, and MockRead and
  94. // MockWrite are instantiated by using this template. Template parameter |type|
  95. // is not used in the struct definition (it purely exists for creating a new
  96. // type).
  97. //
  98. // |data| in MockRead and MockWrite has different meanings: |data| in MockRead
  99. // is the data returned from the socket when MockTCPClientSocket::Read() is
  100. // attempted, while |data| in MockWrite is the expected data that should be
  101. // given in MockTCPClientSocket::Write().
  102. enum MockReadWriteType { MOCK_READ, MOCK_WRITE };
  103. template <MockReadWriteType type>
  104. struct MockReadWrite {
  105. // Flag to indicate that the message loop should be terminated.
  106. enum { STOPLOOP = 1 << 31 };
  107. // Default
  108. MockReadWrite()
  109. : mode(SYNCHRONOUS),
  110. result(0),
  111. data(nullptr),
  112. data_len(0),
  113. sequence_number(0) {}
  114. // Read/write failure (no data).
  115. MockReadWrite(IoMode io_mode, int result)
  116. : mode(io_mode),
  117. result(result),
  118. data(nullptr),
  119. data_len(0),
  120. sequence_number(0) {}
  121. // Read/write failure (no data), with sequence information.
  122. MockReadWrite(IoMode io_mode, int result, int seq)
  123. : mode(io_mode),
  124. result(result),
  125. data(nullptr),
  126. data_len(0),
  127. sequence_number(seq) {}
  128. // Asynchronous read/write success (inferred data length).
  129. explicit MockReadWrite(const char* data)
  130. : mode(ASYNC),
  131. result(0),
  132. data(data),
  133. data_len(strlen(data)),
  134. sequence_number(0) {}
  135. // Read/write success (inferred data length).
  136. MockReadWrite(IoMode io_mode, const char* data)
  137. : mode(io_mode),
  138. result(0),
  139. data(data),
  140. data_len(strlen(data)),
  141. sequence_number(0) {}
  142. // Read/write success.
  143. MockReadWrite(IoMode io_mode, const char* data, int data_len)
  144. : mode(io_mode),
  145. result(0),
  146. data(data),
  147. data_len(data_len),
  148. sequence_number(0) {}
  149. // Read/write success (inferred data length) with sequence information.
  150. MockReadWrite(IoMode io_mode, int seq, const char* data)
  151. : mode(io_mode),
  152. result(0),
  153. data(data),
  154. data_len(strlen(data)),
  155. sequence_number(seq) {}
  156. // Read/write success with sequence information.
  157. MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq)
  158. : mode(io_mode),
  159. result(0),
  160. data(data),
  161. data_len(data_len),
  162. sequence_number(seq) {}
  163. IoMode mode;
  164. int result;
  165. const char* data;
  166. int data_len;
  167. // For data providers that only allows reads to occur in a particular
  168. // sequence. If a read occurs before the given |sequence_number| is reached,
  169. // an ERR_IO_PENDING is returned.
  170. int sequence_number; // The sequence number at which a read is allowed
  171. // to occur.
  172. };
  173. typedef MockReadWrite<MOCK_READ> MockRead;
  174. typedef MockReadWrite<MOCK_WRITE> MockWrite;
  175. struct MockWriteResult {
  176. MockWriteResult(IoMode io_mode, int result) : mode(io_mode), result(result) {}
  177. IoMode mode;
  178. int result;
  179. };
  180. // The SocketDataProvider is an interface used by the MockClientSocket
  181. // for getting data about individual reads and writes on the socket. Can be
  182. // used with at most one socket at a time.
  183. // TODO(mmenke): Do these really need to be re-useable?
  184. class SocketDataProvider {
  185. public:
  186. SocketDataProvider();
  187. SocketDataProvider(const SocketDataProvider&) = delete;
  188. SocketDataProvider& operator=(const SocketDataProvider&) = delete;
  189. virtual ~SocketDataProvider();
  190. // Returns the buffer and result code for the next simulated read.
  191. // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
  192. // that it will be called via the AsyncSocket::OnReadComplete()
  193. // function at a later time.
  194. virtual MockRead OnRead() = 0;
  195. virtual MockWriteResult OnWrite(const std::string& data) = 0;
  196. virtual bool AllReadDataConsumed() const = 0;
  197. virtual bool AllWriteDataConsumed() const = 0;
  198. virtual void CancelPendingRead() {}
  199. // Returns the last set receive buffer size, or -1 if never set.
  200. int receive_buffer_size() const { return receive_buffer_size_; }
  201. void set_receive_buffer_size(int receive_buffer_size) {
  202. receive_buffer_size_ = receive_buffer_size;
  203. }
  204. // Returns the last set send buffer size, or -1 if never set.
  205. int send_buffer_size() const { return send_buffer_size_; }
  206. void set_send_buffer_size(int send_buffer_size) {
  207. send_buffer_size_ = send_buffer_size;
  208. }
  209. // Returns the last set value of TCP no delay, or false if never set.
  210. bool no_delay() const { return no_delay_; }
  211. void set_no_delay(bool no_delay) { no_delay_ = no_delay; }
  212. // Returns whether TCP keepalives were enabled or not. Returns kDefault by
  213. // default.
  214. enum class KeepAliveState { kEnabled, kDisabled, kDefault };
  215. KeepAliveState keep_alive_state() const { return keep_alive_state_; }
  216. // Last set TCP keepalive delay.
  217. int keep_alive_delay() const { return keep_alive_delay_; }
  218. void set_keep_alive(bool enable, int delay) {
  219. keep_alive_state_ =
  220. enable ? KeepAliveState::kEnabled : KeepAliveState::kDisabled;
  221. keep_alive_delay_ = delay;
  222. }
  223. // Setters / getters for the return values of the corresponding Set*()
  224. // methods. By default, they all succeed, if the socket is connected.
  225. void set_set_receive_buffer_size_result(int receive_buffer_size_result) {
  226. set_receive_buffer_size_result_ = receive_buffer_size_result;
  227. }
  228. int set_receive_buffer_size_result() const {
  229. return set_receive_buffer_size_result_;
  230. }
  231. void set_set_send_buffer_size_result(int set_send_buffer_size_result) {
  232. set_send_buffer_size_result_ = set_send_buffer_size_result;
  233. }
  234. int set_send_buffer_size_result() const {
  235. return set_send_buffer_size_result_;
  236. }
  237. void set_set_no_delay_result(bool set_no_delay_result) {
  238. set_no_delay_result_ = set_no_delay_result;
  239. }
  240. bool set_no_delay_result() const { return set_no_delay_result_; }
  241. void set_set_keep_alive_result(bool set_keep_alive_result) {
  242. set_keep_alive_result_ = set_keep_alive_result;
  243. }
  244. bool set_keep_alive_result() const { return set_keep_alive_result_; }
  245. const absl::optional<AddressList>& expected_addresses() const {
  246. return expected_addresses_;
  247. }
  248. void set_expected_addresses(net::AddressList addresses) {
  249. expected_addresses_ = std::move(addresses);
  250. }
  251. // Returns true if the request should be considered idle, for the purposes of
  252. // IsConnectedAndIdle.
  253. virtual bool IsIdle() const;
  254. // Initializes the SocketDataProvider for use with |socket|. Must be called
  255. // before use
  256. void Initialize(AsyncSocket* socket);
  257. // Detaches the socket associated with a SocketDataProvider. Must be called
  258. // before |socket_| is destroyed, unless the SocketDataProvider has informed
  259. // |socket_| it was destroyed. Must also be called before Initialize() may
  260. // be called again with a new socket.
  261. void DetachSocket();
  262. // Accessor for the socket which is using the SocketDataProvider.
  263. AsyncSocket* socket() { return socket_; }
  264. MockConnect connect_data() const { return connect_; }
  265. void set_connect_data(const MockConnect& connect) { connect_ = connect; }
  266. private:
  267. // Called to inform subclasses of initialization.
  268. virtual void Reset() = 0;
  269. MockConnect connect_;
  270. raw_ptr<AsyncSocket> socket_ = nullptr;
  271. int receive_buffer_size_ = -1;
  272. int send_buffer_size_ = -1;
  273. // This reflects the default state of TCPClientSockets.
  274. bool no_delay_ = true;
  275. KeepAliveState keep_alive_state_ = KeepAliveState::kDefault;
  276. int keep_alive_delay_ = 0;
  277. int set_receive_buffer_size_result_ = net::OK;
  278. int set_send_buffer_size_result_ = net::OK;
  279. bool set_no_delay_result_ = true;
  280. bool set_keep_alive_result_ = true;
  281. absl::optional<AddressList> expected_addresses_;
  282. };
  283. // The AsyncSocket is an interface used by the SocketDataProvider to
  284. // complete the asynchronous read operation.
  285. class AsyncSocket {
  286. public:
  287. // If an async IO is pending because the SocketDataProvider returned
  288. // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
  289. // is called to complete the asynchronous read operation.
  290. // data.async is ignored, and this read is completed synchronously as
  291. // part of this call.
  292. // TODO(rch): this should take a StringPiece since most of the fields
  293. // are ignored.
  294. virtual void OnReadComplete(const MockRead& data) = 0;
  295. // If an async IO is pending because the SocketDataProvider returned
  296. // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
  297. // is called to complete the asynchronous read operation.
  298. virtual void OnWriteComplete(int rv) = 0;
  299. virtual void OnConnectComplete(const MockConnect& data) = 0;
  300. // Called when the SocketDataProvider associated with the socket is destroyed.
  301. // The socket may continue to be used after the data provider is destroyed,
  302. // so it should be sure not to dereference the provider after this is called.
  303. virtual void OnDataProviderDestroyed() = 0;
  304. };
  305. class SocketDataPrinter {
  306. public:
  307. ~SocketDataPrinter() = default;
  308. // Prints the write in |data| using some sort of protocol-specific
  309. // format.
  310. virtual std::string PrintWrite(const std::string& data) = 0;
  311. };
  312. // StaticSocketDataHelper manages a list of reads and writes.
  313. class StaticSocketDataHelper {
  314. public:
  315. StaticSocketDataHelper(base::span<const MockRead> reads,
  316. base::span<const MockWrite> writes);
  317. StaticSocketDataHelper(const StaticSocketDataHelper&) = delete;
  318. StaticSocketDataHelper& operator=(const StaticSocketDataHelper&) = delete;
  319. ~StaticSocketDataHelper();
  320. // These functions get access to the next available read and write data. They
  321. // CHECK fail if there is no data available.
  322. const MockRead& PeekRead() const;
  323. const MockWrite& PeekWrite() const;
  324. // Returns the current read or write, and then advances to the next one.
  325. const MockRead& AdvanceRead();
  326. const MockWrite& AdvanceWrite();
  327. // Resets the read and write indexes to 0.
  328. void Reset();
  329. // Returns true if |data| is valid data for the next write. In order
  330. // to support short writes, the next write may be longer than |data|
  331. // in which case this method will still return true.
  332. bool VerifyWriteData(const std::string& data, SocketDataPrinter* printer);
  333. size_t read_index() const { return read_index_; }
  334. size_t write_index() const { return write_index_; }
  335. size_t read_count() const { return reads_.size(); }
  336. size_t write_count() const { return writes_.size(); }
  337. bool AllReadDataConsumed() const { return read_index() >= read_count(); }
  338. bool AllWriteDataConsumed() const { return write_index() >= write_count(); }
  339. private:
  340. // Returns the next available read or write that is not a pause event. CHECK
  341. // fails if no data is available.
  342. const MockWrite& PeekRealWrite() const;
  343. const base::span<const MockRead> reads_;
  344. size_t read_index_ = 0;
  345. const base::span<const MockWrite> writes_;
  346. size_t write_index_ = 0;
  347. };
  348. // SocketDataProvider which responds based on static tables of mock reads and
  349. // writes.
  350. class StaticSocketDataProvider : public SocketDataProvider {
  351. public:
  352. StaticSocketDataProvider();
  353. StaticSocketDataProvider(base::span<const MockRead> reads,
  354. base::span<const MockWrite> writes);
  355. StaticSocketDataProvider(const StaticSocketDataProvider&) = delete;
  356. StaticSocketDataProvider& operator=(const StaticSocketDataProvider&) = delete;
  357. ~StaticSocketDataProvider() override;
  358. // Pause/resume reads from this provider.
  359. void Pause();
  360. void Resume();
  361. // From SocketDataProvider:
  362. MockRead OnRead() override;
  363. MockWriteResult OnWrite(const std::string& data) override;
  364. bool AllReadDataConsumed() const override;
  365. bool AllWriteDataConsumed() const override;
  366. size_t read_index() const { return helper_.read_index(); }
  367. size_t write_index() const { return helper_.write_index(); }
  368. size_t read_count() const { return helper_.read_count(); }
  369. size_t write_count() const { return helper_.write_count(); }
  370. void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
  371. private:
  372. // From SocketDataProvider:
  373. void Reset() override;
  374. StaticSocketDataHelper helper_;
  375. SocketDataPrinter* printer_ = nullptr;
  376. bool paused_ = false;
  377. };
  378. // SSLSocketDataProviders only need to keep track of the return code from calls
  379. // to Connect().
  380. struct SSLSocketDataProvider {
  381. SSLSocketDataProvider(IoMode mode, int result);
  382. SSLSocketDataProvider(const SSLSocketDataProvider& other);
  383. ~SSLSocketDataProvider();
  384. // Returns whether MockConnect data has been consumed.
  385. bool ConnectDataConsumed() const { return is_connect_data_consumed; }
  386. // Returns whether MockConfirm data has been consumed.
  387. bool ConfirmDataConsumed() const { return is_confirm_data_consumed; }
  388. // Returns whether a Write occurred before ConfirmHandshake completed.
  389. bool WriteBeforeConfirm() const { return write_called_before_confirm; }
  390. // Result for Connect().
  391. MockConnect connect;
  392. // Callback to run when Connect() is called. This is called at most once per
  393. // socket but is repeating because SSLSocketDataProvider is copyable.
  394. base::RepeatingClosure connect_callback;
  395. // Result for ConfirmHandshake().
  396. MockConfirm confirm;
  397. // Callback to run when ConfirmHandshake() is called. This is called at most
  398. // once per socket but is repeating because SSLSocketDataProvider is
  399. // copyable.
  400. base::RepeatingClosure confirm_callback;
  401. // Result for GetNegotiatedProtocol().
  402. NextProto next_proto = kProtoUnknown;
  403. // Result for GetPeerApplicationSettings().
  404. absl::optional<std::string> peer_application_settings;
  405. // Result for GetSSLInfo().
  406. SSLInfo ssl_info;
  407. // Result for GetSSLCertRequestInfo().
  408. SSLCertRequestInfo* cert_request_info = nullptr;
  409. // Result for GetECHRetryConfigs().
  410. std::vector<uint8_t> ech_retry_configs;
  411. absl::optional<NextProtoVector> next_protos_expected_in_ssl_config;
  412. uint16_t expected_ssl_version_min;
  413. uint16_t expected_ssl_version_max;
  414. absl::optional<bool> expected_send_client_cert;
  415. scoped_refptr<X509Certificate> expected_client_cert;
  416. absl::optional<HostPortPair> expected_host_and_port;
  417. absl::optional<NetworkIsolationKey> expected_network_isolation_key;
  418. absl::optional<bool> expected_disable_legacy_crypto;
  419. absl::optional<std::vector<uint8_t>> expected_ech_config_list;
  420. bool is_connect_data_consumed = false;
  421. bool is_confirm_data_consumed = false;
  422. bool write_called_before_confirm = false;
  423. };
  424. // Uses the sequence_number field in the mock reads and writes to
  425. // complete the operations in a specified order.
  426. class SequencedSocketData : public SocketDataProvider {
  427. public:
  428. SequencedSocketData();
  429. // |reads| is the list of MockRead completions.
  430. // |writes| is the list of MockWrite completions.
  431. SequencedSocketData(base::span<const MockRead> reads,
  432. base::span<const MockWrite> writes);
  433. // |connect| is the result for the connect phase.
  434. // |reads| is the list of MockRead completions.
  435. // |writes| is the list of MockWrite completions.
  436. SequencedSocketData(const MockConnect& connect,
  437. base::span<const MockRead> reads,
  438. base::span<const MockWrite> writes);
  439. SequencedSocketData(const SequencedSocketData&) = delete;
  440. SequencedSocketData& operator=(const SequencedSocketData&) = delete;
  441. ~SequencedSocketData() override;
  442. // From SocketDataProvider:
  443. MockRead OnRead() override;
  444. MockWriteResult OnWrite(const std::string& data) override;
  445. bool AllReadDataConsumed() const override;
  446. bool AllWriteDataConsumed() const override;
  447. bool IsIdle() const override;
  448. void CancelPendingRead() override;
  449. // An ASYNC read event with a return value of ERR_IO_PENDING will cause the
  450. // socket data to pause at that event, and advance no further, until Resume is
  451. // invoked. At that point, the socket will continue at the next event in the
  452. // sequence.
  453. //
  454. // If a request just wants to simulate a connection that stays open and never
  455. // receives any more data, instead of pausing and then resuming a request, it
  456. // should use a SYNCHRONOUS event with a return value of ERR_IO_PENDING
  457. // instead.
  458. bool IsPaused() const;
  459. // Resumes events once |this| is in the paused state. The next event will
  460. // occur synchronously with the call if it can.
  461. void Resume();
  462. void RunUntilPaused();
  463. // When true, IsConnectedAndIdle() will return false if the next event in the
  464. // sequence is a synchronous. Otherwise, the socket claims to be idle as
  465. // long as it's connected. Defaults to false.
  466. // TODO(mmenke): See if this can be made the default behavior, and consider
  467. // removing this mehtod. Need to make sure it doesn't change what code any
  468. // tests are targetted at testing.
  469. void set_busy_before_sync_reads(bool busy_before_sync_reads) {
  470. busy_before_sync_reads_ = busy_before_sync_reads;
  471. }
  472. void set_printer(SocketDataPrinter* printer) { printer_ = printer; }
  473. private:
  474. // Defines the state for the read or write path.
  475. enum class IoState {
  476. kIdle, // No async operation is in progress.
  477. kPending, // An async operation in waiting for another operation to
  478. // complete.
  479. kCompleting, // A task has been posted to complete an async operation.
  480. kPaused, // IO is paused until Resume() is called.
  481. };
  482. // From SocketDataProvider:
  483. void Reset() override;
  484. void OnReadComplete();
  485. void OnWriteComplete();
  486. void MaybePostReadCompleteTask();
  487. void MaybePostWriteCompleteTask();
  488. StaticSocketDataHelper helper_;
  489. raw_ptr<SocketDataPrinter> printer_ = nullptr;
  490. int sequence_number_ = 0;
  491. IoState read_state_ = IoState::kIdle;
  492. IoState write_state_ = IoState::kIdle;
  493. bool busy_before_sync_reads_ = false;
  494. // Used by RunUntilPaused. NULL at all other times.
  495. std::unique_ptr<base::RunLoop> run_until_paused_run_loop_;
  496. base::WeakPtrFactory<SequencedSocketData> weak_factory_{this};
  497. };
  498. // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}StreamSocket
  499. // objects get instantiated, they take their data from the i'th element of this
  500. // array.
  501. template <typename T>
  502. class SocketDataProviderArray {
  503. public:
  504. SocketDataProviderArray() = default;
  505. T* GetNext() {
  506. DCHECK_LT(next_index_, data_providers_.size());
  507. return data_providers_[next_index_++];
  508. }
  509. // Like GetNext(), but returns nullptr when the end of the array is reached,
  510. // instead of DCHECKing. GetNext() should generally be preferred, unless
  511. // having no remaining elements is expected in some cases and is handled
  512. // safely.
  513. T* GetNextWithoutAsserting() {
  514. if (next_index_ == data_providers_.size())
  515. return nullptr;
  516. return data_providers_[next_index_++];
  517. }
  518. void Add(T* data_provider) {
  519. DCHECK(data_provider);
  520. data_providers_.push_back(data_provider);
  521. }
  522. size_t next_index() { return next_index_; }
  523. void ResetNextIndex() { next_index_ = 0; }
  524. private:
  525. // Index of the next |data_providers_| element to use. Not an iterator
  526. // because those are invalidated on vector reallocation.
  527. size_t next_index_ = 0;
  528. // SocketDataProviders to be returned.
  529. std::vector<T*> data_providers_;
  530. };
  531. class MockUDPClientSocket;
  532. class MockTCPClientSocket;
  533. class MockSSLClientSocket;
  534. // ClientSocketFactory which contains arrays of sockets of each type.
  535. // You should first fill the arrays using Add{SSL,}SocketDataProvider(). When
  536. // the factory is asked to create a socket, it takes next entry from appropriate
  537. // array. You can use ResetNextMockIndexes to reset that next entry index for
  538. // all mock socket types.
  539. class MockClientSocketFactory : public ClientSocketFactory {
  540. public:
  541. MockClientSocketFactory();
  542. MockClientSocketFactory(const MockClientSocketFactory&) = delete;
  543. MockClientSocketFactory& operator=(const MockClientSocketFactory&) = delete;
  544. ~MockClientSocketFactory() override;
  545. // Adds a SocketDataProvider that can be used to served either TCP or UDP
  546. // connection requests. Sockets are returned in FIFO order.
  547. void AddSocketDataProvider(SocketDataProvider* socket);
  548. // Like AddSocketDataProvider(), except sockets will only be used to service
  549. // TCP connection requests. Sockets added with this method are used first,
  550. // before sockets added with AddSocketDataProvider(). Particularly useful for
  551. // QUIC tests with multiple sockets, where TCP connections may or may not be
  552. // made, and have no guaranteed order, relative to UDP connections.
  553. void AddTcpSocketDataProvider(SocketDataProvider* socket);
  554. void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
  555. void ResetNextMockIndexes();
  556. SocketDataProviderArray<SocketDataProvider>& mock_data() {
  557. return mock_data_;
  558. }
  559. void set_enable_read_if_ready(bool enable_read_if_ready) {
  560. enable_read_if_ready_ = enable_read_if_ready;
  561. }
  562. // ClientSocketFactory
  563. std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
  564. DatagramSocket::BindType bind_type,
  565. NetLog* net_log,
  566. const NetLogSource& source) override;
  567. std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
  568. const AddressList& addresses,
  569. std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
  570. NetworkQualityEstimator* network_quality_estimator,
  571. NetLog* net_log,
  572. const NetLogSource& source) override;
  573. std::unique_ptr<SSLClientSocket> CreateSSLClientSocket(
  574. SSLClientContext* context,
  575. std::unique_ptr<StreamSocket> stream_socket,
  576. const HostPortPair& host_and_port,
  577. const SSLConfig& ssl_config) override;
  578. const std::vector<uint16_t>& udp_client_socket_ports() const {
  579. return udp_client_socket_ports_;
  580. }
  581. private:
  582. SocketDataProviderArray<SocketDataProvider> mock_data_;
  583. SocketDataProviderArray<SocketDataProvider> mock_tcp_data_;
  584. SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
  585. std::vector<uint16_t> udp_client_socket_ports_;
  586. // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
  587. // ERR_READ_IF_READY_NOT_IMPLEMENTED.
  588. bool enable_read_if_ready_ = false;
  589. };
  590. class MockClientSocket : public TransportClientSocket {
  591. public:
  592. // The NetLogWithSource is needed to test LoadTimingInfo, which uses NetLog
  593. // IDs as
  594. // unique socket IDs.
  595. explicit MockClientSocket(const NetLogWithSource& net_log);
  596. MockClientSocket(const MockClientSocket&) = delete;
  597. MockClientSocket& operator=(const MockClientSocket&) = delete;
  598. // Socket implementation.
  599. int Read(IOBuffer* buf,
  600. int buf_len,
  601. CompletionOnceCallback callback) override = 0;
  602. int Write(IOBuffer* buf,
  603. int buf_len,
  604. CompletionOnceCallback callback,
  605. const NetworkTrafficAnnotationTag& traffic_annotation) override = 0;
  606. int SetReceiveBufferSize(int32_t size) override;
  607. int SetSendBufferSize(int32_t size) override;
  608. // TransportClientSocket implementation.
  609. int Bind(const net::IPEndPoint& local_addr) override;
  610. bool SetNoDelay(bool no_delay) override;
  611. bool SetKeepAlive(bool enable, int delay) override;
  612. // StreamSocket implementation.
  613. int Connect(CompletionOnceCallback callback) override = 0;
  614. void Disconnect() override;
  615. bool IsConnected() const override;
  616. bool IsConnectedAndIdle() const override;
  617. int GetPeerAddress(IPEndPoint* address) const override;
  618. int GetLocalAddress(IPEndPoint* address) const override;
  619. const NetLogWithSource& NetLog() const override;
  620. bool WasAlpnNegotiated() const override;
  621. NextProto GetNegotiatedProtocol() const override;
  622. int64_t GetTotalReceivedBytes() const override;
  623. void ApplySocketTag(const SocketTag& tag) override {}
  624. protected:
  625. ~MockClientSocket() override;
  626. void RunCallbackAsync(CompletionOnceCallback callback, int result);
  627. void RunCallback(CompletionOnceCallback callback, int result);
  628. // True if Connect completed successfully and Disconnect hasn't been called.
  629. bool connected_ = false;
  630. IPEndPoint local_addr_;
  631. IPEndPoint peer_addr_;
  632. NetLogWithSource net_log_;
  633. private:
  634. base::WeakPtrFactory<MockClientSocket> weak_factory_{this};
  635. };
  636. class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
  637. public:
  638. MockTCPClientSocket(const AddressList& addresses,
  639. net::NetLog* net_log,
  640. SocketDataProvider* socket);
  641. MockTCPClientSocket(const MockTCPClientSocket&) = delete;
  642. MockTCPClientSocket& operator=(const MockTCPClientSocket&) = delete;
  643. ~MockTCPClientSocket() override;
  644. const AddressList& addresses() const { return addresses_; }
  645. // Socket implementation.
  646. int Read(IOBuffer* buf,
  647. int buf_len,
  648. CompletionOnceCallback callback) override;
  649. int ReadIfReady(IOBuffer* buf,
  650. int buf_len,
  651. CompletionOnceCallback callback) override;
  652. int CancelReadIfReady() override;
  653. int Write(IOBuffer* buf,
  654. int buf_len,
  655. CompletionOnceCallback callback,
  656. const NetworkTrafficAnnotationTag& traffic_annotation) override;
  657. int SetReceiveBufferSize(int32_t size) override;
  658. int SetSendBufferSize(int32_t size) override;
  659. // TransportClientSocket implementation.
  660. bool SetNoDelay(bool no_delay) override;
  661. bool SetKeepAlive(bool enable, int delay) override;
  662. // StreamSocket implementation.
  663. void SetBeforeConnectCallback(
  664. const BeforeConnectCallback& before_connect_callback) override;
  665. int Connect(CompletionOnceCallback callback) override;
  666. void Disconnect() override;
  667. bool IsConnected() const override;
  668. bool IsConnectedAndIdle() const override;
  669. int GetPeerAddress(IPEndPoint* address) const override;
  670. bool WasEverUsed() const override;
  671. bool GetSSLInfo(SSLInfo* ssl_info) override;
  672. // AsyncSocket:
  673. void OnReadComplete(const MockRead& data) override;
  674. void OnWriteComplete(int rv) override;
  675. void OnConnectComplete(const MockConnect& data) override;
  676. void OnDataProviderDestroyed() override;
  677. void set_enable_read_if_ready(bool enable_read_if_ready) {
  678. enable_read_if_ready_ = enable_read_if_ready;
  679. }
  680. private:
  681. void RetryRead(int rv);
  682. int ReadIfReadyImpl(IOBuffer* buf,
  683. int buf_len,
  684. CompletionOnceCallback callback);
  685. // Helper method to run |pending_read_if_ready_callback_| if it is not null.
  686. void RunReadIfReadyCallback(int result);
  687. AddressList addresses_;
  688. raw_ptr<SocketDataProvider> data_;
  689. int read_offset_ = 0;
  690. MockRead read_data_;
  691. bool need_read_data_ = true;
  692. // True if the peer has closed the connection. This allows us to simulate
  693. // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
  694. // TCPClientSocket.
  695. bool peer_closed_connection_ = false;
  696. // While an asynchronous read is pending, we save our user-buffer state.
  697. scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
  698. int pending_read_buf_len_ = 0;
  699. CompletionOnceCallback pending_read_callback_;
  700. // Non-null when a ReadIfReady() is pending.
  701. CompletionOnceCallback pending_read_if_ready_callback_;
  702. CompletionOnceCallback pending_connect_callback_;
  703. CompletionOnceCallback pending_write_callback_;
  704. bool was_used_to_convey_data_ = false;
  705. // If true, ReadIfReady() is enabled; otherwise ReadIfReady() returns
  706. // ERR_READ_IF_READY_NOT_IMPLEMENTED.
  707. bool enable_read_if_ready_ = false;
  708. BeforeConnectCallback before_connect_callback_;
  709. };
  710. class MockSSLClientSocket : public AsyncSocket, public SSLClientSocket {
  711. public:
  712. MockSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,
  713. const HostPortPair& host_and_port,
  714. const SSLConfig& ssl_config,
  715. SSLSocketDataProvider* socket);
  716. MockSSLClientSocket(const MockSSLClientSocket&) = delete;
  717. MockSSLClientSocket& operator=(const MockSSLClientSocket&) = delete;
  718. ~MockSSLClientSocket() override;
  719. // Socket implementation.
  720. int Read(IOBuffer* buf,
  721. int buf_len,
  722. CompletionOnceCallback callback) override;
  723. int ReadIfReady(IOBuffer* buf,
  724. int buf_len,
  725. CompletionOnceCallback callback) override;
  726. int Write(IOBuffer* buf,
  727. int buf_len,
  728. CompletionOnceCallback callback,
  729. const NetworkTrafficAnnotationTag& traffic_annotation) override;
  730. int CancelReadIfReady() override;
  731. // StreamSocket implementation.
  732. int Connect(CompletionOnceCallback callback) override;
  733. void Disconnect() override;
  734. int ConfirmHandshake(CompletionOnceCallback callback) override;
  735. bool IsConnected() const override;
  736. bool IsConnectedAndIdle() const override;
  737. bool WasEverUsed() const override;
  738. int GetPeerAddress(IPEndPoint* address) const override;
  739. int GetLocalAddress(IPEndPoint* address) const override;
  740. bool WasAlpnNegotiated() const override;
  741. NextProto GetNegotiatedProtocol() const override;
  742. absl::optional<base::StringPiece> GetPeerApplicationSettings() const override;
  743. bool GetSSLInfo(SSLInfo* ssl_info) override;
  744. void GetSSLCertRequestInfo(
  745. SSLCertRequestInfo* cert_request_info) const override;
  746. void ApplySocketTag(const SocketTag& tag) override;
  747. const NetLogWithSource& NetLog() const override;
  748. int64_t GetTotalReceivedBytes() const override;
  749. int SetReceiveBufferSize(int32_t size) override;
  750. int SetSendBufferSize(int32_t size) override;
  751. // SSLSocket implementation.
  752. int ExportKeyingMaterial(const base::StringPiece& label,
  753. bool has_context,
  754. const base::StringPiece& context,
  755. unsigned char* out,
  756. unsigned int outlen) override;
  757. // SSLClientSocket implementation.
  758. std::vector<uint8_t> GetECHRetryConfigs() override;
  759. // This MockSocket does not implement the manual async IO feature.
  760. void OnReadComplete(const MockRead& data) override;
  761. void OnWriteComplete(int rv) override;
  762. void OnConnectComplete(const MockConnect& data) override;
  763. // SSL sockets don't need magic to deal with destruction of their data
  764. // provider.
  765. // TODO(mmenke): Probably a good idea to support it, anyways.
  766. void OnDataProviderDestroyed() override {}
  767. private:
  768. static void ConnectCallback(MockSSLClientSocket* ssl_client_socket,
  769. CompletionOnceCallback callback,
  770. int rv);
  771. void RunCallbackAsync(CompletionOnceCallback callback, int result);
  772. void RunCallback(CompletionOnceCallback callback, int result);
  773. void RunConfirmHandshakeCallback(CompletionOnceCallback callback, int result);
  774. bool connected_ = false;
  775. bool in_confirm_handshake_ = false;
  776. NetLogWithSource net_log_;
  777. std::unique_ptr<StreamSocket> stream_socket_;
  778. raw_ptr<SSLSocketDataProvider> data_;
  779. // Address of the "remote" peer we're connected to.
  780. IPEndPoint peer_addr_;
  781. base::WeakPtrFactory<MockSSLClientSocket> weak_factory_{this};
  782. };
  783. class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket {
  784. public:
  785. explicit MockUDPClientSocket(SocketDataProvider* data = nullptr,
  786. net::NetLog* net_log = nullptr);
  787. MockUDPClientSocket(const MockUDPClientSocket&) = delete;
  788. MockUDPClientSocket& operator=(const MockUDPClientSocket&) = delete;
  789. ~MockUDPClientSocket() override;
  790. // Socket implementation.
  791. int Read(IOBuffer* buf,
  792. int buf_len,
  793. CompletionOnceCallback callback) override;
  794. int Write(IOBuffer* buf,
  795. int buf_len,
  796. CompletionOnceCallback callback,
  797. const NetworkTrafficAnnotationTag& traffic_annotation) override;
  798. int SetReceiveBufferSize(int32_t size) override;
  799. int SetSendBufferSize(int32_t size) override;
  800. int SetDoNotFragment() override;
  801. // DatagramSocket implementation.
  802. void Close() override;
  803. int GetPeerAddress(IPEndPoint* address) const override;
  804. int GetLocalAddress(IPEndPoint* address) const override;
  805. void UseNonBlockingIO() override;
  806. int SetMulticastInterface(uint32_t interface_index) override;
  807. const NetLogWithSource& NetLog() const override;
  808. // DatagramClientSocket implementation.
  809. int Connect(const IPEndPoint& address) override;
  810. int ConnectUsingNetwork(handles::NetworkHandle network,
  811. const IPEndPoint& address) override;
  812. int ConnectUsingDefaultNetwork(const IPEndPoint& address) override;
  813. handles::NetworkHandle GetBoundNetwork() const override;
  814. void ApplySocketTag(const SocketTag& tag) override;
  815. void SetMsgConfirm(bool confirm) override {}
  816. // AsyncSocket implementation.
  817. void OnReadComplete(const MockRead& data) override;
  818. void OnWriteComplete(int rv) override;
  819. void OnConnectComplete(const MockConnect& data) override;
  820. void OnDataProviderDestroyed() override;
  821. void set_source_port(uint16_t port) { source_port_ = port; }
  822. uint16_t source_port() const { return source_port_; }
  823. void set_source_host(IPAddress addr) { source_host_ = addr; }
  824. IPAddress source_host() const { return source_host_; }
  825. // Returns last tag applied to socket.
  826. SocketTag tag() const { return tag_; }
  827. // Returns false if socket's tag was changed after the socket was used for
  828. // data transfer (e.g. Read/Write() called), otherwise returns true.
  829. bool tagged_before_data_transferred() const {
  830. return tagged_before_data_transferred_;
  831. }
  832. private:
  833. int CompleteRead();
  834. void RunCallbackAsync(CompletionOnceCallback callback, int result);
  835. void RunCallback(CompletionOnceCallback callback, int result);
  836. bool connected_ = false;
  837. raw_ptr<SocketDataProvider> data_;
  838. int read_offset_ = 0;
  839. MockRead read_data_;
  840. bool need_read_data_ = true;
  841. IPAddress source_host_;
  842. uint16_t source_port_ = 123; // Ephemeral source port.
  843. // Address of the "remote" peer we're connected to.
  844. IPEndPoint peer_addr_;
  845. // Network that the socket is bound to.
  846. handles::NetworkHandle network_ = handles::kInvalidNetworkHandle;
  847. // While an asynchronous IO is pending, we save our user-buffer state.
  848. scoped_refptr<IOBuffer> pending_read_buf_ = nullptr;
  849. int pending_read_buf_len_ = 0;
  850. CompletionOnceCallback pending_read_callback_;
  851. CompletionOnceCallback pending_write_callback_;
  852. NetLogWithSource net_log_;
  853. DatagramBuffers unwritten_buffers_;
  854. SocketTag tag_;
  855. bool data_transferred_ = false;
  856. bool tagged_before_data_transferred_ = true;
  857. base::WeakPtrFactory<MockUDPClientSocket> weak_factory_{this};
  858. };
  859. class TestSocketRequest : public TestCompletionCallbackBase {
  860. public:
  861. TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
  862. size_t* completion_count);
  863. TestSocketRequest(const TestSocketRequest&) = delete;
  864. TestSocketRequest& operator=(const TestSocketRequest&) = delete;
  865. ~TestSocketRequest() override;
  866. ClientSocketHandle* handle() { return &handle_; }
  867. CompletionOnceCallback callback() {
  868. return base::BindOnce(&TestSocketRequest::OnComplete,
  869. base::Unretained(this));
  870. }
  871. private:
  872. void OnComplete(int result);
  873. ClientSocketHandle handle_;
  874. raw_ptr<std::vector<TestSocketRequest*>> request_order_;
  875. raw_ptr<size_t> completion_count_;
  876. };
  877. class ClientSocketPoolTest {
  878. public:
  879. enum KeepAlive {
  880. KEEP_ALIVE,
  881. // A socket will be disconnected in addition to handle being reset.
  882. NO_KEEP_ALIVE,
  883. };
  884. static const int kIndexOutOfBounds;
  885. static const int kRequestNotFound;
  886. ClientSocketPoolTest();
  887. ClientSocketPoolTest(const ClientSocketPoolTest&) = delete;
  888. ClientSocketPoolTest& operator=(const ClientSocketPoolTest&) = delete;
  889. ~ClientSocketPoolTest();
  890. template <typename PoolType>
  891. int StartRequestUsingPool(
  892. PoolType* socket_pool,
  893. const ClientSocketPool::GroupId& group_id,
  894. RequestPriority priority,
  895. ClientSocketPool::RespectLimits respect_limits,
  896. const scoped_refptr<typename PoolType::SocketParams>& socket_params) {
  897. DCHECK(socket_pool);
  898. TestSocketRequest* request(
  899. new TestSocketRequest(&request_order_, &completion_count_));
  900. requests_.push_back(base::WrapUnique(request));
  901. int rv = request->handle()->Init(
  902. group_id, socket_params, absl::nullopt /* proxy_annotation_tag */,
  903. priority, SocketTag(), respect_limits, request->callback(),
  904. ClientSocketPool::ProxyAuthCallback(), socket_pool, NetLogWithSource());
  905. if (rv != ERR_IO_PENDING)
  906. request_order_.push_back(request);
  907. return rv;
  908. }
  909. // Provided there were n requests started, takes |index| in range 1..n
  910. // and returns order in which that request completed, in range 1..n,
  911. // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
  912. // if that request did not complete (for example was canceled).
  913. int GetOrderOfRequest(size_t index) const;
  914. // Resets first initialized socket handle from |requests_|. If found such
  915. // a handle, returns true.
  916. bool ReleaseOneConnection(KeepAlive keep_alive);
  917. // Releases connections until there is nothing to release.
  918. void ReleaseAllConnections(KeepAlive keep_alive);
  919. // Note that this uses 0-based indices, while GetOrderOfRequest takes and
  920. // returns 1-based indices.
  921. TestSocketRequest* request(int i) { return requests_[i].get(); }
  922. size_t requests_size() const { return requests_.size(); }
  923. std::vector<std::unique_ptr<TestSocketRequest>>* requests() {
  924. return &requests_;
  925. }
  926. size_t completion_count() const { return completion_count_; }
  927. private:
  928. std::vector<std::unique_ptr<TestSocketRequest>> requests_;
  929. std::vector<TestSocketRequest*> request_order_;
  930. size_t completion_count_ = 0;
  931. };
  932. class MockTransportSocketParams
  933. : public base::RefCounted<MockTransportSocketParams> {
  934. public:
  935. MockTransportSocketParams(const MockTransportSocketParams&) = delete;
  936. MockTransportSocketParams& operator=(const MockTransportSocketParams&) =
  937. delete;
  938. private:
  939. friend class base::RefCounted<MockTransportSocketParams>;
  940. ~MockTransportSocketParams() = default;
  941. };
  942. class MockTransportClientSocketPool : public TransportClientSocketPool {
  943. public:
  944. class MockConnectJob {
  945. public:
  946. MockConnectJob(std::unique_ptr<StreamSocket> socket,
  947. ClientSocketHandle* handle,
  948. const SocketTag& socket_tag,
  949. CompletionOnceCallback callback,
  950. RequestPriority priority);
  951. MockConnectJob(const MockConnectJob&) = delete;
  952. MockConnectJob& operator=(const MockConnectJob&) = delete;
  953. ~MockConnectJob();
  954. int Connect();
  955. bool CancelHandle(const ClientSocketHandle* handle);
  956. ClientSocketHandle* handle() const { return handle_; }
  957. RequestPriority priority() const { return priority_; }
  958. void set_priority(RequestPriority priority) { priority_ = priority; }
  959. private:
  960. void OnConnect(int rv);
  961. std::unique_ptr<StreamSocket> socket_;
  962. raw_ptr<ClientSocketHandle> handle_;
  963. const SocketTag socket_tag_;
  964. CompletionOnceCallback user_callback_;
  965. RequestPriority priority_;
  966. };
  967. MockTransportClientSocketPool(
  968. int max_sockets,
  969. int max_sockets_per_group,
  970. const CommonConnectJobParams* common_connect_job_params);
  971. MockTransportClientSocketPool(const MockTransportClientSocketPool&) = delete;
  972. MockTransportClientSocketPool& operator=(
  973. const MockTransportClientSocketPool&) = delete;
  974. ~MockTransportClientSocketPool() override;
  975. RequestPriority last_request_priority() const {
  976. return last_request_priority_;
  977. }
  978. const std::vector<std::unique_ptr<MockConnectJob>>& requests() const {
  979. return job_list_;
  980. }
  981. int release_count() const { return release_count_; }
  982. int cancel_count() const { return cancel_count_; }
  983. // TransportClientSocketPool implementation.
  984. int RequestSocket(
  985. const GroupId& group_id,
  986. scoped_refptr<ClientSocketPool::SocketParams> socket_params,
  987. const absl::optional<NetworkTrafficAnnotationTag>& proxy_annotation_tag,
  988. RequestPriority priority,
  989. const SocketTag& socket_tag,
  990. RespectLimits respect_limits,
  991. ClientSocketHandle* handle,
  992. CompletionOnceCallback callback,
  993. const ProxyAuthCallback& on_auth_callback,
  994. const NetLogWithSource& net_log) override;
  995. void SetPriority(const GroupId& group_id,
  996. ClientSocketHandle* handle,
  997. RequestPriority priority) override;
  998. void CancelRequest(const GroupId& group_id,
  999. ClientSocketHandle* handle,
  1000. bool cancel_connect_job) override;
  1001. void ReleaseSocket(const GroupId& group_id,
  1002. std::unique_ptr<StreamSocket> socket,
  1003. int64_t generation) override;
  1004. private:
  1005. raw_ptr<ClientSocketFactory> client_socket_factory_;
  1006. std::vector<std::unique_ptr<MockConnectJob>> job_list_;
  1007. RequestPriority last_request_priority_ = DEFAULT_PRIORITY;
  1008. int release_count_ = 0;
  1009. int cancel_count_ = 0;
  1010. };
  1011. // WrappedStreamSocket is a base class that wraps an existing StreamSocket,
  1012. // forwarding the Socket and StreamSocket interfaces to the underlying
  1013. // transport.
  1014. // This is to provide a common base class for subclasses to override specific
  1015. // StreamSocket methods for testing, while still communicating with a 'real'
  1016. // StreamSocket.
  1017. class WrappedStreamSocket : public TransportClientSocket {
  1018. public:
  1019. explicit WrappedStreamSocket(std::unique_ptr<StreamSocket> transport);
  1020. ~WrappedStreamSocket() override;
  1021. // StreamSocket implementation:
  1022. int Bind(const net::IPEndPoint& local_addr) override;
  1023. int Connect(CompletionOnceCallback callback) override;
  1024. void Disconnect() override;
  1025. bool IsConnected() const override;
  1026. bool IsConnectedAndIdle() const override;
  1027. int GetPeerAddress(IPEndPoint* address) const override;
  1028. int GetLocalAddress(IPEndPoint* address) const override;
  1029. const NetLogWithSource& NetLog() const override;
  1030. bool WasEverUsed() const override;
  1031. bool WasAlpnNegotiated() const override;
  1032. NextProto GetNegotiatedProtocol() const override;
  1033. bool GetSSLInfo(SSLInfo* ssl_info) override;
  1034. int64_t GetTotalReceivedBytes() const override;
  1035. void ApplySocketTag(const SocketTag& tag) override;
  1036. // Socket implementation:
  1037. int Read(IOBuffer* buf,
  1038. int buf_len,
  1039. CompletionOnceCallback callback) override;
  1040. int ReadIfReady(IOBuffer* buf,
  1041. int buf_len,
  1042. CompletionOnceCallback callback) override;
  1043. int Write(IOBuffer* buf,
  1044. int buf_len,
  1045. CompletionOnceCallback callback,
  1046. const NetworkTrafficAnnotationTag& traffic_annotation) override;
  1047. int SetReceiveBufferSize(int32_t size) override;
  1048. int SetSendBufferSize(int32_t size) override;
  1049. protected:
  1050. std::unique_ptr<StreamSocket> transport_;
  1051. };
  1052. // StreamSocket that wraps another StreamSocket, but keeps track of any
  1053. // SocketTag applied to the socket.
  1054. class MockTaggingStreamSocket : public WrappedStreamSocket {
  1055. public:
  1056. explicit MockTaggingStreamSocket(std::unique_ptr<StreamSocket> transport)
  1057. : WrappedStreamSocket(std::move(transport)) {}
  1058. MockTaggingStreamSocket(const MockTaggingStreamSocket&) = delete;
  1059. MockTaggingStreamSocket& operator=(const MockTaggingStreamSocket&) = delete;
  1060. ~MockTaggingStreamSocket() override = default;
  1061. // StreamSocket implementation.
  1062. int Connect(CompletionOnceCallback callback) override;
  1063. void ApplySocketTag(const SocketTag& tag) override;
  1064. // Returns false if socket's tag was changed after the socket was connected,
  1065. // otherwise returns true.
  1066. bool tagged_before_connected() const { return tagged_before_connected_; }
  1067. // Returns last tag applied to socket.
  1068. SocketTag tag() const { return tag_; }
  1069. private:
  1070. bool connected_ = false;
  1071. bool tagged_before_connected_ = true;
  1072. SocketTag tag_;
  1073. };
  1074. // Extend MockClientSocketFactory to return MockTaggingStreamSockets and
  1075. // keep track of last socket produced for test inspection.
  1076. class MockTaggingClientSocketFactory : public MockClientSocketFactory {
  1077. public:
  1078. MockTaggingClientSocketFactory() = default;
  1079. MockTaggingClientSocketFactory(const MockTaggingClientSocketFactory&) =
  1080. delete;
  1081. MockTaggingClientSocketFactory& operator=(
  1082. const MockTaggingClientSocketFactory&) = delete;
  1083. // ClientSocketFactory implementation.
  1084. std::unique_ptr<DatagramClientSocket> CreateDatagramClientSocket(
  1085. DatagramSocket::BindType bind_type,
  1086. NetLog* net_log,
  1087. const NetLogSource& source) override;
  1088. std::unique_ptr<TransportClientSocket> CreateTransportClientSocket(
  1089. const AddressList& addresses,
  1090. std::unique_ptr<SocketPerformanceWatcher> socket_performance_watcher,
  1091. NetworkQualityEstimator* network_quality_estimator,
  1092. NetLog* net_log,
  1093. const NetLogSource& source) override;
  1094. // These methods return pointers to last TCP and UDP sockets produced by this
  1095. // factory. NOTE: Socket must still exist, or pointer will be to freed memory.
  1096. MockTaggingStreamSocket* GetLastProducedTCPSocket() const {
  1097. return tcp_socket_;
  1098. }
  1099. MockUDPClientSocket* GetLastProducedUDPSocket() const { return udp_socket_; }
  1100. private:
  1101. // TODO(crbug.com/1298696): Breaks net_unittests.
  1102. raw_ptr<MockTaggingStreamSocket, DegradeToNoOpWhenMTE> tcp_socket_ = nullptr;
  1103. raw_ptr<MockUDPClientSocket> udp_socket_ = nullptr;
  1104. };
  1105. // Host / port used for SOCKS4 test strings.
  1106. extern const char kSOCKS4TestHost[];
  1107. extern const int kSOCKS4TestPort;
  1108. // Constants for a successful SOCKS v4 handshake (connecting to kSOCKS4TestHost
  1109. // on port kSOCKS4TestPort, for the request).
  1110. extern const char kSOCKS4OkRequestLocalHostPort80[];
  1111. extern const int kSOCKS4OkRequestLocalHostPort80Length;
  1112. extern const char kSOCKS4OkReply[];
  1113. extern const int kSOCKS4OkReplyLength;
  1114. // Host / port used for SOCKS5 test strings.
  1115. extern const char kSOCKS5TestHost[];
  1116. extern const int kSOCKS5TestPort;
  1117. // Constants for a successful SOCKS v5 handshake (connecting to kSOCKS5TestHost
  1118. // on port kSOCKS5TestPort, for the request)..
  1119. extern const char kSOCKS5GreetRequest[];
  1120. extern const int kSOCKS5GreetRequestLength;
  1121. extern const char kSOCKS5GreetResponse[];
  1122. extern const int kSOCKS5GreetResponseLength;
  1123. extern const char kSOCKS5OkRequest[];
  1124. extern const int kSOCKS5OkRequestLength;
  1125. extern const char kSOCKS5OkResponse[];
  1126. extern const int kSOCKS5OkResponseLength;
  1127. // Helper function to get the total data size of the MockReads in |reads|.
  1128. int64_t CountReadBytes(base::span<const MockRead> reads);
  1129. // Helper function to get the total data size of the MockWrites in |writes|.
  1130. int64_t CountWriteBytes(base::span<const MockWrite> writes);
  1131. #if BUILDFLAG(IS_ANDROID)
  1132. // Returns whether the device supports calling GetTaggedBytes().
  1133. bool CanGetTaggedBytes();
  1134. // Query the system to find out how many bytes were received with tag
  1135. // |expected_tag| for our UID. Return the count of recieved bytes.
  1136. uint64_t GetTaggedBytes(int32_t expected_tag);
  1137. #endif
  1138. } // namespace net
  1139. #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_