transport_client_socket_test_util.cc 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // Copyright 2022 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include <memory>
  5. #include <string>
  6. #include "net/socket/transport_client_socket_test_util.h"
  7. #include "base/memory/ref_counted.h"
  8. #include "net/base/io_buffer.h"
  9. #include "net/base/net_errors.h"
  10. #include "net/test/gtest_util.h"
  11. #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
  12. #include "testing/gtest/include/gtest/gtest.h"
  13. namespace net {
  14. void SendRequestAndResponse(StreamSocket* socket,
  15. StreamSocket* connected_socket) {
  16. // Send client request.
  17. const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
  18. int request_len = strlen(request_text);
  19. scoped_refptr<DrainableIOBuffer> request_buffer =
  20. base::MakeRefCounted<DrainableIOBuffer>(
  21. base::MakeRefCounted<IOBuffer>(request_len), request_len);
  22. memcpy(request_buffer->data(), request_text, request_len);
  23. int bytes_written = 0;
  24. while (request_buffer->BytesRemaining() > 0) {
  25. TestCompletionCallback write_callback;
  26. int write_result =
  27. socket->Write(request_buffer.get(), request_buffer->BytesRemaining(),
  28. write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
  29. write_result = write_callback.GetResult(write_result);
  30. ASSERT_GT(write_result, 0);
  31. ASSERT_LE(bytes_written + write_result, request_len);
  32. request_buffer->DidConsume(write_result);
  33. bytes_written += write_result;
  34. }
  35. ASSERT_EQ(request_len, bytes_written);
  36. // Confirm that the server receives what client sent.
  37. std::string data_received =
  38. ReadDataOfExpectedLength(connected_socket, bytes_written);
  39. ASSERT_TRUE(connected_socket->IsConnectedAndIdle());
  40. ASSERT_EQ(request_text, data_received);
  41. // Write server response.
  42. SendServerResponse(connected_socket);
  43. }
  44. std::string ReadDataOfExpectedLength(StreamSocket* socket,
  45. int expected_bytes_read) {
  46. int bytes_read = 0;
  47. scoped_refptr<IOBufferWithSize> read_buffer =
  48. base::MakeRefCounted<IOBufferWithSize>(expected_bytes_read);
  49. while (bytes_read < expected_bytes_read) {
  50. TestCompletionCallback read_callback;
  51. int rv = socket->Read(read_buffer.get(), expected_bytes_read - bytes_read,
  52. read_callback.callback());
  53. EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
  54. rv = read_callback.GetResult(rv);
  55. EXPECT_GE(rv, 0);
  56. bytes_read += rv;
  57. }
  58. EXPECT_EQ(expected_bytes_read, bytes_read);
  59. return std::string(read_buffer->data(), bytes_read);
  60. }
  61. void SendServerResponse(StreamSocket* socket) {
  62. const char kServerReply[] = "HTTP/1.1 404 Not Found";
  63. int reply_len = strlen(kServerReply);
  64. scoped_refptr<DrainableIOBuffer> write_buffer =
  65. base::MakeRefCounted<DrainableIOBuffer>(
  66. base::MakeRefCounted<IOBuffer>(reply_len), reply_len);
  67. memcpy(write_buffer->data(), kServerReply, reply_len);
  68. int bytes_written = 0;
  69. while (write_buffer->BytesRemaining() > 0) {
  70. TestCompletionCallback write_callback;
  71. int write_result =
  72. socket->Write(write_buffer.get(), write_buffer->BytesRemaining(),
  73. write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
  74. write_result = write_callback.GetResult(write_result);
  75. ASSERT_GE(write_result, 0);
  76. ASSERT_LE(bytes_written + write_result, reply_len);
  77. write_buffer->DidConsume(write_result);
  78. bytes_written += write_result;
  79. }
  80. }
  81. int DrainStreamSocket(StreamSocket* socket,
  82. IOBuffer* buf,
  83. uint32_t buf_len,
  84. uint32_t bytes_to_read,
  85. TestCompletionCallback* callback) {
  86. int rv = OK;
  87. uint32_t bytes_read = 0;
  88. while (bytes_read < bytes_to_read) {
  89. rv = socket->Read(buf, buf_len, callback->callback());
  90. EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
  91. rv = callback->GetResult(rv);
  92. EXPECT_GT(rv, 0);
  93. bytes_read += rv;
  94. }
  95. return static_cast<int>(bytes_read);
  96. }
  97. } // namespace net