media_foundation_cdm_session_unittest.cc 9.2 KB


  1. // Copyright 2020 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "media/cdm/win/media_foundation_cdm_session.h"
  5. #include <wchar.h>
  6. #include "base/bind.h"
  7. #include "base/test/mock_callback.h"
  8. #include "base/test/task_environment.h"
  9. #include "media/base/mock_filters.h"
  10. #include "media/base/test_helpers.h"
  11. #include "media/base/win/mf_helpers.h"
  12. #include "media/base/win/mf_mocks.h"
  13. #include "testing/gmock/include/gmock/gmock.h"
  14. #include "testing/gtest/include/gtest/gtest.h"
  15. using ::testing::_;
  16. using ::testing::DoAll;
  17. using ::testing::InSequence;
  18. using ::testing::IsEmpty;
  19. using ::testing::NotNull;
  20. using ::testing::Return;
  21. using ::testing::SetArgPointee;
  22. using ::testing::StrictMock;
  23. using ::testing::WithoutArgs;
  24. namespace media {
  25. namespace {
  26. const double kExpirationMs = 123456789.0;
  27. const auto kExpirationTime = base::Time::FromJsTime(kExpirationMs);
  28. const char kTestUmaPrefix[] = "Media.EME.TestUmaPrefix.";
  29. std::vector<uint8_t> StringToVector(const std::string& str) {
  30. return std::vector<uint8_t>(str.begin(), str.end());
  31. }
  32. } // namespace
  33. using Microsoft::WRL::ComPtr;
  34. class MediaFoundationCdmSessionTest : public testing::Test {
  35. public:
  36. MediaFoundationCdmSessionTest()
  37. : mf_cdm_(MakeComPtr<MockMFCdm>()),
  38. mf_cdm_session_(MakeComPtr<MockMFCdmSession>()),
  39. cdm_session_(
  40. kTestUmaPrefix,
  41. base::BindRepeating(&MockCdmClient::OnSessionMessage,
  42. base::Unretained(&cdm_client_)),
  43. base::BindRepeating(&MockCdmClient::OnSessionKeysChange,
  44. base::Unretained(&cdm_client_)),
  45. base::BindRepeating(&MockCdmClient::OnSessionExpirationUpdate,
  46. base::Unretained(&cdm_client_))) {}
  47. ~MediaFoundationCdmSessionTest() override = default;
  48. void Initialize() {
  49. COM_EXPECT_CALL(mf_cdm_,
  50. CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
  51. .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_),
  52. SetComPointee<2>(mf_cdm_session_.Get()), Return(S_OK)));
  53. ASSERT_SUCCESS(
  54. cdm_session_.Initialize(mf_cdm_.Get(), CdmSessionType::kTemporary));
  55. }
  56. void GenerateRequest() {
  57. std::vector<uint8_t> init_data = StringToVector("init_data");
  58. std::vector<uint8_t> license_request = StringToVector("request");
  59. base::MockCallback<MediaFoundationCdmSession::SessionIdCB> session_id_cb;
  60. // Session ID to return. Will be released by |mf_cdm_session_|.
  61. LPWSTR session_id = nullptr;
  62. ASSERT_SUCCESS(CopyCoTaskMemWideString(L"session_id", &session_id));
  63. {
  64. // Use InSequence here because the order of events matter. |session_id_cb|
  65. // must be called before OnSessionMessage().
  66. InSequence seq;
  67. COM_EXPECT_CALL(mf_cdm_session_, GenerateRequest(_, _, init_data.size()))
  68. .WillOnce(WithoutArgs([&] {
  69. mf_cdm_session_callbacks_->KeyMessage(
  70. MF_MEDIAKEYSESSION_MESSAGETYPE_LICENSE_REQUEST,
  71. license_request.data(), license_request.size(), nullptr);
  72. return S_OK;
  73. }));
  74. COM_EXPECT_CALL(mf_cdm_session_, GetSessionId(_))
  75. .WillOnce(DoAll(SetArgPointee<0>(session_id), Return(S_OK)));
  76. EXPECT_CALL(session_id_cb, Run(_)).WillOnce(Return(true));
  77. EXPECT_CALL(cdm_client_,
  78. OnSessionMessage(_, CdmMessageType::LICENSE_REQUEST,
  79. license_request));
  80. }
  81. EXPECT_SUCCESS(cdm_session_.GenerateRequest(
  82. EmeInitDataType::WEBM, init_data, session_id_cb.Get()));
  83. task_environment_.RunUntilIdle();
  84. }
  85. protected:
  86. base::test::TaskEnvironment task_environment_;
  87. StrictMock<MockCdmClient> cdm_client_;
  88. ComPtr<MockMFCdm> mf_cdm_;
  89. ComPtr<MockMFCdmSession> mf_cdm_session_;
  90. MediaFoundationCdmSession cdm_session_;
  91. ComPtr<IMFContentDecryptionModuleSessionCallbacks> mf_cdm_session_callbacks_;
  92. };
  93. TEST_F(MediaFoundationCdmSessionTest, Initialize) {
  94. Initialize();
  95. }
  96. TEST_F(MediaFoundationCdmSessionTest, Initialize_Failure) {
  97. COM_EXPECT_CALL(mf_cdm_,
  98. CreateSession(MF_MEDIAKEYSESSION_TYPE_TEMPORARY, _, _))
  99. .WillOnce(DoAll(SaveComPtr<1>(&mf_cdm_session_callbacks_),
  100. SetComPointee<2>(mf_cdm_session_.Get()), Return(E_FAIL)));
  101. EXPECT_FAILED(
  102. cdm_session_.Initialize(mf_cdm_.Get(), CdmSessionType::kTemporary));
  103. }
  104. TEST_F(MediaFoundationCdmSessionTest, GenerateRequest) {
  105. Initialize();
  106. GenerateRequest();
  107. }
  108. TEST_F(MediaFoundationCdmSessionTest, GenerateRequest_Failure) {
  109. Initialize();
  110. std::vector<uint8_t> init_data = StringToVector("init_data");
  111. base::MockCallback<MediaFoundationCdmSession::SessionIdCB> session_id_cb;
  112. COM_EXPECT_CALL(mf_cdm_session_, GenerateRequest(_, _, init_data.size()))
  113. .WillOnce(Return(E_FAIL));
  114. EXPECT_FAILED(cdm_session_.GenerateRequest(EmeInitDataType::WEBM, init_data,
  115. session_id_cb.Get()));
  116. task_environment_.RunUntilIdle();
  117. }
  118. TEST_F(MediaFoundationCdmSessionTest, GetSessionId_Failure) {
  119. Initialize();
  120. std::vector<uint8_t> init_data = StringToVector("init_data");
  121. std::vector<uint8_t> license_request = StringToVector("request");
  122. base::MockCallback<MediaFoundationCdmSession::SessionIdCB> session_id_cb;
  123. COM_EXPECT_CALL(mf_cdm_session_, GenerateRequest(_, _, init_data.size()))
  124. .WillOnce(WithoutArgs([&] {
  125. mf_cdm_session_callbacks_->KeyMessage(
  126. MF_MEDIAKEYSESSION_MESSAGETYPE_LICENSE_REQUEST,
  127. license_request.data(), license_request.size(), nullptr);
  128. return S_OK;
  129. }));
  130. COM_EXPECT_CALL(mf_cdm_session_, GetSessionId(_)).WillOnce(Return(E_FAIL));
  131. EXPECT_CALL(session_id_cb, Run(IsEmpty()));
  132. // OnSessionMessage() will not be called.
  133. EXPECT_SUCCESS(cdm_session_.GenerateRequest(EmeInitDataType::WEBM, init_data,
  134. session_id_cb.Get()));
  135. task_environment_.RunUntilIdle();
  136. }
  137. TEST_F(MediaFoundationCdmSessionTest, GetSessionId_Empty) {
  138. Initialize();
  139. std::vector<uint8_t> init_data = StringToVector("init_data");
  140. std::vector<uint8_t> license_request = StringToVector("request");
  141. base::MockCallback<MediaFoundationCdmSession::SessionIdCB> session_id_cb;
  142. // Session ID to return. Will be released by |mf_cdm_session_|.
  143. LPWSTR empty_session_id = nullptr;
  144. ASSERT_SUCCESS(CopyCoTaskMemWideString(L"", &empty_session_id));
  145. COM_EXPECT_CALL(mf_cdm_session_, GenerateRequest(_, _, init_data.size()))
  146. .WillOnce(WithoutArgs([&] {
  147. mf_cdm_session_callbacks_->KeyMessage(
  148. MF_MEDIAKEYSESSION_MESSAGETYPE_LICENSE_REQUEST,
  149. license_request.data(), license_request.size(), nullptr);
  150. return S_OK;
  151. }));
  152. COM_EXPECT_CALL(mf_cdm_session_, GetSessionId(_))
  153. .WillOnce(DoAll(SetArgPointee<0>(empty_session_id), Return(S_OK)));
  154. EXPECT_CALL(session_id_cb, Run(IsEmpty()));
  155. // OnSessionMessage() will not be called since session ID is empty.
  156. EXPECT_SUCCESS(cdm_session_.GenerateRequest(EmeInitDataType::WEBM, init_data,
  157. session_id_cb.Get()));
  158. task_environment_.RunUntilIdle();
  159. }
  160. TEST_F(MediaFoundationCdmSessionTest, Update) {
  161. Initialize();
  162. GenerateRequest();
  163. std::vector<uint8_t> response = StringToVector("response");
  164. COM_EXPECT_CALL(mf_cdm_session_, Update(NotNull(), response.size()))
  165. .WillOnce(DoAll([&] { mf_cdm_session_callbacks_->KeyStatusChanged(); },
  166. Return(S_OK)));
  167. COM_EXPECT_CALL(mf_cdm_session_, GetKeyStatuses(_, _)).WillOnce(Return(S_OK));
  168. COM_EXPECT_CALL(mf_cdm_session_, GetExpiration(_))
  169. .WillOnce(DoAll(SetArgPointee<0>(kExpirationMs), Return(S_OK)));
  170. EXPECT_CALL(cdm_client_, OnSessionKeysChangeCalled(_, true));
  171. EXPECT_CALL(cdm_client_, OnSessionExpirationUpdate(_, kExpirationTime));
  172. EXPECT_SUCCESS(cdm_session_.Update(response));
  173. task_environment_.RunUntilIdle();
  174. }
  175. TEST_F(MediaFoundationCdmSessionTest, Update_Failure) {
  176. Initialize();
  177. GenerateRequest();
  178. std::vector<uint8_t> response = StringToVector("response");
  179. COM_EXPECT_CALL(mf_cdm_session_, Update(NotNull(), response.size()))
  180. .WillOnce(Return(E_FAIL));
  181. EXPECT_FAILED(cdm_session_.Update(response));
  182. }
  183. TEST_F(MediaFoundationCdmSessionTest, Close) {
  184. Initialize();
  185. GenerateRequest();
  186. COM_EXPECT_CALL(mf_cdm_session_, Close()).WillOnce(Return(S_OK));
  187. EXPECT_SUCCESS(cdm_session_.Close());
  188. }
  189. TEST_F(MediaFoundationCdmSessionTest, Close_Failure) {
  190. Initialize();
  191. GenerateRequest();
  192. COM_EXPECT_CALL(mf_cdm_session_, Close()).WillOnce(Return(E_FAIL));
  193. EXPECT_FAILED(cdm_session_.Close());
  194. }
  195. TEST_F(MediaFoundationCdmSessionTest, Remove) {
  196. Initialize();
  197. GenerateRequest();
  198. COM_EXPECT_CALL(mf_cdm_session_, Remove()).WillOnce(Return(S_OK));
  199. COM_EXPECT_CALL(mf_cdm_session_, GetExpiration(_))
  200. .WillOnce(DoAll(SetArgPointee<0>(kExpirationMs), Return(S_OK)));
  201. EXPECT_CALL(cdm_client_, OnSessionExpirationUpdate(_, kExpirationTime));
  202. EXPECT_SUCCESS(cdm_session_.Remove());
  203. }
  204. TEST_F(MediaFoundationCdmSessionTest, Remove_Failure) {
  205. Initialize();
  206. GenerateRequest();
  207. COM_EXPECT_CALL(mf_cdm_session_, Remove()).WillOnce(Return(E_FAIL));
  208. EXPECT_FAILED(cdm_session_.Remove());
  209. }
  210. } // namespace media