hid_connection_impl_unittest.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. // Copyright 2019 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "services/device/hid/hid_connection_impl.h"
  5. #include "base/bind.h"
  6. #include "base/callback_helpers.h"
  7. #include "base/memory/raw_ptr.h"
  8. #include "base/memory/ref_counted_memory.h"
  9. #include "build/build_config.h"
  10. #include "mojo/public/cpp/bindings/pending_receiver.h"
  11. #include "mojo/public/cpp/bindings/receiver.h"
  12. #include "mojo/public/cpp/bindings/self_owned_receiver.h"
  13. #include "services/device/device_service_test_base.h"
  14. #include "testing/gtest/include/gtest/gtest.h"
  15. namespace device {
  16. namespace {
  17. #if BUILDFLAG(IS_MAC)
  18. const uint64_t kTestDeviceId = 123;
  19. #elif BUILDFLAG(IS_WIN)
  20. const wchar_t* kTestDeviceId = L"123";
  21. #else
  22. const char* kTestDeviceId = "123";
  23. #endif
  24. // The report ID to use for reports sent to or received from the test device.
  25. const uint8_t kTestReportId = 0x42;
  26. // The max size of input and output reports for the test device. Feature reports
  27. // are not used in this test.
  28. const uint64_t kMaxReportSizeBytes = 10;
  29. // A fake HidConnection implementation that allows the test to simulate an
  30. // input report.
  31. class FakeHidConnection : public HidConnection {
  32. public:
  33. explicit FakeHidConnection(scoped_refptr<HidDeviceInfo> device)
  34. : HidConnection(device,
  35. /*allow_protected_reports=*/false,
  36. /*allow_fido_reports=*/false) {}
  37. FakeHidConnection(const FakeHidConnection&) = delete;
  38. FakeHidConnection& operator=(const FakeHidConnection&) = delete;
  39. // HidConnection implementation.
  40. void PlatformClose() override {}
  41. void PlatformWrite(scoped_refptr<base::RefCountedBytes> buffer,
  42. WriteCallback callback) override {
  43. std::move(callback).Run(true);
  44. }
  45. void PlatformGetFeatureReport(uint8_t report_id,
  46. ReadCallback callback) override {
  47. NOTIMPLEMENTED();
  48. }
  49. void PlatformSendFeatureReport(scoped_refptr<base::RefCountedBytes> buffer,
  50. WriteCallback callback) override {
  51. NOTIMPLEMENTED();
  52. }
  53. void SimulateInputReport(scoped_refptr<base::RefCountedBytes> buffer) {
  54. ProcessInputReport(buffer, buffer->size());
  55. }
  56. private:
  57. ~FakeHidConnection() override = default;
  58. };
  59. // A test implementation of HidConnectionClient that signals once an input
  60. // report has been received. The contents of the input report are saved.
  61. class TestHidConnectionClient : public mojom::HidConnectionClient {
  62. public:
  63. TestHidConnectionClient() = default;
  64. TestHidConnectionClient(const TestHidConnectionClient&) = delete;
  65. TestHidConnectionClient& operator=(const TestHidConnectionClient&) = delete;
  66. ~TestHidConnectionClient() override = default;
  67. void Bind(mojo::PendingReceiver<mojom::HidConnectionClient> receiver) {
  68. receiver_.Bind(std::move(receiver));
  69. }
  70. // mojom::HidConnectionClient implementation.
  71. void OnInputReport(uint8_t report_id,
  72. const std::vector<uint8_t>& buffer) override {
  73. report_id_ = report_id;
  74. buffer_ = buffer;
  75. run_loop_.Quit();
  76. }
  77. void WaitForInputReport() { run_loop_.Run(); }
  78. uint8_t report_id() { return report_id_; }
  79. const std::vector<uint8_t>& buffer() { return buffer_; }
  80. private:
  81. base::RunLoop run_loop_;
  82. mojo::Receiver<mojom::HidConnectionClient> receiver_{this};
  83. uint8_t report_id_ = 0;
  84. std::vector<uint8_t> buffer_;
  85. };
  86. // A utility for capturing the state returned by mojom::HidConnection I/O
  87. // callbacks.
  88. class TestIoCallback {
  89. public:
  90. TestIoCallback() = default;
  91. TestIoCallback(const TestIoCallback&) = delete;
  92. TestIoCallback& operator=(const TestIoCallback&) = delete;
  93. ~TestIoCallback() = default;
  94. void SetReadResult(bool result,
  95. uint8_t report_id,
  96. const absl::optional<std::vector<uint8_t>>& buffer) {
  97. result_ = result;
  98. report_id_ = report_id;
  99. has_buffer_ = buffer.has_value();
  100. if (has_buffer_)
  101. buffer_ = *buffer;
  102. run_loop_.Quit();
  103. }
  104. void SetWriteResult(bool result) {
  105. result_ = result;
  106. run_loop_.Quit();
  107. }
  108. bool WaitForResult() {
  109. run_loop_.Run();
  110. return result_;
  111. }
  112. mojom::HidConnection::ReadCallback GetReadCallback() {
  113. return base::BindOnce(&TestIoCallback::SetReadResult,
  114. base::Unretained(this));
  115. }
  116. mojom::HidConnection::WriteCallback GetWriteCallback() {
  117. return base::BindOnce(&TestIoCallback::SetWriteResult,
  118. base::Unretained(this));
  119. }
  120. uint8_t report_id() { return report_id_; }
  121. bool has_buffer() { return has_buffer_; }
  122. const std::vector<uint8_t>& buffer() { return buffer_; }
  123. private:
  124. base::RunLoop run_loop_;
  125. bool result_ = false;
  126. uint8_t report_id_ = 0;
  127. bool has_buffer_ = false;
  128. std::vector<uint8_t> buffer_;
  129. };
  130. } // namespace
  131. class HidConnectionImplTest : public DeviceServiceTestBase {
  132. public:
  133. HidConnectionImplTest() = default;
  134. HidConnectionImplTest(HidConnectionImplTest&) = delete;
  135. HidConnectionImplTest& operator=(HidConnectionImplTest&) = delete;
  136. protected:
  137. void SetUp() override {
  138. DeviceServiceTestBase::SetUp();
  139. base::RunLoop().RunUntilIdle();
  140. }
  141. void CreateHidConnection(bool with_connection_client) {
  142. mojo::PendingRemote<mojom::HidConnectionClient> hid_connection_client;
  143. if (with_connection_client) {
  144. connection_client_ = std::make_unique<TestHidConnectionClient>();
  145. connection_client_->Bind(
  146. hid_connection_client.InitWithNewPipeAndPassReceiver());
  147. }
  148. fake_connection_ = new FakeHidConnection(CreateTestDevice());
  149. hid_connection_impl_ = new HidConnectionImpl(
  150. fake_connection_, hid_connection_.InitWithNewPipeAndPassReceiver(),
  151. std::move(hid_connection_client),
  152. /*watcher=*/mojo::NullRemote());
  153. }
  154. scoped_refptr<HidDeviceInfo> CreateTestDevice() {
  155. auto collection = mojom::HidCollectionInfo::New();
  156. collection->usage = mojom::HidUsageAndPage::New(0, 0);
  157. collection->report_ids.push_back(kTestReportId);
  158. return base::MakeRefCounted<HidDeviceInfo>(
  159. kTestDeviceId, /*physical_device_id=*/"1", "interface id",
  160. /*vendor_id=*/0x1234, /*product_id=*/0xabcd, "product name",
  161. "serial number", mojom::HidBusType::kHIDBusTypeUSB,
  162. std::move(collection), kMaxReportSizeBytes, kMaxReportSizeBytes,
  163. /*max_feature_report_size=*/0);
  164. }
  165. std::vector<uint8_t> CreateTestReportBuffer(uint8_t report_id, size_t size) {
  166. std::vector<uint8_t> buffer(size);
  167. buffer[0] = report_id;
  168. for (size_t i = 1; i < size; ++i)
  169. buffer[i] = i;
  170. return buffer;
  171. }
  172. mojo::PendingRemote<mojom::HidConnection> hid_connection_;
  173. raw_ptr<HidConnectionImpl>
  174. hid_connection_impl_; // Owned by |hid_connection_|.
  175. scoped_refptr<FakeHidConnection> fake_connection_;
  176. std::unique_ptr<TestHidConnectionClient> connection_client_;
  177. };
  178. TEST_F(HidConnectionImplTest, ReadWrite) {
  179. CreateHidConnection(/*with_connection_client=*/false);
  180. const size_t kTestBufferSize = kMaxReportSizeBytes;
  181. std::vector<uint8_t> buffer_vec =
  182. CreateTestReportBuffer(kTestReportId, kTestBufferSize);
  183. // Simulate an output report (host to device).
  184. TestIoCallback write_callback;
  185. hid_connection_impl_->Write(kTestReportId, buffer_vec,
  186. write_callback.GetWriteCallback());
  187. ASSERT_TRUE(write_callback.WaitForResult());
  188. // Simulate an input report (device to host).
  189. auto buffer = base::MakeRefCounted<base::RefCountedBytes>(buffer_vec);
  190. ASSERT_EQ(buffer->size(), kTestBufferSize);
  191. fake_connection_->SimulateInputReport(buffer);
  192. // Simulate reading the input report.
  193. TestIoCallback read_callback;
  194. hid_connection_impl_->Read(read_callback.GetReadCallback());
  195. ASSERT_TRUE(read_callback.WaitForResult());
  196. EXPECT_EQ(read_callback.report_id(), kTestReportId);
  197. ASSERT_TRUE(read_callback.has_buffer());
  198. const auto& read_buffer = read_callback.buffer();
  199. ASSERT_EQ(read_buffer.size(), kTestBufferSize - 1);
  200. for (size_t i = 1; i < kTestBufferSize; ++i) {
  201. EXPECT_EQ(read_buffer[i - 1], buffer_vec[i])
  202. << "Mismatch at index " << i << ".";
  203. }
  204. }
  205. TEST_F(HidConnectionImplTest, ReadWriteWithConnectionClient) {
  206. CreateHidConnection(/*with_connection_client=*/true);
  207. const size_t kTestBufferSize = kMaxReportSizeBytes;
  208. std::vector<uint8_t> buffer_vec =
  209. CreateTestReportBuffer(kTestReportId, kTestBufferSize);
  210. // Simulate an output report (host to device).
  211. TestIoCallback write_callback;
  212. hid_connection_impl_->Write(kTestReportId, buffer_vec,
  213. write_callback.GetWriteCallback());
  214. ASSERT_TRUE(write_callback.WaitForResult());
  215. // Simulate an input report (device to host).
  216. auto buffer = base::MakeRefCounted<base::RefCountedBytes>(buffer_vec);
  217. ASSERT_EQ(buffer->size(), kTestBufferSize);
  218. fake_connection_->SimulateInputReport(buffer);
  219. connection_client_->WaitForInputReport();
  220. // The connection client should have been notified.
  221. EXPECT_EQ(connection_client_->report_id(), kTestReportId);
  222. const std::vector<uint8_t>& in_buffer = connection_client_->buffer();
  223. ASSERT_EQ(in_buffer.size(), kTestBufferSize - 1);
  224. for (size_t i = 1; i < kTestBufferSize; ++i) {
  225. EXPECT_EQ(in_buffer[i - 1], buffer_vec[i])
  226. << "Mismatch at index " << i << ".";
  227. }
  228. }
  229. TEST_F(HidConnectionImplTest, DestroyWithPendingInputReport) {
  230. CreateHidConnection(/*with_connection_client=*/false);
  231. const size_t kTestBufferSize = kMaxReportSizeBytes;
  232. std::vector<uint8_t> buffer_vec =
  233. CreateTestReportBuffer(kTestReportId, kTestBufferSize);
  234. // Simulate an input report (device to host).
  235. auto buffer = base::MakeRefCounted<base::RefCountedBytes>(buffer_vec);
  236. ASSERT_EQ(buffer->size(), kTestBufferSize);
  237. fake_connection_->SimulateInputReport(buffer);
  238. // Destroy the connection without reading the report.
  239. hid_connection_.reset();
  240. }
  241. TEST_F(HidConnectionImplTest, DestroyWithPendingRead) {
  242. CreateHidConnection(/*with_connection_client=*/false);
  243. // Simulate reading an input report.
  244. TestIoCallback read_callback;
  245. hid_connection_impl_->Read(read_callback.GetReadCallback());
  246. // Destroy the connection without receiving an input report.
  247. hid_connection_.reset();
  248. }
  249. TEST_F(HidConnectionImplTest, WatcherClosedWhenHidConnectionClosed) {
  250. mojo::PendingRemote<mojom::HidConnectionWatcher> watcher;
  251. auto watcher_receiver = mojo::MakeSelfOwnedReceiver(
  252. std::make_unique<mojom::HidConnectionWatcher>(),
  253. watcher.InitWithNewPipeAndPassReceiver());
  254. mojo::Remote<mojom::HidConnection> hid_connection;
  255. HidConnectionImpl::Create(
  256. base::MakeRefCounted<FakeHidConnection>(CreateTestDevice()),
  257. hid_connection.BindNewPipeAndPassReceiver(),
  258. /*connection_client=*/mojo::NullRemote(), std::move(watcher));
  259. // To start with both the HID connection and the connection watcher connection
  260. // should remain open.
  261. hid_connection.FlushForTesting();
  262. EXPECT_TRUE(hid_connection.is_connected());
  263. watcher_receiver->FlushForTesting();
  264. EXPECT_TRUE(watcher_receiver);
  265. // When the HID connection is closed the watcher connection should be closed.
  266. hid_connection.reset();
  267. watcher_receiver->FlushForTesting();
  268. EXPECT_FALSE(watcher_receiver);
  269. }
  270. TEST_F(HidConnectionImplTest, HidConnectionClosedWhenWatcherClosed) {
  271. mojo::PendingRemote<mojom::HidConnectionWatcher> watcher;
  272. auto watcher_receiver = mojo::MakeSelfOwnedReceiver(
  273. std::make_unique<mojom::HidConnectionWatcher>(),
  274. watcher.InitWithNewPipeAndPassReceiver());
  275. mojo::Remote<mojom::HidConnection> hid_connection;
  276. HidConnectionImpl::Create(
  277. base::MakeRefCounted<FakeHidConnection>(CreateTestDevice()),
  278. hid_connection.BindNewPipeAndPassReceiver(),
  279. /*connection_client=*/mojo::NullRemote(), std::move(watcher));
  280. // To start with both the HID connection and the connection watcher connection
  281. // should remain open.
  282. hid_connection.FlushForTesting();
  283. EXPECT_TRUE(hid_connection.is_connected());
  284. watcher_receiver->FlushForTesting();
  285. EXPECT_TRUE(watcher_receiver);
  286. // When the watcher connection is closed, for safety, the HID connection
  287. // should also be closed.
  288. watcher_receiver->Close();
  289. hid_connection.FlushForTesting();
  290. EXPECT_FALSE(hid_connection.is_connected());
  291. }
  292. } // namespace device