address_sorter_win.cc 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/dns/address_sorter.h"
  5. #include <winsock2.h>
  6. #include <algorithm>
  7. #include <utility>
  8. #include <vector>
  9. #include "base/bind.h"
  10. #include "base/location.h"
  11. #include "base/logging.h"
  12. #include "base/memory/free_deleter.h"
  13. #include "base/task/thread_pool.h"
  14. #include "net/base/ip_address.h"
  15. #include "net/base/ip_endpoint.h"
  16. #include "net/base/winsock_init.h"
  17. namespace net {
  18. namespace {
  19. class AddressSorterWin : public AddressSorter {
  20. public:
  21. AddressSorterWin() {
  22. EnsureWinsockInit();
  23. }
  24. AddressSorterWin(const AddressSorterWin&) = delete;
  25. AddressSorterWin& operator=(const AddressSorterWin&) = delete;
  26. ~AddressSorterWin() override {}
  27. // AddressSorter:
  28. void Sort(const std::vector<IPEndPoint>& endpoints,
  29. CallbackType callback) const override {
  30. DCHECK(!endpoints.empty());
  31. Job::Start(endpoints, std::move(callback));
  32. }
  33. private:
  34. // Executes the SIO_ADDRESS_LIST_SORT ioctl asynchronously, and
  35. // performs the necessary conversions to/from `std::vector<IPEndPoint>`.
  36. class Job : public base::RefCountedThreadSafe<Job> {
  37. public:
  38. static void Start(const std::vector<IPEndPoint>& endpoints,
  39. CallbackType callback) {
  40. auto job = base::WrapRefCounted(new Job(endpoints, std::move(callback)));
  41. base::ThreadPool::PostTaskAndReply(
  42. FROM_HERE,
  43. {base::MayBlock(), base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN},
  44. base::BindOnce(&Job::Run, job),
  45. base::BindOnce(&Job::OnComplete, job));
  46. }
  47. Job(const Job&) = delete;
  48. Job& operator=(const Job&) = delete;
  49. private:
  50. friend class base::RefCountedThreadSafe<Job>;
  51. Job(const std::vector<IPEndPoint>& endpoints, CallbackType callback)
  52. : callback_(std::move(callback)),
  53. buffer_size_((sizeof(SOCKET_ADDRESS_LIST) +
  54. base::CheckedNumeric<DWORD>(endpoints.size()) *
  55. (sizeof(SOCKET_ADDRESS) + sizeof(SOCKADDR_STORAGE)))
  56. .ValueOrDie<DWORD>()),
  57. input_buffer_(
  58. reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))),
  59. output_buffer_(
  60. reinterpret_cast<SOCKET_ADDRESS_LIST*>(malloc(buffer_size_))) {
  61. input_buffer_->iAddressCount = base::checked_cast<INT>(endpoints.size());
  62. SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>(
  63. input_buffer_->Address + input_buffer_->iAddressCount);
  64. for (size_t i = 0; i < endpoints.size(); ++i) {
  65. IPEndPoint ipe = endpoints[i];
  66. // Addresses must be sockaddr_in6.
  67. if (ipe.address().IsIPv4()) {
  68. ipe = IPEndPoint(ConvertIPv4ToIPv4MappedIPv6(ipe.address()),
  69. ipe.port());
  70. }
  71. struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i);
  72. socklen_t addr_len = sizeof(SOCKADDR_STORAGE);
  73. bool result = ipe.ToSockAddr(addr, &addr_len);
  74. DCHECK(result);
  75. input_buffer_->Address[i].lpSockaddr = addr;
  76. input_buffer_->Address[i].iSockaddrLength = addr_len;
  77. }
  78. }
  79. ~Job() {}
  80. // Executed asynchronously in ThreadPool.
  81. void Run() {
  82. SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
  83. if (sock == INVALID_SOCKET)
  84. return;
  85. DWORD result_size = 0;
  86. int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(),
  87. buffer_size_, output_buffer_.get(), buffer_size_,
  88. &result_size, nullptr, nullptr);
  89. if (result == SOCKET_ERROR) {
  90. LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError();
  91. } else {
  92. success_ = true;
  93. }
  94. closesocket(sock);
  95. }
  96. // Executed on the calling thread.
  97. void OnComplete() {
  98. std::vector<IPEndPoint> sorted;
  99. if (success_) {
  100. sorted.reserve(output_buffer_->iAddressCount);
  101. for (int i = 0; i < output_buffer_->iAddressCount; ++i) {
  102. IPEndPoint ipe;
  103. bool result =
  104. ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr,
  105. output_buffer_->Address[i].iSockaddrLength);
  106. DCHECK(result) << "Unable to roundtrip between IPEndPoint and "
  107. << "SOCKET_ADDRESS!";
  108. // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works.
  109. if (ipe.address().IsIPv4MappedIPv6()) {
  110. ipe = IPEndPoint(ConvertIPv4MappedIPv6ToIPv4(ipe.address()),
  111. ipe.port());
  112. }
  113. sorted.push_back(ipe);
  114. }
  115. }
  116. std::move(callback_).Run(success_, std::move(sorted));
  117. }
  118. CallbackType callback_;
  119. const DWORD buffer_size_;
  120. std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> input_buffer_;
  121. std::unique_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> output_buffer_;
  122. bool success_ = false;
  123. };
  124. };
  125. } // namespace
  126. // static
  127. std::unique_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
  128. return std::make_unique<AddressSorterWin>();
  129. }
  130. } // namespace net