cast_socket_unittest.cc 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168
  1. // Copyright 2014 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/cast_channel/cast_socket.h"
  5. #include <stdint.h>
  6. #include <memory>
  7. #include <utility>
  8. #include <vector>
  9. #include "base/bind.h"
  10. #include "base/callback_helpers.h"
  11. #include "base/files/file_util.h"
  12. #include "base/location.h"
  13. #include "base/memory/ptr_util.h"
  14. #include "base/memory/raw_ptr.h"
  15. #include "base/memory/weak_ptr.h"
  16. #include "base/path_service.h"
  17. #include "base/run_loop.h"
  18. #include "base/strings/string_number_conversions.h"
  19. #include "base/sys_byteorder.h"
  20. #include "base/task/single_thread_task_runner.h"
  21. #include "base/test/bind.h"
  22. #include "base/threading/thread_task_runner_handle.h"
  23. #include "base/timer/mock_timer.h"
  24. #include "build/build_config.h"
  25. #include "components/cast_channel/cast_auth_util.h"
  26. #include "components/cast_channel/cast_framer.h"
  27. #include "components/cast_channel/cast_message_util.h"
  28. #include "components/cast_channel/cast_test_util.h"
  29. #include "components/cast_channel/cast_transport.h"
  30. #include "components/cast_channel/logger.h"
  31. #include "content/public/test/browser_task_environment.h"
  32. #include "crypto/rsa_private_key.h"
  33. #include "mojo/public/cpp/bindings/remote.h"
  34. #include "net/base/address_list.h"
  35. #include "net/base/net_errors.h"
  36. #include "net/cert/pem.h"
  37. #include "net/socket/client_socket_factory.h"
  38. #include "net/socket/socket_test_util.h"
  39. #include "net/socket/ssl_client_socket.h"
  40. #include "net/socket/ssl_server_socket.h"
  41. #include "net/socket/tcp_client_socket.h"
  42. #include "net/socket/tcp_server_socket.h"
  43. #include "net/ssl/ssl_info.h"
  44. #include "net/ssl/ssl_server_config.h"
  45. #include "net/test/cert_test_util.h"
  46. #include "net/test/test_data_directory.h"
  47. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  48. #include "net/url_request/url_request_context.h"
  49. #include "net/url_request/url_request_context_builder.h"
  50. #include "net/url_request/url_request_test_util.h"
  51. #include "services/network/network_context.h"
  52. #include "testing/gmock/include/gmock/gmock.h"
  53. #include "testing/gtest/include/gtest/gtest.h"
  54. #include "third_party/openscreen/src/cast/common/channel/proto/cast_channel.pb.h"
  55. const int64_t kDistantTimeoutMillis = 100000; // 100 seconds (never hit).
  56. using ::testing::A;
  57. using ::testing::DoAll;
  58. using ::testing::Invoke;
  59. using ::testing::InvokeArgument;
  60. using ::testing::NotNull;
  61. using ::testing::Return;
  62. using ::testing::SaveArg;
  63. using ::testing::_;
  64. using ::cast::channel::CastMessage;
  65. namespace cast_channel {
  66. namespace {
  67. const char kAuthNamespace[] = "urn:x-cast:com.google.cast.tp.deviceauth";
  68. // Returns an auth challenge message inline.
  69. CastMessage CreateAuthChallenge() {
  70. CastMessage output;
  71. CreateAuthChallengeMessage(&output, AuthContext::Create());
  72. return output;
  73. }
  74. // Returns an auth challenge response message inline.
  75. CastMessage CreateAuthReply() {
  76. CastMessage output;
  77. output.set_protocol_version(CastMessage::CASTV2_1_0);
  78. output.set_source_id("sender-0");
  79. output.set_destination_id("receiver-0");
  80. output.set_payload_type(CastMessage::BINARY);
  81. output.set_payload_binary("abcd");
  82. output.set_namespace_(kAuthNamespace);
  83. return output;
  84. }
  85. CastMessage CreateTestMessage() {
  86. CastMessage test_message;
  87. test_message.set_protocol_version(CastMessage::CASTV2_1_0);
  88. test_message.set_namespace_("ns");
  89. test_message.set_source_id("source");
  90. test_message.set_destination_id("dest");
  91. test_message.set_payload_type(CastMessage::STRING);
  92. test_message.set_payload_utf8("payload");
  93. return test_message;
  94. }
  95. base::FilePath GetTestCertsDirectory() {
  96. base::FilePath path;
  97. base::PathService::Get(base::DIR_SOURCE_ROOT, &path);
  98. path = path.Append(FILE_PATH_LITERAL("components"));
  99. path = path.Append(FILE_PATH_LITERAL("test"));
  100. path = path.Append(FILE_PATH_LITERAL("data"));
  101. path = path.Append(FILE_PATH_LITERAL("cast_channel"));
  102. return path;
  103. }
  104. class MockTCPSocket : public net::MockTCPClientSocket {
  105. public:
  106. MockTCPSocket(bool do_nothing, net::SocketDataProvider* socket_provider)
  107. : net::MockTCPClientSocket(net::AddressList(), nullptr, socket_provider) {
  108. do_nothing_ = do_nothing;
  109. set_enable_read_if_ready(true);
  110. }
  111. MockTCPSocket(const MockTCPSocket&) = delete;
  112. MockTCPSocket& operator=(const MockTCPSocket&) = delete;
  113. int Connect(net::CompletionOnceCallback callback) override {
  114. if (do_nothing_) {
  115. // Stall the I/O event loop.
  116. return net::ERR_IO_PENDING;
  117. }
  118. return net::MockTCPClientSocket::Connect(std::move(callback));
  119. }
  120. private:
  121. bool do_nothing_;
  122. };
  123. class CompleteHandler {
  124. public:
  125. CompleteHandler() {}
  126. CompleteHandler(const CompleteHandler&) = delete;
  127. CompleteHandler& operator=(const CompleteHandler&) = delete;
  128. MOCK_METHOD1(OnCloseComplete, void(int result));
  129. MOCK_METHOD1(OnConnectComplete, void(CastSocket* socket));
  130. MOCK_METHOD1(OnWriteComplete, void(int result));
  131. MOCK_METHOD1(OnReadComplete, void(int result));
  132. };
  133. class TestCastSocketBase : public CastSocketImpl {
  134. public:
  135. TestCastSocketBase(network::mojom::NetworkContext* network_context,
  136. const CastSocketOpenParams& open_params,
  137. Logger* logger)
  138. : CastSocketImpl(base::BindRepeating(
  139. [](network::mojom::NetworkContext* network_context) {
  140. return network_context;
  141. },
  142. network_context),
  143. open_params,
  144. logger,
  145. AuthContext::Create()),
  146. verify_challenge_result_(true),
  147. verify_challenge_disallow_(false),
  148. mock_timer_(new base::MockOneShotTimer()) {
  149. SetPeerCertForTesting(
  150. net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem"));
  151. }
  152. TestCastSocketBase(const TestCastSocketBase&) = delete;
  153. TestCastSocketBase& operator=(const TestCastSocketBase&) = delete;
  154. ~TestCastSocketBase() override {}
  155. void SetVerifyChallengeResult(bool value) {
  156. verify_challenge_result_ = value;
  157. }
  158. void TriggerTimeout() { mock_timer_->Fire(); }
  159. bool TestVerifyChannelPolicyNone() {
  160. AuthResult authResult;
  161. return VerifyChannelPolicy(authResult);
  162. }
  163. void DisallowVerifyChallengeResult() { verify_challenge_disallow_ = true; }
  164. protected:
  165. bool VerifyChallengeReply() override {
  166. EXPECT_FALSE(verify_challenge_disallow_);
  167. return verify_challenge_result_;
  168. }
  169. base::OneShotTimer* GetTimer() override { return mock_timer_.get(); }
  170. // Simulated result of verifying challenge reply.
  171. bool verify_challenge_result_;
  172. bool verify_challenge_disallow_;
  173. std::unique_ptr<base::MockOneShotTimer> mock_timer_;
  174. };
  175. class MockTestCastSocket : public TestCastSocketBase {
  176. public:
  177. static std::unique_ptr<MockTestCastSocket> CreateSecure(
  178. network::mojom::NetworkContext* network_context,
  179. const CastSocketOpenParams& open_params,
  180. Logger* logger) {
  181. return std::make_unique<MockTestCastSocket>(network_context, open_params,
  182. logger);
  183. }
  184. using TestCastSocketBase::TestCastSocketBase;
  185. MockTestCastSocket(network::mojom::NetworkContext* network_context,
  186. const CastSocketOpenParams& open_params,
  187. Logger* logger)
  188. : TestCastSocketBase(network_context, open_params, logger) {}
  189. MockTestCastSocket(const MockTestCastSocket&) = delete;
  190. MockTestCastSocket& operator=(const MockTestCastSocket&) = delete;
  191. ~MockTestCastSocket() override {}
  192. void SetupMockTransport() {
  193. mock_transport_ = new MockCastTransport;
  194. SetTransportForTesting(base::WrapUnique(mock_transport_.get()));
  195. }
  196. bool TestVerifyChannelPolicyAudioOnly() {
  197. AuthResult authResult;
  198. authResult.channel_policies |= AuthResult::POLICY_AUDIO_ONLY;
  199. return VerifyChannelPolicy(authResult);
  200. }
  201. MockCastTransport* GetMockTransport() {
  202. CHECK(mock_transport_);
  203. return mock_transport_;
  204. }
  205. private:
  206. raw_ptr<MockCastTransport> mock_transport_ = nullptr;
  207. };
  208. // TODO(https://crbug.com/928467): Remove this class.
  209. class TestSocketFactory : public net::ClientSocketFactory {
  210. public:
  211. explicit TestSocketFactory(net::IPEndPoint ip) : ip_(ip) {}
  212. TestSocketFactory(const TestSocketFactory&) = delete;
  213. TestSocketFactory& operator=(const TestSocketFactory&) = delete;
  214. ~TestSocketFactory() override = default;
  215. // Socket connection helpers.
  216. void SetupTcpConnect(net::IoMode mode, int result) {
  217. tcp_connect_data_ = std::make_unique<net::MockConnect>(mode, result, ip_);
  218. }
  219. void SetupSslConnect(net::IoMode mode, int result) {
  220. ssl_connect_data_ = std::make_unique<net::MockConnect>(mode, result, ip_);
  221. }
  222. // Socket I/O helpers.
  223. void AddWriteResult(const net::MockWrite& write) { writes_.push_back(write); }
  224. void AddWriteResult(net::IoMode mode, int result) {
  225. AddWriteResult(net::MockWrite(mode, result));
  226. }
  227. void AddWriteResultForData(net::IoMode mode, const std::string& msg) {
  228. AddWriteResult(mode, msg.size());
  229. }
  230. void AddReadResult(const net::MockRead& read) { reads_.push_back(read); }
  231. void AddReadResult(net::IoMode mode, int result) {
  232. AddReadResult(net::MockRead(mode, result));
  233. }
  234. void AddReadResultForData(net::IoMode mode, const std::string& data) {
  235. AddReadResult(net::MockRead(mode, data.c_str(), data.size()));
  236. }
  237. // Helpers for modifying other connection-related behaviors.
  238. void SetupTcpConnectUnresponsive() { tcp_unresponsive_ = true; }
  239. void SetTcpSocket(
  240. std::unique_ptr<net::TransportClientSocket> tcp_client_socket) {
  241. tcp_client_socket_ = std::move(tcp_client_socket);
  242. }
  243. void SetTLSSocketCreatedClosure(base::OnceClosure closure) {
  244. tls_socket_created_ = std::move(closure);
  245. }
  246. void Pause() {
  247. if (socket_data_provider_)
  248. socket_data_provider_->Pause();
  249. else
  250. socket_data_provider_paused_ = true;
  251. }
  252. void Resume() { socket_data_provider_->Resume(); }
  253. private:
  254. std::unique_ptr<net::DatagramClientSocket> CreateDatagramClientSocket(
  255. net::DatagramSocket::BindType,
  256. net::NetLog*,
  257. const net::NetLogSource&) override {
  258. NOTIMPLEMENTED();
  259. return nullptr;
  260. }
  261. std::unique_ptr<net::TransportClientSocket> CreateTransportClientSocket(
  262. const net::AddressList&,
  263. std::unique_ptr<net::SocketPerformanceWatcher>,
  264. net::NetworkQualityEstimator*,
  265. net::NetLog*,
  266. const net::NetLogSource&) override {
  267. if (tcp_client_socket_)
  268. return std::move(tcp_client_socket_);
  269. if (tcp_unresponsive_) {
  270. socket_data_provider_ = std::make_unique<net::StaticSocketDataProvider>();
  271. return std::unique_ptr<net::TransportClientSocket>(
  272. new MockTCPSocket(true, socket_data_provider_.get()));
  273. } else {
  274. socket_data_provider_ =
  275. std::make_unique<net::StaticSocketDataProvider>(reads_, writes_);
  276. socket_data_provider_->set_connect_data(*tcp_connect_data_);
  277. if (socket_data_provider_paused_)
  278. socket_data_provider_->Pause();
  279. return std::unique_ptr<net::TransportClientSocket>(
  280. new MockTCPSocket(false, socket_data_provider_.get()));
  281. }
  282. }
  283. std::unique_ptr<net::SSLClientSocket> CreateSSLClientSocket(
  284. net::SSLClientContext* context,
  285. std::unique_ptr<net::StreamSocket> nested_socket,
  286. const net::HostPortPair& host_and_port,
  287. const net::SSLConfig& ssl_config) override {
  288. if (!ssl_connect_data_) {
  289. // Test isn't overriding SSL socket creation.
  290. return net::ClientSocketFactory::GetDefaultFactory()
  291. ->CreateSSLClientSocket(context, std::move(nested_socket),
  292. host_and_port, ssl_config);
  293. }
  294. ssl_socket_data_provider_ = std::make_unique<net::SSLSocketDataProvider>(
  295. ssl_connect_data_->mode, ssl_connect_data_->result);
  296. if (tls_socket_created_)
  297. std::move(tls_socket_created_).Run();
  298. return std::make_unique<net::MockSSLClientSocket>(
  299. std::move(nested_socket), net::HostPortPair(), net::SSLConfig(),
  300. ssl_socket_data_provider_.get());
  301. }
  302. net::IPEndPoint ip_;
  303. // Simulated connect data
  304. std::unique_ptr<net::MockConnect> tcp_connect_data_;
  305. std::unique_ptr<net::MockConnect> ssl_connect_data_;
  306. // Simulated read / write data
  307. std::vector<net::MockWrite> writes_;
  308. std::vector<net::MockRead> reads_;
  309. std::unique_ptr<net::StaticSocketDataProvider> socket_data_provider_;
  310. std::unique_ptr<net::SSLSocketDataProvider> ssl_socket_data_provider_;
  311. bool socket_data_provider_paused_ = false;
  312. // If true, makes TCP connection process stall. For timeout testing.
  313. bool tcp_unresponsive_ = false;
  314. std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
  315. base::OnceClosure tls_socket_created_;
  316. };
  317. class CastSocketTestBase : public testing::Test {
  318. protected:
  319. CastSocketTestBase()
  320. : task_environment_(content::BrowserTaskEnvironment::IO_MAINLOOP),
  321. logger_(new Logger()),
  322. observer_(new MockCastSocketObserver()),
  323. socket_open_params_(CreateIPEndPointForTest(),
  324. base::Milliseconds(kDistantTimeoutMillis)),
  325. client_socket_factory_(socket_open_params_.ip_endpoint) {}
  326. CastSocketTestBase(const CastSocketTestBase&) = delete;
  327. CastSocketTestBase& operator=(const CastSocketTestBase&) = delete;
  328. ~CastSocketTestBase() override {}
  329. void SetUp() override {
  330. EXPECT_CALL(*observer_, OnMessage(_, _)).Times(0);
  331. auto context_builder = net::CreateTestURLRequestContextBuilder();
  332. context_builder->set_client_socket_factory_for_testing(
  333. &client_socket_factory_);
  334. url_request_context_ = context_builder->Build();
  335. network_context_ = std::make_unique<network::NetworkContext>(
  336. nullptr, network_context_remote_.BindNewPipeAndPassReceiver(),
  337. url_request_context_.get(),
  338. /*cors_exempt_header_list=*/std::vector<std::string>());
  339. }
  340. // Runs all pending tasks in the message loop.
  341. void RunPendingTasks() {
  342. base::RunLoop run_loop;
  343. run_loop.RunUntilIdle();
  344. }
  345. TestSocketFactory* client_socket_factory() { return &client_socket_factory_; }
  346. content::BrowserTaskEnvironment task_environment_;
  347. std::unique_ptr<net::URLRequestContext> url_request_context_;
  348. std::unique_ptr<network::NetworkContext> network_context_;
  349. mojo::Remote<network::mojom::NetworkContext> network_context_remote_;
  350. raw_ptr<Logger> logger_;
  351. CompleteHandler handler_;
  352. std::unique_ptr<MockCastSocketObserver> observer_;
  353. CastSocketOpenParams socket_open_params_;
  354. TestSocketFactory client_socket_factory_;
  355. };
  356. class MockCastSocketTest : public CastSocketTestBase {
  357. public:
  358. MockCastSocketTest(const MockCastSocketTest&) = delete;
  359. MockCastSocketTest& operator=(const MockCastSocketTest&) = delete;
  360. protected:
  361. MockCastSocketTest() {}
  362. void TearDown() override {
  363. if (socket_) {
  364. EXPECT_CALL(handler_, OnCloseComplete(net::OK));
  365. socket_->Close(base::BindOnce(&CompleteHandler::OnCloseComplete,
  366. base::Unretained(&handler_)));
  367. }
  368. }
  369. void CreateCastSocketSecure() {
  370. socket_ = MockTestCastSocket::CreateSecure(network_context_.get(),
  371. socket_open_params_, logger_);
  372. }
  373. void HandleAuthHandshake() {
  374. socket_->SetupMockTransport();
  375. CastMessage challenge_proto = CreateAuthChallenge();
  376. EXPECT_CALL(*socket_->GetMockTransport(),
  377. SendMessage_(EqualsProto(challenge_proto), _))
  378. .WillOnce(PostCompletionCallbackTask<1>(net::OK));
  379. EXPECT_CALL(*socket_->GetMockTransport(), Start());
  380. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  381. socket_->AddObserver(observer_.get());
  382. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  383. base::Unretained(&handler_)));
  384. RunPendingTasks();
  385. socket_->GetMockTransport()->current_delegate()->OnMessage(
  386. CreateAuthReply());
  387. RunPendingTasks();
  388. }
  389. std::unique_ptr<MockTestCastSocket> socket_;
  390. };
  391. class SslCastSocketTest : public CastSocketTestBase {
  392. public:
  393. SslCastSocketTest(const SslCastSocketTest&) = delete;
  394. SslCastSocketTest& operator=(const SslCastSocketTest&) = delete;
  395. protected:
  396. SslCastSocketTest() {}
  397. void TearDown() override {
  398. if (socket_) {
  399. EXPECT_CALL(handler_, OnCloseComplete(net::OK));
  400. socket_->Close(base::BindOnce(&CompleteHandler::OnCloseComplete,
  401. base::Unretained(&handler_)));
  402. }
  403. }
  404. void CreateSockets() {
  405. socket_ = std::make_unique<TestCastSocketBase>(
  406. network_context_.get(), socket_open_params_, logger_);
  407. server_cert_ =
  408. net::ImportCertFromFile(GetTestCertsDirectory(), "self_signed.pem");
  409. ASSERT_TRUE(server_cert_);
  410. server_private_key_ = ReadTestKeyFromPEM("self_signed.pem");
  411. ASSERT_TRUE(server_private_key_);
  412. server_context_ = CreateSSLServerContext(
  413. server_cert_.get(), *server_private_key_, server_ssl_config_);
  414. tcp_server_socket_ =
  415. std::make_unique<net::TCPServerSocket>(nullptr, net::NetLogSource());
  416. ASSERT_EQ(net::OK,
  417. tcp_server_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 1));
  418. net::IPEndPoint server_address;
  419. ASSERT_EQ(net::OK, tcp_server_socket_->GetLocalAddress(&server_address));
  420. tcp_client_socket_ = std::make_unique<net::TCPClientSocket>(
  421. net::AddressList(server_address), nullptr, nullptr, nullptr,
  422. net::NetLogSource());
  423. std::unique_ptr<net::StreamSocket> accepted_socket;
  424. accept_result_ = tcp_server_socket_->Accept(
  425. &accepted_socket, base::BindOnce(&SslCastSocketTest::TcpAcceptCallback,
  426. base::Unretained(this)));
  427. connect_result_ = tcp_client_socket_->Connect(base::BindOnce(
  428. &SslCastSocketTest::TcpConnectCallback, base::Unretained(this)));
  429. while (accept_result_ == net::ERR_IO_PENDING ||
  430. connect_result_ == net::ERR_IO_PENDING) {
  431. RunPendingTasks();
  432. }
  433. ASSERT_EQ(net::OK, accept_result_);
  434. ASSERT_EQ(net::OK, connect_result_);
  435. ASSERT_TRUE(accepted_socket);
  436. ASSERT_TRUE(tcp_client_socket_->IsConnected());
  437. server_socket_ =
  438. server_context_->CreateSSLServerSocket(std::move(accepted_socket));
  439. ASSERT_TRUE(server_socket_);
  440. client_socket_factory()->SetTcpSocket(std::move(tcp_client_socket_));
  441. }
  442. void ConnectSockets() {
  443. socket_->AddObserver(observer_.get());
  444. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  445. base::Unretained(&handler_)));
  446. net::TestCompletionCallback handshake_callback;
  447. int server_ret = handshake_callback.GetResult(
  448. server_socket_->Handshake(handshake_callback.callback()));
  449. ASSERT_EQ(net::OK, server_ret);
  450. }
  451. void TcpAcceptCallback(int result) { accept_result_ = result; }
  452. void TcpConnectCallback(int result) { connect_result_ = result; }
  453. std::unique_ptr<crypto::RSAPrivateKey> ReadTestKeyFromPEM(
  454. const base::StringPiece& name) {
  455. base::FilePath key_path = GetTestCertsDirectory().AppendASCII(name);
  456. std::string pem_data;
  457. if (!base::ReadFileToString(key_path, &pem_data)) {
  458. return nullptr;
  459. }
  460. const std::vector<std::string> headers({"PRIVATE KEY"});
  461. net::PEMTokenizer pem_tokenizer(pem_data, headers);
  462. if (!pem_tokenizer.GetNext()) {
  463. return nullptr;
  464. }
  465. std::vector<uint8_t> key_vector(pem_tokenizer.data().begin(),
  466. pem_tokenizer.data().end());
  467. std::unique_ptr<crypto::RSAPrivateKey> key(
  468. crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
  469. return key;
  470. }
  471. int ReadExactLength(net::IOBuffer* buffer,
  472. int buffer_length,
  473. net::Socket* socket) {
  474. scoped_refptr<net::DrainableIOBuffer> draining_buffer =
  475. base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
  476. while (draining_buffer->BytesRemaining() > 0) {
  477. net::TestCompletionCallback read_callback;
  478. int read_result = read_callback.GetResult(server_socket_->Read(
  479. draining_buffer.get(), draining_buffer->BytesRemaining(),
  480. read_callback.callback()));
  481. EXPECT_GT(read_result, 0);
  482. draining_buffer->DidConsume(read_result);
  483. }
  484. return buffer_length;
  485. }
  486. int WriteExactLength(net::IOBuffer* buffer,
  487. int buffer_length,
  488. net::Socket* socket) {
  489. scoped_refptr<net::DrainableIOBuffer> draining_buffer =
  490. base::MakeRefCounted<net::DrainableIOBuffer>(buffer, buffer_length);
  491. while (draining_buffer->BytesRemaining() > 0) {
  492. net::TestCompletionCallback write_callback;
  493. int write_result = write_callback.GetResult(server_socket_->Write(
  494. draining_buffer.get(), draining_buffer->BytesRemaining(),
  495. write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS));
  496. EXPECT_GT(write_result, 0);
  497. draining_buffer->DidConsume(write_result);
  498. }
  499. return buffer_length;
  500. }
  501. // Result values used for TCP socket setup. These should contain values from
  502. // net::Error.
  503. int accept_result_;
  504. int connect_result_;
  505. // Underlying TCP sockets for |socket_| to communicate with |server_socket_|
  506. // when testing with the real SSL implementation.
  507. std::unique_ptr<net::TransportClientSocket> tcp_client_socket_;
  508. std::unique_ptr<net::TCPServerSocket> tcp_server_socket_;
  509. std::unique_ptr<TestCastSocketBase> socket_;
  510. // |server_socket_| is used for the *RealSSL tests in order to test the
  511. // CastSocket over a real SSL socket. The other members below are used to
  512. // initialize |server_socket_|.
  513. std::unique_ptr<net::SSLServerSocket> server_socket_;
  514. std::unique_ptr<net::SSLServerContext> server_context_;
  515. std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
  516. scoped_refptr<net::X509Certificate> server_cert_;
  517. net::SSLServerConfig server_ssl_config_;
  518. };
  519. } // namespace
  520. // Tests that the following connection flow works:
  521. // - TCP connection succeeds (async)
  522. // - SSL connection succeeds (async)
  523. // - Cert is extracted successfully
  524. // - Challenge request is sent (async)
  525. // - Challenge response is received (async)
  526. // - Credentials are verified successfuly
  527. TEST_F(MockCastSocketTest, TestConnectFullSecureFlowAsync) {
  528. CreateCastSocketSecure();
  529. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
  530. client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
  531. HandleAuthHandshake();
  532. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  533. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  534. }
  535. // Tests that the following connection flow works:
  536. // - TCP connection succeeds (sync)
  537. // - SSL connection succeeds (sync)
  538. // - Cert is extracted successfully
  539. // - Challenge request is sent (sync)
  540. // - Challenge response is received (sync)
  541. // - Credentials are verified successfuly
  542. TEST_F(MockCastSocketTest, TestConnectFullSecureFlowSync) {
  543. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  544. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
  545. CreateCastSocketSecure();
  546. HandleAuthHandshake();
  547. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  548. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  549. }
  550. // Test that an AuthMessage with a mangled namespace triggers cancelation
  551. // of the connection event loop.
  552. TEST_F(MockCastSocketTest, TestConnectAuthMessageCorrupted) {
  553. CreateCastSocketSecure();
  554. socket_->SetupMockTransport();
  555. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
  556. client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
  557. CastMessage challenge_proto = CreateAuthChallenge();
  558. EXPECT_CALL(*socket_->GetMockTransport(),
  559. SendMessage_(EqualsProto(challenge_proto), _))
  560. .WillOnce(PostCompletionCallbackTask<1>(net::OK));
  561. EXPECT_CALL(*socket_->GetMockTransport(), Start());
  562. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  563. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  564. base::Unretained(&handler_)));
  565. RunPendingTasks();
  566. CastMessage mangled_auth_reply = CreateAuthReply();
  567. mangled_auth_reply.set_namespace_("BOGUS_NAMESPACE");
  568. socket_->GetMockTransport()->current_delegate()->OnMessage(
  569. mangled_auth_reply);
  570. RunPendingTasks();
  571. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  572. EXPECT_EQ(ChannelError::TRANSPORT_ERROR, socket_->error_state());
  573. // Verifies that the CastSocket's resources were torn down during channel
  574. // close. (see http://crbug.com/504078)
  575. EXPECT_EQ(nullptr, socket_->transport());
  576. }
  577. // Test connection error - TCP connect fails (async)
  578. TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorAsync) {
  579. CreateCastSocketSecure();
  580. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
  581. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  582. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  583. base::Unretained(&handler_)));
  584. RunPendingTasks();
  585. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  586. EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
  587. }
  588. // Test connection error - TCP connect fails (sync)
  589. TEST_F(MockCastSocketTest, TestConnectTcpConnectErrorSync) {
  590. CreateCastSocketSecure();
  591. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::ERR_FAILED);
  592. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  593. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  594. base::Unretained(&handler_)));
  595. RunPendingTasks();
  596. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  597. EXPECT_EQ(ChannelError::CONNECT_ERROR, socket_->error_state());
  598. }
  599. // Test connection error - timeout
  600. TEST_F(MockCastSocketTest, TestConnectTcpTimeoutError) {
  601. CreateCastSocketSecure();
  602. client_socket_factory()->SetupTcpConnectUnresponsive();
  603. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  604. EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
  605. socket_->AddObserver(observer_.get());
  606. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  607. base::Unretained(&handler_)));
  608. RunPendingTasks();
  609. EXPECT_EQ(ReadyState::CONNECTING, socket_->ready_state());
  610. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  611. socket_->TriggerTimeout();
  612. RunPendingTasks();
  613. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  614. EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
  615. }
  616. // Test connection error - TCP socket returns timeout
  617. TEST_F(MockCastSocketTest, TestConnectTcpSocketTimeoutError) {
  618. CreateCastSocketSecure();
  619. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS,
  620. net::ERR_CONNECTION_TIMED_OUT);
  621. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  622. EXPECT_CALL(*observer_, OnError(_, ChannelError::CONNECT_TIMEOUT));
  623. socket_->AddObserver(observer_.get());
  624. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  625. base::Unretained(&handler_)));
  626. RunPendingTasks();
  627. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  628. EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
  629. EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
  630. logger_->GetLastError(socket_->id()).net_return_value);
  631. }
  632. // Test connection error - SSL connect fails (async)
  633. TEST_F(MockCastSocketTest, TestConnectSslConnectErrorAsync) {
  634. CreateCastSocketSecure();
  635. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  636. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
  637. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  638. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  639. base::Unretained(&handler_)));
  640. RunPendingTasks();
  641. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  642. EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
  643. }
  644. // Test connection error - SSL connect fails (sync)
  645. TEST_F(MockCastSocketTest, TestConnectSslConnectErrorSync) {
  646. CreateCastSocketSecure();
  647. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  648. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::ERR_FAILED);
  649. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  650. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  651. base::Unretained(&handler_)));
  652. RunPendingTasks();
  653. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  654. EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
  655. EXPECT_EQ(net::ERR_FAILED,
  656. logger_->GetLastError(socket_->id()).net_return_value);
  657. }
  658. // Test connection error - SSL connect times out (sync)
  659. TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutSync) {
  660. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  661. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS,
  662. net::ERR_CONNECTION_TIMED_OUT);
  663. CreateCastSocketSecure();
  664. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  665. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  666. base::Unretained(&handler_)));
  667. RunPendingTasks();
  668. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  669. EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
  670. EXPECT_EQ(net::ERR_CONNECTION_TIMED_OUT,
  671. logger_->GetLastError(socket_->id()).net_return_value);
  672. }
  673. // Test connection error - SSL connect times out (async)
  674. TEST_F(MockCastSocketTest, TestConnectSslConnectTimeoutAsync) {
  675. CreateCastSocketSecure();
  676. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
  677. client_socket_factory()->SetupSslConnect(net::ASYNC,
  678. net::ERR_CONNECTION_TIMED_OUT);
  679. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  680. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  681. base::Unretained(&handler_)));
  682. RunPendingTasks();
  683. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  684. EXPECT_EQ(ChannelError::CONNECT_TIMEOUT, socket_->error_state());
  685. }
  686. // Test connection error - challenge send fails
  687. TEST_F(MockCastSocketTest, TestConnectChallengeSendError) {
  688. CreateCastSocketSecure();
  689. socket_->SetupMockTransport();
  690. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  691. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
  692. EXPECT_CALL(*socket_->GetMockTransport(),
  693. SendMessage_(EqualsProto(CreateAuthChallenge()), _))
  694. .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
  695. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  696. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  697. base::Unretained(&handler_)));
  698. RunPendingTasks();
  699. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  700. EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
  701. }
  702. // Test connection error - connection is destroyed after the challenge is
  703. // sent, with the async result still lurking in the task queue.
  704. TEST_F(MockCastSocketTest, TestConnectDestroyedAfterChallengeSent) {
  705. CreateCastSocketSecure();
  706. socket_->SetupMockTransport();
  707. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  708. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
  709. EXPECT_CALL(*socket_->GetMockTransport(),
  710. SendMessage_(EqualsProto(CreateAuthChallenge()), _))
  711. .WillOnce(PostCompletionCallbackTask<1>(net::ERR_CONNECTION_RESET));
  712. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  713. base::Unretained(&handler_)));
  714. RunPendingTasks();
  715. socket_.reset();
  716. RunPendingTasks();
  717. }
  718. // Test connection error - challenge reply receive fails
  719. TEST_F(MockCastSocketTest, TestConnectChallengeReplyReceiveError) {
  720. CreateCastSocketSecure();
  721. socket_->SetupMockTransport();
  722. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  723. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
  724. EXPECT_CALL(*socket_->GetMockTransport(),
  725. SendMessage_(EqualsProto(CreateAuthChallenge()), _))
  726. .WillOnce(PostCompletionCallbackTask<1>(net::OK));
  727. client_socket_factory()->AddReadResult(net::SYNCHRONOUS, net::ERR_FAILED);
  728. EXPECT_CALL(*observer_, OnError(_, ChannelError::CAST_SOCKET_ERROR));
  729. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  730. EXPECT_CALL(*socket_->GetMockTransport(), Start());
  731. socket_->AddObserver(observer_.get());
  732. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  733. base::Unretained(&handler_)));
  734. RunPendingTasks();
  735. socket_->GetMockTransport()->current_delegate()->OnError(
  736. ChannelError::CAST_SOCKET_ERROR);
  737. RunPendingTasks();
  738. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  739. EXPECT_EQ(ChannelError::CAST_SOCKET_ERROR, socket_->error_state());
  740. }
  741. TEST_F(MockCastSocketTest, TestConnectChallengeVerificationFails) {
  742. CreateCastSocketSecure();
  743. socket_->SetupMockTransport();
  744. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
  745. client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
  746. socket_->SetVerifyChallengeResult(false);
  747. EXPECT_CALL(*observer_, OnError(_, ChannelError::AUTHENTICATION_ERROR));
  748. CastMessage challenge_proto = CreateAuthChallenge();
  749. EXPECT_CALL(*socket_->GetMockTransport(),
  750. SendMessage_(EqualsProto(challenge_proto), _))
  751. .WillOnce(PostCompletionCallbackTask<1>(net::OK));
  752. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  753. EXPECT_CALL(*socket_->GetMockTransport(), Start());
  754. socket_->AddObserver(observer_.get());
  755. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  756. base::Unretained(&handler_)));
  757. RunPendingTasks();
  758. socket_->GetMockTransport()->current_delegate()->OnMessage(CreateAuthReply());
  759. RunPendingTasks();
  760. EXPECT_EQ(ReadyState::CLOSED, socket_->ready_state());
  761. EXPECT_EQ(ChannelError::AUTHENTICATION_ERROR, socket_->error_state());
  762. }
  763. // Sends message data through an actual non-mocked CastTransport object,
  764. // testing the two components in integration.
  765. TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportAsync) {
  766. CreateCastSocketSecure();
  767. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
  768. client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
  769. // Set low-level auth challenge expectations.
  770. CastMessage challenge = CreateAuthChallenge();
  771. std::string challenge_str;
  772. EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
  773. client_socket_factory()->AddWriteResultForData(net::ASYNC, challenge_str);
  774. // Set low-level auth reply expectations.
  775. CastMessage reply = CreateAuthReply();
  776. std::string reply_str;
  777. EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
  778. client_socket_factory()->AddReadResultForData(net::ASYNC, reply_str);
  779. client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
  780. // Make sure the data is ready by the TLS socket and not the TCP socket.
  781. client_socket_factory()->Pause();
  782. client_socket_factory()->SetTLSSocketCreatedClosure(
  783. base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
  784. CastMessage test_message = CreateTestMessage();
  785. std::string test_message_str;
  786. EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
  787. client_socket_factory()->AddWriteResultForData(net::ASYNC, test_message_str);
  788. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  789. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  790. base::Unretained(&handler_)));
  791. RunPendingTasks();
  792. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  793. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  794. // Send the test message through a real transport object.
  795. EXPECT_CALL(handler_, OnWriteComplete(net::OK));
  796. socket_->transport()->SendMessage(
  797. test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
  798. base::Unretained(&handler_)));
  799. RunPendingTasks();
  800. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  801. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  802. }
  803. // Same as TestConnectEndToEndWithRealTransportAsync, except synchronous.
  804. TEST_F(MockCastSocketTest, TestConnectEndToEndWithRealTransportSync) {
  805. CreateCastSocketSecure();
  806. client_socket_factory()->SetupTcpConnect(net::SYNCHRONOUS, net::OK);
  807. client_socket_factory()->SetupSslConnect(net::SYNCHRONOUS, net::OK);
  808. // Set low-level auth challenge expectations.
  809. CastMessage challenge = CreateAuthChallenge();
  810. std::string challenge_str;
  811. EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
  812. client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
  813. challenge_str);
  814. // Set low-level auth reply expectations.
  815. CastMessage reply = CreateAuthReply();
  816. std::string reply_str;
  817. EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
  818. client_socket_factory()->AddReadResultForData(net::SYNCHRONOUS, reply_str);
  819. client_socket_factory()->AddReadResult(net::ASYNC, net::ERR_IO_PENDING);
  820. // Make sure the data is ready by the TLS socket and not the TCP socket.
  821. client_socket_factory()->Pause();
  822. client_socket_factory()->SetTLSSocketCreatedClosure(
  823. base::BindLambdaForTesting([&] { client_socket_factory()->Resume(); }));
  824. CastMessage test_message = CreateTestMessage();
  825. std::string test_message_str;
  826. EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
  827. client_socket_factory()->AddWriteResultForData(net::SYNCHRONOUS,
  828. test_message_str);
  829. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  830. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  831. base::Unretained(&handler_)));
  832. RunPendingTasks();
  833. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  834. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  835. // Send the test message through a real transport object.
  836. EXPECT_CALL(handler_, OnWriteComplete(net::OK));
  837. socket_->transport()->SendMessage(
  838. test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
  839. base::Unretained(&handler_)));
  840. RunPendingTasks();
  841. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  842. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  843. }
  844. TEST_F(MockCastSocketTest, TestObservers) {
  845. CreateCastSocketSecure();
  846. // Test AddObserever
  847. MockCastSocketObserver observer1;
  848. MockCastSocketObserver observer2;
  849. socket_->AddObserver(&observer1);
  850. socket_->AddObserver(&observer1);
  851. socket_->AddObserver(&observer2);
  852. socket_->AddObserver(&observer2);
  853. // Test notify observers
  854. EXPECT_CALL(observer1, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
  855. EXPECT_CALL(observer2, OnError(_, cast_channel::ChannelError::CONNECT_ERROR));
  856. CastSocketImpl::CastSocketMessageDelegate delegate(socket_.get());
  857. delegate.OnError(cast_channel::ChannelError::CONNECT_ERROR);
  858. }
  859. TEST_F(MockCastSocketTest, TestOpenChannelConnectingSocket) {
  860. CreateCastSocketSecure();
  861. client_socket_factory()->SetupTcpConnectUnresponsive();
  862. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  863. base::Unretained(&handler_)));
  864. RunPendingTasks();
  865. EXPECT_CALL(handler_, OnConnectComplete(socket_.get())).Times(2);
  866. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  867. base::Unretained(&handler_)));
  868. socket_->TriggerTimeout();
  869. RunPendingTasks();
  870. }
  871. TEST_F(MockCastSocketTest, TestOpenChannelConnectedSocket) {
  872. CreateCastSocketSecure();
  873. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::OK);
  874. client_socket_factory()->SetupSslConnect(net::ASYNC, net::OK);
  875. HandleAuthHandshake();
  876. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  877. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  878. base::Unretained(&handler_)));
  879. }
  880. TEST_F(MockCastSocketTest, TestOpenChannelClosedSocket) {
  881. CreateCastSocketSecure();
  882. client_socket_factory()->SetupTcpConnect(net::ASYNC, net::ERR_FAILED);
  883. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  884. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  885. base::Unretained(&handler_)));
  886. RunPendingTasks();
  887. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  888. socket_->Connect(base::BindOnce(&CompleteHandler::OnConnectComplete,
  889. base::Unretained(&handler_)));
  890. }
  891. // https://crbug.com/874491, flaky on Win and Mac
  892. #if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_APPLE) || BUILDFLAG(IS_FUCHSIA)
  893. #define MAYBE_TestConnectEndToEndWithRealSSL \
  894. DISABLED_TestConnectEndToEndWithRealSSL
  895. #else
  896. #define MAYBE_TestConnectEndToEndWithRealSSL TestConnectEndToEndWithRealSSL
  897. #endif
  898. // Tests connecting through an actual non-mocked CastTransport object and
  899. // non-mocked SSLClientSocket, testing the components in integration.
  900. TEST_F(SslCastSocketTest, MAYBE_TestConnectEndToEndWithRealSSL) {
  901. CreateSockets();
  902. ConnectSockets();
  903. // Set low-level auth challenge expectations.
  904. CastMessage challenge = CreateAuthChallenge();
  905. std::string challenge_str;
  906. EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
  907. int challenge_buffer_length = challenge_str.size();
  908. scoped_refptr<net::IOBuffer> challenge_buffer =
  909. base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
  910. int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
  911. server_socket_.get());
  912. EXPECT_EQ(challenge_buffer_length, read);
  913. EXPECT_EQ(challenge_str,
  914. std::string(challenge_buffer->data(), challenge_buffer_length));
  915. // Set low-level auth reply expectations.
  916. CastMessage reply = CreateAuthReply();
  917. std::string reply_str;
  918. EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
  919. scoped_refptr<net::StringIOBuffer> reply_buffer =
  920. base::MakeRefCounted<net::StringIOBuffer>(reply_str);
  921. int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
  922. server_socket_.get());
  923. EXPECT_EQ(reply_buffer->size(), written);
  924. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  925. RunPendingTasks();
  926. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  927. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  928. }
  929. // Sends message data through an actual non-mocked CastTransport object and
  930. // non-mocked SSLClientSocket, testing the components in integration.
  931. TEST_F(SslCastSocketTest, DISABLED_TestMessageEndToEndWithRealSSL) {
  932. CreateSockets();
  933. ConnectSockets();
  934. // Set low-level auth challenge expectations.
  935. CastMessage challenge = CreateAuthChallenge();
  936. std::string challenge_str;
  937. EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
  938. int challenge_buffer_length = challenge_str.size();
  939. scoped_refptr<net::IOBuffer> challenge_buffer =
  940. base::MakeRefCounted<net::IOBuffer>(challenge_buffer_length);
  941. int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
  942. server_socket_.get());
  943. EXPECT_EQ(challenge_buffer_length, read);
  944. EXPECT_EQ(challenge_str,
  945. std::string(challenge_buffer->data(), challenge_buffer_length));
  946. // Set low-level auth reply expectations.
  947. CastMessage reply = CreateAuthReply();
  948. std::string reply_str;
  949. EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
  950. scoped_refptr<net::StringIOBuffer> reply_buffer =
  951. base::MakeRefCounted<net::StringIOBuffer>(reply_str);
  952. int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
  953. server_socket_.get());
  954. EXPECT_EQ(reply_buffer->size(), written);
  955. EXPECT_CALL(handler_, OnConnectComplete(socket_.get()));
  956. RunPendingTasks();
  957. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  958. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  959. // Send a test message through the ssl socket.
  960. CastMessage test_message = CreateTestMessage();
  961. std::string test_message_str;
  962. EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
  963. int test_message_length = test_message_str.size();
  964. scoped_refptr<net::IOBuffer> test_message_buffer =
  965. base::MakeRefCounted<net::IOBuffer>(test_message_length);
  966. EXPECT_CALL(handler_, OnWriteComplete(net::OK));
  967. socket_->transport()->SendMessage(
  968. test_message, base::BindOnce(&CompleteHandler::OnWriteComplete,
  969. base::Unretained(&handler_)));
  970. RunPendingTasks();
  971. read = ReadExactLength(test_message_buffer.get(), test_message_length,
  972. server_socket_.get());
  973. EXPECT_EQ(test_message_length, read);
  974. EXPECT_EQ(test_message_str,
  975. std::string(test_message_buffer->data(), test_message_length));
  976. EXPECT_EQ(ReadyState::OPEN, socket_->ready_state());
  977. EXPECT_EQ(ChannelError::NONE, socket_->error_state());
  978. }
  979. } // namespace cast_channel