ssl_hmac_channel_authenticator.cc 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. // Copyright (c) 2012 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "remoting/protocol/ssl_hmac_channel_authenticator.h"
  5. #include <stdint.h>
  6. #include <utility>
  7. #include "base/bind.h"
  8. #include "base/callback_helpers.h"
  9. #include "base/logging.h"
  10. #include "build/build_config.h"
  11. #include "crypto/secure_util.h"
  12. #include "net/base/host_port_pair.h"
  13. #include "net/base/io_buffer.h"
  14. #include "net/base/ip_address.h"
  15. #include "net/base/net_errors.h"
  16. #include "net/cert/cert_status_flags.h"
  17. #include "net/cert/cert_verifier.h"
  18. #include "net/cert/cert_verify_result.h"
  19. #include "net/cert/ct_policy_enforcer.h"
  20. #include "net/cert/ct_policy_status.h"
  21. #include "net/cert/signed_certificate_timestamp_and_status.h"
  22. #include "net/cert/x509_certificate.h"
  23. #include "net/http/transport_security_state.h"
  24. #include "net/log/net_log_with_source.h"
  25. #include "net/socket/client_socket_factory.h"
  26. #include "net/socket/ssl_client_socket.h"
  27. #include "net/socket/ssl_server_socket.h"
  28. #include "net/socket/stream_socket.h"
  29. #include "net/ssl/ssl_config_service.h"
  30. #include "net/ssl/ssl_server_config.h"
  31. #include "net/traffic_annotation/network_traffic_annotation.h"
  32. #include "remoting/base/rsa_key_pair.h"
  33. #include "remoting/protocol/auth_util.h"
  34. #include "remoting/protocol/p2p_stream_socket.h"
  35. namespace remoting {
  36. namespace protocol {
  37. namespace {
  38. constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
  39. net::DefineNetworkTrafficAnnotation("ssl_hmac_channel_authenticator",
  40. R"(
  41. semantics {
  42. sender: "Chrome Remote Desktop"
  43. description:
  44. "Performs the required authentication to start a Chrome Remote "
  45. "Desktop connection."
  46. trigger:
  47. "Initiating a Chrome Remote Desktop connection."
  48. data: "No user data."
  49. destination: OTHER
  50. destination_other:
  51. "The Chrome Remote Desktop client/host that user is connecting to."
  52. }
  53. policy {
  54. cookies_allowed: NO
  55. setting:
  56. "This request cannot be stopped in settings, but will not be sent "
  57. "if user does not use Chrome Remote Desktop."
  58. policy_exception_justification:
  59. "Not implemented. 'RemoteAccessHostClientDomainList' and "
  60. "'RemoteAccessHostDomainList' policies can limit the domains to "
  61. "which a connection can be made, but they cannot be used to block "
  62. "the request to all domains. Please refer to help desk for other "
  63. "approaches to manage this feature."
  64. })");
  65. // A CertVerifier which rejects every certificate.
  66. class FailingCertVerifier : public net::CertVerifier {
  67. public:
  68. FailingCertVerifier() = default;
  69. ~FailingCertVerifier() override = default;
  70. int Verify(const RequestParams& params,
  71. net::CertVerifyResult* verify_result,
  72. net::CompletionOnceCallback callback,
  73. std::unique_ptr<Request>* out_req,
  74. const net::NetLogWithSource& net_log) override {
  75. verify_result->verified_cert = params.certificate();
  76. verify_result->cert_status = net::CERT_STATUS_INVALID;
  77. return net::ERR_CERT_INVALID;
  78. }
  79. void SetConfig(const Config& config) override {}
  80. };
  81. // Implements net::StreamSocket interface on top of P2PStreamSocket to be passed
  82. // to net::SSLClientSocket and net::SSLServerSocket.
  83. class NetStreamSocketAdapter : public net::StreamSocket {
  84. public:
  85. NetStreamSocketAdapter(std::unique_ptr<P2PStreamSocket> socket)
  86. : socket_(std::move(socket)) {}
  87. ~NetStreamSocketAdapter() override = default;
  88. int Read(net::IOBuffer* buf,
  89. int buf_len,
  90. net::CompletionOnceCallback callback) override {
  91. return socket_->Read(buf, buf_len, std::move(callback));
  92. }
  93. int Write(
  94. net::IOBuffer* buf,
  95. int buf_len,
  96. net::CompletionOnceCallback callback,
  97. const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
  98. return socket_->Write(buf, buf_len, std::move(callback),
  99. traffic_annotation);
  100. }
  101. int SetReceiveBufferSize(int32_t size) override {
  102. NOTREACHED();
  103. return net::ERR_FAILED;
  104. }
  105. int SetSendBufferSize(int32_t size) override {
  106. NOTREACHED();
  107. return net::ERR_FAILED;
  108. }
  109. int Connect(net::CompletionOnceCallback callback) override {
  110. NOTREACHED();
  111. return net::ERR_FAILED;
  112. }
  113. void Disconnect() override { socket_.reset(); }
  114. bool IsConnected() const override { return true; }
  115. bool IsConnectedAndIdle() const override { return true; }
  116. int GetPeerAddress(net::IPEndPoint* address) const override {
  117. // SSL sockets call this function so it must return some result.
  118. *address = net::IPEndPoint(net::IPAddress::IPv4AllZeros(), 0);
  119. return net::OK;
  120. }
  121. int GetLocalAddress(net::IPEndPoint* address) const override {
  122. NOTREACHED();
  123. return net::ERR_FAILED;
  124. }
  125. const net::NetLogWithSource& NetLog() const override { return net_log_; }
  126. bool WasEverUsed() const override {
  127. NOTREACHED();
  128. return true;
  129. }
  130. bool WasAlpnNegotiated() const override {
  131. NOTREACHED();
  132. return false;
  133. }
  134. net::NextProto GetNegotiatedProtocol() const override {
  135. NOTREACHED();
  136. return net::kProtoUnknown;
  137. }
  138. bool GetSSLInfo(net::SSLInfo* ssl_info) override {
  139. NOTREACHED();
  140. return false;
  141. }
  142. int64_t GetTotalReceivedBytes() const override {
  143. NOTIMPLEMENTED();
  144. return 0;
  145. }
  146. void ApplySocketTag(const net::SocketTag& tag) override { NOTIMPLEMENTED(); }
  147. private:
  148. std::unique_ptr<P2PStreamSocket> socket_;
  149. net::NetLogWithSource net_log_;
  150. };
  151. } // namespace
  152. // Implements P2PStreamSocket interface on top of net::StreamSocket.
  153. class SslHmacChannelAuthenticator::P2PStreamSocketAdapter
  154. : public P2PStreamSocket {
  155. public:
  156. P2PStreamSocketAdapter(SslSocketContext socket_context,
  157. std::unique_ptr<net::StreamSocket> socket)
  158. : socket_context_(std::move(socket_context)),
  159. socket_(std::move(socket)) {}
  160. ~P2PStreamSocketAdapter() override = default;
  161. int Read(const scoped_refptr<net::IOBuffer>& buf,
  162. int buf_len,
  163. net::CompletionOnceCallback callback) override {
  164. return socket_->Read(buf.get(), buf_len, std::move(callback));
  165. }
  166. int Write(
  167. const scoped_refptr<net::IOBuffer>& buf,
  168. int buf_len,
  169. net::CompletionOnceCallback callback,
  170. const net::NetworkTrafficAnnotationTag& traffic_annotation) override {
  171. return socket_->Write(buf.get(), buf_len, std::move(callback),
  172. traffic_annotation);
  173. }
  174. private:
  175. // The socket_context_ must outlive any associated sockets.
  176. SslSocketContext socket_context_;
  177. std::unique_ptr<net::StreamSocket> socket_;
  178. };
  179. SslHmacChannelAuthenticator::SslSocketContext::SslSocketContext() = default;
  180. SslHmacChannelAuthenticator::SslSocketContext::SslSocketContext(
  181. SslSocketContext&&) = default;
  182. SslHmacChannelAuthenticator::SslSocketContext::~SslSocketContext() = default;
  183. SslHmacChannelAuthenticator::SslSocketContext&
  184. SslHmacChannelAuthenticator::SslSocketContext::operator=(SslSocketContext&&) =
  185. default;
  186. // static
  187. std::unique_ptr<SslHmacChannelAuthenticator>
  188. SslHmacChannelAuthenticator::CreateForClient(const std::string& remote_cert,
  189. const std::string& auth_key) {
  190. std::unique_ptr<SslHmacChannelAuthenticator> result(
  191. new SslHmacChannelAuthenticator(auth_key));
  192. result->remote_cert_ = remote_cert;
  193. return result;
  194. }
  195. std::unique_ptr<SslHmacChannelAuthenticator>
  196. SslHmacChannelAuthenticator::CreateForHost(const std::string& local_cert,
  197. scoped_refptr<RsaKeyPair> key_pair,
  198. const std::string& auth_key) {
  199. std::unique_ptr<SslHmacChannelAuthenticator> result(
  200. new SslHmacChannelAuthenticator(auth_key));
  201. result->local_cert_ = local_cert;
  202. result->local_key_pair_ = key_pair;
  203. return result;
  204. }
  205. SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
  206. const std::string& auth_key)
  207. : auth_key_(auth_key) {
  208. }
  209. SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
  210. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  211. }
  212. void SslHmacChannelAuthenticator::SecureAndAuthenticate(
  213. std::unique_ptr<P2PStreamSocket> socket,
  214. DoneCallback done_callback) {
  215. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  216. done_callback_ = std::move(done_callback);
  217. int result;
  218. if (is_ssl_server()) {
  219. scoped_refptr<net::X509Certificate> cert =
  220. net::X509Certificate::CreateFromBytes(
  221. base::as_bytes(base::make_span(local_cert_)));
  222. if (!cert) {
  223. LOG(ERROR) << "Failed to parse X509Certificate";
  224. NotifyError(net::ERR_FAILED);
  225. return;
  226. }
  227. net::SSLServerConfig ssl_config;
  228. ssl_config.require_ecdhe = true;
  229. socket_context_.server_context = net::CreateSSLServerContext(
  230. cert.get(), *local_key_pair_->private_key(), ssl_config);
  231. std::unique_ptr<net::SSLServerSocket> server_socket =
  232. socket_context_.server_context->CreateSSLServerSocket(
  233. std::make_unique<NetStreamSocketAdapter>(std::move(socket)));
  234. net::SSLServerSocket* raw_server_socket = server_socket.get();
  235. socket_ = std::move(server_socket);
  236. result = raw_server_socket->Handshake(base::BindOnce(
  237. &SslHmacChannelAuthenticator::OnConnected, base::Unretained(this)));
  238. } else {
  239. socket_context_.transport_security_state =
  240. std::make_unique<net::TransportSecurityState>();
  241. socket_context_.cert_verifier = std::make_unique<FailingCertVerifier>();
  242. socket_context_.ct_policy_enforcer =
  243. std::make_unique<net::DefaultCTPolicyEnforcer>();
  244. socket_context_.client_context = std::make_unique<net::SSLClientContext>(
  245. nullptr /* default config */, socket_context_.cert_verifier.get(),
  246. socket_context_.transport_security_state.get(),
  247. socket_context_.ct_policy_enforcer.get(),
  248. nullptr /* no session caching */, nullptr /* no sct auditing */);
  249. net::SSLConfig ssl_config;
  250. ssl_config.require_ecdhe = true;
  251. scoped_refptr<net::X509Certificate> cert =
  252. net::X509Certificate::CreateFromBytes(
  253. base::as_bytes(base::make_span(remote_cert_)));
  254. if (!cert) {
  255. LOG(ERROR) << "Failed to parse X509Certificate";
  256. NotifyError(net::ERR_FAILED);
  257. return;
  258. }
  259. ssl_config.allowed_bad_certs.emplace_back(
  260. std::move(cert), net::CERT_STATUS_AUTHORITY_INVALID);
  261. net::HostPortPair host_and_port(kSslFakeHostName, 0);
  262. std::unique_ptr<net::StreamSocket> stream_socket =
  263. std::make_unique<NetStreamSocketAdapter>(std::move(socket));
  264. socket_ =
  265. net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
  266. socket_context_.client_context.get(), std::move(stream_socket),
  267. host_and_port, ssl_config);
  268. result = socket_->Connect(base::BindOnce(
  269. &SslHmacChannelAuthenticator::OnConnected, base::Unretained(this)));
  270. }
  271. if (result == net::ERR_IO_PENDING)
  272. return;
  273. OnConnected(result);
  274. }
  275. bool SslHmacChannelAuthenticator::is_ssl_server() {
  276. return local_key_pair_.get() != nullptr;
  277. }
  278. void SslHmacChannelAuthenticator::OnConnected(int result) {
  279. if (result != net::OK) {
  280. LOG(WARNING) << "Failed to establish SSL connection. Error: "
  281. << net::ErrorToString(result);
  282. NotifyError(result);
  283. return;
  284. }
  285. // Generate authentication digest to write to the socket.
  286. std::string auth_bytes = GetAuthBytes(
  287. socket_.get(), is_ssl_server() ?
  288. kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
  289. if (auth_bytes.empty()) {
  290. NotifyError(net::ERR_FAILED);
  291. return;
  292. }
  293. // Allocate a buffer to write the digest.
  294. auth_write_buf_ = base::MakeRefCounted<net::DrainableIOBuffer>(
  295. base::MakeRefCounted<net::StringIOBuffer>(auth_bytes), auth_bytes.size());
  296. // Read an incoming token.
  297. auth_read_buf_ = base::MakeRefCounted<net::GrowableIOBuffer>();
  298. auth_read_buf_->SetCapacity(kAuthDigestLength);
  299. // If WriteAuthenticationBytes() results in |done_callback_| being
  300. // called then we must not do anything else because this object may
  301. // be destroyed at that point.
  302. bool callback_called = false;
  303. WriteAuthenticationBytes(&callback_called);
  304. if (!callback_called)
  305. ReadAuthenticationBytes();
  306. }
  307. void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
  308. bool* callback_called) {
  309. while (true) {
  310. int result = socket_->Write(
  311. auth_write_buf_.get(), auth_write_buf_->BytesRemaining(),
  312. base::BindOnce(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
  313. base::Unretained(this)),
  314. kTrafficAnnotation);
  315. if (result == net::ERR_IO_PENDING)
  316. break;
  317. if (!HandleAuthBytesWritten(result, callback_called))
  318. break;
  319. }
  320. }
  321. void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
  322. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  323. if (HandleAuthBytesWritten(result, nullptr))
  324. WriteAuthenticationBytes(nullptr);
  325. }
  326. bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
  327. int result, bool* callback_called) {
  328. if (result <= 0) {
  329. LOG(ERROR) << "Error writing authentication: " << result;
  330. if (callback_called)
  331. *callback_called = false;
  332. NotifyError(result);
  333. return false;
  334. }
  335. auth_write_buf_->DidConsume(result);
  336. if (auth_write_buf_->BytesRemaining() > 0)
  337. return true;
  338. auth_write_buf_ = nullptr;
  339. CheckDone(callback_called);
  340. return false;
  341. }
  342. void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
  343. while (true) {
  344. int result = socket_->Read(
  345. auth_read_buf_.get(), auth_read_buf_->RemainingCapacity(),
  346. base::BindOnce(&SslHmacChannelAuthenticator::OnAuthBytesRead,
  347. base::Unretained(this)));
  348. if (result == net::ERR_IO_PENDING)
  349. break;
  350. if (!HandleAuthBytesRead(result))
  351. break;
  352. }
  353. }
  354. void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
  355. DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  356. if (HandleAuthBytesRead(result))
  357. ReadAuthenticationBytes();
  358. }
  359. bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
  360. if (read_result <= 0) {
  361. NotifyError(read_result);
  362. return false;
  363. }
  364. auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
  365. if (auth_read_buf_->RemainingCapacity() > 0)
  366. return true;
  367. if (!VerifyAuthBytes(std::string(
  368. auth_read_buf_->StartOfBuffer(),
  369. auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
  370. LOG(WARNING) << "Mismatched authentication";
  371. NotifyError(net::ERR_FAILED);
  372. return false;
  373. }
  374. auth_read_buf_ = nullptr;
  375. CheckDone(nullptr);
  376. return false;
  377. }
  378. bool SslHmacChannelAuthenticator::VerifyAuthBytes(
  379. const std::string& received_auth_bytes) {
  380. DCHECK(received_auth_bytes.length() == kAuthDigestLength);
  381. // Compute expected auth bytes.
  382. std::string auth_bytes = GetAuthBytes(
  383. socket_.get(), is_ssl_server() ?
  384. kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
  385. if (auth_bytes.empty())
  386. return false;
  387. return crypto::SecureMemEqual(received_auth_bytes.data(),
  388. &(auth_bytes[0]), kAuthDigestLength);
  389. }
  390. void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
  391. if (auth_write_buf_.get() == nullptr && auth_read_buf_.get() == nullptr) {
  392. DCHECK(socket_.get() != nullptr);
  393. if (callback_called)
  394. *callback_called = true;
  395. std::move(done_callback_)
  396. .Run(net::OK, std::make_unique<P2PStreamSocketAdapter>(
  397. std::move(socket_context_), std::move(socket_)));
  398. }
  399. }
  400. void SslHmacChannelAuthenticator::NotifyError(int error) {
  401. std::move(done_callback_).Run(error, nullptr);
  402. }
  403. } // namespace protocol
  404. } // namespace remoting