123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- // Copyright 2014 The Chromium Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style license that can be
- // found in the LICENSE file.
- #include "net/socket/transport_client_socket_pool_test_util.h"
- #include <stdint.h>
- #include <string>
- #include <utility>
- #include "base/bind.h"
- #include "base/check_op.h"
- #include "base/location.h"
- #include "base/memory/weak_ptr.h"
- #include "base/notreached.h"
- #include "base/run_loop.h"
- #include "base/task/single_thread_task_runner.h"
- #include "base/threading/thread_task_runner_handle.h"
- #include "net/base/ip_address.h"
- #include "net/base/ip_endpoint.h"
- #include "net/base/load_timing_info.h"
- #include "net/base/load_timing_info_test_util.h"
- #include "net/log/net_log_source.h"
- #include "net/log/net_log_source_type.h"
- #include "net/log/net_log_with_source.h"
- #include "net/socket/client_socket_handle.h"
- #include "net/socket/datagram_client_socket.h"
- #include "net/socket/ssl_client_socket.h"
- #include "net/socket/transport_client_socket.h"
- #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
- #include "testing/gtest/include/gtest/gtest.h"
- namespace net {
- namespace {
- IPAddress ParseIP(const std::string& ip) {
- IPAddress address;
- CHECK(address.AssignFromIPLiteral(ip));
- return address;
- }
- // A StreamSocket which connects synchronously and successfully.
- class MockConnectClientSocket : public TransportClientSocket {
- public:
- MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log)
- : addrlist_(addrlist),
- net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
- MockConnectClientSocket(const MockConnectClientSocket&) = delete;
- MockConnectClientSocket& operator=(const MockConnectClientSocket&) = delete;
- // TransportClientSocket implementation.
- int Bind(const net::IPEndPoint& local_addr) override {
- NOTREACHED();
- return ERR_FAILED;
- }
- // StreamSocket implementation.
- int Connect(CompletionOnceCallback callback) override {
- connected_ = true;
- return OK;
- }
- void Disconnect() override { connected_ = false; }
- bool IsConnected() const override { return connected_; }
- bool IsConnectedAndIdle() const override { return connected_; }
- int GetPeerAddress(IPEndPoint* address) const override {
- *address = addrlist_.front();
- return OK;
- }
- int GetLocalAddress(IPEndPoint* address) const override {
- if (!connected_)
- return ERR_SOCKET_NOT_CONNECTED;
- if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
- SetIPv4Address(address);
- else
- SetIPv6Address(address);
- return OK;
- }
- const NetLogWithSource& NetLog() const override { return net_log_; }
- bool WasEverUsed() const override { return false; }
- bool WasAlpnNegotiated() const override { return false; }
- NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
- bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
- int64_t GetTotalReceivedBytes() const override {
- NOTIMPLEMENTED();
- return 0;
- }
- void ApplySocketTag(const SocketTag& tag) override {}
- // Socket implementation.
- int Read(IOBuffer* buf,
- int buf_len,
- CompletionOnceCallback callback) override {
- return ERR_FAILED;
- }
- int Write(IOBuffer* buf,
- int buf_len,
- CompletionOnceCallback callback,
- const NetworkTrafficAnnotationTag& traffic_annotation) override {
- return ERR_FAILED;
- }
- int SetReceiveBufferSize(int32_t size) override { return OK; }
- int SetSendBufferSize(int32_t size) override { return OK; }
- private:
- bool connected_ = false;
- const AddressList addrlist_;
- NetLogWithSource net_log_;
- };
- class MockFailingClientSocket : public TransportClientSocket {
- public:
- MockFailingClientSocket(const AddressList& addrlist,
- Error connect_error,
- net::NetLog* net_log)
- : addrlist_(addrlist),
- connect_error_(connect_error),
- net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
- MockFailingClientSocket(const MockFailingClientSocket&) = delete;
- MockFailingClientSocket& operator=(const MockFailingClientSocket&) = delete;
- // TransportClientSocket implementation.
- int Bind(const net::IPEndPoint& local_addr) override {
- NOTREACHED();
- return ERR_FAILED;
- }
- // StreamSocket implementation.
- int Connect(CompletionOnceCallback callback) override {
- return connect_error_;
- }
- void Disconnect() override {}
- bool IsConnected() const override { return false; }
- bool IsConnectedAndIdle() const override { return false; }
- int GetPeerAddress(IPEndPoint* address) const override {
- return ERR_UNEXPECTED;
- }
- int GetLocalAddress(IPEndPoint* address) const override {
- return ERR_UNEXPECTED;
- }
- const NetLogWithSource& NetLog() const override { return net_log_; }
- bool WasEverUsed() const override { return false; }
- bool WasAlpnNegotiated() const override { return false; }
- NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
- bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
- int64_t GetTotalReceivedBytes() const override {
- NOTIMPLEMENTED();
- return 0;
- }
- void ApplySocketTag(const SocketTag& tag) override {}
- // Socket implementation.
- int Read(IOBuffer* buf,
- int buf_len,
- CompletionOnceCallback callback) override {
- return ERR_FAILED;
- }
- int Write(IOBuffer* buf,
- int buf_len,
- CompletionOnceCallback callback,
- const NetworkTrafficAnnotationTag& traffic_annotation) override {
- return ERR_FAILED;
- }
- int SetReceiveBufferSize(int32_t size) override { return OK; }
- int SetSendBufferSize(int32_t size) override { return OK; }
- private:
- const AddressList addrlist_;
- const Error connect_error_;
- NetLogWithSource net_log_;
- };
- class MockTriggerableClientSocket : public TransportClientSocket {
- public:
- // |connect_error| indicates whether the socket should successfully complete
- // or fail.
- MockTriggerableClientSocket(const AddressList& addrlist,
- Error connect_error,
- net::NetLog* net_log)
- : connect_error_(connect_error),
- addrlist_(addrlist),
- net_log_(NetLogWithSource::Make(net_log, NetLogSourceType::SOCKET)) {}
- MockTriggerableClientSocket(const MockTriggerableClientSocket&) = delete;
- MockTriggerableClientSocket& operator=(const MockTriggerableClientSocket&) =
- delete;
- // Call this method to get a closure which will trigger the connect callback
- // when called. The closure can be called even after the socket is deleted; it
- // will safely do nothing.
- base::OnceClosure GetConnectCallback() {
- return base::BindOnce(&MockTriggerableClientSocket::DoCallback,
- weak_factory_.GetWeakPtr());
- }
- static std::unique_ptr<TransportClientSocket> MakeMockPendingClientSocket(
- const AddressList& addrlist,
- Error connect_error,
- net::NetLog* net_log) {
- auto socket = std::make_unique<MockTriggerableClientSocket>(
- addrlist, connect_error, net_log);
- base::ThreadTaskRunnerHandle::Get()->PostTask(FROM_HERE,
- socket->GetConnectCallback());
- return std::move(socket);
- }
- static std::unique_ptr<TransportClientSocket> MakeMockDelayedClientSocket(
- const AddressList& addrlist,
- Error connect_error,
- const base::TimeDelta& delay,
- net::NetLog* net_log) {
- auto socket = std::make_unique<MockTriggerableClientSocket>(
- addrlist, connect_error, net_log);
- base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
- FROM_HERE, socket->GetConnectCallback(), delay);
- return std::move(socket);
- }
- static std::unique_ptr<TransportClientSocket> MakeMockStalledClientSocket(
- const AddressList& addrlist,
- net::NetLog* net_log) {
- // We never post `GetConnectCallback()`, so the value of `connect_error`
- // does not matter.
- return std::make_unique<MockTriggerableClientSocket>(
- addrlist, /*connect_error=*/OK, net_log);
- }
- // TransportClientSocket implementation.
- int Bind(const net::IPEndPoint& local_addr) override {
- NOTREACHED();
- return ERR_FAILED;
- }
- // StreamSocket implementation.
- int Connect(CompletionOnceCallback callback) override {
- DCHECK(callback_.is_null());
- callback_ = std::move(callback);
- return ERR_IO_PENDING;
- }
- void Disconnect() override {}
- bool IsConnected() const override { return is_connected_; }
- bool IsConnectedAndIdle() const override { return is_connected_; }
- int GetPeerAddress(IPEndPoint* address) const override {
- *address = addrlist_.front();
- return OK;
- }
- int GetLocalAddress(IPEndPoint* address) const override {
- if (!is_connected_)
- return ERR_SOCKET_NOT_CONNECTED;
- if (addrlist_.front().GetFamily() == ADDRESS_FAMILY_IPV4)
- SetIPv4Address(address);
- else
- SetIPv6Address(address);
- return OK;
- }
- const NetLogWithSource& NetLog() const override { return net_log_; }
- bool WasEverUsed() const override { return false; }
- bool WasAlpnNegotiated() const override { return false; }
- NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; }
- bool GetSSLInfo(SSLInfo* ssl_info) override { return false; }
- int64_t GetTotalReceivedBytes() const override {
- NOTIMPLEMENTED();
- return 0;
- }
- void ApplySocketTag(const SocketTag& tag) override {}
- // Socket implementation.
- int Read(IOBuffer* buf,
- int buf_len,
- CompletionOnceCallback callback) override {
- return ERR_FAILED;
- }
- int Write(IOBuffer* buf,
- int buf_len,
- CompletionOnceCallback callback,
- const NetworkTrafficAnnotationTag& traffic_annotation) override {
- return ERR_FAILED;
- }
- int SetReceiveBufferSize(int32_t size) override { return OK; }
- int SetSendBufferSize(int32_t size) override { return OK; }
- private:
- void DoCallback() {
- is_connected_ = connect_error_ == OK;
- std::move(callback_).Run(connect_error_);
- }
- Error connect_error_;
- bool is_connected_ = false;
- const AddressList addrlist_;
- NetLogWithSource net_log_;
- CompletionOnceCallback callback_;
- base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_{this};
- };
- } // namespace
- void TestLoadTimingInfoConnectedReused(const ClientSocketHandle& handle) {
- LoadTimingInfo load_timing_info;
- // Only pass true in as |is_reused|, as in general, HttpStream types should
- // have stricter concepts of reuse than socket pools.
- EXPECT_TRUE(handle.GetLoadTimingInfo(true, &load_timing_info));
- EXPECT_TRUE(load_timing_info.socket_reused);
- EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
- ExpectConnectTimingHasNoTimes(load_timing_info.connect_timing);
- ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
- }
- void TestLoadTimingInfoConnectedNotReused(const ClientSocketHandle& handle) {
- EXPECT_FALSE(handle.is_reused());
- LoadTimingInfo load_timing_info;
- EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
- EXPECT_FALSE(load_timing_info.socket_reused);
- EXPECT_NE(NetLogSource::kInvalidId, load_timing_info.socket_log_id);
- ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
- CONNECT_TIMING_HAS_DNS_TIMES);
- ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
- TestLoadTimingInfoConnectedReused(handle);
- }
- void SetIPv4Address(IPEndPoint* address) {
- *address = IPEndPoint(ParseIP("1.1.1.1"), 80);
- }
- void SetIPv6Address(IPEndPoint* address) {
- *address = IPEndPoint(ParseIP("1:abcd::3:4:ff"), 80);
- }
- MockTransportClientSocketFactory::Rule::Rule(
- Type type,
- absl::optional<std::vector<IPEndPoint>> expected_addresses,
- Error connect_error)
- : type(type),
- expected_addresses(std::move(expected_addresses)),
- connect_error(connect_error) {}
- MockTransportClientSocketFactory::Rule::~Rule() = default;
- MockTransportClientSocketFactory::Rule::Rule(const Rule&) = default;
- MockTransportClientSocketFactory::Rule&
- MockTransportClientSocketFactory::Rule::operator=(const Rule&) = default;
- MockTransportClientSocketFactory::MockTransportClientSocketFactory(
- NetLog* net_log)
- : net_log_(net_log),
- delay_(base::Milliseconds(ClientSocketPool::kMaxConnectRetryIntervalMs)) {
- }
- MockTransportClientSocketFactory::~MockTransportClientSocketFactory() = default;
- std::unique_ptr<DatagramClientSocket>
- MockTransportClientSocketFactory::CreateDatagramClientSocket(
- DatagramSocket::BindType bind_type,
- NetLog* net_log,
- const NetLogSource& source) {
- NOTREACHED();
- return nullptr;
- }
- std::unique_ptr<TransportClientSocket>
- MockTransportClientSocketFactory::CreateTransportClientSocket(
- const AddressList& addresses,
- std::unique_ptr<SocketPerformanceWatcher> /* socket_performance_watcher */,
- NetworkQualityEstimator* /* network_quality_estimator */,
- NetLog* /* net_log */,
- const NetLogSource& /* source */) {
- allocation_count_++;
- Rule rule(client_socket_type_);
- if (!rules_.empty()) {
- rule = rules_.front();
- rules_ = rules_.subspan(1);
- }
- if (rule.expected_addresses) {
- EXPECT_EQ(addresses.endpoints(), *rule.expected_addresses);
- }
- switch (rule.type) {
- case Type::kUnexpected:
- ADD_FAILURE() << "Unexpectedly created socket to "
- << addresses.endpoints().front();
- return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
- case Type::kSynchronous:
- return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
- case Type::kFailing:
- return std::make_unique<MockFailingClientSocket>(
- addresses, rule.connect_error, net_log_);
- case Type::kPending:
- return MockTriggerableClientSocket::MakeMockPendingClientSocket(
- addresses, OK, net_log_);
- case Type::kPendingFailing:
- return MockTriggerableClientSocket::MakeMockPendingClientSocket(
- addresses, rule.connect_error, net_log_);
- case Type::kDelayed:
- return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
- addresses, OK, delay_, net_log_);
- case Type::kDelayedFailing:
- return MockTriggerableClientSocket::MakeMockDelayedClientSocket(
- addresses, rule.connect_error, delay_, net_log_);
- case Type::kStalled:
- return MockTriggerableClientSocket::MakeMockStalledClientSocket(addresses,
- net_log_);
- case Type::kTriggerable: {
- auto rv = std::make_unique<MockTriggerableClientSocket>(addresses, OK,
- net_log_);
- triggerable_sockets_.push(rv->GetConnectCallback());
- // run_loop_quit_closure_ behaves like a condition variable. It will
- // wake up WaitForTriggerableSocketCreation() if it is sleeping. We
- // don't need to worry about atomicity because this code is
- // single-threaded.
- if (!run_loop_quit_closure_.is_null())
- std::move(run_loop_quit_closure_).Run();
- return std::move(rv);
- }
- default:
- NOTREACHED();
- return std::make_unique<MockConnectClientSocket>(addresses, net_log_);
- }
- }
- std::unique_ptr<SSLClientSocket>
- MockTransportClientSocketFactory::CreateSSLClientSocket(
- SSLClientContext* context,
- std::unique_ptr<StreamSocket> stream_socket,
- const HostPortPair& host_and_port,
- const SSLConfig& ssl_config) {
- NOTIMPLEMENTED();
- return nullptr;
- }
- void MockTransportClientSocketFactory::SetRules(base::span<const Rule> rules) {
- DCHECK(rules_.empty());
- client_socket_type_ = Type::kUnexpected;
- rules_ = rules;
- }
- base::OnceClosure
- MockTransportClientSocketFactory::WaitForTriggerableSocketCreation() {
- while (triggerable_sockets_.empty()) {
- base::RunLoop run_loop;
- run_loop_quit_closure_ = run_loop.QuitClosure();
- run_loop.Run();
- run_loop_quit_closure_.Reset();
- }
- base::OnceClosure trigger = std::move(triggerable_sockets_.front());
- triggerable_sockets_.pop();
- return trigger;
- }
- } // namespace net
|