123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- // Copyright 2019 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "components/openscreen_platform/tls_client_connection.h"
- #include <cstring>
- #include <iterator>
- #include <memory>
- #include <utility>
- #include <vector>
- #include "base/bind.h"
- #include "base/run_loop.h"
- #include "base/task/sequenced_task_runner.h"
- #include "base/test/task_environment.h"
- #include "components/openscreen_platform/task_runner.h"
- #include "testing/gmock/include/gmock/gmock.h"
- #include "testing/gtest/include/gtest/gtest.h"
- using ::testing::_;
- using ::testing::Mock;
- using ::testing::StrictMock;
- namespace openscreen_platform {
- using openscreen::Error;
- using openscreen::TlsConnection;
- namespace {
- const openscreen::IPEndpoint kValidEndpointOne{
- openscreen::IPAddress{192, 168, 0, 1}, 80};
- const openscreen::IPEndpoint kValidEndpointTwo{
- openscreen::IPAddress{10, 9, 8, 7}, 81};
- constexpr int kDataPipeCapacity = 32;
- const uint8_t kTestMessage[] = "Hello world!";
- // Creates two data pipes, one for inbound data and one for outbound data, and
- // provides test utilities for simulating socket stream events of interest.
- class FakeSocketStreams {
- public:
- FakeSocketStreams()
- : outbound_stream_watcher_(FROM_HERE,
- mojo::SimpleWatcher::ArmingPolicy::MANUAL) {
- MojoCreateDataPipeOptions options{};
- options.struct_size = sizeof(options);
- options.flags = MOJO_CREATE_DATA_PIPE_FLAG_NONE;
- options.element_num_bytes = 1;
- options.capacity_num_bytes = kDataPipeCapacity;
- MojoResult result =
- CreateDataPipe(&options, inbound_stream_, receive_stream_);
- CHECK_EQ(result, MOJO_RESULT_OK);
- result = CreateDataPipe(&options, send_stream_, outbound_stream_);
- CHECK_EQ(result, MOJO_RESULT_OK);
- outbound_stream_watcher_.Watch(
- outbound_stream_.get(),
- MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED |
- MOJO_HANDLE_SIGNAL_NEW_DATA_READABLE,
- MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED,
- base::BindRepeating(&FakeSocketStreams::OnOutboundStreamActivity,
- base::Unretained(this)));
- outbound_stream_watcher_.ArmOrNotify();
- }
- ~FakeSocketStreams() = default;
- // These should be passed to the TlsClientConnection constructor.
- mojo::ScopedDataPipeConsumerHandle TakeReceiveStream() {
- return std::move(receive_stream_);
- }
- mojo::ScopedDataPipeProducerHandle TakeSendStream() {
- return std::move(send_stream_);
- }
- // Writes data into the inbound data pipe, which should ultimately result in a
- // TlsClientConnection::Client's OnRead() method being called.
- void SimulateSocketReceive(const void* data, uint32_t num_bytes) {
- const MojoResult result = inbound_stream_->WriteData(
- data, &num_bytes, MOJO_WRITE_DATA_FLAG_ALL_OR_NONE);
- ASSERT_EQ(result, MOJO_RESULT_OK);
- }
- // Closes the inbound (or outbound) data pipe to allow the unit tests to check
- // the error handling of TlsClientConnection.
- void SimulateInboundClose() { inbound_stream_.reset(); }
- void SimulateOutboundClose() { outbound_stream_.reset(); }
- // Returns all outbound stream data accumulated so far, and clears the
- // internal buffer.
- std::vector<uint8_t> TakeAccumulatedOutboundData() {
- std::vector<uint8_t> result;
- result.swap(outbound_data_);
- return result;
- }
- private:
- // Mojo SimpleWatcher callback to save all data being sent from a connection.
- void OnOutboundStreamActivity(MojoResult result,
- const mojo::HandleSignalsState& state) {
- if (!outbound_stream_.is_valid()) {
- return;
- }
- ASSERT_EQ(result, MOJO_RESULT_OK);
- uint32_t num_bytes = 0;
- result = outbound_stream_->ReadData(nullptr, &num_bytes,
- MOJO_READ_DATA_FLAG_QUERY);
- ASSERT_EQ(result, MOJO_RESULT_OK);
- auto old_end_index = outbound_data_.size();
- outbound_data_.resize(old_end_index + num_bytes);
- result = outbound_stream_->ReadData(outbound_data_.data() + old_end_index,
- &num_bytes, MOJO_READ_DATA_FLAG_NONE);
- ASSERT_EQ(result, MOJO_RESULT_OK);
- outbound_data_.resize(old_end_index + num_bytes);
- outbound_stream_watcher_.ArmOrNotify();
- }
- mojo::ScopedDataPipeProducerHandle inbound_stream_;
- mojo::ScopedDataPipeConsumerHandle receive_stream_;
- mojo::ScopedDataPipeProducerHandle send_stream_;
- mojo::ScopedDataPipeConsumerHandle outbound_stream_;
- mojo::SimpleWatcher outbound_stream_watcher_;
- std::vector<uint8_t> outbound_data_;
- };
- class MockTlsConnectionClient : public TlsConnection::Client {
- public:
- MOCK_METHOD(void, OnError, (TlsConnection*, Error), (override));
- MOCK_METHOD(void, OnRead, (TlsConnection*, std::vector<uint8_t>), (override));
- };
- } // namespace
- class TlsClientConnectionTest : public ::testing::Test {
- public:
- TlsClientConnectionTest() = default;
- ~TlsClientConnectionTest() override = default;
- void SetUp() override {
- task_runner_ = std::make_unique<openscreen_platform::TaskRunner>(
- task_environment_.GetMainThreadTaskRunner());
- socket_streams_ = std::make_unique<FakeSocketStreams>();
- connection_ = std::make_unique<TlsClientConnection>(
- task_runner_.get(), kValidEndpointOne, kValidEndpointTwo,
- socket_streams_->TakeReceiveStream(), socket_streams_->TakeSendStream(),
- mojo::Remote<network::mojom::TCPConnectedSocket>{},
- mojo::Remote<network::mojom::TLSClientSocket>{});
- }
- void TearDown() override {
- connection_.reset();
- socket_streams_.reset();
- base::RunLoop().RunUntilIdle();
- }
- FakeSocketStreams* socket_streams() const { return socket_streams_.get(); }
- TlsClientConnection* connection() const { return connection_.get(); }
- private:
- base::test::TaskEnvironment task_environment_;
- std::unique_ptr<openscreen_platform::TaskRunner> task_runner_;
- std::unique_ptr<FakeSocketStreams> socket_streams_;
- std::unique_ptr<TlsClientConnection> connection_;
- };
- TEST_F(TlsClientConnectionTest, CallsClientOnReadForInboundData) {
- // Test multiple reads to confirm the data pipe watcher is being re-armed
- // correctly after each read.
- constexpr int kNumReads = 3;
- StrictMock<MockTlsConnectionClient> client;
- connection()->SetClient(&client);
- for (int i = 0; i < kNumReads; ++i) {
- // Send a different message in each iteration.
- std::vector<uint8_t> expected_data(std::begin(kTestMessage),
- std::end(kTestMessage));
- for (uint8_t& byte : expected_data) {
- byte ^= i;
- }
- EXPECT_CALL(client, OnRead(connection(), expected_data)).Times(1);
- socket_streams()->SimulateSocketReceive(expected_data.data(),
- expected_data.size());
- base::RunLoop().RunUntilIdle();
- Mock::VerifyAndClearExpectations(&client);
- }
- }
- TEST_F(TlsClientConnectionTest, CallsClientOnErrorWhenSocketInboundCloses) {
- StrictMock<MockTlsConnectionClient> client;
- EXPECT_CALL(client, OnError(connection(), _)).Times(1);
- connection()->SetClient(&client);
- socket_streams()->SimulateInboundClose();
- base::RunLoop().RunUntilIdle();
- }
- TEST_F(TlsClientConnectionTest, SendsUntilBlocked) {
- StrictMock<MockTlsConnectionClient> client;
- // Note: Client::OnError() should not be called during this test since an
- // outbound-blocked socket is not a fatal error.
- EXPECT_CALL(client, OnError(connection(), _)).Times(0);
- connection()->SetClient(&client);
- std::vector<uint8_t> message(kDataPipeCapacity / 2);
- for (int i = 0; i < kDataPipeCapacity / 2; ++i) {
- message[i] = static_cast<uint8_t>(i);
- }
- // Send one message whose size is half the pipe's capacity.
- EXPECT_TRUE(connection()->Send(message.data(), message.size()));
- base::RunLoop().RunUntilIdle();
- EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData());
- // Send two messages whose sizes are half the pipe's capacity.
- EXPECT_TRUE(connection()->Send(message.data(), message.size()));
- EXPECT_TRUE(connection()->Send(message.data(), message.size()));
- base::RunLoop().RunUntilIdle();
- std::vector<uint8_t> accumulated_data =
- socket_streams()->TakeAccumulatedOutboundData();
- ASSERT_EQ(message.size() * 2, accumulated_data.size());
- EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size()));
- EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(),
- message.size()));
- // Attempt to send three messages, but expect the third to fail.
- EXPECT_TRUE(connection()->Send(message.data(), message.size()));
- EXPECT_TRUE(connection()->Send(message.data(), message.size()));
- EXPECT_FALSE(connection()->Send(message.data(), message.size()));
- base::RunLoop().RunUntilIdle();
- accumulated_data = socket_streams()->TakeAccumulatedOutboundData();
- ASSERT_EQ(message.size() * 2, accumulated_data.size());
- EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size()));
- EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(),
- message.size()));
- // Sending should resume when there is capacity available again.
- EXPECT_TRUE(connection()->Send(message.data(), message.size()));
- base::RunLoop().RunUntilIdle();
- EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData());
- }
- TEST_F(TlsClientConnectionTest,
- CallsClientOnErrorWhenSendingToClosedOutboundStream) {
- StrictMock<MockTlsConnectionClient> client;
- EXPECT_CALL(client, OnError(connection(), _)).Times(0);
- connection()->SetClient(&client);
- // Send a message and immediately close the outbound stream.
- EXPECT_TRUE(connection()->Send(kTestMessage, sizeof(kTestMessage)));
- socket_streams()->SimulateOutboundClose();
- base::RunLoop().RunUntilIdle();
- // The Client should not have encountered any fatal errors yet.
- Mock::VerifyAndClearExpectations(&client);
- // Now, call Send() again and this should trigger a fatal error.
- EXPECT_CALL(client, OnError(connection(), _)).Times(1);
- EXPECT_FALSE(connection()->Send(kTestMessage, sizeof(kTestMessage)));
- }
- TEST_F(TlsClientConnectionTest, CanRetrieveAddresses) {
- EXPECT_EQ(kValidEndpointTwo, connection()->GetRemoteEndpoint());
- }
- } // namespace openscreen_platform
|