123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584 |
- // Copyright 2020 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "dispatch.h"
- #include <cassert>
- #include "cbor.h"
- #include "error_support.h"
- #include "find_by_first.h"
- #include "frontend_channel.h"
- #include "protocol_core.h"
- namespace crdtp {
- // =============================================================================
- // DispatchResponse - Error status and chaining / fall through
- // =============================================================================
- // static
- DispatchResponse DispatchResponse::Success() {
- DispatchResponse result;
- result.code_ = DispatchCode::SUCCESS;
- return result;
- }
- // static
- DispatchResponse DispatchResponse::FallThrough() {
- DispatchResponse result;
- result.code_ = DispatchCode::FALL_THROUGH;
- return result;
- }
- // static
- DispatchResponse DispatchResponse::ParseError(std::string message) {
- DispatchResponse result;
- result.code_ = DispatchCode::PARSE_ERROR;
- result.message_ = std::move(message);
- return result;
- }
- // static
- DispatchResponse DispatchResponse::InvalidRequest(std::string message) {
- DispatchResponse result;
- result.code_ = DispatchCode::INVALID_REQUEST;
- result.message_ = std::move(message);
- return result;
- }
- // static
- DispatchResponse DispatchResponse::MethodNotFound(std::string message) {
- DispatchResponse result;
- result.code_ = DispatchCode::METHOD_NOT_FOUND;
- result.message_ = std::move(message);
- return result;
- }
- // static
- DispatchResponse DispatchResponse::InvalidParams(std::string message) {
- DispatchResponse result;
- result.code_ = DispatchCode::INVALID_PARAMS;
- result.message_ = std::move(message);
- return result;
- }
- // static
- DispatchResponse DispatchResponse::InternalError() {
- DispatchResponse result;
- result.code_ = DispatchCode::INTERNAL_ERROR;
- result.message_ = "Internal error";
- return result;
- }
- // static
- DispatchResponse DispatchResponse::ServerError(std::string message) {
- DispatchResponse result;
- result.code_ = DispatchCode::SERVER_ERROR;
- result.message_ = std::move(message);
- return result;
- }
- // static
- DispatchResponse DispatchResponse::SessionNotFound(std::string message) {
- DispatchResponse result;
- result.code_ = DispatchCode::SESSION_NOT_FOUND;
- result.message_ = std::move(message);
- return result;
- }
- // =============================================================================
- // Dispatchable - a shallow parser for CBOR encoded DevTools messages
- // =============================================================================
- Dispatchable::Dispatchable(span<uint8_t> serialized) : serialized_(serialized) {
- Status s = cbor::CheckCBORMessage(serialized);
- if (!s.ok()) {
- status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, s.pos};
- return;
- }
- cbor::CBORTokenizer tokenizer(serialized);
- if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
- status_ = tokenizer.Status();
- return;
- }
- // We checked for the envelope start byte above, so the tokenizer
- // must agree here, since it's not an error.
- assert(tokenizer.TokenTag() == cbor::CBORTokenTag::ENVELOPE);
- // Before we enter the envelope, we save the position that we
- // expect to see after we're done parsing the envelope contents.
- // This way we can compare and produce an error if the contents
- // didn't fit exactly into the envelope length.
- const size_t pos_past_envelope =
- tokenizer.Status().pos + tokenizer.GetEnvelopeHeader().outer_size();
- tokenizer.EnterEnvelope();
- if (tokenizer.TokenTag() == cbor::CBORTokenTag::ERROR_VALUE) {
- status_ = tokenizer.Status();
- return;
- }
- if (tokenizer.TokenTag() != cbor::CBORTokenTag::MAP_START) {
- status_ = {Error::MESSAGE_MUST_BE_AN_OBJECT, tokenizer.Status().pos};
- return;
- }
- assert(tokenizer.TokenTag() == cbor::CBORTokenTag::MAP_START);
- tokenizer.Next(); // Now we should be pointed at the map key.
- while (tokenizer.TokenTag() != cbor::CBORTokenTag::STOP) {
- switch (tokenizer.TokenTag()) {
- case cbor::CBORTokenTag::DONE:
- status_ =
- Status{Error::CBOR_UNEXPECTED_EOF_IN_MAP, tokenizer.Status().pos};
- return;
- case cbor::CBORTokenTag::ERROR_VALUE:
- status_ = tokenizer.Status();
- return;
- case cbor::CBORTokenTag::STRING8:
- if (!MaybeParseProperty(&tokenizer))
- return;
- break;
- default:
- // We require the top-level keys to be UTF8 (US-ASCII in practice).
- status_ = Status{Error::CBOR_INVALID_MAP_KEY, tokenizer.Status().pos};
- return;
- }
- }
- tokenizer.Next();
- if (!has_call_id_) {
- status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
- tokenizer.Status().pos};
- return;
- }
- if (method_.empty()) {
- status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
- tokenizer.Status().pos};
- return;
- }
- // The contents of the envelope parsed OK, now check that we're at
- // the expected position.
- if (pos_past_envelope != tokenizer.Status().pos) {
- status_ = Status{Error::CBOR_ENVELOPE_CONTENTS_LENGTH_MISMATCH,
- tokenizer.Status().pos};
- return;
- }
- if (tokenizer.TokenTag() != cbor::CBORTokenTag::DONE) {
- status_ = Status{Error::CBOR_TRAILING_JUNK, tokenizer.Status().pos};
- return;
- }
- }
- bool Dispatchable::ok() const {
- return status_.ok();
- }
- DispatchResponse Dispatchable::DispatchError() const {
- // TODO(johannes): Replace with DCHECK / similar?
- if (status_.ok())
- return DispatchResponse::Success();
- if (status_.IsMessageError())
- return DispatchResponse::InvalidRequest(status_.Message());
- return DispatchResponse::ParseError(status_.ToASCIIString());
- }
- bool Dispatchable::MaybeParseProperty(cbor::CBORTokenizer* tokenizer) {
- span<uint8_t> property_name = tokenizer->GetString8();
- if (SpanEquals(SpanFrom("id"), property_name))
- return MaybeParseCallId(tokenizer);
- if (SpanEquals(SpanFrom("method"), property_name))
- return MaybeParseMethod(tokenizer);
- if (SpanEquals(SpanFrom("params"), property_name))
- return MaybeParseParams(tokenizer);
- if (SpanEquals(SpanFrom("sessionId"), property_name))
- return MaybeParseSessionId(tokenizer);
- status_ =
- Status{Error::MESSAGE_HAS_UNKNOWN_PROPERTY, tokenizer->Status().pos};
- return false;
- }
- bool Dispatchable::MaybeParseCallId(cbor::CBORTokenizer* tokenizer) {
- if (has_call_id_) {
- status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
- return false;
- }
- tokenizer->Next();
- if (tokenizer->TokenTag() != cbor::CBORTokenTag::INT32) {
- status_ = Status{Error::MESSAGE_MUST_HAVE_INTEGER_ID_PROPERTY,
- tokenizer->Status().pos};
- return false;
- }
- call_id_ = tokenizer->GetInt32();
- has_call_id_ = true;
- tokenizer->Next();
- return true;
- }
- bool Dispatchable::MaybeParseMethod(cbor::CBORTokenizer* tokenizer) {
- if (!method_.empty()) {
- status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
- return false;
- }
- tokenizer->Next();
- if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
- status_ = Status{Error::MESSAGE_MUST_HAVE_STRING_METHOD_PROPERTY,
- tokenizer->Status().pos};
- return false;
- }
- method_ = tokenizer->GetString8();
- tokenizer->Next();
- return true;
- }
- bool Dispatchable::MaybeParseParams(cbor::CBORTokenizer* tokenizer) {
- if (params_seen_) {
- status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
- return false;
- }
- params_seen_ = true;
- tokenizer->Next();
- if (tokenizer->TokenTag() == cbor::CBORTokenTag::NULL_VALUE) {
- tokenizer->Next();
- return true;
- }
- if (tokenizer->TokenTag() != cbor::CBORTokenTag::ENVELOPE) {
- status_ = Status{Error::MESSAGE_MAY_HAVE_OBJECT_PARAMS_PROPERTY,
- tokenizer->Status().pos};
- return false;
- }
- params_ = tokenizer->GetEnvelope();
- tokenizer->Next();
- return true;
- }
- bool Dispatchable::MaybeParseSessionId(cbor::CBORTokenizer* tokenizer) {
- if (!session_id_.empty()) {
- status_ = Status{Error::CBOR_DUPLICATE_MAP_KEY, tokenizer->Status().pos};
- return false;
- }
- tokenizer->Next();
- if (tokenizer->TokenTag() != cbor::CBORTokenTag::STRING8) {
- status_ = Status{Error::MESSAGE_MAY_HAVE_STRING_SESSION_ID_PROPERTY,
- tokenizer->Status().pos};
- return false;
- }
- session_id_ = tokenizer->GetString8();
- tokenizer->Next();
- return true;
- }
- namespace {
- class ProtocolError : public Serializable {
- public:
- explicit ProtocolError(DispatchResponse dispatch_response)
- : dispatch_response_(std::move(dispatch_response)) {}
- void AppendSerialized(std::vector<uint8_t>* out) const override {
- Status status;
- std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
- encoder->HandleMapBegin();
- if (has_call_id_) {
- encoder->HandleString8(SpanFrom("id"));
- encoder->HandleInt32(call_id_);
- }
- encoder->HandleString8(SpanFrom("error"));
- encoder->HandleMapBegin();
- encoder->HandleString8(SpanFrom("code"));
- encoder->HandleInt32(static_cast<int32_t>(dispatch_response_.Code()));
- encoder->HandleString8(SpanFrom("message"));
- encoder->HandleString8(SpanFrom(dispatch_response_.Message()));
- if (!data_.empty()) {
- encoder->HandleString8(SpanFrom("data"));
- encoder->HandleString8(SpanFrom(data_));
- }
- encoder->HandleMapEnd();
- encoder->HandleMapEnd();
- assert(status.ok());
- }
- void SetCallId(int call_id) {
- has_call_id_ = true;
- call_id_ = call_id;
- }
- void SetData(std::string data) { data_ = std::move(data); }
- private:
- const DispatchResponse dispatch_response_;
- std::string data_;
- int call_id_ = 0;
- bool has_call_id_ = false;
- };
- } // namespace
- // =============================================================================
- // Helpers for creating protocol cresponses and notifications.
- // =============================================================================
- std::unique_ptr<Serializable> CreateErrorResponse(
- int call_id,
- DispatchResponse dispatch_response) {
- auto protocol_error =
- std::make_unique<ProtocolError>(std::move(dispatch_response));
- protocol_error->SetCallId(call_id);
- return protocol_error;
- }
- std::unique_ptr<Serializable> CreateErrorResponse(
- int call_id,
- DispatchResponse dispatch_response,
- const DeserializerState& state) {
- auto protocol_error =
- std::make_unique<ProtocolError>(std::move(dispatch_response));
- protocol_error->SetCallId(call_id);
- // TODO(caseq): should we plumb the call name here?
- protocol_error->SetData(state.ErrorMessage(MakeSpan("params")));
- return protocol_error;
- }
- std::unique_ptr<Serializable> CreateErrorNotification(
- DispatchResponse dispatch_response) {
- return std::make_unique<ProtocolError>(std::move(dispatch_response));
- }
- namespace {
- class Response : public Serializable {
- public:
- Response(int call_id, std::unique_ptr<Serializable> params)
- : call_id_(call_id), params_(std::move(params)) {}
- void AppendSerialized(std::vector<uint8_t>* out) const override {
- Status status;
- std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
- encoder->HandleMapBegin();
- encoder->HandleString8(SpanFrom("id"));
- encoder->HandleInt32(call_id_);
- encoder->HandleString8(SpanFrom("result"));
- if (params_) {
- params_->AppendSerialized(out);
- } else {
- encoder->HandleMapBegin();
- encoder->HandleMapEnd();
- }
- encoder->HandleMapEnd();
- assert(status.ok());
- }
- private:
- const int call_id_;
- std::unique_ptr<Serializable> params_;
- };
- class Notification : public Serializable {
- public:
- Notification(const char* method, std::unique_ptr<Serializable> params)
- : method_(method), params_(std::move(params)) {}
- void AppendSerialized(std::vector<uint8_t>* out) const override {
- Status status;
- std::unique_ptr<ParserHandler> encoder = cbor::NewCBOREncoder(out, &status);
- encoder->HandleMapBegin();
- encoder->HandleString8(SpanFrom("method"));
- encoder->HandleString8(SpanFrom(method_));
- encoder->HandleString8(SpanFrom("params"));
- if (params_) {
- params_->AppendSerialized(out);
- } else {
- encoder->HandleMapBegin();
- encoder->HandleMapEnd();
- }
- encoder->HandleMapEnd();
- assert(status.ok());
- }
- private:
- const char* method_;
- std::unique_ptr<Serializable> params_;
- };
- } // namespace
- std::unique_ptr<Serializable> CreateResponse(
- int call_id,
- std::unique_ptr<Serializable> params) {
- return std::make_unique<Response>(call_id, std::move(params));
- }
- std::unique_ptr<Serializable> CreateNotification(
- const char* method,
- std::unique_ptr<Serializable> params) {
- return std::make_unique<Notification>(method, std::move(params));
- }
- // =============================================================================
- // DomainDispatcher - Dispatching betwen protocol methods within a domain.
- // =============================================================================
- DomainDispatcher::WeakPtr::WeakPtr(DomainDispatcher* dispatcher)
- : dispatcher_(dispatcher) {}
- DomainDispatcher::WeakPtr::~WeakPtr() {
- if (dispatcher_)
- dispatcher_->weak_ptrs_.erase(this);
- }
- DomainDispatcher::Callback::~Callback() = default;
- void DomainDispatcher::Callback::dispose() {
- backend_impl_ = nullptr;
- }
- DomainDispatcher::Callback::Callback(
- std::unique_ptr<DomainDispatcher::WeakPtr> backend_impl,
- int call_id,
- span<uint8_t> method,
- span<uint8_t> message)
- : backend_impl_(std::move(backend_impl)),
- call_id_(call_id),
- method_(method),
- message_(message.begin(), message.end()) {}
- void DomainDispatcher::Callback::sendIfActive(
- std::unique_ptr<Serializable> partialMessage,
- const DispatchResponse& response) {
- if (!backend_impl_ || !backend_impl_->get())
- return;
- backend_impl_->get()->sendResponse(call_id_, response,
- std::move(partialMessage));
- backend_impl_ = nullptr;
- }
- void DomainDispatcher::Callback::fallThroughIfActive() {
- if (!backend_impl_ || !backend_impl_->get())
- return;
- backend_impl_->get()->channel()->FallThrough(call_id_, method_,
- SpanFrom(message_));
- backend_impl_ = nullptr;
- }
- DomainDispatcher::DomainDispatcher(FrontendChannel* frontendChannel)
- : frontend_channel_(frontendChannel) {}
- DomainDispatcher::~DomainDispatcher() {
- clearFrontend();
- }
- void DomainDispatcher::sendResponse(int call_id,
- const DispatchResponse& response,
- std::unique_ptr<Serializable> result) {
- if (!frontend_channel_)
- return;
- std::unique_ptr<Serializable> serializable;
- if (response.IsError()) {
- serializable = CreateErrorResponse(call_id, response);
- } else {
- serializable = CreateResponse(call_id, std::move(result));
- }
- frontend_channel_->SendProtocolResponse(call_id, std::move(serializable));
- }
- void DomainDispatcher::ReportInvalidParams(const Dispatchable& dispatchable,
- const DeserializerState& state) {
- assert(!state.status().ok());
- if (frontend_channel_) {
- frontend_channel_->SendProtocolResponse(
- dispatchable.CallId(),
- CreateErrorResponse(
- dispatchable.CallId(),
- DispatchResponse::InvalidParams("Invalid parameters"), state));
- }
- }
- void DomainDispatcher::clearFrontend() {
- frontend_channel_ = nullptr;
- for (auto& weak : weak_ptrs_)
- weak->dispose();
- weak_ptrs_.clear();
- }
- std::unique_ptr<DomainDispatcher::WeakPtr> DomainDispatcher::weakPtr() {
- auto weak = std::make_unique<DomainDispatcher::WeakPtr>(this);
- weak_ptrs_.insert(weak.get());
- return weak;
- }
- // =============================================================================
- // UberDispatcher - dispatches between domains (backends).
- // =============================================================================
- UberDispatcher::DispatchResult::DispatchResult(bool method_found,
- std::function<void()> runnable)
- : method_found_(method_found), runnable_(runnable) {}
- void UberDispatcher::DispatchResult::Run() {
- if (!runnable_)
- return;
- runnable_();
- runnable_ = nullptr;
- }
- UberDispatcher::UberDispatcher(FrontendChannel* frontend_channel)
- : frontend_channel_(frontend_channel) {
- assert(frontend_channel);
- }
- UberDispatcher::~UberDispatcher() = default;
- constexpr size_t kNotFound = std::numeric_limits<size_t>::max();
- namespace {
- size_t DotIdx(span<uint8_t> method) {
- const void* p = memchr(method.data(), '.', method.size());
- return p ? reinterpret_cast<const uint8_t*>(p) - method.data() : kNotFound;
- }
- } // namespace
- UberDispatcher::DispatchResult UberDispatcher::Dispatch(
- const Dispatchable& dispatchable) const {
- span<uint8_t> method = FindByFirst(redirects_, dispatchable.Method(),
- /*default_value=*/dispatchable.Method());
- size_t dot_idx = DotIdx(method);
- if (dot_idx != kNotFound) {
- span<uint8_t> domain = method.subspan(0, dot_idx);
- span<uint8_t> command = method.subspan(dot_idx + 1);
- DomainDispatcher* dispatcher = FindByFirst(dispatchers_, domain);
- if (dispatcher) {
- std::function<void(const Dispatchable&)> dispatched =
- dispatcher->Dispatch(command);
- if (dispatched) {
- return DispatchResult(
- true, [dispatchable, dispatched = std::move(dispatched)]() {
- dispatched(dispatchable);
- });
- }
- }
- }
- return DispatchResult(false, [this, dispatchable]() {
- frontend_channel_->SendProtocolResponse(
- dispatchable.CallId(),
- CreateErrorResponse(dispatchable.CallId(),
- DispatchResponse::MethodNotFound(
- "'" +
- std::string(dispatchable.Method().begin(),
- dispatchable.Method().end()) +
- "' wasn't found")));
- });
- }
- template <typename T>
- struct FirstLessThan {
- bool operator()(const std::pair<span<uint8_t>, T>& left,
- const std::pair<span<uint8_t>, T>& right) {
- return SpanLessThan(left.first, right.first);
- }
- };
- void UberDispatcher::WireBackend(
- span<uint8_t> domain,
- const std::vector<std::pair<span<uint8_t>, span<uint8_t>>>&
- sorted_redirects,
- std::unique_ptr<DomainDispatcher> dispatcher) {
- auto it = redirects_.insert(redirects_.end(), sorted_redirects.begin(),
- sorted_redirects.end());
- std::inplace_merge(redirects_.begin(), it, redirects_.end(),
- FirstLessThan<span<uint8_t>>());
- auto jt = dispatchers_.insert(dispatchers_.end(),
- std::make_pair(domain, std::move(dispatcher)));
- std::inplace_merge(dispatchers_.begin(), jt, dispatchers_.end(),
- FirstLessThan<std::unique_ptr<DomainDispatcher>>());
- }
- } // namespace crdtp
|