// Copyright 2019 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "components/openscreen_platform/tls_client_connection.h" #include #include #include #include #include #include "base/bind.h" #include "base/run_loop.h" #include "base/task/sequenced_task_runner.h" #include "base/test/task_environment.h" #include "components/openscreen_platform/task_runner.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" using ::testing::_; using ::testing::Mock; using ::testing::StrictMock; namespace openscreen_platform { using openscreen::Error; using openscreen::TlsConnection; namespace { const openscreen::IPEndpoint kValidEndpointOne{ openscreen::IPAddress{192, 168, 0, 1}, 80}; const openscreen::IPEndpoint kValidEndpointTwo{ openscreen::IPAddress{10, 9, 8, 7}, 81}; constexpr int kDataPipeCapacity = 32; const uint8_t kTestMessage[] = "Hello world!"; // Creates two data pipes, one for inbound data and one for outbound data, and // provides test utilities for simulating socket stream events of interest. class FakeSocketStreams { public: FakeSocketStreams() : outbound_stream_watcher_(FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL) { MojoCreateDataPipeOptions options{}; options.struct_size = sizeof(options); options.flags = MOJO_CREATE_DATA_PIPE_FLAG_NONE; options.element_num_bytes = 1; options.capacity_num_bytes = kDataPipeCapacity; MojoResult result = CreateDataPipe(&options, inbound_stream_, receive_stream_); CHECK_EQ(result, MOJO_RESULT_OK); result = CreateDataPipe(&options, send_stream_, outbound_stream_); CHECK_EQ(result, MOJO_RESULT_OK); outbound_stream_watcher_.Watch( outbound_stream_.get(), MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED | MOJO_HANDLE_SIGNAL_NEW_DATA_READABLE, MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED, base::BindRepeating(&FakeSocketStreams::OnOutboundStreamActivity, base::Unretained(this))); outbound_stream_watcher_.ArmOrNotify(); } ~FakeSocketStreams() = default; // These should be passed to the TlsClientConnection constructor. mojo::ScopedDataPipeConsumerHandle TakeReceiveStream() { return std::move(receive_stream_); } mojo::ScopedDataPipeProducerHandle TakeSendStream() { return std::move(send_stream_); } // Writes data into the inbound data pipe, which should ultimately result in a // TlsClientConnection::Client's OnRead() method being called. void SimulateSocketReceive(const void* data, uint32_t num_bytes) { const MojoResult result = inbound_stream_->WriteData( data, &num_bytes, MOJO_WRITE_DATA_FLAG_ALL_OR_NONE); ASSERT_EQ(result, MOJO_RESULT_OK); } // Closes the inbound (or outbound) data pipe to allow the unit tests to check // the error handling of TlsClientConnection. void SimulateInboundClose() { inbound_stream_.reset(); } void SimulateOutboundClose() { outbound_stream_.reset(); } // Returns all outbound stream data accumulated so far, and clears the // internal buffer. std::vector TakeAccumulatedOutboundData() { std::vector result; result.swap(outbound_data_); return result; } private: // Mojo SimpleWatcher callback to save all data being sent from a connection. void OnOutboundStreamActivity(MojoResult result, const mojo::HandleSignalsState& state) { if (!outbound_stream_.is_valid()) { return; } ASSERT_EQ(result, MOJO_RESULT_OK); uint32_t num_bytes = 0; result = outbound_stream_->ReadData(nullptr, &num_bytes, MOJO_READ_DATA_FLAG_QUERY); ASSERT_EQ(result, MOJO_RESULT_OK); auto old_end_index = outbound_data_.size(); outbound_data_.resize(old_end_index + num_bytes); result = outbound_stream_->ReadData(outbound_data_.data() + old_end_index, &num_bytes, MOJO_READ_DATA_FLAG_NONE); ASSERT_EQ(result, MOJO_RESULT_OK); outbound_data_.resize(old_end_index + num_bytes); outbound_stream_watcher_.ArmOrNotify(); } mojo::ScopedDataPipeProducerHandle inbound_stream_; mojo::ScopedDataPipeConsumerHandle receive_stream_; mojo::ScopedDataPipeProducerHandle send_stream_; mojo::ScopedDataPipeConsumerHandle outbound_stream_; mojo::SimpleWatcher outbound_stream_watcher_; std::vector outbound_data_; }; class MockTlsConnectionClient : public TlsConnection::Client { public: MOCK_METHOD(void, OnError, (TlsConnection*, Error), (override)); MOCK_METHOD(void, OnRead, (TlsConnection*, std::vector), (override)); }; } // namespace class TlsClientConnectionTest : public ::testing::Test { public: TlsClientConnectionTest() = default; ~TlsClientConnectionTest() override = default; void SetUp() override { task_runner_ = std::make_unique( task_environment_.GetMainThreadTaskRunner()); socket_streams_ = std::make_unique(); connection_ = std::make_unique( task_runner_.get(), kValidEndpointOne, kValidEndpointTwo, socket_streams_->TakeReceiveStream(), socket_streams_->TakeSendStream(), mojo::Remote{}, mojo::Remote{}); } void TearDown() override { connection_.reset(); socket_streams_.reset(); base::RunLoop().RunUntilIdle(); } FakeSocketStreams* socket_streams() const { return socket_streams_.get(); } TlsClientConnection* connection() const { return connection_.get(); } private: base::test::TaskEnvironment task_environment_; std::unique_ptr task_runner_; std::unique_ptr socket_streams_; std::unique_ptr connection_; }; TEST_F(TlsClientConnectionTest, CallsClientOnReadForInboundData) { // Test multiple reads to confirm the data pipe watcher is being re-armed // correctly after each read. constexpr int kNumReads = 3; StrictMock client; connection()->SetClient(&client); for (int i = 0; i < kNumReads; ++i) { // Send a different message in each iteration. std::vector expected_data(std::begin(kTestMessage), std::end(kTestMessage)); for (uint8_t& byte : expected_data) { byte ^= i; } EXPECT_CALL(client, OnRead(connection(), expected_data)).Times(1); socket_streams()->SimulateSocketReceive(expected_data.data(), expected_data.size()); base::RunLoop().RunUntilIdle(); Mock::VerifyAndClearExpectations(&client); } } TEST_F(TlsClientConnectionTest, CallsClientOnErrorWhenSocketInboundCloses) { StrictMock client; EXPECT_CALL(client, OnError(connection(), _)).Times(1); connection()->SetClient(&client); socket_streams()->SimulateInboundClose(); base::RunLoop().RunUntilIdle(); } TEST_F(TlsClientConnectionTest, SendsUntilBlocked) { StrictMock client; // Note: Client::OnError() should not be called during this test since an // outbound-blocked socket is not a fatal error. EXPECT_CALL(client, OnError(connection(), _)).Times(0); connection()->SetClient(&client); std::vector message(kDataPipeCapacity / 2); for (int i = 0; i < kDataPipeCapacity / 2; ++i) { message[i] = static_cast(i); } // Send one message whose size is half the pipe's capacity. EXPECT_TRUE(connection()->Send(message.data(), message.size())); base::RunLoop().RunUntilIdle(); EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData()); // Send two messages whose sizes are half the pipe's capacity. EXPECT_TRUE(connection()->Send(message.data(), message.size())); EXPECT_TRUE(connection()->Send(message.data(), message.size())); base::RunLoop().RunUntilIdle(); std::vector accumulated_data = socket_streams()->TakeAccumulatedOutboundData(); ASSERT_EQ(message.size() * 2, accumulated_data.size()); EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size())); EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(), message.size())); // Attempt to send three messages, but expect the third to fail. EXPECT_TRUE(connection()->Send(message.data(), message.size())); EXPECT_TRUE(connection()->Send(message.data(), message.size())); EXPECT_FALSE(connection()->Send(message.data(), message.size())); base::RunLoop().RunUntilIdle(); accumulated_data = socket_streams()->TakeAccumulatedOutboundData(); ASSERT_EQ(message.size() * 2, accumulated_data.size()); EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size())); EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(), message.size())); // Sending should resume when there is capacity available again. EXPECT_TRUE(connection()->Send(message.data(), message.size())); base::RunLoop().RunUntilIdle(); EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData()); } TEST_F(TlsClientConnectionTest, CallsClientOnErrorWhenSendingToClosedOutboundStream) { StrictMock client; EXPECT_CALL(client, OnError(connection(), _)).Times(0); connection()->SetClient(&client); // Send a message and immediately close the outbound stream. EXPECT_TRUE(connection()->Send(kTestMessage, sizeof(kTestMessage))); socket_streams()->SimulateOutboundClose(); base::RunLoop().RunUntilIdle(); // The Client should not have encountered any fatal errors yet. Mock::VerifyAndClearExpectations(&client); // Now, call Send() again and this should trigger a fatal error. EXPECT_CALL(client, OnError(connection(), _)).Times(1); EXPECT_FALSE(connection()->Send(kTestMessage, sizeof(kTestMessage))); } TEST_F(TlsClientConnectionTest, CanRetrieveAddresses) { EXPECT_EQ(kValidEndpointTwo, connection()->GetRemoteEndpoint()); } } // namespace openscreen_platform