// Copyright 2021 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #ifndef COMPONENTS_CAST_CHANNEL_LIBCAST_SOCKET_SERVICE_H_ #define COMPONENTS_CAST_CHANNEL_LIBCAST_SOCKET_SERVICE_H_ #include #include #include "base/observer_list.h" #include "base/sequence_checker.h" #include "base/task/single_thread_task_runner.h" #include "base/time/time.h" #include "components/cast_channel/cast_socket.h" #include "components/cast_channel/cast_socket_service.h" #include "components/openscreen_platform/task_runner.h" #include "third_party/openscreen/src/cast/common/public/cast_socket.h" #include "third_party/openscreen/src/cast/sender/public/sender_socket_factory.h" #include "third_party/openscreen/src/platform/api/tls_connection_factory.h" namespace cast_channel { using LibcastSocket = openscreen::cast::CastSocket; class CastSocketWrapper; class LibcastSocketService final : public CastSocketService, public openscreen::cast::CastSocket::Client, public openscreen::cast::SenderSocketFactory::Client { public: using CastSocketService::NetworkContextGetter; LibcastSocketService(const LibcastSocketService&) = delete; LibcastSocketService& operator=(const LibcastSocketService&) = delete; ~LibcastSocketService() override; // CastSocketService overrides. std::unique_ptr RemoveSocket(int channel_id) override; CastSocket* GetSocket(int channel_id) const override; CastSocket* GetSocket(const net::IPEndPoint& ip_endpoint) const override; void OpenSocket(NetworkContextGetter network_context_getter, const CastSocketOpenParams& open_params, CastSocket::OnOpenCallback open_cb) override; void AddObserver(CastSocket::Observer* observer) override; void RemoveObserver(CastSocket::Observer* observer) override; // openscreen::cast::CastSocket::Client overrides. void OnError(LibcastSocket* socket, openscreen::Error error) override; void OnMessage(LibcastSocket* socket, ::cast::channel::CastMessage message) override; // openscreen::cast::SenderSocketFactory::Client overrides. void OnConnected(openscreen::cast::SenderSocketFactory* factory, const openscreen::IPEndpoint& endpoint, std::unique_ptr socket) override; void OnError(openscreen::cast::SenderSocketFactory* factory, const openscreen::IPEndpoint& endpoint, openscreen::Error error) override; void SetLibcastSocketForTest(std::unique_ptr socket_for_test) { libcast_socket_for_test_ = std::move(socket_for_test); } private: friend class CastSocketService; friend class LibcastSocketServiceTest; struct ConnectTimer { ConnectTimer(std::unique_ptr callback, std::unique_ptr timer); ConnectTimer(ConnectTimer&&); ~ConnectTimer(); ConnectTimer& operator=(ConnectTimer&&); std::unique_ptr callback; std::unique_ptr timer; }; struct SavedOpenParams { base::TimeDelta ping_interval; base::TimeDelta liveness_timeout; }; LibcastSocketService(); bool EndpointPending(const net::IPEndPoint& ip_endpoint) const; void OnErrorSocketIOThread(LibcastSocket* socket, openscreen::Error error); void OnMessageIOThread(LibcastSocket* socket, ::cast::channel::CastMessage message); void OnConnectedIOThread(openscreen::cast::SenderSocketFactory* factory, const openscreen::IPEndpoint& endpoint, std::unique_ptr socket); void OnErrorIOThread(openscreen::cast::SenderSocketFactory* factory, const openscreen::IPEndpoint& endpoint, openscreen::Error error); void OnErrorBounce(LibcastSocket* socket, ChannelError error); // Used to generate CastSocket IDs on error, since the socket factory doesn't // provide us one in that case. static int last_channel_id_; // List of socket observers. base::ObserverList::Unchecked observers_; openscreen_platform::TaskRunner openscreen_task_runner_; openscreen::cast::SenderSocketFactory socket_factory_; std::unique_ptr tls_factory_; std::map> sockets_; std::map socket_endpoints_; // Data for pending connections. std::map pending_endpoints_; std::map> open_callbacks_; std::map open_params_; std::unique_ptr libcast_socket_for_test_; }; } // namespace cast_channel #endif // COMPONENTS_CAST_CHANNEL_LIBCAST_SOCKET_SERVICE_H_