mdns_client_impl.cc 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. // Copyright 2013 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/dns/mdns_client_impl.h"
  5. #include <algorithm>
  6. #include <memory>
  7. #include <utility>
  8. #include "base/bind.h"
  9. #include "base/location.h"
  10. #include "base/observer_list.h"
  11. #include "base/strings/string_util.h"
  12. #include "base/task/single_thread_task_runner.h"
  13. #include "base/threading/thread_task_runner_handle.h"
  14. #include "base/time/clock.h"
  15. #include "base/time/default_clock.h"
  16. #include "base/time/time.h"
  17. #include "base/timer/timer.h"
  18. #include "net/base/net_errors.h"
  19. #include "net/base/rand_callback.h"
  20. #include "net/dns/dns_util.h"
  21. #include "net/dns/public/dns_protocol.h"
  22. #include "net/dns/public/util.h"
  23. #include "net/dns/record_rdata.h"
  24. #include "net/socket/datagram_socket.h"
  25. // TODO(gene): Remove this temporary method of disabling NSEC support once it
  26. // becomes clear whether this feature should be
  27. // supported. http://crbug.com/255232
  28. #define ENABLE_NSEC
  29. namespace net {
  30. namespace {
  31. // The fractions of the record's original TTL after which an active listener
  32. // (one that had |SetActiveRefresh(true)| called) will send a query to refresh
  33. // its cache. This happens both at 85% of the original TTL and again at 95% of
  34. // the original TTL.
  35. const double kListenerRefreshRatio1 = 0.85;
  36. const double kListenerRefreshRatio2 = 0.95;
  37. } // namespace
  38. void MDnsSocketFactoryImpl::CreateSockets(
  39. std::vector<std::unique_ptr<DatagramServerSocket>>* sockets) {
  40. InterfaceIndexFamilyList interfaces(GetMDnsInterfacesToBind());
  41. for (const auto& interface : interfaces) {
  42. DCHECK(interface.second == ADDRESS_FAMILY_IPV4 ||
  43. interface.second == ADDRESS_FAMILY_IPV6);
  44. std::unique_ptr<DatagramServerSocket> socket(
  45. CreateAndBindMDnsSocket(interface.second, interface.first, net_log_));
  46. if (socket)
  47. sockets->push_back(std::move(socket));
  48. }
  49. }
  50. MDnsConnection::SocketHandler::SocketHandler(
  51. std::unique_ptr<DatagramServerSocket> socket,
  52. MDnsConnection* connection)
  53. : socket_(std::move(socket)),
  54. connection_(connection),
  55. response_(dns_protocol::kMaxMulticastSize) {}
  56. MDnsConnection::SocketHandler::~SocketHandler() = default;
  57. int MDnsConnection::SocketHandler::Start() {
  58. IPEndPoint end_point;
  59. int rv = socket_->GetLocalAddress(&end_point);
  60. if (rv != OK)
  61. return rv;
  62. DCHECK(end_point.GetFamily() == ADDRESS_FAMILY_IPV4 ||
  63. end_point.GetFamily() == ADDRESS_FAMILY_IPV6);
  64. multicast_addr_ = dns_util::GetMdnsGroupEndPoint(end_point.GetFamily());
  65. return DoLoop(0);
  66. }
  67. int MDnsConnection::SocketHandler::DoLoop(int rv) {
  68. do {
  69. if (rv > 0)
  70. connection_->OnDatagramReceived(&response_, recv_addr_, rv);
  71. rv = socket_->RecvFrom(
  72. response_.io_buffer(), response_.io_buffer_size(), &recv_addr_,
  73. base::BindOnce(&MDnsConnection::SocketHandler::OnDatagramReceived,
  74. base::Unretained(this)));
  75. } while (rv > 0);
  76. if (rv != ERR_IO_PENDING)
  77. return rv;
  78. return OK;
  79. }
  80. void MDnsConnection::SocketHandler::OnDatagramReceived(int rv) {
  81. if (rv >= OK)
  82. rv = DoLoop(rv);
  83. if (rv != OK)
  84. connection_->PostOnError(this, rv);
  85. }
  86. void MDnsConnection::SocketHandler::Send(const scoped_refptr<IOBuffer>& buffer,
  87. unsigned size) {
  88. if (send_in_progress_) {
  89. send_queue_.push(std::make_pair(buffer, size));
  90. return;
  91. }
  92. int rv =
  93. socket_->SendTo(buffer.get(), size, multicast_addr_,
  94. base::BindOnce(&MDnsConnection::SocketHandler::SendDone,
  95. base::Unretained(this)));
  96. if (rv == ERR_IO_PENDING) {
  97. send_in_progress_ = true;
  98. } else if (rv < OK) {
  99. connection_->PostOnError(this, rv);
  100. }
  101. }
  102. void MDnsConnection::SocketHandler::SendDone(int rv) {
  103. DCHECK(send_in_progress_);
  104. send_in_progress_ = false;
  105. if (rv != OK)
  106. connection_->PostOnError(this, rv);
  107. while (!send_in_progress_ && !send_queue_.empty()) {
  108. std::pair<scoped_refptr<IOBuffer>, unsigned> buffer = send_queue_.front();
  109. send_queue_.pop();
  110. Send(buffer.first, buffer.second);
  111. }
  112. }
  113. MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate)
  114. : delegate_(delegate) {}
  115. MDnsConnection::~MDnsConnection() = default;
  116. int MDnsConnection::Init(MDnsSocketFactory* socket_factory) {
  117. std::vector<std::unique_ptr<DatagramServerSocket>> sockets;
  118. socket_factory->CreateSockets(&sockets);
  119. for (std::unique_ptr<DatagramServerSocket>& socket : sockets) {
  120. socket_handlers_.push_back(std::make_unique<MDnsConnection::SocketHandler>(
  121. std::move(socket), this));
  122. }
  123. // All unbound sockets need to be bound before processing untrusted input.
  124. // This is done for security reasons, so that an attacker can't get an unbound
  125. // socket.
  126. int last_failure = ERR_FAILED;
  127. for (size_t i = 0; i < socket_handlers_.size();) {
  128. int rv = socket_handlers_[i]->Start();
  129. if (rv != OK) {
  130. last_failure = rv;
  131. socket_handlers_.erase(socket_handlers_.begin() + i);
  132. VLOG(1) << "Start failed, socket=" << i << ", error=" << rv;
  133. } else {
  134. ++i;
  135. }
  136. }
  137. VLOG(1) << "Sockets ready:" << socket_handlers_.size();
  138. DCHECK_NE(ERR_IO_PENDING, last_failure);
  139. return socket_handlers_.empty() ? last_failure : OK;
  140. }
  141. void MDnsConnection::Send(const scoped_refptr<IOBuffer>& buffer,
  142. unsigned size) {
  143. for (std::unique_ptr<SocketHandler>& handler : socket_handlers_)
  144. handler->Send(buffer, size);
  145. }
  146. void MDnsConnection::PostOnError(SocketHandler* loop, int rv) {
  147. int id = 0;
  148. for (const auto& it : socket_handlers_) {
  149. if (it.get() == loop)
  150. break;
  151. id++;
  152. }
  153. VLOG(1) << "Socket error. id=" << id << ", error=" << rv;
  154. // Post to allow deletion of this object by delegate.
  155. base::ThreadTaskRunnerHandle::Get()->PostTask(
  156. FROM_HERE, base::BindOnce(&MDnsConnection::OnError,
  157. weak_ptr_factory_.GetWeakPtr(), rv));
  158. }
  159. void MDnsConnection::OnError(int rv) {
  160. // TODO(noamsml): Specific handling of intermittent errors that can be handled
  161. // in the connection.
  162. delegate_->OnConnectionError(rv);
  163. }
  164. void MDnsConnection::OnDatagramReceived(
  165. DnsResponse* response,
  166. const IPEndPoint& recv_addr,
  167. int bytes_read) {
  168. // TODO(noamsml): More sophisticated error handling.
  169. DCHECK_GT(bytes_read, 0);
  170. delegate_->HandlePacket(response, bytes_read);
  171. }
  172. MDnsClientImpl::Core::Core(base::Clock* clock, base::OneShotTimer* timer)
  173. : clock_(clock),
  174. cleanup_timer_(timer),
  175. connection_(
  176. std::make_unique<MDnsConnection>((MDnsConnection::Delegate*)this)) {
  177. DCHECK(cleanup_timer_);
  178. DCHECK(!cleanup_timer_->IsRunning());
  179. }
  180. MDnsClientImpl::Core::~Core() {
  181. cleanup_timer_->Stop();
  182. }
  183. int MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) {
  184. CHECK(!cleanup_timer_->IsRunning());
  185. return connection_->Init(socket_factory);
  186. }
  187. bool MDnsClientImpl::Core::SendQuery(uint16_t rrtype, const std::string& name) {
  188. std::string name_dns;
  189. if (!DNSDomainFromUnrestrictedDot(name, &name_dns))
  190. return false;
  191. DnsQuery query(0, name_dns, rrtype);
  192. query.set_flags(0); // Remove the RD flag from the query. It is unneeded.
  193. connection_->Send(query.io_buffer(), query.io_buffer()->size());
  194. return true;
  195. }
  196. void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
  197. int bytes_read) {
  198. unsigned offset;
  199. // Note: We store cache keys rather than record pointers to avoid
  200. // erroneous behavior in case a packet contains multiple exclusive
  201. // records with the same type and name.
  202. std::map<MDnsCache::Key, MDnsCache::UpdateType> update_keys;
  203. DCHECK_GT(bytes_read, 0);
  204. if (!response->InitParseWithoutQuery(bytes_read)) {
  205. DVLOG(1) << "Could not understand an mDNS packet.";
  206. return; // Message is unreadable.
  207. }
  208. // TODO(noamsml): duplicate query suppression.
  209. if (!(response->flags() & dns_protocol::kFlagResponse))
  210. return; // Message is a query. ignore it.
  211. DnsRecordParser parser = response->Parser();
  212. unsigned answer_count = response->answer_count() +
  213. response->additional_answer_count();
  214. for (unsigned i = 0; i < answer_count; i++) {
  215. offset = parser.GetOffset();
  216. std::unique_ptr<const RecordParsed> record =
  217. RecordParsed::CreateFrom(&parser, clock_->Now());
  218. if (!record) {
  219. DVLOG(1) << "Could not understand an mDNS record.";
  220. if (offset == parser.GetOffset()) {
  221. DVLOG(1) << "Abandoned parsing the rest of the packet.";
  222. return; // The parser did not advance, abort reading the packet.
  223. } else {
  224. continue; // We may be able to extract other records from the packet.
  225. }
  226. }
  227. if ((record->klass() & dns_protocol::kMDnsClassMask) !=
  228. dns_protocol::kClassIN) {
  229. DVLOG(1) << "Received an mDNS record with non-IN class. Ignoring.";
  230. continue; // Ignore all records not in the IN class.
  231. }
  232. MDnsCache::Key update_key = MDnsCache::Key::CreateFor(record.get());
  233. MDnsCache::UpdateType update = cache_.UpdateDnsRecord(std::move(record));
  234. // Cleanup time may have changed.
  235. ScheduleCleanup(cache_.next_expiration());
  236. update_keys.insert(std::make_pair(update_key, update));
  237. }
  238. for (const auto& update_key : update_keys) {
  239. const RecordParsed* record = cache_.LookupKey(update_key.first);
  240. if (!record)
  241. continue;
  242. if (record->type() == dns_protocol::kTypeNSEC) {
  243. #if defined(ENABLE_NSEC)
  244. NotifyNsecRecord(record);
  245. #endif
  246. } else {
  247. AlertListeners(update_key.second,
  248. ListenerKey(record->name(), record->type()), record);
  249. }
  250. }
  251. }
  252. void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) {
  253. DCHECK_EQ(dns_protocol::kTypeNSEC, record->type());
  254. const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>();
  255. DCHECK(rdata);
  256. // Remove all cached records matching the nonexistent RR types.
  257. std::vector<const RecordParsed*> records_to_remove;
  258. cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now());
  259. for (const auto* record_to_remove : records_to_remove) {
  260. if (record_to_remove->type() == dns_protocol::kTypeNSEC)
  261. continue;
  262. if (!rdata->GetBit(record_to_remove->type())) {
  263. std::unique_ptr<const RecordParsed> record_removed =
  264. cache_.RemoveRecord(record_to_remove);
  265. DCHECK(record_removed);
  266. OnRecordRemoved(record_removed.get());
  267. }
  268. }
  269. // Alert all listeners waiting for the nonexistent RR types.
  270. ListenerKey key(record->name(), 0);
  271. auto i = listeners_.upper_bound(key);
  272. for (; i != listeners_.end() &&
  273. i->first.name_lowercase() == key.name_lowercase();
  274. i++) {
  275. if (!rdata->GetBit(i->first.type())) {
  276. for (auto& observer : *i->second)
  277. observer.AlertNsecRecord();
  278. }
  279. }
  280. }
  281. void MDnsClientImpl::Core::OnConnectionError(int error) {
  282. // TODO(noamsml): On connection error, recreate connection and flush cache.
  283. VLOG(1) << "MDNS OnConnectionError (code: " << error << ")";
  284. }
  285. MDnsClientImpl::Core::ListenerKey::ListenerKey(const std::string& name,
  286. uint16_t type)
  287. : name_lowercase_(base::ToLowerASCII(name)), type_(type) {}
  288. bool MDnsClientImpl::Core::ListenerKey::operator<(
  289. const MDnsClientImpl::Core::ListenerKey& key) const {
  290. if (name_lowercase_ == key.name_lowercase_)
  291. return type_ < key.type_;
  292. return name_lowercase_ < key.name_lowercase_;
  293. }
  294. void MDnsClientImpl::Core::AlertListeners(
  295. MDnsCache::UpdateType update_type,
  296. const ListenerKey& key,
  297. const RecordParsed* record) {
  298. auto listener_map_iterator = listeners_.find(key);
  299. if (listener_map_iterator == listeners_.end()) return;
  300. for (auto& observer : *listener_map_iterator->second)
  301. observer.HandleRecordUpdate(update_type, record);
  302. }
  303. void MDnsClientImpl::Core::AddListener(
  304. MDnsListenerImpl* listener) {
  305. ListenerKey key(listener->GetName(), listener->GetType());
  306. auto& observer_list = listeners_[key];
  307. if (!observer_list)
  308. observer_list = std::make_unique<ObserverListType>();
  309. observer_list->AddObserver(listener);
  310. }
  311. void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
  312. ListenerKey key(listener->GetName(), listener->GetType());
  313. auto observer_list_iterator = listeners_.find(key);
  314. DCHECK(observer_list_iterator != listeners_.end());
  315. DCHECK(observer_list_iterator->second->HasObserver(listener));
  316. observer_list_iterator->second->RemoveObserver(listener);
  317. // Remove the observer list from the map if it is empty
  318. if (observer_list_iterator->second->empty()) {
  319. // Schedule the actual removal for later in case the listener removal
  320. // happens while iterating over the observer list.
  321. base::ThreadTaskRunnerHandle::Get()->PostTask(
  322. FROM_HERE, base::BindOnce(&MDnsClientImpl::Core::CleanupObserverList,
  323. AsWeakPtr(), key));
  324. }
  325. }
  326. void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) {
  327. auto found = listeners_.find(key);
  328. if (found != listeners_.end() && found->second->empty()) {
  329. listeners_.erase(found);
  330. }
  331. }
  332. void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
  333. // If cache is overfilled. Force an immediate cleanup.
  334. if (cache_.IsCacheOverfilled())
  335. cleanup = clock_->Now();
  336. // Cleanup is already scheduled, no need to do anything.
  337. if (cleanup == scheduled_cleanup_) {
  338. return;
  339. }
  340. scheduled_cleanup_ = cleanup;
  341. // This cancels the previously scheduled cleanup.
  342. cleanup_timer_->Stop();
  343. // If |cleanup| is empty, then no cleanup necessary.
  344. if (cleanup != base::Time()) {
  345. cleanup_timer_->Start(FROM_HERE,
  346. std::max(base::TimeDelta(), cleanup - clock_->Now()),
  347. base::BindOnce(&MDnsClientImpl::Core::DoCleanup,
  348. base::Unretained(this)));
  349. }
  350. }
  351. void MDnsClientImpl::Core::DoCleanup() {
  352. cache_.CleanupRecords(
  353. clock_->Now(), base::BindRepeating(&MDnsClientImpl::Core::OnRecordRemoved,
  354. base::Unretained(this)));
  355. ScheduleCleanup(cache_.next_expiration());
  356. }
  357. void MDnsClientImpl::Core::OnRecordRemoved(
  358. const RecordParsed* record) {
  359. AlertListeners(MDnsCache::RecordRemoved,
  360. ListenerKey(record->name(), record->type()), record);
  361. }
  362. void MDnsClientImpl::Core::QueryCache(
  363. uint16_t rrtype,
  364. const std::string& name,
  365. std::vector<const RecordParsed*>* records) const {
  366. cache_.FindDnsRecords(rrtype, name, records, clock_->Now());
  367. }
  368. MDnsClientImpl::MDnsClientImpl()
  369. : clock_(base::DefaultClock::GetInstance()),
  370. cleanup_timer_(std::make_unique<base::OneShotTimer>()) {}
  371. MDnsClientImpl::MDnsClientImpl(base::Clock* clock,
  372. std::unique_ptr<base::OneShotTimer> timer)
  373. : clock_(clock), cleanup_timer_(std::move(timer)) {}
  374. MDnsClientImpl::~MDnsClientImpl() {
  375. StopListening();
  376. }
  377. int MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) {
  378. DCHECK(!core_.get());
  379. core_ = std::make_unique<Core>(clock_, cleanup_timer_.get());
  380. int rv = core_->Init(socket_factory);
  381. if (rv != OK) {
  382. DCHECK_NE(ERR_IO_PENDING, rv);
  383. core_.reset();
  384. }
  385. return rv;
  386. }
  387. void MDnsClientImpl::StopListening() {
  388. core_.reset();
  389. }
  390. bool MDnsClientImpl::IsListening() const {
  391. return core_.get() != nullptr;
  392. }
  393. std::unique_ptr<MDnsListener> MDnsClientImpl::CreateListener(
  394. uint16_t rrtype,
  395. const std::string& name,
  396. MDnsListener::Delegate* delegate) {
  397. return std::make_unique<MDnsListenerImpl>(rrtype, name, clock_, delegate,
  398. this);
  399. }
  400. std::unique_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
  401. uint16_t rrtype,
  402. const std::string& name,
  403. int flags,
  404. const MDnsTransaction::ResultCallback& callback) {
  405. return std::make_unique<MDnsTransactionImpl>(rrtype, name, flags, callback,
  406. this);
  407. }
  408. MDnsListenerImpl::MDnsListenerImpl(uint16_t rrtype,
  409. const std::string& name,
  410. base::Clock* clock,
  411. MDnsListener::Delegate* delegate,
  412. MDnsClientImpl* client)
  413. : rrtype_(rrtype),
  414. name_(name),
  415. clock_(clock),
  416. client_(client),
  417. delegate_(delegate) {}
  418. MDnsListenerImpl::~MDnsListenerImpl() {
  419. if (started_) {
  420. DCHECK(client_->core());
  421. client_->core()->RemoveListener(this);
  422. }
  423. }
  424. bool MDnsListenerImpl::Start() {
  425. DCHECK(!started_);
  426. started_ = true;
  427. DCHECK(client_->core());
  428. client_->core()->AddListener(this);
  429. return true;
  430. }
  431. void MDnsListenerImpl::SetActiveRefresh(bool active_refresh) {
  432. active_refresh_ = active_refresh;
  433. if (started_) {
  434. if (!active_refresh_) {
  435. next_refresh_.Cancel();
  436. } else if (last_update_ != base::Time()) {
  437. ScheduleNextRefresh();
  438. }
  439. }
  440. }
  441. const std::string& MDnsListenerImpl::GetName() const {
  442. return name_;
  443. }
  444. uint16_t MDnsListenerImpl::GetType() const {
  445. return rrtype_;
  446. }
  447. void MDnsListenerImpl::HandleRecordUpdate(MDnsCache::UpdateType update_type,
  448. const RecordParsed* record) {
  449. DCHECK(started_);
  450. if (update_type != MDnsCache::RecordRemoved) {
  451. ttl_ = record->ttl();
  452. last_update_ = record->time_created();
  453. ScheduleNextRefresh();
  454. }
  455. if (update_type != MDnsCache::NoChange) {
  456. MDnsListener::UpdateType update_external;
  457. switch (update_type) {
  458. case MDnsCache::RecordAdded:
  459. update_external = MDnsListener::RECORD_ADDED;
  460. break;
  461. case MDnsCache::RecordChanged:
  462. update_external = MDnsListener::RECORD_CHANGED;
  463. break;
  464. case MDnsCache::RecordRemoved:
  465. update_external = MDnsListener::RECORD_REMOVED;
  466. break;
  467. case MDnsCache::NoChange:
  468. default:
  469. NOTREACHED();
  470. // Dummy assignment to suppress compiler warning.
  471. update_external = MDnsListener::RECORD_CHANGED;
  472. break;
  473. }
  474. delegate_->OnRecordUpdate(update_external, record);
  475. }
  476. }
  477. void MDnsListenerImpl::AlertNsecRecord() {
  478. DCHECK(started_);
  479. delegate_->OnNsecRecord(name_, rrtype_);
  480. }
  481. void MDnsListenerImpl::ScheduleNextRefresh() {
  482. DCHECK(last_update_ != base::Time());
  483. if (!active_refresh_)
  484. return;
  485. // A zero TTL is a goodbye packet and should not be refreshed.
  486. if (ttl_ == 0) {
  487. next_refresh_.Cancel();
  488. return;
  489. }
  490. next_refresh_.Reset(
  491. base::BindRepeating(&MDnsListenerImpl::DoRefresh, AsWeakPtr()));
  492. // Schedule refreshes at both 85% and 95% of the original TTL. These will both
  493. // be canceled and rescheduled if the record's TTL is updated due to a
  494. // response being received.
  495. base::Time next_refresh1 =
  496. last_update_ +
  497. base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
  498. kListenerRefreshRatio1 * ttl_));
  499. base::Time next_refresh2 =
  500. last_update_ +
  501. base::Milliseconds(static_cast<int>(base::Time::kMillisecondsPerSecond *
  502. kListenerRefreshRatio2 * ttl_));
  503. base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
  504. FROM_HERE, next_refresh_.callback(), next_refresh1 - clock_->Now());
  505. base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
  506. FROM_HERE, next_refresh_.callback(), next_refresh2 - clock_->Now());
  507. }
  508. void MDnsListenerImpl::DoRefresh() {
  509. client_->core()->SendQuery(rrtype_, name_);
  510. }
  511. MDnsTransactionImpl::MDnsTransactionImpl(
  512. uint16_t rrtype,
  513. const std::string& name,
  514. int flags,
  515. const MDnsTransaction::ResultCallback& callback,
  516. MDnsClientImpl* client)
  517. : rrtype_(rrtype),
  518. name_(name),
  519. callback_(callback),
  520. client_(client),
  521. flags_(flags) {
  522. DCHECK((flags_ & MDnsTransaction::FLAG_MASK) == flags_);
  523. DCHECK(flags_ & MDnsTransaction::QUERY_CACHE ||
  524. flags_ & MDnsTransaction::QUERY_NETWORK);
  525. }
  526. MDnsTransactionImpl::~MDnsTransactionImpl() {
  527. timeout_.Cancel();
  528. }
  529. bool MDnsTransactionImpl::Start() {
  530. DCHECK(!started_);
  531. started_ = true;
  532. base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
  533. if (flags_ & MDnsTransaction::QUERY_CACHE) {
  534. ServeRecordsFromCache();
  535. if (!weak_this || !is_active()) return true;
  536. }
  537. if (flags_ & MDnsTransaction::QUERY_NETWORK) {
  538. return QueryAndListen();
  539. }
  540. // If this is a cache only query, signal that the transaction is over
  541. // immediately.
  542. SignalTransactionOver();
  543. return true;
  544. }
  545. const std::string& MDnsTransactionImpl::GetName() const {
  546. return name_;
  547. }
  548. uint16_t MDnsTransactionImpl::GetType() const {
  549. return rrtype_;
  550. }
  551. void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
  552. DCHECK(started_);
  553. OnRecordUpdate(MDnsListener::RECORD_ADDED, record);
  554. }
  555. void MDnsTransactionImpl::TriggerCallback(MDnsTransaction::Result result,
  556. const RecordParsed* record) {
  557. DCHECK(started_);
  558. if (!is_active()) return;
  559. // Ensure callback is run after touching all class state, so that
  560. // the callback can delete the transaction.
  561. MDnsTransaction::ResultCallback callback = callback_;
  562. // Reset the transaction if it expects a single result, or if the result
  563. // is a final one (everything except for a record).
  564. if (flags_ & MDnsTransaction::SINGLE_RESULT ||
  565. result != MDnsTransaction::RESULT_RECORD) {
  566. Reset();
  567. }
  568. callback.Run(result, record);
  569. }
  570. void MDnsTransactionImpl::Reset() {
  571. callback_.Reset();
  572. listener_.reset();
  573. timeout_.Cancel();
  574. }
  575. void MDnsTransactionImpl::OnRecordUpdate(MDnsListener::UpdateType update,
  576. const RecordParsed* record) {
  577. DCHECK(started_);
  578. if (update == MDnsListener::RECORD_ADDED ||
  579. update == MDnsListener::RECORD_CHANGED)
  580. TriggerCallback(MDnsTransaction::RESULT_RECORD, record);
  581. }
  582. void MDnsTransactionImpl::SignalTransactionOver() {
  583. DCHECK(started_);
  584. if (flags_ & MDnsTransaction::SINGLE_RESULT) {
  585. TriggerCallback(MDnsTransaction::RESULT_NO_RESULTS, nullptr);
  586. } else {
  587. TriggerCallback(MDnsTransaction::RESULT_DONE, nullptr);
  588. }
  589. }
  590. void MDnsTransactionImpl::ServeRecordsFromCache() {
  591. std::vector<const RecordParsed*> records;
  592. base::WeakPtr<MDnsTransactionImpl> weak_this = AsWeakPtr();
  593. if (client_->core()) {
  594. client_->core()->QueryCache(rrtype_, name_, &records);
  595. for (auto i = records.begin(); i != records.end() && weak_this; ++i) {
  596. weak_this->TriggerCallback(MDnsTransaction::RESULT_RECORD, *i);
  597. }
  598. #if defined(ENABLE_NSEC)
  599. if (records.empty()) {
  600. DCHECK(weak_this);
  601. client_->core()->QueryCache(dns_protocol::kTypeNSEC, name_, &records);
  602. if (!records.empty()) {
  603. const NsecRecordRdata* rdata =
  604. records.front()->rdata<NsecRecordRdata>();
  605. DCHECK(rdata);
  606. if (!rdata->GetBit(rrtype_))
  607. weak_this->TriggerCallback(MDnsTransaction::RESULT_NSEC, nullptr);
  608. }
  609. }
  610. #endif
  611. }
  612. }
  613. bool MDnsTransactionImpl::QueryAndListen() {
  614. listener_ = client_->CreateListener(rrtype_, name_, this);
  615. if (!listener_->Start())
  616. return false;
  617. DCHECK(client_->core());
  618. if (!client_->core()->SendQuery(rrtype_, name_))
  619. return false;
  620. timeout_.Reset(
  621. base::BindOnce(&MDnsTransactionImpl::SignalTransactionOver, AsWeakPtr()));
  622. base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
  623. FROM_HERE, timeout_.callback(), kTransactionTimeout);
  624. return true;
  625. }
  626. void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
  627. TriggerCallback(RESULT_NSEC, nullptr);
  628. }
  629. void MDnsTransactionImpl::OnCachePurged() {
  630. // TODO(noamsml): Cache purge situations not yet implemented
  631. }
  632. } // namespace net