ntlm_buffer_reader.cc 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. // Copyright 2017 The Chromium Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style license that can be
  3. // found in the LICENSE file.
  4. #include "net/ntlm/ntlm_buffer_reader.h"
  5. #include <string.h>
  6. #include "base/check_op.h"
  7. namespace net::ntlm {
  8. NtlmBufferReader::NtlmBufferReader() = default;
  9. NtlmBufferReader::NtlmBufferReader(base::span<const uint8_t> buffer)
  10. : buffer_(buffer) {}
  11. NtlmBufferReader::~NtlmBufferReader() = default;
  12. bool NtlmBufferReader::CanRead(size_t len) const {
  13. return CanReadFrom(GetCursor(), len);
  14. }
  15. bool NtlmBufferReader::CanReadFrom(size_t offset, size_t len) const {
  16. if (len == 0)
  17. return true;
  18. return (len <= GetLength() && offset <= GetLength() - len);
  19. }
  20. bool NtlmBufferReader::ReadUInt16(uint16_t* value) {
  21. return ReadUInt<uint16_t>(value);
  22. }
  23. bool NtlmBufferReader::ReadUInt32(uint32_t* value) {
  24. return ReadUInt<uint32_t>(value);
  25. }
  26. bool NtlmBufferReader::ReadUInt64(uint64_t* value) {
  27. return ReadUInt<uint64_t>(value);
  28. }
  29. bool NtlmBufferReader::ReadFlags(NegotiateFlags* flags) {
  30. uint32_t raw;
  31. if (!ReadUInt32(&raw))
  32. return false;
  33. *flags = static_cast<NegotiateFlags>(raw);
  34. return true;
  35. }
  36. bool NtlmBufferReader::ReadBytes(base::span<uint8_t> buffer) {
  37. if (!CanRead(buffer.size()))
  38. return false;
  39. if (buffer.empty())
  40. return true;
  41. memcpy(buffer.data(), GetBufferAtCursor(), buffer.size());
  42. AdvanceCursor(buffer.size());
  43. return true;
  44. }
  45. bool NtlmBufferReader::ReadBytesFrom(const SecurityBuffer& sec_buf,
  46. base::span<uint8_t> buffer) {
  47. if (!CanReadFrom(sec_buf) || buffer.size() < sec_buf.length)
  48. return false;
  49. if (buffer.empty())
  50. return true;
  51. memcpy(buffer.data(), GetBufferPtr() + sec_buf.offset, sec_buf.length);
  52. return true;
  53. }
  54. bool NtlmBufferReader::ReadPayloadAsBufferReader(const SecurityBuffer& sec_buf,
  55. NtlmBufferReader* reader) {
  56. if (!CanReadFrom(sec_buf))
  57. return false;
  58. *reader = NtlmBufferReader(
  59. base::make_span(GetBufferPtr() + sec_buf.offset, sec_buf.length));
  60. return true;
  61. }
  62. bool NtlmBufferReader::ReadSecurityBuffer(SecurityBuffer* sec_buf) {
  63. return ReadUInt16(&sec_buf->length) && SkipBytes(sizeof(uint16_t)) &&
  64. ReadUInt32(&sec_buf->offset);
  65. }
  66. bool NtlmBufferReader::ReadAvPairHeader(TargetInfoAvId* avid, uint16_t* avlen) {
  67. if (!CanRead(kAvPairHeaderLen))
  68. return false;
  69. uint16_t raw_avid;
  70. bool result = ReadUInt16(&raw_avid) && ReadUInt16(avlen);
  71. DCHECK(result);
  72. // Don't try and validate the avid because the code only cares about a few
  73. // specific ones and it is likely a future version might extend this field.
  74. // The implementation can ignore and skip over AV Pairs it doesn't
  75. // understand.
  76. *avid = static_cast<TargetInfoAvId>(raw_avid);
  77. return true;
  78. }
  79. bool NtlmBufferReader::ReadTargetInfo(size_t target_info_len,
  80. std::vector<AvPair>* av_pairs) {
  81. DCHECK(av_pairs->empty());
  82. // A completely empty target info is allowed.
  83. if (target_info_len == 0)
  84. return true;
  85. // If there is any content there has to be at least one terminating header.
  86. if (!CanRead(target_info_len) || target_info_len < kAvPairHeaderLen) {
  87. return false;
  88. }
  89. size_t target_info_end = GetCursor() + target_info_len;
  90. bool saw_eol = false;
  91. while ((GetCursor() < target_info_end)) {
  92. AvPair pair;
  93. if (!ReadAvPairHeader(&pair.avid, &pair.avlen))
  94. break;
  95. // Make sure the length wouldn't read outside the buffer.
  96. if (!CanRead(pair.avlen))
  97. return false;
  98. // Take a copy of the payload in the AVPair.
  99. pair.buffer.assign(GetBufferAtCursor(), GetBufferAtCursor() + pair.avlen);
  100. if (pair.avid == TargetInfoAvId::kEol) {
  101. // Terminator must have zero length.
  102. if (pair.avlen != 0)
  103. return false;
  104. // Break out of the loop once a valid terminator is found. After the
  105. // loop it will be validated that the whole target info was consumed.
  106. saw_eol = true;
  107. break;
  108. }
  109. switch (pair.avid) {
  110. case TargetInfoAvId::kFlags:
  111. // For flags also populate the flags field so it doesn't
  112. // have to be modified through the raw buffer later.
  113. if (pair.avlen != sizeof(uint32_t) ||
  114. !ReadUInt32(reinterpret_cast<uint32_t*>(&pair.flags)))
  115. return false;
  116. break;
  117. case TargetInfoAvId::kTimestamp:
  118. // Populate timestamp so it doesn't need to be read through the
  119. // raw buffer later.
  120. if (pair.avlen != sizeof(uint64_t) || !ReadUInt64(&pair.timestamp))
  121. return false;
  122. break;
  123. case TargetInfoAvId::kChannelBindings:
  124. case TargetInfoAvId::kTargetName:
  125. // The server should never send these, and with EPA enabled the client
  126. // will add these to the authenticate message. To avoid issues with
  127. // duplicates or only one being read, just don't allow them.
  128. return false;
  129. default:
  130. // For all other types, just jump over the payload to the next pair.
  131. // If there aren't enough bytes left, then fail.
  132. if (!SkipBytes(pair.avlen))
  133. return false;
  134. break;
  135. }
  136. av_pairs->push_back(std::move(pair));
  137. }
  138. // Fail if the buffer wasn't properly formed. The entire payload should have
  139. // been consumed and a terminator found.
  140. if ((GetCursor() != target_info_end) || !saw_eol)
  141. return false;
  142. return true;
  143. }
  144. bool NtlmBufferReader::ReadTargetInfoPayload(std::vector<AvPair>* av_pairs) {
  145. DCHECK(av_pairs->empty());
  146. SecurityBuffer sec_buf;
  147. // First read the security buffer.
  148. if (!ReadSecurityBuffer(&sec_buf))
  149. return false;
  150. NtlmBufferReader payload_reader;
  151. if (!ReadPayloadAsBufferReader(sec_buf, &payload_reader))
  152. return false;
  153. if (!payload_reader.ReadTargetInfo(sec_buf.length, av_pairs))
  154. return false;
  155. // |ReadTargetInfo| should have consumed the entire contents.
  156. return payload_reader.IsEndOfBuffer();
  157. }
  158. bool NtlmBufferReader::ReadMessageType(MessageType* message_type) {
  159. uint32_t raw_message_type;
  160. if (!ReadUInt32(&raw_message_type))
  161. return false;
  162. *message_type = static_cast<MessageType>(raw_message_type);
  163. if (*message_type != MessageType::kNegotiate &&
  164. *message_type != MessageType::kChallenge &&
  165. *message_type != MessageType::kAuthenticate)
  166. return false;
  167. return true;
  168. }
  169. bool NtlmBufferReader::SkipSecurityBuffer() {
  170. return SkipBytes(kSecurityBufferLen);
  171. }
  172. bool NtlmBufferReader::SkipSecurityBufferWithValidation() {
  173. SecurityBuffer sec_buf;
  174. return ReadSecurityBuffer(&sec_buf) && CanReadFrom(sec_buf);
  175. }
  176. bool NtlmBufferReader::SkipBytes(size_t count) {
  177. if (!CanRead(count))
  178. return false;
  179. AdvanceCursor(count);
  180. return true;
  181. }
  182. bool NtlmBufferReader::MatchSignature() {
  183. if (!CanRead(kSignatureLen))
  184. return false;
  185. if (memcmp(kSignature, GetBufferAtCursor(), kSignatureLen) != 0)
  186. return false;
  187. AdvanceCursor(kSignatureLen);
  188. return true;
  189. }
  190. bool NtlmBufferReader::MatchMessageType(MessageType message_type) {
  191. MessageType actual_message_type;
  192. return ReadMessageType(&actual_message_type) &&
  193. (actual_message_type == message_type);
  194. }
  195. bool NtlmBufferReader::MatchMessageHeader(MessageType message_type) {
  196. return MatchSignature() && MatchMessageType(message_type);
  197. }
  198. bool NtlmBufferReader::MatchZeros(size_t count) {
  199. if (!CanRead(count))
  200. return false;
  201. for (size_t i = 0; i < count; i++) {
  202. if (GetBufferAtCursor()[i] != 0)
  203. return false;
  204. }
  205. AdvanceCursor(count);
  206. return true;
  207. }
  208. bool NtlmBufferReader::MatchEmptySecurityBuffer() {
  209. SecurityBuffer sec_buf;
  210. return ReadSecurityBuffer(&sec_buf) && (sec_buf.offset <= GetLength()) &&
  211. (sec_buf.length == 0);
  212. }
  213. template <typename T>
  214. bool NtlmBufferReader::ReadUInt(T* value) {
  215. size_t int_size = sizeof(T);
  216. if (!CanRead(int_size))
  217. return false;
  218. *value = 0;
  219. for (size_t i = 0; i < int_size; i++) {
  220. *value += static_cast<T>(GetByteAtCursor()) << (i * 8);
  221. AdvanceCursor(1);
  222. }
  223. return true;
  224. }
  225. void NtlmBufferReader::SetCursor(size_t cursor) {
  226. DCHECK_LE(cursor, GetLength());
  227. cursor_ = cursor;
  228. }
  229. } // namespace net::ntlm