channel_unittest.cc 26 KB


  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "mojo/core/channel.h"
  5. #include <atomic>
  6. #include "base/bind.h"
  7. #include "base/memory/page_size.h"
  8. #include "base/memory/ptr_util.h"
  9. #include "base/message_loop/message_pump_type.h"
  10. #include "base/process/process_handle.h"
  11. #include "base/run_loop.h"
  12. #include "base/strings/stringprintf.h"
  13. #include "base/test/bind.h"
  14. #include "base/test/task_environment.h"
  15. #include "base/threading/thread.h"
  16. #include "base/threading/thread_task_runner_handle.h"
  17. #include "build/build_config.h"
  18. #include "mojo/core/platform_handle_utils.h"
  19. #include "mojo/public/cpp/platform/platform_channel.h"
  20. #include "testing/gmock/include/gmock/gmock.h"
  21. #include "testing/gtest/include/gtest/gtest.h"
  22. #include "third_party/abseil-cpp/absl/types/optional.h"
  23. namespace mojo {
  24. namespace core {
  25. namespace {
  26. class TestChannel : public Channel {
  27. public:
  28. TestChannel(Channel::Delegate* delegate)
  29. : Channel(delegate, Channel::HandlePolicy::kAcceptHandles) {}
  30. char* GetReadBufferTest(size_t* buffer_capacity) {
  31. return GetReadBuffer(buffer_capacity);
  32. }
  33. bool OnReadCompleteTest(size_t bytes_read, size_t* next_read_size_hint) {
  34. return OnReadComplete(bytes_read, next_read_size_hint);
  35. }
  36. MOCK_METHOD7(GetReadPlatformHandles,
  37. bool(const void* payload,
  38. size_t payload_size,
  39. size_t num_handles,
  40. const void* extra_header,
  41. size_t extra_header_size,
  42. std::vector<PlatformHandle>* handles,
  43. bool* deferred));
  44. MOCK_METHOD2(GetReadPlatformHandlesForIpcz,
  45. bool(size_t, std::vector<PlatformHandle>&));
  46. MOCK_METHOD0(Start, void());
  47. MOCK_METHOD0(ShutDownImpl, void());
  48. MOCK_METHOD0(LeakHandle, void());
  49. void Write(MessagePtr message) override {}
  50. protected:
  51. ~TestChannel() override = default;
  52. };
  53. // Not using GMock as I don't think it supports movable types.
  54. class MockChannelDelegate : public Channel::Delegate {
  55. public:
  56. MockChannelDelegate() = default;
  57. size_t GetReceivedPayloadSize() const { return payload_size_; }
  58. const void* GetReceivedPayload() const { return payload_.get(); }
  59. protected:
  60. void OnChannelMessage(const void* payload,
  61. size_t payload_size,
  62. std::vector<PlatformHandle> handles) override {
  63. payload_.reset(new char[payload_size]);
  64. memcpy(payload_.get(), payload, payload_size);
  65. payload_size_ = payload_size;
  66. }
  67. // Notify that an error has occured and the Channel will cease operation.
  68. void OnChannelError(Channel::Error error) override {}
  69. private:
  70. size_t payload_size_ = 0;
  71. std::unique_ptr<char[]> payload_;
  72. };
  73. Channel::MessagePtr CreateDefaultMessage(bool legacy_message) {
  74. const size_t payload_size = 100;
  75. Channel::MessagePtr message = Channel::Message::CreateMessage(
  76. payload_size, 0,
  77. legacy_message ? Channel::Message::MessageType::NORMAL_LEGACY
  78. : Channel::Message::MessageType::NORMAL);
  79. char* payload = static_cast<char*>(message->mutable_payload());
  80. for (size_t i = 0; i < payload_size; i++) {
  81. payload[i] = static_cast<char>(i);
  82. }
  83. return message;
  84. }
  85. void TestMemoryEqual(const void* data1,
  86. size_t data1_size,
  87. const void* data2,
  88. size_t data2_size) {
  89. ASSERT_EQ(data1_size, data2_size);
  90. const unsigned char* data1_char = static_cast<const unsigned char*>(data1);
  91. const unsigned char* data2_char = static_cast<const unsigned char*>(data2);
  92. for (size_t i = 0; i < data1_size; i++) {
  93. // ASSERT so we don't log tons of errors if the data is different.
  94. ASSERT_EQ(data1_char[i], data2_char[i]);
  95. }
  96. }
  97. void TestMessagesAreEqual(Channel::Message* message1,
  98. Channel::Message* message2,
  99. bool legacy_messages) {
  100. // If any of the message is null, this is probably not what you wanted to
  101. // test.
  102. ASSERT_NE(nullptr, message1);
  103. ASSERT_NE(nullptr, message2);
  104. ASSERT_EQ(message1->payload_size(), message2->payload_size());
  105. EXPECT_EQ(message1->has_handles(), message2->has_handles());
  106. TestMemoryEqual(message1->payload(), message1->payload_size(),
  107. message2->payload(), message2->payload_size());
  108. if (legacy_messages)
  109. return;
  110. ASSERT_EQ(message1->extra_header_size(), message2->extra_header_size());
  111. TestMemoryEqual(message1->extra_header(), message1->extra_header_size(),
  112. message2->extra_header(), message2->extra_header_size());
  113. }
  114. TEST(ChannelTest, LegacyMessageDeserialization) {
  115. Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */);
  116. Channel::MessagePtr deserialized_message =
  117. Channel::Message::Deserialize(message->data(), message->data_num_bytes(),
  118. Channel::HandlePolicy::kAcceptHandles);
  119. TestMessagesAreEqual(message.get(), deserialized_message.get(),
  120. true /* legacy_message */);
  121. }
  122. TEST(ChannelTest, NonLegacyMessageDeserialization) {
  123. Channel::MessagePtr message =
  124. CreateDefaultMessage(false /* legacy_message */);
  125. Channel::MessagePtr deserialized_message =
  126. Channel::Message::Deserialize(message->data(), message->data_num_bytes(),
  127. Channel::HandlePolicy::kAcceptHandles);
  128. TestMessagesAreEqual(message.get(), deserialized_message.get(),
  129. false /* legacy_message */);
  130. }
  131. TEST(ChannelTest, OnReadLegacyMessage) {
  132. size_t buffer_size = 100 * 1024;
  133. Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */);
  134. MockChannelDelegate channel_delegate;
  135. scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate);
  136. char* read_buffer = channel->GetReadBufferTest(&buffer_size);
  137. ASSERT_LT(message->data_num_bytes(),
  138. buffer_size); // Bad test. Increase buffer
  139. // size.
  140. memcpy(read_buffer, message->data(), message->data_num_bytes());
  141. size_t next_read_size_hint = 0;
  142. EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(),
  143. &next_read_size_hint));
  144. TestMemoryEqual(message->payload(), message->payload_size(),
  145. channel_delegate.GetReceivedPayload(),
  146. channel_delegate.GetReceivedPayloadSize());
  147. }
  148. TEST(ChannelTest, OnReadNonLegacyMessage) {
  149. size_t buffer_size = 100 * 1024;
  150. Channel::MessagePtr message =
  151. CreateDefaultMessage(false /* legacy_message */);
  152. MockChannelDelegate channel_delegate;
  153. scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate);
  154. char* read_buffer = channel->GetReadBufferTest(&buffer_size);
  155. ASSERT_LT(message->data_num_bytes(),
  156. buffer_size); // Bad test. Increase buffer
  157. // size.
  158. memcpy(read_buffer, message->data(), message->data_num_bytes());
  159. size_t next_read_size_hint = 0;
  160. EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(),
  161. &next_read_size_hint));
  162. TestMemoryEqual(message->payload(), message->payload_size(),
  163. channel_delegate.GetReceivedPayload(),
  164. channel_delegate.GetReceivedPayloadSize());
  165. }
  166. class ChannelTestShutdownAndWriteDelegate : public Channel::Delegate {
  167. public:
  168. ChannelTestShutdownAndWriteDelegate(
  169. PlatformChannelEndpoint endpoint,
  170. scoped_refptr<base::SingleThreadTaskRunner> task_runner,
  171. scoped_refptr<Channel> client_channel,
  172. std::unique_ptr<base::Thread> client_thread,
  173. base::RepeatingClosure quit_closure)
  174. : quit_closure_(std::move(quit_closure)),
  175. client_channel_(std::move(client_channel)),
  176. client_thread_(std::move(client_thread)) {
  177. channel_ = Channel::Create(this, ConnectionParams(std::move(endpoint)),
  178. Channel::HandlePolicy::kAcceptHandles,
  179. std::move(task_runner));
  180. channel_->Start();
  181. }
  182. ~ChannelTestShutdownAndWriteDelegate() override { channel_->ShutDown(); }
  183. // Channel::Delegate implementation
  184. void OnChannelMessage(const void* payload,
  185. size_t payload_size,
  186. std::vector<PlatformHandle> handles) override {
  187. ++message_count_;
  188. // If |client_channel_| exists then close it and its thread.
  189. if (client_channel_) {
  190. // Write a fresh message, making our channel readable again.
  191. Channel::MessagePtr message = CreateDefaultMessage(false);
  192. client_thread_->task_runner()->PostTask(
  193. FROM_HERE,
  194. base::BindOnce(&Channel::Write, client_channel_, std::move(message)));
  195. // Close the channel and wait for it to shutdown.
  196. client_channel_->ShutDown();
  197. client_channel_ = nullptr;
  198. client_thread_->Stop();
  199. client_thread_ = nullptr;
  200. }
  201. // Write a message to the channel, to verify whether this triggers an
  202. // OnChannelError callback before all messages were read.
  203. Channel::MessagePtr message = CreateDefaultMessage(false);
  204. channel_->Write(std::move(message));
  205. }
  206. void OnChannelError(Channel::Error error) override {
  207. EXPECT_EQ(2, message_count_);
  208. quit_closure_.Run();
  209. }
  210. base::RepeatingClosure quit_closure_;
  211. int message_count_ = 0;
  212. scoped_refptr<Channel> channel_;
  213. scoped_refptr<Channel> client_channel_;
  214. std::unique_ptr<base::Thread> client_thread_;
  215. };
  216. TEST(ChannelTest, PeerShutdownDuringRead) {
  217. base::test::SingleThreadTaskEnvironment task_environment(
  218. base::test::TaskEnvironment::MainThreadType::IO);
  219. PlatformChannel channel;
  220. // Create a "client" Channel with one end of the pipe, and Start() it.
  221. std::unique_ptr<base::Thread> client_thread =
  222. std::make_unique<base::Thread>("clientio_thread");
  223. client_thread->StartWithOptions(
  224. base::Thread::Options(base::MessagePumpType::IO, 0));
  225. scoped_refptr<Channel> client_channel = Channel::Create(
  226. nullptr, ConnectionParams(channel.TakeRemoteEndpoint()),
  227. Channel::HandlePolicy::kAcceptHandles, client_thread->task_runner());
  228. client_channel->Start();
  229. // On the "client" IO thread, create and write a message.
  230. Channel::MessagePtr message = CreateDefaultMessage(false);
  231. client_thread->task_runner()->PostTask(
  232. FROM_HERE,
  233. base::BindOnce(&Channel::Write, client_channel, std::move(message)));
  234. // Create a "server" Channel with the other end of the pipe, and process the
  235. // messages from it. The |server_delegate| will ShutDown the client end of
  236. // the pipe after the first message, and quit the RunLoop when OnChannelError
  237. // is received.
  238. base::RunLoop run_loop;
  239. ChannelTestShutdownAndWriteDelegate server_delegate(
  240. channel.TakeLocalEndpoint(), base::ThreadTaskRunnerHandle::Get(),
  241. std::move(client_channel), std::move(client_thread),
  242. run_loop.QuitClosure());
  243. run_loop.Run();
  244. }
  245. class RejectHandlesDelegate : public Channel::Delegate {
  246. public:
  247. RejectHandlesDelegate() = default;
  248. RejectHandlesDelegate(const RejectHandlesDelegate&) = delete;
  249. RejectHandlesDelegate& operator=(const RejectHandlesDelegate&) = delete;
  250. size_t num_messages() const { return num_messages_; }
  251. // Channel::Delegate:
  252. void OnChannelMessage(const void* payload,
  253. size_t payload_size,
  254. std::vector<PlatformHandle> handles) override {
  255. ++num_messages_;
  256. }
  257. void OnChannelError(Channel::Error error) override {
  258. if (wait_for_error_loop_)
  259. wait_for_error_loop_->Quit();
  260. }
  261. void WaitForError() {
  262. wait_for_error_loop_.emplace();
  263. wait_for_error_loop_->Run();
  264. }
  265. private:
  266. size_t num_messages_ = 0;
  267. absl::optional<base::RunLoop> wait_for_error_loop_;
  268. };
  269. TEST(ChannelTest, RejectHandles) {
  270. base::test::SingleThreadTaskEnvironment task_environment(
  271. base::test::TaskEnvironment::MainThreadType::IO);
  272. PlatformChannel platform_channel;
  273. RejectHandlesDelegate receiver_delegate;
  274. scoped_refptr<Channel> receiver =
  275. Channel::Create(&receiver_delegate,
  276. ConnectionParams(platform_channel.TakeLocalEndpoint()),
  277. Channel::HandlePolicy::kRejectHandles,
  278. base::ThreadTaskRunnerHandle::Get());
  279. receiver->Start();
  280. RejectHandlesDelegate sender_delegate;
  281. scoped_refptr<Channel> sender = Channel::Create(
  282. &sender_delegate, ConnectionParams(platform_channel.TakeRemoteEndpoint()),
  283. Channel::HandlePolicy::kRejectHandles,
  284. base::ThreadTaskRunnerHandle::Get());
  285. sender->Start();
  286. // Create another platform channel just to stuff one of its endpoint handles
  287. // into a message. Sending this message to the receiver should cause the
  288. // receiver to reject it and close the Channel without ever dispatching the
  289. // message.
  290. PlatformChannel dummy_channel;
  291. std::vector<mojo::PlatformHandle> handles;
  292. handles.push_back(dummy_channel.TakeLocalEndpoint().TakePlatformHandle());
  293. auto message = Channel::Message::CreateMessage(0 /* payload_size */,
  294. 1 /* max_handles */);
  295. message->SetHandles(std::move(handles));
  296. sender->Write(std::move(message));
  297. receiver_delegate.WaitForError();
  298. EXPECT_EQ(0u, receiver_delegate.num_messages());
  299. }
  300. TEST(ChannelTest, DeserializeMessage_BadExtraHeaderSize) {
  301. // Verifies that a message payload is rejected when the extra header chunk
  302. // size not properly aligned.
  303. constexpr uint16_t kBadAlignment = kChannelMessageAlignment + 1;
  304. constexpr uint16_t kTotalHeaderSize =
  305. sizeof(Channel::Message::Header) + kBadAlignment;
  306. constexpr uint32_t kEmptyPayloadSize = 8;
  307. constexpr uint32_t kMessageSize = kTotalHeaderSize + kEmptyPayloadSize;
  308. char message[kMessageSize];
  309. memset(message, 0, kMessageSize);
  310. Channel::Message::Header* header =
  311. reinterpret_cast<Channel::Message::Header*>(&message[0]);
  312. header->num_bytes = kMessageSize;
  313. header->num_header_bytes = kTotalHeaderSize;
  314. header->message_type = Channel::Message::MessageType::NORMAL;
  315. header->num_handles = 0;
  316. EXPECT_EQ(nullptr,
  317. Channel::Message::Deserialize(&message[0], kMessageSize,
  318. Channel::HandlePolicy::kAcceptHandles,
  319. base::kNullProcessHandle));
  320. }
  321. // This test is only enabled for Linux-based platforms.
  322. #if !BUILDFLAG(IS_WIN) && !BUILDFLAG(IS_APPLE) && !BUILDFLAG(IS_FUCHSIA)
  323. TEST(ChannelTest, DeserializeMessage_NonZeroExtraHeaderSize) {
  324. // Verifies that a message payload is rejected when the extra header chunk
  325. // size anything but zero on Linux, even if it's aligned.
  326. constexpr uint16_t kTotalHeaderSize =
  327. sizeof(Channel::Message::Header) + kChannelMessageAlignment;
  328. constexpr uint32_t kEmptyPayloadSize = 8;
  329. constexpr uint32_t kMessageSize = kTotalHeaderSize + kEmptyPayloadSize;
  330. char message[kMessageSize];
  331. memset(message, 0, kMessageSize);
  332. Channel::Message::Header* header =
  333. reinterpret_cast<Channel::Message::Header*>(&message[0]);
  334. header->num_bytes = kMessageSize;
  335. header->num_header_bytes = kTotalHeaderSize;
  336. header->message_type = Channel::Message::MessageType::NORMAL;
  337. header->num_handles = 0;
  338. EXPECT_EQ(nullptr,
  339. Channel::Message::Deserialize(&message[0], kMessageSize,
  340. Channel::HandlePolicy::kAcceptHandles,
  341. base::kNullProcessHandle));
  342. }
  343. #endif
  344. class CountingChannelDelegate : public Channel::Delegate {
  345. public:
  346. explicit CountingChannelDelegate(base::OnceClosure on_final_message)
  347. : on_final_message_(std::move(on_final_message)) {}
  348. ~CountingChannelDelegate() override = default;
  349. void OnChannelMessage(const void* payload,
  350. size_t payload_size,
  351. std::vector<PlatformHandle> handles) override {
  352. // If this is the special "final message", run the closure.
  353. if (payload_size == 1) {
  354. auto* payload_str = reinterpret_cast<const char*>(payload);
  355. if (payload_str[0] == '!') {
  356. std::move(on_final_message_).Run();
  357. return;
  358. }
  359. }
  360. ++message_count_;
  361. }
  362. void OnChannelError(Channel::Error error) override { ++error_count_; }
  363. size_t message_count_ = 0;
  364. size_t error_count_ = 0;
  365. private:
  366. base::OnceClosure on_final_message_;
  367. };
  368. TEST(ChannelTest, PeerStressTest) {
  369. constexpr size_t kLotsOfMessages = 1024;
  370. base::test::SingleThreadTaskEnvironment task_environment(
  371. base::test::TaskEnvironment::MainThreadType::IO);
  372. base::RunLoop run_loop;
  373. // Both channels should receive all the messages that each is sent. When
  374. // the count becomes 2 (indicating both channels have received the final
  375. // message), quit the main test thread's run loop.
  376. std::atomic_int count_channels_received_final_message(0);
  377. auto quit_when_both_channels_received_final_message = base::BindRepeating(
  378. [](std::atomic_int* count_channels_received_final_message,
  379. base::OnceClosure quit_closure) {
  380. if (++(*count_channels_received_final_message) == 2) {
  381. std::move(quit_closure).Run();
  382. }
  383. },
  384. base::Unretained(&count_channels_received_final_message),
  385. run_loop.QuitClosure());
  386. // Create a second IO thread for the peer channel.
  387. base::Thread::Options thread_options;
  388. thread_options.message_pump_type = base::MessagePumpType::IO;
  389. base::Thread peer_thread("peer_b_io");
  390. peer_thread.StartWithOptions(std::move(thread_options));
  391. // Create two channels that run on separate threads.
  392. PlatformChannel platform_channel;
  393. CountingChannelDelegate delegate_a(
  394. quit_when_both_channels_received_final_message);
  395. scoped_refptr<Channel> channel_a = Channel::Create(
  396. &delegate_a, ConnectionParams(platform_channel.TakeLocalEndpoint()),
  397. Channel::HandlePolicy::kRejectHandles,
  398. base::ThreadTaskRunnerHandle::Get());
  399. CountingChannelDelegate delegate_b(
  400. quit_when_both_channels_received_final_message);
  401. scoped_refptr<Channel> channel_b = Channel::Create(
  402. &delegate_b, ConnectionParams(platform_channel.TakeRemoteEndpoint()),
  403. Channel::HandlePolicy::kRejectHandles, peer_thread.task_runner());
  404. // Send a lot of messages, followed by a final terminating message.
  405. auto send_lots_of_messages = [](scoped_refptr<Channel> channel) {
  406. for (size_t i = 0; i < kLotsOfMessages; ++i) {
  407. channel->Write(Channel::Message::CreateMessage(0, 0));
  408. }
  409. };
  410. auto send_final_message = [](scoped_refptr<Channel> channel) {
  411. auto message = Channel::Message::CreateMessage(1, 0);
  412. auto* payload = static_cast<char*>(message->mutable_payload());
  413. payload[0] = '!';
  414. channel->Write(std::move(message));
  415. };
  416. channel_a->Start();
  417. channel_b->Start();
  418. send_lots_of_messages(channel_a);
  419. send_lots_of_messages(channel_b);
  420. base::ThreadTaskRunnerHandle::Get()->PostTask(
  421. FROM_HERE, base::BindOnce(send_lots_of_messages, channel_a));
  422. base::ThreadTaskRunnerHandle::Get()->PostTask(
  423. FROM_HERE, base::BindOnce(send_lots_of_messages, channel_a));
  424. base::ThreadTaskRunnerHandle::Get()->PostTask(
  425. FROM_HERE, base::BindOnce(send_final_message, channel_a));
  426. peer_thread.task_runner()->PostTask(
  427. FROM_HERE, base::BindOnce(send_lots_of_messages, channel_b));
  428. peer_thread.task_runner()->PostTask(
  429. FROM_HERE, base::BindOnce(send_lots_of_messages, channel_b));
  430. peer_thread.task_runner()->PostTask(
  431. FROM_HERE, base::BindOnce(send_final_message, channel_b));
  432. // Run until quit_when_both_channels_received_final_message quits the loop.
  433. run_loop.Run();
  434. channel_a->ShutDown();
  435. channel_b->ShutDown();
  436. peer_thread.StopSoon();
  437. base::RunLoop().RunUntilIdle();
  438. EXPECT_EQ(kLotsOfMessages * 3, delegate_a.message_count_);
  439. EXPECT_EQ(kLotsOfMessages * 3, delegate_b.message_count_);
  440. EXPECT_EQ(0u, delegate_a.error_count_);
  441. EXPECT_EQ(0u, delegate_b.error_count_);
  442. }
  443. class CallbackChannelDelegate : public Channel::Delegate {
  444. public:
  445. CallbackChannelDelegate() = default;
  446. CallbackChannelDelegate(const CallbackChannelDelegate&) = delete;
  447. CallbackChannelDelegate& operator=(const CallbackChannelDelegate&) = delete;
  448. void OnChannelMessage(const void* payload,
  449. size_t payload_size,
  450. std::vector<PlatformHandle> handles) override {
  451. if (on_message_)
  452. std::move(on_message_).Run();
  453. }
  454. void OnChannelError(Channel::Error error) override {
  455. if (on_error_)
  456. std::move(on_error_).Run();
  457. }
  458. void set_on_message(base::OnceClosure on_message) {
  459. on_message_ = std::move(on_message);
  460. }
  461. void set_on_error(base::OnceClosure on_error) {
  462. on_error_ = std::move(on_error);
  463. }
  464. private:
  465. base::OnceClosure on_message_;
  466. base::OnceClosure on_error_;
  467. };
  468. TEST(ChannelTest, MessageSizeTest) {
  469. base::test::SingleThreadTaskEnvironment task_environment(
  470. base::test::TaskEnvironment::MainThreadType::IO);
  471. PlatformChannel platform_channel;
  472. CallbackChannelDelegate receiver_delegate;
  473. scoped_refptr<Channel> receiver =
  474. Channel::Create(&receiver_delegate,
  475. ConnectionParams(platform_channel.TakeLocalEndpoint()),
  476. Channel::HandlePolicy::kAcceptHandles,
  477. base::ThreadTaskRunnerHandle::Get());
  478. receiver->Start();
  479. MockChannelDelegate sender_delegate;
  480. scoped_refptr<Channel> sender = Channel::Create(
  481. &sender_delegate, ConnectionParams(platform_channel.TakeRemoteEndpoint()),
  482. Channel::HandlePolicy::kAcceptHandles,
  483. base::ThreadTaskRunnerHandle::Get());
  484. sender->Start();
  485. for (uint32_t i = 0; i < base::GetPageSize() * 4; ++i) {
  486. SCOPED_TRACE(base::StringPrintf("message size %d", i));
  487. auto message = Channel::Message::CreateMessage(i, 0);
  488. memset(message->mutable_payload(), 0xAB, i);
  489. sender->Write(std::move(message));
  490. bool got_message = false, got_error = false;
  491. base::RunLoop loop;
  492. receiver_delegate.set_on_message(
  493. base::BindLambdaForTesting([&got_message, &loop]() {
  494. got_message = true;
  495. loop.Quit();
  496. }));
  497. receiver_delegate.set_on_error(
  498. base::BindLambdaForTesting([&got_error, &loop]() {
  499. got_error = true;
  500. loop.Quit();
  501. }));
  502. loop.Run();
  503. EXPECT_TRUE(got_message);
  504. EXPECT_FALSE(got_error);
  505. }
  506. }
  507. #if BUILDFLAG(IS_MAC)
  508. TEST(ChannelTest, SendToDeadMachPortName) {
  509. base::test::SingleThreadTaskEnvironment task_environment(
  510. base::test::TaskEnvironment::MainThreadType::IO);
  511. // Create a second IO thread for the B channel. It needs to process tasks
  512. // separately from channel A.
  513. base::Thread::Options thread_options;
  514. thread_options.message_pump_type = base::MessagePumpType::IO;
  515. base::Thread peer_thread("channel_b_io");
  516. peer_thread.StartWithOptions(std::move(thread_options));
  517. // Create a PlatformChannel send/receive right pair.
  518. PlatformChannel platform_channel;
  519. mach_port_urefs_t send = 0, dead = 0;
  520. mach_port_t send_name = platform_channel.local_endpoint()
  521. .platform_handle()
  522. .GetMachSendRight()
  523. .get();
  524. auto get_send_name_refs = [&send, &dead, send_name]() {
  525. kern_return_t kr = mach_port_get_refs(mach_task_self(), send_name,
  526. MACH_PORT_RIGHT_SEND, &send);
  527. ASSERT_EQ(kr, KERN_SUCCESS);
  528. kr = mach_port_get_refs(mach_task_self(), send_name,
  529. MACH_PORT_RIGHT_DEAD_NAME, &dead);
  530. ASSERT_EQ(kr, KERN_SUCCESS);
  531. };
  532. get_send_name_refs();
  533. EXPECT_EQ(1u, send);
  534. EXPECT_EQ(0u, dead);
  535. // Add an extra send right.
  536. ASSERT_EQ(KERN_SUCCESS, mach_port_mod_refs(mach_task_self(), send_name,
  537. MACH_PORT_RIGHT_SEND, 1));
  538. get_send_name_refs();
  539. EXPECT_EQ(2u, send);
  540. EXPECT_EQ(0u, dead);
  541. base::mac::ScopedMachSendRight extra_send(send_name);
  542. // Channel A gets created with the Mach send right from |platform_channel|.
  543. CallbackChannelDelegate delegate_a;
  544. scoped_refptr<Channel> channel_a = Channel::Create(
  545. &delegate_a, ConnectionParams(platform_channel.TakeLocalEndpoint()),
  546. Channel::HandlePolicy::kAcceptHandles,
  547. base::ThreadTaskRunnerHandle::Get());
  548. channel_a->Start();
  549. // Channel B gets the receive right.
  550. MockChannelDelegate delegate_b;
  551. scoped_refptr<Channel> channel_b = Channel::Create(
  552. &delegate_b, ConnectionParams(platform_channel.TakeRemoteEndpoint()),
  553. Channel::HandlePolicy::kAcceptHandles, peer_thread.task_runner());
  554. channel_b->Start();
  555. // Ensure the channels have started and are talking.
  556. channel_b->Write(Channel::Message::CreateMessage(0, 0));
  557. {
  558. base::RunLoop loop;
  559. delegate_a.set_on_message(loop.QuitClosure());
  560. loop.Run();
  561. }
  562. // Queue two messages from B to A. Two are required so that channel A does
  563. // not immediately process the dead-name notification when channel B shuts
  564. // down.
  565. channel_b->Write(Channel::Message::CreateMessage(0, 0));
  566. channel_b->Write(Channel::Message::CreateMessage(0, 0));
  567. // Turn Channel A's send right into a dead name.
  568. channel_b->ShutDown();
  569. channel_b = nullptr;
  570. // ShutDown() posts a task on the channel's TaskRunner, so wait for that
  571. // to run.
  572. base::WaitableEvent event;
  573. peer_thread.task_runner()->PostTask(
  574. FROM_HERE,
  575. base::BindOnce(&base::WaitableEvent::Signal, base::Unretained(&event)));
  576. event.Wait();
  577. // Force a send-to-dead-name on Channel A.
  578. channel_a->Write(Channel::Message::CreateMessage(0, 0));
  579. {
  580. base::RunLoop loop;
  581. delegate_a.set_on_error(base::BindOnce(
  582. [](scoped_refptr<Channel> channel, base::RunLoop* loop) {
  583. channel->ShutDown();
  584. channel = nullptr;
  585. loop->QuitWhenIdle();
  586. },
  587. channel_a, base::Unretained(&loop)));
  588. loop.Run();
  589. }
  590. // The only remaining ref should be the extra one that was added in the test.
  591. get_send_name_refs();
  592. EXPECT_EQ(0u, send);
  593. EXPECT_EQ(1u, dead);
  594. }
  595. #endif // BUILDFLAG(IS_MAC)
  596. } // namespace
  597. } // namespace core
  598. } // namespace mojo