mock_sspi_library_win.cc 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. // Copyright (c) 2010 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/http/mock_sspi_library_win.h"
  5. #include <algorithm>
  6. #include <cstring>
  7. #include <memory>
  8. #include <string>
  9. #include "base/check_op.h"
  10. #include "base/memory/raw_ptr.h"
  11. #include "base/strings/string_util_win.h"
  12. #include "base/strings/stringprintf.h"
  13. #include "base/strings/utf_string_conversions.h"
  14. #include "base/time/time.h"
  15. #include "testing/gtest/include/gtest/gtest.h"
  16. // Comparator so we can use CredHandle and CtxtHandle with std::set. Both of
  17. // those classes are typedefs for _SecHandle.
  18. bool operator<(const _SecHandle left, const _SecHandle right) {
  19. return left.dwUpper < right.dwUpper || left.dwLower < right.dwLower;
  20. }
  21. namespace net {
  22. namespace {
  23. int uniquifier_ = 0;
  24. struct MockCredential {
  25. std::u16string source_principal;
  26. std::u16string package;
  27. bool has_explicit_credentials = false;
  28. int uniquifier = ++uniquifier_;
  29. // CredHandle and CtxtHandle both shared the following definition:
  30. //
  31. // typedef struct _SecHandle {
  32. // ULONG_PTR dwLower;
  33. // ULONG_PTR dwUpper;
  34. // } SecHandle, * PSecHandle;
  35. //
  36. // ULONG_PTR type can hold a pointer. This function stuffs |this| into dwUpper
  37. // and adds a uniquifier to dwLower. This ensures that all PCredHandles issued
  38. // by this method during the lifetime of this process is unique.
  39. void StoreInHandle(PCredHandle handle) {
  40. DCHECK(uniquifier > 0);
  41. EXPECT_FALSE(SecIsValidHandle(handle));
  42. handle->dwLower = uniquifier;
  43. handle->dwUpper = reinterpret_cast<ULONG_PTR>(this);
  44. DCHECK(SecIsValidHandle(handle));
  45. }
  46. static MockCredential* FromHandle(PCredHandle handle) {
  47. return reinterpret_cast<MockCredential*>(handle->dwUpper);
  48. }
  49. };
  50. struct MockContext {
  51. raw_ptr<MockCredential> credential = nullptr;
  52. std::u16string target_principal;
  53. int uniquifier = ++uniquifier_;
  54. int rounds = 0;
  55. // CredHandle and CtxtHandle both shared the following definition:
  56. //
  57. // typedef struct _SecHandle {
  58. // ULONG_PTR dwLower;
  59. // ULONG_PTR dwUpper;
  60. // } SecHandle, * PSecHandle;
  61. //
  62. // ULONG_PTR type can hold a pointer. This function stuffs |this| into dwUpper
  63. // and adds a uniquifier to dwLower. This ensures that all PCredHandles issued
  64. // by this method during the lifetime of this process is unique.
  65. void StoreInHandle(PCtxtHandle handle) {
  66. EXPECT_FALSE(SecIsValidHandle(handle));
  67. DCHECK(uniquifier > 0);
  68. handle->dwLower = uniquifier;
  69. handle->dwUpper = reinterpret_cast<ULONG_PTR>(this);
  70. DCHECK(SecIsValidHandle(handle));
  71. }
  72. std::string ToString() const {
  73. return base::StringPrintf(
  74. "%s's token #%d for %S",
  75. base::UTF16ToUTF8(credential->source_principal).c_str(), rounds + 1,
  76. base::as_wcstr(target_principal));
  77. }
  78. static MockContext* FromHandle(PCtxtHandle handle) {
  79. return reinterpret_cast<MockContext*>(handle->dwUpper);
  80. }
  81. };
  82. } // namespace
  83. MockSSPILibrary::MockSSPILibrary(const wchar_t* package)
  84. : SSPILibrary(package) {}
  85. MockSSPILibrary::~MockSSPILibrary() {
  86. EXPECT_TRUE(expected_package_queries_.empty());
  87. EXPECT_TRUE(expected_freed_packages_.empty());
  88. EXPECT_TRUE(active_credentials_.empty());
  89. EXPECT_TRUE(active_contexts_.empty());
  90. }
  91. SECURITY_STATUS MockSSPILibrary::AcquireCredentialsHandle(
  92. LPWSTR pszPrincipal,
  93. unsigned long fCredentialUse,
  94. void* pvLogonId,
  95. void* pvAuthData,
  96. SEC_GET_KEY_FN pGetKeyFn,
  97. void* pvGetKeyArgument,
  98. PCredHandle phCredential,
  99. PTimeStamp ptsExpiry) {
  100. DCHECK(!SecIsValidHandle(phCredential));
  101. auto* credential = new MockCredential;
  102. credential->source_principal =
  103. pszPrincipal ? base::as_u16cstr(pszPrincipal) : u"<Default>";
  104. credential->package = base::as_u16cstr(package_name_.c_str());
  105. credential->has_explicit_credentials = !!pvAuthData;
  106. credential->StoreInHandle(phCredential);
  107. if (ptsExpiry) {
  108. ptsExpiry->LowPart = 0xBAA5B780;
  109. ptsExpiry->HighPart = 0x01D54E17;
  110. }
  111. active_credentials_.insert(*phCredential);
  112. return SEC_E_OK;
  113. }
  114. SECURITY_STATUS MockSSPILibrary::InitializeSecurityContext(
  115. PCredHandle phCredential,
  116. PCtxtHandle phContext,
  117. SEC_WCHAR* pszTargetName,
  118. unsigned long fContextReq,
  119. unsigned long Reserved1,
  120. unsigned long TargetDataRep,
  121. PSecBufferDesc pInput,
  122. unsigned long Reserved2,
  123. PCtxtHandle phNewContext,
  124. PSecBufferDesc pOutput,
  125. unsigned long* contextAttr,
  126. PTimeStamp ptsExpiry) {
  127. MockContext* new_context = new MockContext;
  128. new_context->credential = MockCredential::FromHandle(phCredential);
  129. new_context->target_principal = base::as_u16cstr(pszTargetName);
  130. new_context->rounds = 0;
  131. // Always rotate contexts. That way tests will fail if the caller's context
  132. // management is broken.
  133. if (phContext && SecIsValidHandle(phContext)) {
  134. std::unique_ptr<MockContext> old_context{
  135. MockContext::FromHandle(phContext)};
  136. EXPECT_EQ(old_context->credential, new_context->credential);
  137. EXPECT_EQ(1u, active_contexts_.erase(*phContext));
  138. new_context->rounds = old_context->rounds + 1;
  139. SecInvalidateHandle(phContext);
  140. }
  141. new_context->StoreInHandle(phNewContext);
  142. active_contexts_.insert(*phNewContext);
  143. auto token = new_context->ToString();
  144. PSecBuffer out_buffer = pOutput->pBuffers;
  145. out_buffer->cbBuffer = std::min<ULONG>(out_buffer->cbBuffer, token.size());
  146. std::memcpy(out_buffer->pvBuffer, token.data(), out_buffer->cbBuffer);
  147. if (ptsExpiry) {
  148. ptsExpiry->LowPart = 0xBAA5B780;
  149. ptsExpiry->HighPart = 0x01D54E15;
  150. }
  151. return SEC_E_OK;
  152. }
  153. SECURITY_STATUS MockSSPILibrary::QueryContextAttributesEx(PCtxtHandle phContext,
  154. ULONG ulAttribute,
  155. PVOID pBuffer,
  156. ULONG cbBuffer) {
  157. static const SecPkgInfoW kNegotiatedPackage = {
  158. 0,
  159. 0,
  160. 0,
  161. 0,
  162. const_cast<SEC_WCHAR*>(L"Itsa me Kerberos!!"),
  163. const_cast<SEC_WCHAR*>(L"I like turtles")};
  164. auto* context = MockContext::FromHandle(phContext);
  165. switch (ulAttribute) {
  166. case SECPKG_ATTR_NATIVE_NAMES: {
  167. auto* native_names =
  168. reinterpret_cast<SecPkgContext_NativeNames*>(pBuffer);
  169. DCHECK_EQ(sizeof(*native_names), cbBuffer);
  170. native_names->sClientName =
  171. base::as_writable_wcstr(context->credential->source_principal);
  172. native_names->sServerName =
  173. base::as_writable_wcstr(context->target_principal);
  174. return SEC_E_OK;
  175. }
  176. case SECPKG_ATTR_NEGOTIATION_INFO: {
  177. auto* negotiation_info =
  178. reinterpret_cast<SecPkgContext_NegotiationInfo*>(pBuffer);
  179. DCHECK_EQ(sizeof(*negotiation_info), cbBuffer);
  180. negotiation_info->PackageInfo =
  181. const_cast<SecPkgInfoW*>(&kNegotiatedPackage);
  182. negotiation_info->NegotiationState = (context->rounds == 1)
  183. ? SECPKG_NEGOTIATION_COMPLETE
  184. : SECPKG_NEGOTIATION_IN_PROGRESS;
  185. return SEC_E_OK;
  186. }
  187. case SECPKG_ATTR_AUTHORITY: {
  188. auto* authority = reinterpret_cast<SecPkgContext_Authority*>(pBuffer);
  189. DCHECK_EQ(sizeof(*authority), cbBuffer);
  190. authority->sAuthorityName = const_cast<SEC_WCHAR*>(L"Dodgy Server");
  191. return SEC_E_OK;
  192. }
  193. default:
  194. return SEC_E_UNSUPPORTED_FUNCTION;
  195. }
  196. }
  197. SECURITY_STATUS MockSSPILibrary::QuerySecurityPackageInfo(
  198. PSecPkgInfoW* pkgInfo) {
  199. if (expected_package_queries_.empty()) {
  200. static SecPkgInfoW kDefaultPkgInfo{
  201. 0, 0, 0, kDefaultMaxTokenLength, nullptr, nullptr};
  202. *pkgInfo = &kDefaultPkgInfo;
  203. expected_freed_packages_.insert(&kDefaultPkgInfo);
  204. return SEC_E_OK;
  205. }
  206. PackageQuery package_query = expected_package_queries_.front();
  207. expected_package_queries_.pop_front();
  208. *pkgInfo = package_query.package_info;
  209. if (package_query.response_code == SEC_E_OK)
  210. expected_freed_packages_.insert(package_query.package_info);
  211. return package_query.response_code;
  212. }
  213. SECURITY_STATUS MockSSPILibrary::FreeCredentialsHandle(
  214. PCredHandle phCredential) {
  215. DCHECK(SecIsValidHandle(phCredential));
  216. EXPECT_EQ(1u, active_credentials_.erase(*phCredential));
  217. std::unique_ptr<MockCredential> owned{
  218. MockCredential::FromHandle(phCredential)};
  219. SecInvalidateHandle(phCredential);
  220. return SEC_E_OK;
  221. }
  222. SECURITY_STATUS MockSSPILibrary::DeleteSecurityContext(PCtxtHandle phContext) {
  223. std::unique_ptr<MockContext> context{MockContext::FromHandle(phContext)};
  224. EXPECT_EQ(1u, active_contexts_.erase(*phContext));
  225. SecInvalidateHandle(phContext);
  226. return SEC_E_OK;
  227. }
  228. SECURITY_STATUS MockSSPILibrary::FreeContextBuffer(PVOID pvContextBuffer) {
  229. PSecPkgInfoW package_info = static_cast<PSecPkgInfoW>(pvContextBuffer);
  230. std::set<PSecPkgInfoW>::iterator it = expected_freed_packages_.find(
  231. package_info);
  232. EXPECT_TRUE(it != expected_freed_packages_.end());
  233. expected_freed_packages_.erase(it);
  234. return SEC_E_OK;
  235. }
  236. void MockSSPILibrary::ExpectQuerySecurityPackageInfo(
  237. SECURITY_STATUS response_code,
  238. PSecPkgInfoW package_info) {
  239. expected_package_queries_.emplace_back(
  240. PackageQuery{response_code, package_info});
  241. }
  242. } // namespace net