// Copyright 2015 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_stream_create_test_base.h" #include "base/memory/raw_ptr.h" #include #include "base/callback.h" #include "net/base/ip_endpoint.h" #include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/log/net_log_with_source.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/websockets/websocket_basic_handshake_stream.h" #include "net/websockets/websocket_handshake_request_info.h" #include "net/websockets/websocket_handshake_response_info.h" #include "net/websockets/websocket_stream.h" #include "url/gurl.h" #include "url/origin.h" namespace net { using HeaderKeyValuePair = WebSocketStreamCreateTestBase::HeaderKeyValuePair; class WebSocketStreamCreateTestBase::TestConnectDelegate : public WebSocketStream::ConnectDelegate { public: TestConnectDelegate(WebSocketStreamCreateTestBase* owner, base::OnceClosure done_callback) : owner_(owner), done_callback_(std::move(done_callback)) {} TestConnectDelegate(const TestConnectDelegate&) = delete; TestConnectDelegate& operator=(const TestConnectDelegate&) = delete; void OnCreateRequest(URLRequest* request) override { owner_->url_request_ = request; } void OnSuccess( std::unique_ptr stream, std::unique_ptr response) override { if (owner_->response_info_) ADD_FAILURE(); owner_->response_info_ = std::move(response); stream.swap(owner_->stream_); std::move(done_callback_).Run(); } void OnFailure(const std::string& message, int net_error, absl::optional response_code) override { owner_->has_failed_ = true; owner_->failure_message_ = message; owner_->failure_response_code_ = response_code.value_or(-1); std::move(done_callback_).Run(); } void OnStartOpeningHandshake( std::unique_ptr request) override { // Can be called multiple times (in the case of HTTP auth). Last call // wins. owner_->request_info_ = std::move(request); } void OnSSLCertificateError( std::unique_ptr ssl_error_callbacks, int net_error, const SSLInfo& ssl_info, bool fatal) override { owner_->ssl_error_callbacks_ = std::move(ssl_error_callbacks); owner_->ssl_info_ = ssl_info; owner_->ssl_fatal_ = fatal; } int OnAuthRequired(const AuthChallengeInfo& auth_info, scoped_refptr response_headers, const IPEndPoint& remote_endpoint, base::OnceCallback callback, absl::optional* credentials) override { owner_->run_loop_waiting_for_on_auth_required_.Quit(); owner_->auth_challenge_info_ = auth_info; *credentials = owner_->auth_credentials_; owner_->on_auth_required_callback_ = std::move(callback); return owner_->on_auth_required_rv_; } private: raw_ptr owner_; base::OnceClosure done_callback_; }; WebSocketStreamCreateTestBase::WebSocketStreamCreateTestBase() = default; WebSocketStreamCreateTestBase::~WebSocketStreamCreateTestBase() = default; void WebSocketStreamCreateTestBase::CreateAndConnectStream( const GURL& socket_url, const std::vector& sub_protocols, const url::Origin& origin, const SiteForCookies& site_for_cookies, const IsolationInfo& isolation_info, const HttpRequestHeaders& additional_headers, std::unique_ptr timer) { auto connect_delegate = std::make_unique( this, connect_run_loop_.QuitClosure()); auto api_delegate = std::make_unique(); stream_request_ = WebSocketStream::CreateAndConnectStreamForTesting( socket_url, sub_protocols, origin, site_for_cookies, isolation_info, additional_headers, url_request_context_host_.GetURLRequestContext(), NetLogWithSource(), TRAFFIC_ANNOTATION_FOR_TESTS, std::move(connect_delegate), timer ? std::move(timer) : std::make_unique(), std::move(api_delegate)); } std::vector WebSocketStreamCreateTestBase::RequestHeadersToVector( const HttpRequestHeaders& headers) { HttpRequestHeaders::Iterator it(headers); std::vector result; while (it.GetNext()) result.emplace_back(it.name(), it.value()); return result; } std::vector WebSocketStreamCreateTestBase::ResponseHeadersToVector( const HttpResponseHeaders& headers) { size_t iter = 0; std::string name, value; std::vector result; while (headers.EnumerateHeaderLines(&iter, &name, &value)) result.emplace_back(name, value); return result; } void WebSocketStreamCreateTestBase::WaitUntilConnectDone() { connect_run_loop_.Run(); } void WebSocketStreamCreateTestBase::WaitUntilOnAuthRequired() { run_loop_waiting_for_on_auth_required_.Run(); } std::vector WebSocketStreamCreateTestBase::NoSubProtocols() { return std::vector(); } } // namespace net