proxy_resolving_socket_mojo_unittest.cc 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. // Copyright 2018 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include <string>
  5. #include <utility>
  6. #include <vector>
  7. #include "base/containers/span.h"
  8. #include "base/run_loop.h"
  9. #include "base/strings/stringprintf.h"
  10. #include "base/test/bind.h"
  11. #include "base/test/task_environment.h"
  12. #include "components/webrtc/fake_ssl_client_socket.h"
  13. #include "mojo/public/cpp/bindings/pending_receiver.h"
  14. #include "mojo/public/cpp/bindings/pending_remote.h"
  15. #include "mojo/public/cpp/bindings/receiver.h"
  16. #include "mojo/public/cpp/bindings/remote.h"
  17. #include "mojo/public/cpp/system/data_pipe_utils.h"
  18. #include "net/base/net_errors.h"
  19. #include "net/base/network_isolation_key.h"
  20. #include "net/base/test_completion_callback.h"
  21. #include "net/dns/mock_host_resolver.h"
  22. #include "net/proxy_resolution/configured_proxy_resolution_service.h"
  23. #include "net/socket/socket_test_util.h"
  24. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  25. #include "net/url_request/url_request_context_builder.h"
  26. #include "net/url_request/url_request_test_util.h"
  27. #include "services/network/mojo_socket_test_util.h"
  28. #include "services/network/proxy_resolving_socket_factory_mojo.h"
  29. #include "services/network/proxy_resolving_socket_mojo.h"
  30. #include "services/network/socket_factory.h"
  31. #include "testing/gtest/include/gtest/gtest.h"
  32. namespace network {
  33. class ProxyResolvingSocketTestBase {
  34. public:
  35. ProxyResolvingSocketTestBase(bool use_tls)
  36. : use_tls_(use_tls),
  37. fake_tls_handshake_(false),
  38. task_environment_(base::test::TaskEnvironment::MainThreadType::IO) {}
  39. ProxyResolvingSocketTestBase(const ProxyResolvingSocketTestBase&) = delete;
  40. ProxyResolvingSocketTestBase& operator=(const ProxyResolvingSocketTestBase&) =
  41. delete;
  42. ~ProxyResolvingSocketTestBase() {}
  43. void Init(const std::string& pac_result) {
  44. // Init() can be called multiple times in a test. Reset the members for each
  45. // invocation. `context_` must outlive `factory_impl_`, which uses the
  46. // URLRequestContext.
  47. factory_receiver_ = nullptr;
  48. factory_impl_ = nullptr;
  49. factory_remote_.reset();
  50. context_ = nullptr;
  51. mock_client_socket_factory_ =
  52. std::make_unique<net::MockClientSocketFactory>();
  53. mock_client_socket_factory_->set_enable_read_if_ready(true);
  54. auto context_builder = net::CreateTestURLRequestContextBuilder();
  55. context_builder->set_proxy_resolution_service(
  56. net::ConfiguredProxyResolutionService::CreateFixedFromPacResultForTest(
  57. pac_result, TRAFFIC_ANNOTATION_FOR_TESTS));
  58. context_builder->set_client_socket_factory_for_testing(
  59. mock_client_socket_factory_.get());
  60. context_ = context_builder->Build();
  61. factory_impl_ =
  62. std::make_unique<ProxyResolvingSocketFactoryMojo>(context_.get());
  63. factory_receiver_ =
  64. std::make_unique<mojo::Receiver<mojom::ProxyResolvingSocketFactory>>(
  65. factory_impl_.get(), factory_remote_.BindNewPipeAndPassReceiver());
  66. }
  67. // Reads |num_bytes| from |handle| or reads until an error occurs. Returns the
  68. // bytes read as a string.
  69. std::string Read(mojo::ScopedDataPipeConsumerHandle* handle,
  70. size_t num_bytes) {
  71. std::string received_contents;
  72. while (received_contents.size() < num_bytes) {
  73. base::RunLoop().RunUntilIdle();
  74. std::vector<char> buffer(num_bytes);
  75. uint32_t read_size =
  76. static_cast<uint32_t>(num_bytes - received_contents.size());
  77. MojoResult result = handle->get().ReadData(buffer.data(), &read_size,
  78. MOJO_READ_DATA_FLAG_NONE);
  79. if (result == MOJO_RESULT_SHOULD_WAIT)
  80. continue;
  81. if (result != MOJO_RESULT_OK)
  82. return received_contents;
  83. received_contents.append(buffer.data(), read_size);
  84. }
  85. return received_contents;
  86. }
  87. int CreateSocketSync(
  88. mojo::PendingReceiver<mojom::ProxyResolvingSocket> receiver,
  89. mojo::PendingRemote<mojom::SocketObserver> socket_observer,
  90. net::IPEndPoint* peer_addr_out,
  91. const GURL& url,
  92. mojo::ScopedDataPipeConsumerHandle* receive_pipe_handle_out,
  93. mojo::ScopedDataPipeProducerHandle* send_pipe_handle_out) {
  94. base::RunLoop run_loop;
  95. int net_error = net::ERR_FAILED;
  96. network::mojom::ProxyResolvingSocketOptionsPtr options =
  97. network::mojom::ProxyResolvingSocketOptions::New();
  98. options->use_tls = use_tls_;
  99. options->fake_tls_handshake = fake_tls_handshake_;
  100. factory_remote_->CreateProxyResolvingSocket(
  101. url, net::NetworkIsolationKey(), std::move(options),
  102. net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
  103. std::move(receiver), std::move(socket_observer),
  104. base::BindLambdaForTesting(
  105. [&](int result, const absl::optional<net::IPEndPoint>& local_addr,
  106. const absl::optional<net::IPEndPoint>& peer_addr,
  107. mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
  108. mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
  109. net_error = result;
  110. if (net_error == net::OK)
  111. EXPECT_NE(0, local_addr.value().port());
  112. if (peer_addr_out && peer_addr)
  113. *peer_addr_out = peer_addr.value();
  114. *receive_pipe_handle_out = std::move(receive_pipe_handle);
  115. *send_pipe_handle_out = std::move(send_pipe_handle);
  116. run_loop.Quit();
  117. }));
  118. run_loop.Run();
  119. return net_error;
  120. }
  121. net::MockClientSocketFactory* mock_client_socket_factory() {
  122. return mock_client_socket_factory_.get();
  123. }
  124. bool use_tls() const { return use_tls_; }
  125. void set_fake_tls_handshake(bool val) { fake_tls_handshake_ = val; }
  126. mojom::ProxyResolvingSocketFactory* factory() {
  127. return factory_remote_.get();
  128. }
  129. private:
  130. const bool use_tls_;
  131. bool fake_tls_handshake_;
  132. base::test::TaskEnvironment task_environment_;
  133. std::unique_ptr<net::MockClientSocketFactory> mock_client_socket_factory_;
  134. std::unique_ptr<net::URLRequestContext> context_;
  135. mojo::Remote<mojom::ProxyResolvingSocketFactory> factory_remote_;
  136. std::unique_ptr<mojo::Receiver<mojom::ProxyResolvingSocketFactory>>
  137. factory_receiver_;
  138. std::unique_ptr<ProxyResolvingSocketFactoryMojo> factory_impl_;
  139. };
  140. class ProxyResolvingSocketTest : public ProxyResolvingSocketTestBase,
  141. public testing::TestWithParam<bool> {
  142. public:
  143. ProxyResolvingSocketTest() : ProxyResolvingSocketTestBase(GetParam()) {}
  144. ProxyResolvingSocketTest(const ProxyResolvingSocketTest&) = delete;
  145. ProxyResolvingSocketTest& operator=(const ProxyResolvingSocketTest&) = delete;
  146. ~ProxyResolvingSocketTest() override {}
  147. };
  148. INSTANTIATE_TEST_SUITE_P(All,
  149. ProxyResolvingSocketTest,
  150. ::testing::Bool());
  151. // Tests that the connection is established to the proxy.
  152. TEST_P(ProxyResolvingSocketTest, ConnectToProxy) {
  153. const GURL kDestination("https://example.com:443");
  154. const int kProxyPort = 8009;
  155. const int kDirectPort = 443;
  156. for (bool is_direct : {true, false}) {
  157. net::MockClientSocketFactory socket_factory;
  158. std::unique_ptr<net::URLRequestContext> context;
  159. if (is_direct) {
  160. Init("DIRECT");
  161. } else {
  162. Init(base::StringPrintf("PROXY myproxy.com:%d", kProxyPort));
  163. }
  164. // Note that this read is not consumed when |!is_direct|.
  165. net::MockRead reads[] = {net::MockRead("HTTP/1.1 200 Success\r\n\r\n"),
  166. net::MockRead(net::ASYNC, net::OK)};
  167. // Note that this write is not consumed when |is_direct|.
  168. net::MockWrite writes[] = {
  169. net::MockWrite("CONNECT example.com:443 HTTP/1.1\r\n"
  170. "Host: example.com:443\r\n"
  171. "Proxy-Connection: keep-alive\r\n\r\n")};
  172. net::SSLSocketDataProvider ssl_socket(net::ASYNC, net::OK);
  173. mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_socket);
  174. net::StaticSocketDataProvider socket_data(reads, writes);
  175. net::IPEndPoint remote_addr(net::IPAddress(127, 0, 0, 1),
  176. is_direct ? kDirectPort : kProxyPort);
  177. socket_data.set_connect_data(
  178. net::MockConnect(net::ASYNC, net::OK, remote_addr));
  179. mock_client_socket_factory()->AddSocketDataProvider(&socket_data);
  180. mojo::PendingRemote<mojom::ProxyResolvingSocket> socket;
  181. mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle;
  182. mojo::ScopedDataPipeProducerHandle client_socket_send_handle;
  183. net::IPEndPoint actual_remote_addr;
  184. EXPECT_EQ(net::OK, CreateSocketSync(socket.InitWithNewPipeAndPassReceiver(),
  185. mojo::NullRemote() /* socket_observer*/,
  186. &actual_remote_addr, kDestination,
  187. &client_socket_receive_handle,
  188. &client_socket_send_handle));
  189. // Consume all read data.
  190. base::RunLoop().RunUntilIdle();
  191. if (!is_direct) {
  192. EXPECT_EQ(net::IPEndPoint(), actual_remote_addr);
  193. EXPECT_TRUE(socket_data.AllReadDataConsumed());
  194. EXPECT_TRUE(socket_data.AllWriteDataConsumed());
  195. } else {
  196. EXPECT_EQ(remote_addr.ToString(), actual_remote_addr.ToString());
  197. EXPECT_TRUE(socket_data.AllReadDataConsumed());
  198. EXPECT_FALSE(socket_data.AllWriteDataConsumed());
  199. }
  200. EXPECT_EQ(use_tls(), ssl_socket.ConnectDataConsumed());
  201. }
  202. }
  203. TEST_P(ProxyResolvingSocketTest, ConnectError) {
  204. const struct TestData {
  205. // Whether the error is encountered synchronously as opposed to
  206. // asynchronously.
  207. bool is_error_sync;
  208. // Whether it is using a direct connection as opposed to a proxy connection.
  209. bool is_direct;
  210. } kTestCases[] = {
  211. {true, true}, {true, false}, {false, true}, {false, false},
  212. };
  213. const GURL kDestination("https://example.com:443");
  214. for (auto test : kTestCases) {
  215. std::unique_ptr<net::URLRequestContext> context;
  216. if (test.is_direct) {
  217. Init("DIRECT");
  218. } else {
  219. Init("PROXY myproxy.com:89");
  220. }
  221. net::StaticSocketDataProvider socket_data;
  222. socket_data.set_connect_data(net::MockConnect(
  223. test.is_error_sync ? net::SYNCHRONOUS : net::ASYNC, net::ERR_FAILED));
  224. mock_client_socket_factory()->AddSocketDataProvider(&socket_data);
  225. mojo::PendingRemote<mojom::ProxyResolvingSocket> socket;
  226. mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle;
  227. mojo::ScopedDataPipeProducerHandle client_socket_send_handle;
  228. int status = CreateSocketSync(socket.InitWithNewPipeAndPassReceiver(),
  229. mojo::NullRemote() /* socket_observer*/,
  230. nullptr /* peer_addr_out */, kDestination,
  231. &client_socket_receive_handle,
  232. &client_socket_send_handle);
  233. if (test.is_direct) {
  234. EXPECT_EQ(net::ERR_FAILED, status);
  235. } else {
  236. EXPECT_EQ(net::ERR_PROXY_CONNECTION_FAILED, status);
  237. }
  238. EXPECT_TRUE(socket_data.AllReadDataConsumed());
  239. EXPECT_TRUE(socket_data.AllWriteDataConsumed());
  240. }
  241. }
  242. // Tests writing to and reading from a mojom::ProxyResolvingSocket.
  243. TEST_P(ProxyResolvingSocketTest, BasicReadWrite) {
  244. Init("DIRECT");
  245. mojo::PendingRemote<mojom::ProxyResolvingSocket> socket;
  246. const char kTestMsg[] = "abcdefghij";
  247. const size_t kMsgSize = strlen(kTestMsg);
  248. const int kNumIterations = 3;
  249. std::vector<net::MockRead> reads;
  250. std::vector<net::MockWrite> writes;
  251. int sequence_number = 0;
  252. for (int j = 0; j < kNumIterations; ++j) {
  253. for (size_t i = 0; i < kMsgSize; ++i) {
  254. reads.push_back(
  255. net::MockRead(net::ASYNC, &kTestMsg[i], 1, sequence_number++));
  256. }
  257. if (j == kNumIterations - 1) {
  258. reads.push_back(net::MockRead(net::ASYNC, net::OK, sequence_number++));
  259. }
  260. for (size_t i = 0; i < kMsgSize; ++i) {
  261. writes.push_back(
  262. net::MockWrite(net::ASYNC, &kTestMsg[i], 1, sequence_number++));
  263. }
  264. }
  265. net::StaticSocketDataProvider data_provider(reads, writes);
  266. data_provider.set_connect_data(net::MockConnect(net::SYNCHRONOUS, net::OK));
  267. mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
  268. net::SSLSocketDataProvider ssl_data(net::ASYNC, net::OK);
  269. mock_client_socket_factory()->AddSSLSocketDataProvider(&ssl_data);
  270. mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle;
  271. mojo::ScopedDataPipeProducerHandle client_socket_send_handle;
  272. const GURL kDestination("http://example.com");
  273. EXPECT_EQ(net::OK, CreateSocketSync(socket.InitWithNewPipeAndPassReceiver(),
  274. mojo::NullRemote() /* socket_observer */,
  275. nullptr /* peer_addr_out */, kDestination,
  276. &client_socket_receive_handle,
  277. &client_socket_send_handle));
  278. // Loop kNumIterations times to test that writes can follow reads, and reads
  279. // can follow writes.
  280. for (int j = 0; j < kNumIterations; ++j) {
  281. // Reading kMsgSize should coalesce the 1-byte mock reads.
  282. EXPECT_EQ(kTestMsg, Read(&client_socket_receive_handle, kMsgSize));
  283. // Write multiple times.
  284. for (size_t i = 0; i < kMsgSize; ++i) {
  285. uint32_t num_bytes = 1;
  286. EXPECT_EQ(MOJO_RESULT_OK,
  287. client_socket_send_handle->WriteData(
  288. &kTestMsg[i], &num_bytes, MOJO_WRITE_DATA_FLAG_NONE));
  289. // Flush the 1 byte write.
  290. base::RunLoop().RunUntilIdle();
  291. }
  292. }
  293. EXPECT_TRUE(data_provider.AllReadDataConsumed());
  294. EXPECT_TRUE(data_provider.AllWriteDataConsumed());
  295. EXPECT_EQ(use_tls(), ssl_data.ConnectDataConsumed());
  296. }
  297. // Tests that exercise logic related to mojo.
  298. class ProxyResolvingSocketMojoTest : public ProxyResolvingSocketTestBase,
  299. public testing::Test {
  300. public:
  301. ProxyResolvingSocketMojoTest() : ProxyResolvingSocketTestBase(false) {}
  302. ProxyResolvingSocketMojoTest(const ProxyResolvingSocketMojoTest&) = delete;
  303. ProxyResolvingSocketMojoTest& operator=(const ProxyResolvingSocketMojoTest&) =
  304. delete;
  305. ~ProxyResolvingSocketMojoTest() override {}
  306. };
  307. TEST_F(ProxyResolvingSocketMojoTest, ConnectWithFakeTLSHandshake) {
  308. const GURL kDestination("https://example.com:443");
  309. const char kTestMsg[] = "abcdefghij";
  310. const size_t kMsgSize = strlen(kTestMsg);
  311. Init("DIRECT");
  312. set_fake_tls_handshake(true);
  313. base::StringPiece client_hello =
  314. webrtc::FakeSSLClientSocket::GetSslClientHello();
  315. base::StringPiece server_hello =
  316. webrtc::FakeSSLClientSocket::GetSslServerHello();
  317. std::vector<net::MockRead> reads = {
  318. net::MockRead(net::ASYNC, server_hello.data(), server_hello.length(), 1),
  319. net::MockRead(net::ASYNC, 2, kTestMsg),
  320. net::MockRead(net::ASYNC, net::OK, 3)};
  321. std::vector<net::MockWrite> writes = {net::MockWrite(
  322. net::ASYNC, client_hello.data(), client_hello.length(), 0)};
  323. net::StaticSocketDataProvider data_provider(reads, writes);
  324. data_provider.set_connect_data(net::MockConnect(net::ASYNC, net::OK));
  325. mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
  326. mojo::PendingRemote<mojom::ProxyResolvingSocket> socket;
  327. mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle;
  328. mojo::ScopedDataPipeProducerHandle client_socket_send_handle;
  329. net::IPEndPoint actual_remote_addr;
  330. EXPECT_EQ(net::OK, CreateSocketSync(socket.InitWithNewPipeAndPassReceiver(),
  331. mojo::NullRemote() /* socket_observer*/,
  332. &actual_remote_addr, kDestination,
  333. &client_socket_receive_handle,
  334. &client_socket_send_handle));
  335. EXPECT_EQ(kTestMsg, Read(&client_socket_receive_handle, kMsgSize));
  336. EXPECT_TRUE(data_provider.AllReadDataConsumed());
  337. EXPECT_TRUE(data_provider.AllWriteDataConsumed());
  338. }
  339. // Tests that when ProxyResolvingSocket remote is destroyed but not the
  340. // ProxyResolvingSocketFactory, the connect callback is not dropped.
  341. // Regression test for https://crbug.com/862608.
  342. TEST_F(ProxyResolvingSocketMojoTest, SocketDestroyedBeforeConnectCompletes) {
  343. Init("DIRECT");
  344. std::vector<net::MockRead> reads;
  345. std::vector<net::MockWrite> writes;
  346. net::StaticSocketDataProvider data_provider(reads, writes);
  347. data_provider.set_connect_data(net::MockConnect(net::ASYNC, net::OK));
  348. mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
  349. const GURL kDestination("http://example.com");
  350. mojo::PendingRemote<mojom::ProxyResolvingSocket> socket;
  351. base::RunLoop run_loop;
  352. int net_error = net::OK;
  353. factory()->CreateProxyResolvingSocket(
  354. kDestination, net::NetworkIsolationKey(), nullptr,
  355. net::MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS),
  356. socket.InitWithNewPipeAndPassReceiver(),
  357. mojo::NullRemote() /* observer */,
  358. base::BindLambdaForTesting(
  359. [&](int result, const absl::optional<net::IPEndPoint>& local_addr,
  360. const absl::optional<net::IPEndPoint>& peer_addr,
  361. mojo::ScopedDataPipeConsumerHandle receive_pipe_handle,
  362. mojo::ScopedDataPipeProducerHandle send_pipe_handle) {
  363. net_error = result;
  364. }));
  365. socket.reset();
  366. base::RunLoop().RunUntilIdle();
  367. EXPECT_EQ(net::ERR_ABORTED, net_error);
  368. }
  369. TEST_F(ProxyResolvingSocketMojoTest, SocketObserver) {
  370. Init("DIRECT");
  371. const char kMsg[] = "message!";
  372. const char kMsgLen = strlen(kMsg);
  373. std::vector<net::MockRead> reads = {
  374. net::MockRead(kMsg),
  375. net::MockRead(net::ASYNC, net::ERR_CONNECTION_ABORTED)};
  376. std::vector<net::MockWrite> writes = {
  377. net::MockWrite(net::ASYNC, net::ERR_TIMED_OUT)};
  378. net::StaticSocketDataProvider data_provider(reads, writes);
  379. data_provider.set_connect_data(net::MockConnect(net::ASYNC, net::OK));
  380. mock_client_socket_factory()->AddSocketDataProvider(&data_provider);
  381. const GURL kDestination("http://example.com");
  382. mojo::PendingRemote<mojom::ProxyResolvingSocket> socket;
  383. mojo::ScopedDataPipeConsumerHandle client_socket_receive_handle;
  384. mojo::ScopedDataPipeProducerHandle client_socket_send_handle;
  385. TestSocketObserver test_observer;
  386. int status = CreateSocketSync(
  387. socket.InitWithNewPipeAndPassReceiver(),
  388. test_observer.GetObserverRemote(), nullptr /* peer_addr_out */,
  389. kDestination, &client_socket_receive_handle, &client_socket_send_handle);
  390. EXPECT_EQ(net::OK, status);
  391. EXPECT_EQ(kMsg, Read(&client_socket_receive_handle, kMsgLen));
  392. EXPECT_EQ(net::ERR_CONNECTION_ABORTED, test_observer.WaitForReadError());
  393. EXPECT_TRUE(mojo::BlockingCopyFromString(kMsg, client_socket_send_handle));
  394. EXPECT_EQ(net::ERR_TIMED_OUT, test_observer.WaitForWriteError());
  395. EXPECT_TRUE(data_provider.AllReadDataConsumed());
  396. EXPECT_TRUE(data_provider.AllWriteDataConsumed());
  397. }
  398. } // namespace network