fake_ssl_client_socket_unittest.cc 13 KB


  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/webrtc/fake_ssl_client_socket.h"
  5. #include <stddef.h>
  6. #include <stdint.h>
  7. #include <algorithm>
  8. #include <memory>
  9. #include <utility>
  10. #include <vector>
  11. #include "base/memory/ref_counted.h"
  12. #include "base/test/task_environment.h"
  13. #include "net/base/completion_once_callback.h"
  14. #include "net/base/completion_repeating_callback.h"
  15. #include "net/base/io_buffer.h"
  16. #include "net/base/ip_address.h"
  17. #include "net/base/test_completion_callback.h"
  18. #include "net/log/net_log_source.h"
  19. #include "net/log/net_log_with_source.h"
  20. #include "net/socket/socket_tag.h"
  21. #include "net/socket/socket_test_util.h"
  22. #include "net/socket/stream_socket.h"
  23. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  24. #include "testing/gmock/include/gmock/gmock.h"
  25. #include "testing/gtest/include/gtest/gtest.h"
  26. namespace webrtc {
  27. namespace {
  28. using ::testing::Return;
  29. using ::testing::ReturnRef;
  30. // Used by RunUnsuccessfulHandshakeTestHelper. Represents where in
  31. // the handshake step an error should be inserted.
  32. enum HandshakeErrorLocation {
  33. CONNECT_ERROR,
  34. SEND_CLIENT_HELLO_ERROR,
  35. VERIFY_SERVER_HELLO_ERROR,
  36. };
  37. // Private error codes appended to the net::Error set.
  38. enum {
  39. // An error representing a server hello that has been corrupted in
  40. // transit.
  41. ERR_MALFORMED_SERVER_HELLO = -15000,
  42. };
  43. // Used by PassThroughMethods test.
  44. class MockClientSocket : public net::StreamSocket {
  45. public:
  46. ~MockClientSocket() override {}
  47. MOCK_METHOD3(Read, int(net::IOBuffer*, int, net::CompletionOnceCallback));
  48. MOCK_METHOD4(Write,
  49. int(net::IOBuffer*,
  50. int,
  51. net::CompletionOnceCallback,
  52. const net::NetworkTrafficAnnotationTag&));
  53. MOCK_METHOD1(SetReceiveBufferSize, int(int32_t));
  54. MOCK_METHOD1(SetSendBufferSize, int(int32_t));
  55. MOCK_METHOD1(Connect, int(net::CompletionOnceCallback));
  56. MOCK_METHOD0(Disconnect, void());
  57. MOCK_CONST_METHOD0(IsConnected, bool());
  58. MOCK_CONST_METHOD0(IsConnectedAndIdle, bool());
  59. MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*));
  60. MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*));
  61. MOCK_CONST_METHOD0(NetLog, const net::NetLogWithSource&());
  62. MOCK_CONST_METHOD0(WasEverUsed, bool());
  63. MOCK_CONST_METHOD0(UsingTCPFastOpen, bool());
  64. MOCK_CONST_METHOD0(NumBytesRead, int64_t());
  65. MOCK_CONST_METHOD0(GetConnectTimeMicros, base::TimeDelta());
  66. MOCK_CONST_METHOD0(WasAlpnNegotiated, bool());
  67. MOCK_CONST_METHOD0(GetNegotiatedProtocol, net::NextProto());
  68. MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*));
  69. MOCK_CONST_METHOD0(GetTotalReceivedBytes, int64_t());
  70. MOCK_METHOD1(ApplySocketTag, void(const net::SocketTag&));
  71. };
  72. // Break up |data| into a bunch of chunked MockReads/Writes and push
  73. // them onto |ops|.
  74. template <net::MockReadWriteType type>
  75. void AddChunkedOps(base::StringPiece data,
  76. size_t chunk_size,
  77. net::IoMode mode,
  78. std::vector<net::MockReadWrite<type>>* ops) {
  79. DCHECK_GT(chunk_size, 0U);
  80. size_t offset = 0;
  81. while (offset < data.size()) {
  82. size_t bounded_chunk_size = std::min(data.size() - offset, chunk_size);
  83. ops->push_back(net::MockReadWrite<type>(mode, data.data() + offset,
  84. bounded_chunk_size));
  85. offset += bounded_chunk_size;
  86. }
  87. }
  88. class FakeSSLClientSocketTest : public testing::Test {
  89. protected:
  90. FakeSSLClientSocketTest() {}
  91. ~FakeSSLClientSocketTest() override {}
  92. std::unique_ptr<net::StreamSocket> MakeClientSocket() {
  93. return mock_client_socket_factory_.CreateTransportClientSocket(
  94. net::AddressList(), nullptr, nullptr, nullptr, net::NetLogSource());
  95. }
  96. void SetData(const net::MockConnect& mock_connect,
  97. std::vector<net::MockRead>* reads,
  98. std::vector<net::MockWrite>* writes) {
  99. static_socket_data_provider_ =
  100. std::make_unique<net::StaticSocketDataProvider>(*reads, *writes);
  101. static_socket_data_provider_->set_connect_data(mock_connect);
  102. mock_client_socket_factory_.AddSocketDataProvider(
  103. static_socket_data_provider_.get());
  104. }
  105. void ExpectStatus(net::IoMode mode,
  106. int expected_status,
  107. int immediate_status,
  108. net::TestCompletionCallback* test_completion_callback) {
  109. if (mode == net::ASYNC) {
  110. EXPECT_EQ(net::ERR_IO_PENDING, immediate_status);
  111. int status = test_completion_callback->WaitForResult();
  112. EXPECT_EQ(expected_status, status);
  113. } else {
  114. EXPECT_EQ(expected_status, immediate_status);
  115. }
  116. }
  117. // Sets up the mock socket to generate a successful handshake
  118. // (sliced up according to the parameters) and makes sure the
  119. // FakeSSLClientSocket behaves as expected.
  120. void RunSuccessfulHandshakeTest(net::IoMode mode,
  121. size_t read_chunk_size,
  122. size_t write_chunk_size,
  123. int num_resets) {
  124. base::StringPiece ssl_client_hello =
  125. FakeSSLClientSocket::GetSslClientHello();
  126. base::StringPiece ssl_server_hello =
  127. FakeSSLClientSocket::GetSslServerHello();
  128. net::MockConnect mock_connect(mode, net::OK);
  129. std::vector<net::MockRead> reads;
  130. std::vector<net::MockWrite> writes;
  131. static const char kReadTestData[] = "read test data";
  132. static const char kWriteTestData[] = "write test data";
  133. for (int i = 0; i < num_resets + 1; ++i) {
  134. SCOPED_TRACE(i);
  135. AddChunkedOps(ssl_server_hello, read_chunk_size, mode, &reads);
  136. AddChunkedOps(ssl_client_hello, write_chunk_size, mode, &writes);
  137. reads.push_back(
  138. net::MockRead(mode, kReadTestData, std::size(kReadTestData)));
  139. writes.push_back(
  140. net::MockWrite(mode, kWriteTestData, std::size(kWriteTestData)));
  141. }
  142. SetData(mock_connect, &reads, &writes);
  143. FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket());
  144. for (int i = 0; i < num_resets + 1; ++i) {
  145. SCOPED_TRACE(i);
  146. net::TestCompletionCallback connect_callback;
  147. int status = fake_ssl_client_socket.Connect(connect_callback.callback());
  148. if (mode == net::ASYNC) {
  149. EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
  150. }
  151. ExpectStatus(mode, net::OK, status, &connect_callback);
  152. if (fake_ssl_client_socket.IsConnected()) {
  153. int read_len = std::size(kReadTestData);
  154. int read_buf_len = 2 * read_len;
  155. auto read_buf = base::MakeRefCounted<net::IOBuffer>(read_buf_len);
  156. net::TestCompletionCallback read_callback;
  157. int read_status = fake_ssl_client_socket.Read(
  158. read_buf.get(), read_buf_len, read_callback.callback());
  159. ExpectStatus(mode, read_len, read_status, &read_callback);
  160. auto write_buf =
  161. base::MakeRefCounted<net::StringIOBuffer>(kWriteTestData);
  162. net::TestCompletionCallback write_callback;
  163. int write_status = fake_ssl_client_socket.Write(
  164. write_buf.get(), std::size(kWriteTestData),
  165. write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
  166. ExpectStatus(mode, std::size(kWriteTestData), write_status,
  167. &write_callback);
  168. } else {
  169. ADD_FAILURE();
  170. }
  171. fake_ssl_client_socket.Disconnect();
  172. EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
  173. }
  174. }
  175. // Sets up the mock socket to generate an unsuccessful handshake
  176. // FakeSSLClientSocket fails as expected.
  177. void RunUnsuccessfulHandshakeTestHelper(net::IoMode mode,
  178. int error,
  179. HandshakeErrorLocation location) {
  180. DCHECK_NE(error, net::OK);
  181. base::StringPiece ssl_client_hello =
  182. FakeSSLClientSocket::GetSslClientHello();
  183. base::StringPiece ssl_server_hello =
  184. FakeSSLClientSocket::GetSslServerHello();
  185. net::MockConnect mock_connect(mode, net::OK);
  186. std::vector<net::MockRead> reads;
  187. std::vector<net::MockWrite> writes;
  188. const size_t kChunkSize = 1;
  189. AddChunkedOps(ssl_server_hello, kChunkSize, mode, &reads);
  190. AddChunkedOps(ssl_client_hello, kChunkSize, mode, &writes);
  191. switch (location) {
  192. case CONNECT_ERROR:
  193. mock_connect.result = error;
  194. writes.clear();
  195. reads.clear();
  196. break;
  197. case SEND_CLIENT_HELLO_ERROR: {
  198. // Use a fixed index for repeatability.
  199. size_t index = 100 % writes.size();
  200. writes[index].result = error;
  201. writes[index].data = NULL;
  202. writes[index].data_len = 0;
  203. writes.resize(index + 1);
  204. reads.clear();
  205. break;
  206. }
  207. case VERIFY_SERVER_HELLO_ERROR: {
  208. // Use a fixed index for repeatability.
  209. size_t index = 50 % reads.size();
  210. if (error == ERR_MALFORMED_SERVER_HELLO) {
  211. static const char kBadData[] = "BAD_DATA";
  212. reads[index].data = kBadData;
  213. reads[index].data_len = std::size(kBadData);
  214. } else {
  215. reads[index].result = error;
  216. reads[index].data = NULL;
  217. reads[index].data_len = 0;
  218. }
  219. reads.resize(index + 1);
  220. if (error == net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
  221. static const char kDummyData[] = "DUMMY";
  222. reads.push_back(net::MockRead(mode, kDummyData));
  223. }
  224. break;
  225. }
  226. }
  227. SetData(mock_connect, &reads, &writes);
  228. FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket());
  229. // The two errors below are interpreted by FakeSSLClientSocket as
  230. // an unexpected event.
  231. int expected_status =
  232. ((error == net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) ||
  233. (error == ERR_MALFORMED_SERVER_HELLO))
  234. ? net::ERR_UNEXPECTED
  235. : error;
  236. net::TestCompletionCallback test_completion_callback;
  237. int status =
  238. fake_ssl_client_socket.Connect(test_completion_callback.callback());
  239. EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
  240. ExpectStatus(mode, expected_status, status, &test_completion_callback);
  241. EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
  242. }
  243. void RunUnsuccessfulHandshakeTest(int error,
  244. HandshakeErrorLocation location) {
  245. RunUnsuccessfulHandshakeTestHelper(net::SYNCHRONOUS, error, location);
  246. RunUnsuccessfulHandshakeTestHelper(net::ASYNC, error, location);
  247. }
  248. // MockTCPClientSocket needs a message loop.
  249. base::test::SingleThreadTaskEnvironment task_environment_;
  250. net::MockClientSocketFactory mock_client_socket_factory_;
  251. std::unique_ptr<net::StaticSocketDataProvider> static_socket_data_provider_;
  252. };
  253. TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
  254. std::unique_ptr<MockClientSocket> mock_client_socket(new MockClientSocket());
  255. const int kReceiveBufferSize = 10;
  256. const int kSendBufferSize = 20;
  257. net::IPEndPoint ip_endpoint(net::IPAddress::IPv4AllZeros(), 80);
  258. const int kPeerAddress = 30;
  259. net::NetLogWithSource net_log;
  260. EXPECT_CALL(*mock_client_socket, SetReceiveBufferSize(kReceiveBufferSize));
  261. EXPECT_CALL(*mock_client_socket, SetSendBufferSize(kSendBufferSize));
  262. EXPECT_CALL(*mock_client_socket, GetPeerAddress(&ip_endpoint))
  263. .WillOnce(Return(kPeerAddress));
  264. EXPECT_CALL(*mock_client_socket, NetLog()).WillOnce(ReturnRef(net_log));
  265. // Takes ownership of |mock_client_socket|.
  266. FakeSSLClientSocket fake_ssl_client_socket(std::move(mock_client_socket));
  267. fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize);
  268. fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize);
  269. EXPECT_EQ(kPeerAddress, fake_ssl_client_socket.GetPeerAddress(&ip_endpoint));
  270. EXPECT_EQ(&net_log, &fake_ssl_client_socket.NetLog());
  271. }
  272. TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeSync) {
  273. for (size_t i = 1; i < 100; i += 3) {
  274. SCOPED_TRACE(i);
  275. for (size_t j = 1; j < 100; j += 5) {
  276. SCOPED_TRACE(j);
  277. RunSuccessfulHandshakeTest(net::SYNCHRONOUS, i, j, 0);
  278. }
  279. }
  280. }
  281. TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeAsync) {
  282. for (size_t i = 1; i < 100; i += 7) {
  283. SCOPED_TRACE(i);
  284. for (size_t j = 1; j < 100; j += 9) {
  285. SCOPED_TRACE(j);
  286. RunSuccessfulHandshakeTest(net::ASYNC, i, j, 0);
  287. }
  288. }
  289. }
  290. TEST_F(FakeSSLClientSocketTest, ResetSocket) {
  291. RunSuccessfulHandshakeTest(net::ASYNC, 1, 2, 3);
  292. }
  293. TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeConnectError) {
  294. RunUnsuccessfulHandshakeTest(net::ERR_ACCESS_DENIED, CONNECT_ERROR);
  295. }
  296. TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeWriteError) {
  297. RunUnsuccessfulHandshakeTest(net::ERR_OUT_OF_MEMORY, SEND_CLIENT_HELLO_ERROR);
  298. }
  299. TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeReadError) {
  300. RunUnsuccessfulHandshakeTest(net::ERR_CONNECTION_CLOSED,
  301. VERIFY_SERVER_HELLO_ERROR);
  302. }
  303. TEST_F(FakeSSLClientSocketTest, PeerClosedDuringHandshake) {
  304. RunUnsuccessfulHandshakeTest(net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ,
  305. VERIFY_SERVER_HELLO_ERROR);
  306. }
  307. TEST_F(FakeSSLClientSocketTest, MalformedServerHello) {
  308. RunUnsuccessfulHandshakeTest(ERR_MALFORMED_SERVER_HELLO,
  309. VERIFY_SERVER_HELLO_ERROR);
  310. }
  311. } // namespace
  312. } // namespace webrtc