dispatch.cc 19 KB

  1. // Copyright 2020 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "dispatch.h"
  5. #include <cassert>
  6. #include "cbor.h"
  7. #include "error_support.h"
  8. #include "find_by_first.h"
  9. #include "frontend_channel.h"
  10. #include "protocol_core.h"
  11. namespace crdtp {
  12. // =============================================================================
  13. // DispatchResponse - Error status and chaining / fall through
  14. // =============================================================================
  15. // static
  16. DispatchResponse DispatchResponse::Success() {
  17. DispatchResponse result;
  18. result.code_ = DispatchCode::SUCCESS;
  19. return result;
  20. }
  21. // static
  22. DispatchResponse DispatchResponse::FallThrough() {
  23. DispatchResponse result;
  24. result.code_ = DispatchCode::FALL_THROUGH;
  25. return result;
  26. }
  27. // static
  28. DispatchResponse DispatchResponse::ParseError(std::string message) {
  29. DispatchResponse result;
  30. result.code_ = DispatchCode::PARSE_ERROR;
  31. result.message_ = std::move(message);
  32. return result;
  33. }
  34. // static
  35. DispatchResponse DispatchResponse::InvalidRequest(std::string message) {
  36. DispatchResponse result;
  37. result.code_ = DispatchCode::INVALID_REQUEST;
  38. result.message_ = std::move(message);
  39. return result;
  40. }
  41. // static
  42. DispatchResponse DispatchResponse::MethodNotFound(std::string message) {
  43. DispatchResponse result;
  44. result.code_ = DispatchCode::METHOD_NOT_FOUND;
  45. result.message_ = std::move(message);
  46. return result;
  47. }
  48. // static
  49. DispatchResponse DispatchResponse::InvalidParams(std::string message) {
  50. DispatchResponse result;
  51. result.code_ = DispatchCode::INVALID_PARAMS;
  52. result.message_ = std::move(message);
  53. return result;
  54. }
  55. // static
  56. DispatchResponse DispatchResponse::InternalError() {
  57. DispatchResponse result;
  58. result.code_ = DispatchCode::INTERNAL_ERROR;
  59. result.message_ = "Internal error";
  60. return result;
  61. }
  62. // static
  63. DispatchResponse DispatchResponse::ServerError(std::string message) {
  64. DispatchResponse result;
  65. result.code_ = DispatchCode::SERVER_ERROR;
  66. result.message_ = std::move(message);
  67. return result;
  68. }
  69. // static
  70. DispatchResponse DispatchResponse::SessionNotFound(std::string message) {
  71. DispatchResponse result;
  72. result.code_ = DispatchCode::SESSION_NOT_FOUND;
  73. result.message_ = std::move(message);
  74. return result;
  75. }
  76. // =============================================================================
  77. // Dispatchable - a shallow parser for CBOR encoded DevTools messages
  78. // =============================================================================
  79. Dispatchable::Dispatchable(span<uint8_t> serialized) : serialized_(serialized) {
  80. Status s = cbor::CheckCBORMessage(serialized);
  81. if (!s.ok()) {
  82. status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, s.pos};
  83. return;
  84. }
  85. cbor::CBORTokenizer tokenizer(serialized);
  86. if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
  87. status_ = tokenizer.Status();
  88. return;
  89. }
  90. // We checked for the envelope start byte above, so the tokenizer
  91. // must agree here, since it's not an error.
  92. assert(tokenizer.TokenTag() == cbor::CBORTokenTag::ENVELOPE);
  93. // Before we enter the envelope, we save the position that we
  94. // expect to see after we're done parsing the envelope contents.
  95. // This way we can compare and produce an error if the contents
  96. // didn't fit exactly into the envelope length.
  97. const size_t pos_past_envelope =
  98. tokenizer.Status().pos + tokenizer.GetEnvelopeHeader().outer_size();
  99. tokenizer.EnterEnvelope();
  100. if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
  101. status_ = tokenizer.Status();
  102. return;
  103. }
  104. if (tokenizer.TokenTag() != cbor::CBORTokenTag::MAP_START) {
  105. status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, tokenizer.Status().pos};
  106. return;
  107. }
  108. assert(tokenizer.TokenTag() == cbor::CBORTokenTag::MAP_START);
  109. tokenizer.Next(); // Now we should be pointed at the map key.
  110. while (tokenizer.TokenTag() != cbor::CBORTokenTag::STOP) {
  111. switch (tokenizer.TokenTag()) {
  112. case cbor::CBORTokenTag::DONE:
  113. status_ =
  114. Status{Error::CBOR_UNEXPECTED_EOF_IN_MAP, tokenizer.Status().pos};
  115. return;
  116. case cbor::CBORTokenTag::ERROR_VALUE:
  117. status_ = tokenizer.Status();
  118. return;
  119. case cbor::CBORTokenTag::STRING8:
  120. if (!MaybeParseProperty(&tokenizer))
  121. return;
  122. break;
  123. default:
  124. // We require the top-level keys to be UTF8 (US-ASCII in practice).
  125. status_ = Status{Error::CBOR_INVALID_MAP_KEY, tokenizer.Status().pos};
  126. return;
  127. }
  128. }
  129. tokenizer.Next();
  130. if (!has_call_id_) {
  131. status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
  132. tokenizer.Status().pos};
  133. return;
  134. }
  135. if (method_.empty()) {
  137. tokenizer.Status().pos};
  138. return;
  139. }
  140. // The contents of the envelope parsed OK, now check that we're at
  141. // the expected position.
  142. if (pos_past_envelope != tokenizer.Status().pos) {
  144. tokenizer.Status().pos};
  145. return;
  146. }
  147. if (tokenizer.TokenTag() != cbor::CBORTokenTag::DONE) {
  148. status_ = Status{Error::CBOR_TRAILING_JUNK, tokenizer.Status().pos};
  149. return;
  150. }
  151. }
  152. bool Dispatchable::ok() const {
  153. return status_.ok();
  154. }
  155. DispatchResponse Dispatchable::DispatchError() const {
  156. // TODO(johannes): Replace with DCHECK / similar?
  157. if (status_.ok())
  158. return DispatchResponse::Success();
  159. if (status_.IsMessageError())
  160. return DispatchResponse::InvalidRequest(status_.Message());
  161. return DispatchResponse::ParseError(status_.ToASCIIString());
  162. }
  163. bool Dispatchable::MaybeParseProperty(cbor::CBORTokenizer* tokenizer) {
  164. span<uint8_t> property_name = tokenizer->GetString8();
  165. if (SpanEquals(SpanFrom("id"), property_name))
  166. return MaybeParseCallId(tokenizer);
  167. if (SpanEquals(SpanFrom("method"), property_name))
  168. return MaybeParseMethod(tokenizer);
  169. if (SpanEquals(SpanFrom("params"), property_name))
  170. return MaybeParseParams(tokenizer);
  171. if (SpanEquals(SpanFrom("sessionId"), property_name))
  172. return MaybeParseSessionId(tokenizer);
  173. status_ =
  174. Status{Error::MESSAGE_HAS_UNKNOWN_PROPERTY, tokenizer->Status().pos};
  175. return false;
  176. }
  177. bool Dispatchable::MaybeParseCallId(cbor::CBORTokenizer* tokenizer) {
  178. if (has_call_id_) {
  179. status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
  180. return false;
  181. }
  182. tokenizer->Next();
  183. if (tokenizer->TokenTag() != cbor::CBORTokenTag::INT32) {
  184. status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
  185. tokenizer->Status().pos};
  186. return false;
  187. }
  188. call_id_ = tokenizer->GetInt32();
  189. has_call_id_ = true;
  190. tokenizer->Next();
  191. return true;
  192. }
  193. bool Dispatchable::MaybeParseMethod(cbor::CBORTokenizer* tokenizer) {
  194. if (!method_.empty()) {
  195. status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
  196. return false;
  197. }
  198. tokenizer->Next();
  199. if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
  201. tokenizer->Status().pos};
  202. return false;
  203. }
  204. method_ = tokenizer->GetString8();
  205. tokenizer->Next();
  206. return true;
  207. }
  208. bool Dispatchable::MaybeParseParams(cbor::CBORTokenizer* tokenizer) {
  209. if (params_seen_) {
  210. status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
  211. return false;
  212. }
  213. params_seen_ = true;
  214. tokenizer->Next();
  215. if (tokenizer->TokenTag() == cbor::CBORTokenTag::NULL_VALUE) {
  216. tokenizer->Next();
  217. return true;
  218. }
  219. if (tokenizer->TokenTag() != cbor::CBORTokenTag::ENVELOPE) {
  221. tokenizer->Status().pos};
  222. return false;
  223. }
  224. params_ = tokenizer->GetEnvelope();
  225. tokenizer->Next();
  226. return true;
  227. }
  228. bool Dispatchable::MaybeParseSessionId(cbor::CBORTokenizer* tokenizer) {
  229. if (!session_id_.empty()) {
  230. status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
  231. return false;
  232. }
  233. tokenizer->Next();
  234. if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
  236. tokenizer->Status().pos};
  237. return false;
  238. }
  239. session_id_ = tokenizer->GetString8();
  240. tokenizer->Next();
  241. return true;
  242. }
  243. namespace {
  244. class ProtocolError : public Serializable {
  245. public:
  246. explicit ProtocolError(DispatchResponse dispatch_response)
  247. : dispatch_response_(std::move(dispatch_response)) {}
  248. void AppendSerialized(std::vector<uint8_t>* out) const override {
  249. Status status;
  250. std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
  251. encoder->HandleMapBegin();
  252. if (has_call_id_) {
  253. encoder->HandleString8(SpanFrom("id"));
  254. encoder->HandleInt32(call_id_);
  255. }
  256. encoder->HandleString8(SpanFrom("error"));
  257. encoder->HandleMapBegin();
  258. encoder->HandleString8(SpanFrom("code"));
  259. encoder->HandleInt32(static_cast<int32_t>(dispatch_response_.Code()));
  260. encoder->HandleString8(SpanFrom("message"));
  261. encoder->HandleString8(SpanFrom(dispatch_response_.Message()));
  262. if (!data_.empty()) {
  263. encoder->HandleString8(SpanFrom("data"));
  264. encoder->HandleString8(SpanFrom(data_));
  265. }
  266. encoder->HandleMapEnd();
  267. encoder->HandleMapEnd();
  268. assert(status.ok());
  269. }
  270. void SetCallId(int call_id) {
  271. has_call_id_ = true;
  272. call_id_ = call_id;
  273. }
  274. void SetData(std::string data) { data_ = std::move(data); }
  275. private:
  276. const DispatchResponse dispatch_response_;
  277. std::string data_;
  278. int call_id_ = 0;
  279. bool has_call_id_ = false;
  280. };
  281. } // namespace
  282. // =============================================================================
  283. // Helpers for creating protocol cresponses and notifications.
  284. // =============================================================================
  285. std::unique_ptr<Serializable> CreateErrorResponse(
  286. int call_id,
  287. DispatchResponse dispatch_response) {
  288. auto protocol_error =
  289. std::make_unique<ProtocolError>(std::move(dispatch_response));
  290. protocol_error->SetCallId(call_id);
  291. return protocol_error;
  292. }
  293. std::unique_ptr<Serializable> CreateErrorResponse(
  294. int call_id,
  295. DispatchResponse dispatch_response,
  296. const DeserializerState& state) {
  297. auto protocol_error =
  298. std::make_unique<ProtocolError>(std::move(dispatch_response));
  299. protocol_error->SetCallId(call_id);
  300. // TODO(caseq): should we plumb the call name here?
  301. protocol_error->SetData(state.ErrorMessage(MakeSpan("params")));
  302. return protocol_error;
  303. }
  304. std::unique_ptr<Serializable> CreateErrorNotification(
  305. DispatchResponse dispatch_response) {
  306. return std::make_unique<ProtocolError>(std::move(dispatch_response));
  307. }
  308. namespace {
  309. class Response : public Serializable {
  310. public:
  311. Response(int call_id, std::unique_ptr<Serializable> params)
  312. : call_id_(call_id), params_(std::move(params)) {}
  313. void AppendSerialized(std::vector<uint8_t>* out) const override {
  314. Status status;
  315. std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
  316. encoder->HandleMapBegin();
  317. encoder->HandleString8(SpanFrom("id"));
  318. encoder->HandleInt32(call_id_);
  319. encoder->HandleString8(SpanFrom("result"));
  320. if (params_) {
  321. params_->AppendSerialized(out);
  322. } else {
  323. encoder->HandleMapBegin();
  324. encoder->HandleMapEnd();
  325. }
  326. encoder->HandleMapEnd();
  327. assert(status.ok());
  328. }
  329. private:
  330. const int call_id_;
  331. std::unique_ptr<Serializable> params_;
  332. };
  333. class Notification : public Serializable {
  334. public:
  335. Notification(const char* method, std::unique_ptr<Serializable> params)
  336. : method_(method), params_(std::move(params)) {}
  337. void AppendSerialized(std::vector<uint8_t>* out) const override {
  338. Status status;
  339. std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
  340. encoder->HandleMapBegin();
  341. encoder->HandleString8(SpanFrom("method"));
  342. encoder->HandleString8(SpanFrom(method_));
  343. encoder->HandleString8(SpanFrom("params"));
  344. if (params_) {
  345. params_->AppendSerialized(out);
  346. } else {
  347. encoder->HandleMapBegin();
  348. encoder->HandleMapEnd();
  349. }
  350. encoder->HandleMapEnd();
  351. assert(status.ok());
  352. }
  353. private:
  354. const char* method_;
  355. std::unique_ptr<Serializable> params_;
  356. };
  357. } // namespace
  358. std::unique_ptr<Serializable> CreateResponse(
  359. int call_id,
  360. std::unique_ptr<Serializable> params) {
  361. return std::make_unique<Response>(call_id, std::move(params));
  362. }
  363. std::unique_ptr<Serializable> CreateNotification(
  364. const char* method,
  365. std::unique_ptr<Serializable> params) {
  366. return std::make_unique<Notification>(method, std::move(params));
  367. }
  368. // =============================================================================
  369. // DomainDispatcher - Dispatching betwen protocol methods within a domain.
  370. // =============================================================================
  371. DomainDispatcher::WeakPtr::WeakPtr(DomainDispatcher* dispatcher)
  372. : dispatcher_(dispatcher) {}
  373. DomainDispatcher::WeakPtr::~WeakPtr() {
  374. if (dispatcher_)
  375. dispatcher_->weak_ptrs_.erase(this);
  376. }
  377. DomainDispatcher::Callback::~Callback() = default;
  378. void DomainDispatcher::Callback::dispose() {
  379. backend_impl_ = nullptr;
  380. }
  381. DomainDispatcher::Callback::Callback(
  382. std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,
  383. int call_id,
  384. span<uint8_t> method,
  385. span<uint8_t> message)
  386. : backend_impl_(std::move(backend_impl)),
  387. call_id_(call_id),
  388. method_(method),
  389. message_(message.begin(), message.end()) {}
  390. void DomainDispatcher::Callback::sendIfActive(
  391. std::unique_ptr<Serializable> partialMessage,
  392. const DispatchResponse& response) {
  393. if (!backend_impl_ || !backend_impl_->get())
  394. return;
  395. backend_impl_->get()->sendResponse(call_id_, response,
  396. std::move(partialMessage));
  397. backend_impl_ = nullptr;
  398. }
  399. void DomainDispatcher::Callback::fallThroughIfActive() {
  400. if (!backend_impl_ || !backend_impl_->get())
  401. return;
  402. backend_impl_->get()->channel()->FallThrough(call_id_, method_,
  403. SpanFrom(message_));
  404. backend_impl_ = nullptr;
  405. }
  406. DomainDispatcher::DomainDispatcher(FrontendChannel* frontendChannel)
  407. : frontend_channel_(frontendChannel) {}
  408. DomainDispatcher::~DomainDispatcher() {
  409. clearFrontend();
  410. }
  411. void DomainDispatcher::sendResponse(int call_id,
  412. const DispatchResponse& response,
  413. std::unique_ptr<Serializable> result) {
  414. if (!frontend_channel_)
  415. return;
  416. std::unique_ptr<Serializable> serializable;
  417. if (response.IsError()) {
  418. serializable = CreateErrorResponse(call_id, response);
  419. } else {
  420. serializable = CreateResponse(call_id, std::move(result));
  421. }
  422. frontend_channel_->SendProtocolResponse(call_id, std::move(serializable));
  423. }
  424. void DomainDispatcher::ReportInvalidParams(const Dispatchable& dispatchable,
  425. const DeserializerState& state) {
  426. assert(!state.status().ok());
  427. if (frontend_channel_) {
  428. frontend_channel_->SendProtocolResponse(
  429. dispatchable.CallId(),
  430. CreateErrorResponse(
  431. dispatchable.CallId(),
  432. DispatchResponse::InvalidParams("Invalid parameters"), state));
  433. }
  434. }
  435. void DomainDispatcher::clearFrontend() {
  436. frontend_channel_ = nullptr;
  437. for (auto& weak : weak_ptrs_)
  438. weak->dispose();
  439. weak_ptrs_.clear();
  440. }
  441. std::unique_ptr<DomainDispatcher::WeakPtr> DomainDispatcher::weakPtr() {
  442. auto weak = std::make_unique<DomainDispatcher::WeakPtr>(this);
  443. weak_ptrs_.insert(weak.get());
  444. return weak;
  445. }
  446. // =============================================================================
  447. // UberDispatcher - dispatches between domains (backends).
  448. // =============================================================================
  449. UberDispatcher::DispatchResult::DispatchResult(bool method_found,
  450. std::function<void()> runnable)
  451. : method_found_(method_found), runnable_(runnable) {}
  452. void UberDispatcher::DispatchResult::Run() {
  453. if (!runnable_)
  454. return;
  455. runnable_();
  456. runnable_ = nullptr;
  457. }
  458. UberDispatcher::UberDispatcher(FrontendChannel* frontend_channel)
  459. : frontend_channel_(frontend_channel) {
  460. assert(frontend_channel);
  461. }
  462. UberDispatcher::~UberDispatcher() = default;
  463. constexpr size_t kNotFound = std::numeric_limits<size_t>::max();
  464. namespace {
  465. size_t DotIdx(span<uint8_t> method) {
  466. const void* p = memchr(method.data(), '.', method.size());
  467. return p ? reinterpret_cast<const uint8_t*>(p) - method.data() : kNotFound;
  468. }
  469. } // namespace
  470. UberDispatcher::DispatchResult UberDispatcher::Dispatch(
  471. const Dispatchable& dispatchable) const {
  472. span<uint8_t> method = FindByFirst(redirects_, dispatchable.Method(),
  473. /*default_value=*/dispatchable.Method());
  474. size_t dot_idx = DotIdx(method);
  475. if (dot_idx != kNotFound) {
  476. span<uint8_t> domain = method.subspan(0, dot_idx);
  477. span<uint8_t> command = method.subspan(dot_idx + 1);
  478. DomainDispatcher* dispatcher = FindByFirst(dispatchers_, domain);
  479. if (dispatcher) {
  480. std::function<void(const Dispatchable&)> dispatched =
  481. dispatcher->Dispatch(command);
  482. if (dispatched) {
  483. return DispatchResult(
  484. true, [dispatchable, dispatched = std::move(dispatched)]() {
  485. dispatched(dispatchable);
  486. });
  487. }
  488. }
  489. }
  490. return DispatchResult(false, [this, dispatchable]() {
  491. frontend_channel_->SendProtocolResponse(
  492. dispatchable.CallId(),
  493. CreateErrorResponse(dispatchable.CallId(),
  494. DispatchResponse::MethodNotFound(
  495. "'" +
  496. std::string(dispatchable.Method().begin(),
  497. dispatchable.Method().end()) +
  498. "' wasn't found")));
  499. });
  500. }
  501. template <typename T>
  502. struct FirstLessThan {
  503. bool operator()(const std::pair<span<uint8_t>, T>& left,
  504. const std::pair<span<uint8_t>, T>& right) {
  505. return SpanLessThan(left.first, right.first);
  506. }
  507. };
  508. void UberDispatcher::WireBackend(
  509. span<uint8_t> domain,
  510. const std::vector<std::pair<span<uint8_t>, span<uint8_t>>>&
  511. sorted_redirects,
  512. std::unique_ptr<DomainDispatcher> dispatcher) {
  513. auto it = redirects_.insert(redirects_.end(), sorted_redirects.begin(),
  514. sorted_redirects.end());
  515. std::inplace_merge(redirects_.begin(), it, redirects_.end(),
  516. FirstLessThan<span<uint8_t>>());
  517. auto jt = dispatchers_.insert(dispatchers_.end(),
  518. std::make_pair(domain, std::move(dispatcher)));
  519. std::inplace_merge(dispatchers_.begin(), jt, dispatchers_.end(),
  520. FirstLessThan<std::unique_ptr<DomainDispatcher>>());
  521. }
  522. } // namespace crdtp