_handshake.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """
  2. websocket - WebSocket client library for Python
  3. Copyright (C) 2010 Hiroki Ohtani(liris)
  4. This library is free software; you can redistribute it and/or
  5. modify it under the terms of the GNU Lesser General Public
  6. License as published by the Free Software Foundation; either
  7. version 2.1 of the License, or (at your option) any later version.
  8. This library is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  11. Lesser General Public License for more details.
  12. You should have received a copy of the GNU Lesser General Public
  13. License along with this library; if not, write to the Free Software
  14. Foundation, Inc., 51 Franklin Street, Fifth Floor,
  15. Boston, MA 02110-1335 USA
  16. """
  17. import hashlib
  18. import hmac
  19. import os
  20. import six
  21. from ._cookiejar import SimpleCookieJar
  22. from ._exceptions import *
  23. from ._http import *
  24. from ._logging import *
  25. from ._socket import *
  26. if six.PY3:
  27. from base64 import encodebytes as base64encode
  28. else:
  29. from base64 import encodestring as base64encode
  30. if six.PY3:
  31. if six.PY34:
  32. from http import client as HTTPStatus
  33. else:
  34. from http import HTTPStatus
  35. else:
  36. import httplib as HTTPStatus
  37. __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
  38. if hasattr(hmac, "compare_digest"):
  39. compare_digest = hmac.compare_digest
  40. else:
  41. def compare_digest(s1, s2):
  42. return s1 == s2
  43. # websocket supported version.
  44. VERSION = 13
  45. SUPPORTED_REDIRECT_STATUSES = [HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER]
  46. CookieJar = SimpleCookieJar()
  47. class handshake_response(object):
  48. def __init__(self, status, headers, subprotocol):
  49. self.status = status
  50. self.headers = headers
  51. self.subprotocol = subprotocol
  52. CookieJar.add(headers.get("set-cookie"))
  53. def handshake(sock, hostname, port, resource, **options):
  54. headers, key = _get_handshake_headers(resource, hostname, port, options)
  55. header_str = "\r\n".join(headers)
  56. send(sock, header_str)
  57. dump("request header", header_str)
  58. status, resp = _get_resp_headers(sock)
  59. if status in SUPPORTED_REDIRECT_STATUSES:
  60. return handshake_response(status, resp, None)
  61. success, subproto = _validate(resp, key, options.get("subprotocols"))
  62. if not success:
  63. raise WebSocketException("Invalid WebSocket Header")
  64. return handshake_response(status, resp, subproto)
  65. def _pack_hostname(hostname):
  66. # IPv6 address
  67. if ':' in hostname:
  68. return '[' + hostname + ']'
  69. return hostname
  70. def _get_handshake_headers(resource, host, port, options):
  71. headers = [
  72. "GET %s HTTP/1.1" % resource,
  73. "Upgrade: websocket",
  74. "Connection: Upgrade"
  75. ]
  76. if port == 80 or port == 443:
  77. hostport = _pack_hostname(host)
  78. else:
  79. hostport = "%s:%d" % (_pack_hostname(host), port)
  80. if "host" in options and options["host"] is not None:
  81. headers.append("Host: %s" % options["host"])
  82. else:
  83. headers.append("Host: %s" % hostport)
  84. if "suppress_origin" not in options or not options["suppress_origin"]:
  85. if "origin" in options and options["origin"] is not None:
  86. headers.append("Origin: %s" % options["origin"])
  87. else:
  88. headers.append("Origin: http://%s" % hostport)
  89. key = _create_sec_websocket_key()
  90. # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
  91. if not 'header' in options or 'Sec-WebSocket-Key' not in options['header']:
  92. key = _create_sec_websocket_key()
  93. headers.append("Sec-WebSocket-Key: %s" % key)
  94. else:
  95. key = options['header']['Sec-WebSocket-Key']
  96. if not 'header' in options or 'Sec-WebSocket-Version' not in options['header']:
  97. headers.append("Sec-WebSocket-Version: %s" % VERSION)
  98. subprotocols = options.get("subprotocols")
  99. if subprotocols:
  100. headers.append("Sec-WebSocket-Protocol: %s" % ",".join(subprotocols))
  101. if "header" in options:
  102. header = options["header"]
  103. if isinstance(header, dict):
  104. header = [
  105. ": ".join([k, v])
  106. for k, v in header.items()
  107. if v is not None
  108. ]
  109. headers.extend(header)
  110. server_cookie = CookieJar.get(host)
  111. client_cookie = options.get("cookie", None)
  112. cookie = "; ".join(filter(None, [server_cookie, client_cookie]))
  113. if cookie:
  114. headers.append("Cookie: %s" % cookie)
  115. headers.append("")
  116. headers.append("")
  117. return headers, key
  118. def _get_resp_headers(sock, success_statuses=(101, 301, 302, 303)):
  119. status, resp_headers, status_message = read_headers(sock)
  120. if status not in success_statuses:
  121. raise WebSocketBadStatusException("Handshake status %d %s", status, status_message, resp_headers)
  122. return status, resp_headers
  123. _HEADERS_TO_CHECK = {
  124. "upgrade": "websocket",
  125. "connection": "upgrade",
  126. }
  127. def _validate(headers, key, subprotocols):
  128. subproto = None
  129. for k, v in _HEADERS_TO_CHECK.items():
  130. r = headers.get(k, None)
  131. if not r:
  132. return False, None
  133. r = r.lower()
  134. if v != r:
  135. return False, None
  136. if subprotocols:
  137. subproto = headers.get("sec-websocket-protocol", None).lower()
  138. if not subproto or subproto not in [s.lower() for s in subprotocols]:
  139. error("Invalid subprotocol: " + str(subprotocols))
  140. return False, None
  141. result = headers.get("sec-websocket-accept", None)
  142. if not result:
  143. return False, None
  144. result = result.lower()
  145. if isinstance(result, six.text_type):
  146. result = result.encode('utf-8')
  147. value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
  148. hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
  149. success = compare_digest(hashed, result)
  150. if success:
  151. return True, subproto
  152. else:
  153. return False, None
  154. def _create_sec_websocket_key():
  155. randomness = os.urandom(16)
  156. return base64encode(randomness).decode('utf-8').strip()