broker_posix.cc 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. // Copyright 2016 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "mojo/core/broker.h"
  5. #include <fcntl.h>
  6. #include <unistd.h>
  7. #include <utility>
  8. #include <vector>
  9. #include "base/logging.h"
  10. #include "base/memory/platform_shared_memory_region.h"
  11. #include "build/build_config.h"
  12. #include "mojo/core/broker_messages.h"
  13. #include "mojo/core/channel.h"
  14. #include "mojo/core/platform_handle_utils.h"
  15. #include "mojo/public/cpp/platform/socket_utils_posix.h"
  16. namespace mojo {
  17. namespace core {
  18. namespace {
  19. Channel::MessagePtr WaitForBrokerMessage(
  20. int socket_fd,
  21. BrokerMessageType expected_type,
  22. size_t expected_num_handles,
  23. size_t expected_data_size,
  24. std::vector<PlatformHandle>* incoming_handles) {
  25. Channel::MessagePtr message = Channel::Message::CreateMessage(
  26. sizeof(BrokerMessageHeader) + expected_data_size, expected_num_handles);
  27. std::vector<base::ScopedFD> incoming_fds;
  28. ssize_t read_result =
  29. SocketRecvmsg(socket_fd, const_cast<void*>(message->data()),
  30. message->data_num_bytes(), &incoming_fds, true /* block */);
  31. bool error = false;
  32. if (read_result < 0) {
  33. PLOG(ERROR) << "Recvmsg error";
  34. error = true;
  35. } else if (static_cast<size_t>(read_result) != message->data_num_bytes()) {
  36. LOG(ERROR) << "Invalid node channel message";
  37. error = true;
  38. } else if (incoming_fds.size() != expected_num_handles) {
  39. DLOG(ERROR) << "Received unexpected number of handles";
  40. error = true;
  41. }
  42. if (error)
  43. return nullptr;
  44. const BrokerMessageHeader* header =
  45. reinterpret_cast<const BrokerMessageHeader*>(message->payload());
  46. if (header->type != expected_type) {
  47. LOG(ERROR) << "Unexpected message";
  48. return nullptr;
  49. }
  50. incoming_handles->reserve(incoming_fds.size());
  51. for (size_t i = 0; i < incoming_fds.size(); ++i)
  52. incoming_handles->emplace_back(std::move(incoming_fds[i]));
  53. return message;
  54. }
  55. } // namespace
  56. Broker::Broker(PlatformHandle handle, bool wait_for_channel_handle)
  57. : sync_channel_(std::move(handle)) {
  58. CHECK(sync_channel_.is_valid());
  59. int fd = sync_channel_.GetFD().get();
  60. // Mark the channel as blocking.
  61. int flags = fcntl(fd, F_GETFL);
  62. PCHECK(flags != -1);
  63. flags = fcntl(fd, F_SETFL, flags & ~O_NONBLOCK);
  64. PCHECK(flags != -1);
  65. if (!wait_for_channel_handle)
  66. return;
  67. // Wait for the first message, which should contain a handle.
  68. std::vector<PlatformHandle> incoming_platform_handles;
  69. if (WaitForBrokerMessage(fd, BrokerMessageType::INIT, 1, 0,
  70. &incoming_platform_handles)) {
  71. inviter_endpoint_ =
  72. PlatformChannelEndpoint(std::move(incoming_platform_handles[0]));
  73. }
  74. }
  75. Broker::~Broker() = default;
  76. PlatformChannelEndpoint Broker::GetInviterEndpoint() {
  77. return std::move(inviter_endpoint_);
  78. }
  79. base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
  80. size_t num_bytes) {
  81. base::AutoLock lock(lock_);
  82. BufferRequestData* buffer_request;
  83. Channel::MessagePtr out_message = CreateBrokerMessage(
  84. BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
  85. buffer_request->size = num_bytes;
  86. ssize_t write_result =
  87. SocketWrite(sync_channel_.GetFD().get(), out_message->data(),
  88. out_message->data_num_bytes());
  89. if (write_result < 0) {
  90. PLOG(ERROR) << "Error sending sync broker message";
  91. return base::WritableSharedMemoryRegion();
  92. } else if (static_cast<size_t>(write_result) !=
  93. out_message->data_num_bytes()) {
  94. LOG(ERROR) << "Error sending complete broker message";
  95. return base::WritableSharedMemoryRegion();
  96. }
  97. #if !BUILDFLAG(IS_POSIX) || BUILDFLAG(IS_ANDROID) || BUILDFLAG(IS_MAC)
  98. // Non-POSIX systems, as well as Android and Mac, only use a single handle to
  99. // represent a writable region.
  100. constexpr size_t kNumExpectedHandles = 1;
  101. #else
  102. constexpr size_t kNumExpectedHandles = 2;
  103. #endif
  104. std::vector<PlatformHandle> handles;
  105. Channel::MessagePtr message = WaitForBrokerMessage(
  106. sync_channel_.GetFD().get(), BrokerMessageType::BUFFER_RESPONSE,
  107. kNumExpectedHandles, sizeof(BufferResponseData), &handles);
  108. if (message) {
  109. const BufferResponseData* data;
  110. if (!GetBrokerMessageData(message.get(), &data))
  111. return base::WritableSharedMemoryRegion();
  112. if (handles.size() == 1)
  113. handles.emplace_back();
  114. return base::WritableSharedMemoryRegion::Deserialize(
  115. base::subtle::PlatformSharedMemoryRegion::Take(
  116. CreateSharedMemoryRegionHandleFromPlatformHandles(
  117. std::move(handles[0]), std::move(handles[1])),
  118. base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
  119. num_bytes,
  120. base::UnguessableToken::Deserialize(data->guid_high,
  121. data->guid_low)));
  122. }
  123. return base::WritableSharedMemoryRegion();
  124. }
  125. } // namespace core
  126. } // namespace mojo