tls_client_connection_unittest.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "components/openscreen_platform/tls_client_connection.h"
  5. #include <cstring>
  6. #include <iterator>
  7. #include <memory>
  8. #include <utility>
  9. #include <vector>
  10. #include "base/bind.h"
  11. #include "base/run_loop.h"
  12. #include "base/task/sequenced_task_runner.h"
  13. #include "base/test/task_environment.h"
  14. #include "components/openscreen_platform/task_runner.h"
  15. #include "testing/gmock/include/gmock/gmock.h"
  16. #include "testing/gtest/include/gtest/gtest.h"
  17. using ::testing::_;
  18. using ::testing::Mock;
  19. using ::testing::StrictMock;
  20. namespace openscreen_platform {
  21. using openscreen::Error;
  22. using openscreen::TlsConnection;
  23. namespace {
  24. const openscreen::IPEndpoint kValidEndpointOne{
  25. openscreen::IPAddress{192, 168, 0, 1}, 80};
  26. const openscreen::IPEndpoint kValidEndpointTwo{
  27. openscreen::IPAddress{10, 9, 8, 7}, 81};
  28. constexpr int kDataPipeCapacity = 32;
  29. const uint8_t kTestMessage[] = "Hello world!";
  30. // Creates two data pipes, one for inbound data and one for outbound data, and
  31. // provides test utilities for simulating socket stream events of interest.
  32. class FakeSocketStreams {
  33. public:
  34. FakeSocketStreams()
  35. : outbound_stream_watcher_(FROM_HERE,
  36. mojo::SimpleWatcher::ArmingPolicy::MANUAL) {
  37. MojoCreateDataPipeOptions options{};
  38. options.struct_size = sizeof(options);
  39. options.flags = MOJO_CREATE_DATA_PIPE_FLAG_NONE;
  40. options.element_num_bytes = 1;
  41. options.capacity_num_bytes = kDataPipeCapacity;
  42. MojoResult result =
  43. CreateDataPipe(&options, inbound_stream_, receive_stream_);
  44. CHECK_EQ(result, MOJO_RESULT_OK);
  45. result = CreateDataPipe(&options, send_stream_, outbound_stream_);
  46. CHECK_EQ(result, MOJO_RESULT_OK);
  47. outbound_stream_watcher_.Watch(
  48. outbound_stream_.get(),
  49. MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED |
  50. MOJO_HANDLE_SIGNAL_NEW_DATA_READABLE,
  51. MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED,
  52. base::BindRepeating(&FakeSocketStreams::OnOutboundStreamActivity,
  53. base::Unretained(this)));
  54. outbound_stream_watcher_.ArmOrNotify();
  55. }
  56. ~FakeSocketStreams() = default;
  57. // These should be passed to the TlsClientConnection constructor.
  58. mojo::ScopedDataPipeConsumerHandle TakeReceiveStream() {
  59. return std::move(receive_stream_);
  60. }
  61. mojo::ScopedDataPipeProducerHandle TakeSendStream() {
  62. return std::move(send_stream_);
  63. }
  64. // Writes data into the inbound data pipe, which should ultimately result in a
  65. // TlsClientConnection::Client's OnRead() method being called.
  66. void SimulateSocketReceive(const void* data, uint32_t num_bytes) {
  67. const MojoResult result = inbound_stream_->WriteData(
  68. data, &num_bytes, MOJO_WRITE_DATA_FLAG_ALL_OR_NONE);
  69. ASSERT_EQ(result, MOJO_RESULT_OK);
  70. }
  71. // Closes the inbound (or outbound) data pipe to allow the unit tests to check
  72. // the error handling of TlsClientConnection.
  73. void SimulateInboundClose() { inbound_stream_.reset(); }
  74. void SimulateOutboundClose() { outbound_stream_.reset(); }
  75. // Returns all outbound stream data accumulated so far, and clears the
  76. // internal buffer.
  77. std::vector<uint8_t> TakeAccumulatedOutboundData() {
  78. std::vector<uint8_t> result;
  79. result.swap(outbound_data_);
  80. return result;
  81. }
  82. private:
  83. // Mojo SimpleWatcher callback to save all data being sent from a connection.
  84. void OnOutboundStreamActivity(MojoResult result,
  85. const mojo::HandleSignalsState& state) {
  86. if (!outbound_stream_.is_valid()) {
  87. return;
  88. }
  89. ASSERT_EQ(result, MOJO_RESULT_OK);
  90. uint32_t num_bytes = 0;
  91. result = outbound_stream_->ReadData(nullptr, &num_bytes,
  92. MOJO_READ_DATA_FLAG_QUERY);
  93. ASSERT_EQ(result, MOJO_RESULT_OK);
  94. auto old_end_index = outbound_data_.size();
  95. outbound_data_.resize(old_end_index + num_bytes);
  96. result = outbound_stream_->ReadData(outbound_data_.data() + old_end_index,
  97. &num_bytes, MOJO_READ_DATA_FLAG_NONE);
  98. ASSERT_EQ(result, MOJO_RESULT_OK);
  99. outbound_data_.resize(old_end_index + num_bytes);
  100. outbound_stream_watcher_.ArmOrNotify();
  101. }
  102. mojo::ScopedDataPipeProducerHandle inbound_stream_;
  103. mojo::ScopedDataPipeConsumerHandle receive_stream_;
  104. mojo::ScopedDataPipeProducerHandle send_stream_;
  105. mojo::ScopedDataPipeConsumerHandle outbound_stream_;
  106. mojo::SimpleWatcher outbound_stream_watcher_;
  107. std::vector<uint8_t> outbound_data_;
  108. };
  109. class MockTlsConnectionClient : public TlsConnection::Client {
  110. public:
  111. MOCK_METHOD(void, OnError, (TlsConnection*, Error), (override));
  112. MOCK_METHOD(void, OnRead, (TlsConnection*, std::vector<uint8_t>), (override));
  113. };
  114. } // namespace
  115. class TlsClientConnectionTest : public ::testing::Test {
  116. public:
  117. TlsClientConnectionTest() = default;
  118. ~TlsClientConnectionTest() override = default;
  119. void SetUp() override {
  120. task_runner_ = std::make_unique<openscreen_platform::TaskRunner>(
  121. task_environment_.GetMainThreadTaskRunner());
  122. socket_streams_ = std::make_unique<FakeSocketStreams>();
  123. connection_ = std::make_unique<TlsClientConnection>(
  124. task_runner_.get(), kValidEndpointOne, kValidEndpointTwo,
  125. socket_streams_->TakeReceiveStream(), socket_streams_->TakeSendStream(),
  126. mojo::Remote<network::mojom::TCPConnectedSocket>{},
  127. mojo::Remote<network::mojom::TLSClientSocket>{});
  128. }
  129. void TearDown() override {
  130. connection_.reset();
  131. socket_streams_.reset();
  132. base::RunLoop().RunUntilIdle();
  133. }
  134. FakeSocketStreams* socket_streams() const { return socket_streams_.get(); }
  135. TlsClientConnection* connection() const { return connection_.get(); }
  136. private:
  137. base::test::TaskEnvironment task_environment_;
  138. std::unique_ptr<openscreen_platform::TaskRunner> task_runner_;
  139. std::unique_ptr<FakeSocketStreams> socket_streams_;
  140. std::unique_ptr<TlsClientConnection> connection_;
  141. };
  142. TEST_F(TlsClientConnectionTest, CallsClientOnReadForInboundData) {
  143. // Test multiple reads to confirm the data pipe watcher is being re-armed
  144. // correctly after each read.
  145. constexpr int kNumReads = 3;
  146. StrictMock<MockTlsConnectionClient> client;
  147. connection()->SetClient(&client);
  148. for (int i = 0; i < kNumReads; ++i) {
  149. // Send a different message in each iteration.
  150. std::vector<uint8_t> expected_data(std::begin(kTestMessage),
  151. std::end(kTestMessage));
  152. for (uint8_t& byte : expected_data) {
  153. byte ^= i;
  154. }
  155. EXPECT_CALL(client, OnRead(connection(), expected_data)).Times(1);
  156. socket_streams()->SimulateSocketReceive(expected_data.data(),
  157. expected_data.size());
  158. base::RunLoop().RunUntilIdle();
  159. Mock::VerifyAndClearExpectations(&client);
  160. }
  161. }
  162. TEST_F(TlsClientConnectionTest, CallsClientOnErrorWhenSocketInboundCloses) {
  163. StrictMock<MockTlsConnectionClient> client;
  164. EXPECT_CALL(client, OnError(connection(), _)).Times(1);
  165. connection()->SetClient(&client);
  166. socket_streams()->SimulateInboundClose();
  167. base::RunLoop().RunUntilIdle();
  168. }
  169. TEST_F(TlsClientConnectionTest, SendsUntilBlocked) {
  170. StrictMock<MockTlsConnectionClient> client;
  171. // Note: Client::OnError() should not be called during this test since an
  172. // outbound-blocked socket is not a fatal error.
  173. EXPECT_CALL(client, OnError(connection(), _)).Times(0);
  174. connection()->SetClient(&client);
  175. std::vector<uint8_t> message(kDataPipeCapacity / 2);
  176. for (int i = 0; i < kDataPipeCapacity / 2; ++i) {
  177. message[i] = static_cast<uint8_t>(i);
  178. }
  179. // Send one message whose size is half the pipe's capacity.
  180. EXPECT_TRUE(connection()->Send(message.data(), message.size()));
  181. base::RunLoop().RunUntilIdle();
  182. EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData());
  183. // Send two messages whose sizes are half the pipe's capacity.
  184. EXPECT_TRUE(connection()->Send(message.data(), message.size()));
  185. EXPECT_TRUE(connection()->Send(message.data(), message.size()));
  186. base::RunLoop().RunUntilIdle();
  187. std::vector<uint8_t> accumulated_data =
  188. socket_streams()->TakeAccumulatedOutboundData();
  189. ASSERT_EQ(message.size() * 2, accumulated_data.size());
  190. EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size()));
  191. EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(),
  192. message.size()));
  193. // Attempt to send three messages, but expect the third to fail.
  194. EXPECT_TRUE(connection()->Send(message.data(), message.size()));
  195. EXPECT_TRUE(connection()->Send(message.data(), message.size()));
  196. EXPECT_FALSE(connection()->Send(message.data(), message.size()));
  197. base::RunLoop().RunUntilIdle();
  198. accumulated_data = socket_streams()->TakeAccumulatedOutboundData();
  199. ASSERT_EQ(message.size() * 2, accumulated_data.size());
  200. EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data(), message.size()));
  201. EXPECT_EQ(0, memcmp(message.data(), accumulated_data.data() + message.size(),
  202. message.size()));
  203. // Sending should resume when there is capacity available again.
  204. EXPECT_TRUE(connection()->Send(message.data(), message.size()));
  205. base::RunLoop().RunUntilIdle();
  206. EXPECT_EQ(message, socket_streams()->TakeAccumulatedOutboundData());
  207. }
  208. TEST_F(TlsClientConnectionTest,
  209. CallsClientOnErrorWhenSendingToClosedOutboundStream) {
  210. StrictMock<MockTlsConnectionClient> client;
  211. EXPECT_CALL(client, OnError(connection(), _)).Times(0);
  212. connection()->SetClient(&client);
  213. // Send a message and immediately close the outbound stream.
  214. EXPECT_TRUE(connection()->Send(kTestMessage, sizeof(kTestMessage)));
  215. socket_streams()->SimulateOutboundClose();
  216. base::RunLoop().RunUntilIdle();
  217. // The Client should not have encountered any fatal errors yet.
  218. Mock::VerifyAndClearExpectations(&client);
  219. // Now, call Send() again and this should trigger a fatal error.
  220. EXPECT_CALL(client, OnError(connection(), _)).Times(1);
  221. EXPECT_FALSE(connection()->Send(kTestMessage, sizeof(kTestMessage)));
  222. }
  223. TEST_F(TlsClientConnectionTest, CanRetrieveAddresses) {
  224. EXPECT_EQ(kValidEndpointTwo, connection()->GetRemoteEndpoint());
  225. }
  226. } // namespace openscreen_platform