// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include #include #include "base/bind.h" #include "base/callback.h" #include "base/numerics/safe_conversions.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/filter/filter_source_stream.h" #include "net/filter/mock_source_stream.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { namespace { const size_t kDefaultBufferSize = 4096; const size_t kSmallBufferSize = 1; class TestFilterSourceStreamBase : public FilterSourceStream { public: explicit TestFilterSourceStreamBase(std::unique_ptr upstream) : FilterSourceStream(SourceStream::TYPE_NONE, std::move(upstream)) {} TestFilterSourceStreamBase(const TestFilterSourceStreamBase&) = delete; TestFilterSourceStreamBase& operator=(const TestFilterSourceStreamBase&) = delete; ~TestFilterSourceStreamBase() override { DCHECK(buffer_.empty()); } std::string GetTypeAsString() const override { return type_string_; } void set_type_string(const std::string& type_string) { type_string_ = type_string; } protected: // Writes contents of |buffer_| to |output_buffer| and returns the number of // bytes written or an error code. Additionally removes consumed data from // |buffer_|. size_t WriteBufferToOutput(IOBuffer* output_buffer, size_t output_buffer_size) { size_t bytes_to_filter = std::min(buffer_.length(), output_buffer_size); memcpy(output_buffer->data(), buffer_.data(), bytes_to_filter); buffer_.erase(0, bytes_to_filter); return bytes_to_filter; } // Buffer used by subclasses to hold data that is yet to be passed to the // caller. std::string buffer_; private: std::string type_string_; }; // A FilterSourceStream that needs all input data before it can return non-zero // bytes read. class NeedsAllInputFilterSourceStream : public TestFilterSourceStreamBase { public: NeedsAllInputFilterSourceStream(std::unique_ptr upstream, size_t expected_input_bytes) : TestFilterSourceStreamBase(std::move(upstream)), expected_input_bytes_(expected_input_bytes) {} NeedsAllInputFilterSourceStream(const NeedsAllInputFilterSourceStream&) = delete; NeedsAllInputFilterSourceStream& operator=( const NeedsAllInputFilterSourceStream&) = delete; base::expected FilterData(IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool upstream_eof_reached) override { buffer_.append(input_buffer->data(), input_buffer_size); EXPECT_GE(expected_input_bytes_, input_buffer_size); expected_input_bytes_ -= input_buffer_size; *consumed_bytes = input_buffer_size; if (!upstream_eof_reached) { // Keep returning 0 bytes read until all input has been consumed. return 0; } EXPECT_EQ(0u, expected_input_bytes_); return WriteBufferToOutput(output_buffer, output_buffer_size); } private: // Expected remaining bytes to be received from |upstream|. size_t expected_input_bytes_; }; // A FilterSourceStream that repeat every input byte by |multiplier| amount of // times. class MultiplySourceStream : public TestFilterSourceStreamBase { public: MultiplySourceStream(std::unique_ptr upstream, int multiplier) : TestFilterSourceStreamBase(std::move(upstream)), multiplier_(multiplier) {} MultiplySourceStream(const MultiplySourceStream&) = delete; MultiplySourceStream& operator=(const MultiplySourceStream&) = delete; base::expected FilterData( IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool /*upstream_eof_reached*/) override { for (size_t i = 0; i < input_buffer_size; i++) { for (int j = 0; j < multiplier_; j++) buffer_.append(input_buffer->data() + i, 1); } *consumed_bytes = input_buffer_size; return WriteBufferToOutput(output_buffer, output_buffer_size); } private: int multiplier_; }; // A FilterSourceStream passes through data unchanged to consumer. class PassThroughFilterSourceStream : public TestFilterSourceStreamBase { public: explicit PassThroughFilterSourceStream(std::unique_ptr upstream) : TestFilterSourceStreamBase(std::move(upstream)) {} PassThroughFilterSourceStream(const PassThroughFilterSourceStream&) = delete; PassThroughFilterSourceStream& operator=( const PassThroughFilterSourceStream&) = delete; base::expected FilterData( IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool /*upstream_eof_reached*/) override { buffer_.append(input_buffer->data(), input_buffer_size); *consumed_bytes = input_buffer_size; return WriteBufferToOutput(output_buffer, output_buffer_size); } }; // A FilterSourceStream passes throttle input data such that it returns them to // caller only one bytes at a time. class ThrottleSourceStream : public TestFilterSourceStreamBase { public: explicit ThrottleSourceStream(std::unique_ptr upstream) : TestFilterSourceStreamBase(std::move(upstream)) {} ThrottleSourceStream(const ThrottleSourceStream&) = delete; ThrottleSourceStream& operator=(const ThrottleSourceStream&) = delete; base::expected FilterData( IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool /*upstream_eof_reached*/) override { buffer_.append(input_buffer->data(), input_buffer_size); *consumed_bytes = input_buffer_size; size_t bytes_to_read = std::min(size_t{1}, buffer_.size()); memcpy(output_buffer->data(), buffer_.data(), bytes_to_read); buffer_.erase(0, bytes_to_read); return bytes_to_read; } }; // A FilterSourceStream that consumes all input data but return no output. class NoOutputSourceStream : public TestFilterSourceStreamBase { public: NoOutputSourceStream(std::unique_ptr upstream, size_t expected_input_size) : TestFilterSourceStreamBase(std::move(upstream)), expected_input_size_(expected_input_size) {} NoOutputSourceStream(const NoOutputSourceStream&) = delete; NoOutputSourceStream& operator=(const NoOutputSourceStream&) = delete; base::expected FilterData( IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool /*upstream_eof_reached*/) override { EXPECT_GE(expected_input_size_, input_buffer_size); expected_input_size_ -= input_buffer_size; *consumed_bytes = input_buffer_size; consumed_all_input_ = (expected_input_size_ == 0); return 0; } bool consumed_all_input() const { return consumed_all_input_; } private: // Expected remaining bytes to be received from |upstream|. size_t expected_input_size_; bool consumed_all_input_ = false; }; // A FilterSourceStream return an error code in FilterData(). class ErrorFilterSourceStream : public FilterSourceStream { public: explicit ErrorFilterSourceStream(std::unique_ptr upstream) : FilterSourceStream(SourceStream::TYPE_NONE, std::move(upstream)) {} ErrorFilterSourceStream(const ErrorFilterSourceStream&) = delete; ErrorFilterSourceStream& operator=(const ErrorFilterSourceStream&) = delete; base::expected FilterData( IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool /*upstream_eof_reached*/) override { return base::unexpected(ERR_CONTENT_DECODING_FAILED); } std::string GetTypeAsString() const override { return ""; } }; } // namespace class FilterSourceStreamTest : public ::testing::TestWithParam { protected: // If MockSourceStream::Mode is ASYNC, completes |num_reads| from // |mock_stream| and wait for |callback| to complete. If Mode is not ASYNC, // does nothing and returns |previous_result|. int CompleteReadIfAsync(int previous_result, TestCompletionCallback* callback, MockSourceStream* mock_stream, size_t num_reads) { if (GetParam() == MockSourceStream::ASYNC) { EXPECT_EQ(ERR_IO_PENDING, previous_result); while (num_reads > 0) { mock_stream->CompleteNextRead(); num_reads--; } return callback->WaitForResult(); } return previous_result; } }; INSTANTIATE_TEST_SUITE_P(FilterSourceStreamTests, FilterSourceStreamTest, ::testing::Values(MockSourceStream::SYNC, MockSourceStream::ASYNC)); // Tests that a FilterSourceStream subclass (NeedsAllInputFilterSourceStream) // can return 0 bytes for FilterData()s when it has not consumed EOF from the // upstream. In this case, FilterSourceStream should continue reading from // upstream to complete filtering. TEST_P(FilterSourceStreamTest, FilterDataReturnNoBytesExceptLast) { auto source = std::make_unique(); std::string input("hello, world!"); size_t read_size = 2; size_t num_reads = 0; // Add a sequence of small reads. for (size_t offset = 0; offset < input.length(); offset += read_size) { source->AddReadResult(input.data() + offset, std::min(read_size, input.length() - offset), OK, GetParam()); num_reads++; } source->AddReadResult(input.data(), 0, OK, GetParam()); // EOF num_reads++; MockSourceStream* mock_stream = source.get(); NeedsAllInputFilterSourceStream stream(std::move(source), input.length()); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); if (rv == ERR_IO_PENDING) rv = CompleteReadIfAsync(rv, &callback, mock_stream, num_reads); if (rv == OK) break; ASSERT_GT(rv, OK); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ(input, actual_output); } // Tests that FilterData() returns 0 byte read because the upstream gives an // EOF. TEST_P(FilterSourceStreamTest, FilterDataReturnNoByte) { auto source = std::make_unique(); std::string input; source->AddReadResult(input.data(), 0, OK, GetParam()); MockSourceStream* mock_stream = source.get(); PassThroughFilterSourceStream stream(std::move(source)); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); TestCompletionCallback callback; int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); rv = CompleteReadIfAsync(rv, &callback, mock_stream, 1); EXPECT_EQ(OK, rv); } // Tests that FilterData() returns 0 byte filtered even though the upstream // produces data. TEST_P(FilterSourceStreamTest, FilterDataOutputNoData) { auto source = std::make_unique(); std::string input = "hello, world!"; size_t read_size = 2; size_t num_reads = 0; // Add a sequence of small reads. for (size_t offset = 0; offset < input.length(); offset += read_size) { source->AddReadResult(input.data() + offset, std::min(read_size, input.length() - offset), OK, GetParam()); num_reads++; } // Add a 0 byte read to signal EOF. source->AddReadResult(input.data() + input.length(), 0, OK, GetParam()); num_reads++; MockSourceStream* mock_stream = source.get(); NoOutputSourceStream stream(std::move(source), input.length()); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); TestCompletionCallback callback; int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); rv = CompleteReadIfAsync(rv, &callback, mock_stream, num_reads); EXPECT_EQ(OK, rv); EXPECT_TRUE(stream.consumed_all_input()); } // Tests that FilterData() returns non-zero bytes because the upstream // returns data. TEST_P(FilterSourceStreamTest, FilterDataReturnData) { auto source = std::make_unique(); std::string input = "hello, world!"; size_t read_size = 2; // Add a sequence of small reads. for (size_t offset = 0; offset < input.length(); offset += read_size) { source->AddReadResult(input.data() + offset, std::min(read_size, input.length() - offset), OK, GetParam()); } // Add a 0 byte read to signal EOF. source->AddReadResult(input.data() + input.length(), 0, OK, GetParam()); MockSourceStream* mock_stream = source.get(); PassThroughFilterSourceStream stream(std::move(source)); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/1); if (rv == OK) break; ASSERT_GE(static_cast(read_size), rv); ASSERT_GT(rv, OK); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ(input, actual_output); } // Tests that FilterData() returns more data than what it consumed. TEST_P(FilterSourceStreamTest, FilterDataReturnMoreData) { auto source = std::make_unique(); std::string input = "hello, world!"; size_t read_size = 2; // Add a sequence of small reads. for (size_t offset = 0; offset < input.length(); offset += read_size) { source->AddReadResult(input.data() + offset, std::min(read_size, input.length() - offset), OK, GetParam()); } // Add a 0 byte read to signal EOF. source->AddReadResult(input.data() + input.length(), 0, OK, GetParam()); MockSourceStream* mock_stream = source.get(); int multiplier = 2; MultiplySourceStream stream(std::move(source), multiplier); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/1); if (rv == OK) break; ASSERT_GE(static_cast(read_size) * multiplier, rv); ASSERT_GT(rv, OK); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ("hheelllloo,, wwoorrlldd!!", actual_output); } // Tests that FilterData() returns non-zero bytes and output buffer size is // smaller than the number of bytes read from the upstream. TEST_P(FilterSourceStreamTest, FilterDataOutputSpace) { auto source = std::make_unique(); std::string input = "hello, world!"; size_t read_size = 2; // Add a sequence of small reads. for (size_t offset = 0; offset < input.length(); offset += read_size) { source->AddReadResult(input.data() + offset, std::min(read_size, input.length() - offset), OK, GetParam()); } // Add a 0 byte read to signal EOF. source->AddReadResult(input.data() + input.length(), 0, OK, GetParam()); // Use an extremely small buffer size, so FilterData will need more output // space. scoped_refptr output_buffer = base::MakeRefCounted(kSmallBufferSize); MockSourceStream* mock_stream = source.get(); PassThroughFilterSourceStream stream(std::move(source)); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); if (rv == ERR_IO_PENDING) rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/1); if (rv == OK) break; ASSERT_GT(rv, OK); ASSERT_GE(kSmallBufferSize, static_cast(rv)); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ(input, actual_output); } // Tests that FilterData() returns an error code, which is then surfaced as // the result of calling Read(). TEST_P(FilterSourceStreamTest, FilterDataReturnError) { auto source = std::make_unique(); std::string input; source->AddReadResult(input.data(), 0, OK, GetParam()); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); MockSourceStream* mock_stream = source.get(); ErrorFilterSourceStream stream(std::move(source)); TestCompletionCallback callback; int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/1); EXPECT_EQ(ERR_CONTENT_DECODING_FAILED, rv); // Reading from |stream| again should return the same error. rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); EXPECT_EQ(ERR_CONTENT_DECODING_FAILED, rv); } TEST_P(FilterSourceStreamTest, FilterChaining) { auto source = std::make_unique(); std::string input = "hello, world!"; source->AddReadResult(input.data(), input.length(), OK, GetParam()); source->AddReadResult(input.data(), 0, OK, GetParam()); // EOF MockSourceStream* mock_stream = source.get(); auto pass_through_source = std::make_unique(std::move(source)); pass_through_source->set_type_string("FIRST_PASS_THROUGH"); auto needs_all_input_source = std::make_unique( std::move(pass_through_source), input.length()); needs_all_input_source->set_type_string("NEEDS_ALL"); auto second_pass_through_source = std::make_unique( std::move(needs_all_input_source)); second_pass_through_source->set_type_string("SECOND_PASS_THROUGH"); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = second_pass_through_source->Read( output_buffer.get(), output_buffer->size(), callback.callback()); if (rv == ERR_IO_PENDING) rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/2); if (rv == OK) break; ASSERT_GT(rv, OK); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ(input, actual_output); // Type string (from left to right) should be the order of data flow. EXPECT_EQ("FIRST_PASS_THROUGH,NEEDS_ALL,SECOND_PASS_THROUGH", second_pass_through_source->Description()); } // Tests that FilterData() returns multiple times for a single MockStream // read, because there is not enough output space. TEST_P(FilterSourceStreamTest, OutputSpaceForOneRead) { auto source = std::make_unique(); std::string input = "hello, world!"; source->AddReadResult(input.data(), input.length(), OK, GetParam()); // Add a 0 byte read to signal EOF. source->AddReadResult(input.data() + input.length(), 0, OK, GetParam()); // Use an extremely small buffer size (1 byte), so FilterData will need more // output space. scoped_refptr output_buffer = base::MakeRefCounted(kSmallBufferSize); MockSourceStream* mock_stream = source.get(); PassThroughFilterSourceStream stream(std::move(source)); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); if (rv == ERR_IO_PENDING) rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/1); if (rv == OK) break; ASSERT_GT(rv, OK); ASSERT_GE(kSmallBufferSize, static_cast(rv)); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ(input, actual_output); } // Tests that FilterData() returns multiple times for a single MockStream // read, because the filter returns one byte at a time. TEST_P(FilterSourceStreamTest, ThrottleSourceStream) { auto source = std::make_unique(); std::string input = "hello, world!"; source->AddReadResult(input.data(), input.length(), OK, GetParam()); // Add a 0 byte read to signal EOF. source->AddReadResult(input.data() + input.length(), 0, OK, GetParam()); scoped_refptr output_buffer = base::MakeRefCounted(kDefaultBufferSize); MockSourceStream* mock_stream = source.get(); ThrottleSourceStream stream(std::move(source)); TestCompletionCallback callback; std::string actual_output; while (true) { int rv = stream.Read(output_buffer.get(), output_buffer->size(), callback.callback()); if (rv == ERR_IO_PENDING) rv = CompleteReadIfAsync(rv, &callback, mock_stream, /*num_reads=*/1); if (rv == OK) break; ASSERT_GT(rv, OK); // ThrottleSourceStream returns 1 byte at a time. ASSERT_GE(1u, static_cast(rv)); actual_output.append(output_buffer->data(), rv); } EXPECT_EQ(input, actual_output); } } // namespace net