client.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (C) 2019 Garmin Ltd.
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. #
  5. import json
  6. import logging
  7. import socket
  8. import os
  9. from . import chunkify, DEFAULT_MAX_CHUNK
  10. logger = logging.getLogger('hashserv.client')
  11. class HashConnectionError(Exception):
  12. pass
  13. class Client(object):
  14. MODE_NORMAL = 0
  15. MODE_GET_STREAM = 1
  16. def __init__(self):
  17. self._socket = None
  18. self.reader = None
  19. self.writer = None
  20. self.mode = self.MODE_NORMAL
  21. self.max_chunk = DEFAULT_MAX_CHUNK
  22. def connect_tcp(self, address, port):
  23. def connect_sock():
  24. s = socket.create_connection((address, port))
  25. s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
  26. s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
  27. s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
  28. return s
  29. self._connect_sock = connect_sock
  30. def connect_unix(self, path):
  31. def connect_sock():
  32. s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  33. # AF_UNIX has path length issues so chdir here to workaround
  34. cwd = os.getcwd()
  35. try:
  36. os.chdir(os.path.dirname(path))
  37. s.connect(os.path.basename(path))
  38. finally:
  39. os.chdir(cwd)
  40. return s
  41. self._connect_sock = connect_sock
  42. def connect(self):
  43. if self._socket is None:
  44. self._socket = self._connect_sock()
  45. self.reader = self._socket.makefile('r', encoding='utf-8')
  46. self.writer = self._socket.makefile('w', encoding='utf-8')
  47. self.writer.write('OEHASHEQUIV 1.1\n\n')
  48. self.writer.flush()
  49. # Restore mode if the socket is being re-created
  50. cur_mode = self.mode
  51. self.mode = self.MODE_NORMAL
  52. self._set_mode(cur_mode)
  53. return self._socket
  54. def close(self):
  55. if self._socket is not None:
  56. self._socket.close()
  57. self._socket = None
  58. self.reader = None
  59. self.writer = None
  60. def _send_wrapper(self, proc):
  61. count = 0
  62. while True:
  63. try:
  64. self.connect()
  65. return proc()
  66. except (OSError, HashConnectionError, json.JSONDecodeError, UnicodeDecodeError) as e:
  67. logger.warning('Error talking to server: %s' % e)
  68. if count >= 3:
  69. if not isinstance(e, HashConnectionError):
  70. raise HashConnectionError(str(e))
  71. raise e
  72. self.close()
  73. count += 1
  74. def send_message(self, msg):
  75. def get_line():
  76. line = self.reader.readline()
  77. if not line:
  78. raise HashConnectionError('Connection closed')
  79. if not line.endswith('\n'):
  80. raise HashConnectionError('Bad message %r' % message)
  81. return line
  82. def proc():
  83. for c in chunkify(json.dumps(msg), self.max_chunk):
  84. self.writer.write(c)
  85. self.writer.flush()
  86. l = get_line()
  87. m = json.loads(l)
  88. if 'chunk-stream' in m:
  89. lines = []
  90. while True:
  91. l = get_line().rstrip('\n')
  92. if not l:
  93. break
  94. lines.append(l)
  95. m = json.loads(''.join(lines))
  96. return m
  97. return self._send_wrapper(proc)
  98. def send_stream(self, msg):
  99. def proc():
  100. self.writer.write("%s\n" % msg)
  101. self.writer.flush()
  102. l = self.reader.readline()
  103. if not l:
  104. raise HashConnectionError('Connection closed')
  105. return l.rstrip()
  106. return self._send_wrapper(proc)
  107. def _set_mode(self, new_mode):
  108. if new_mode == self.MODE_NORMAL and self.mode == self.MODE_GET_STREAM:
  109. r = self.send_stream('END')
  110. if r != 'ok':
  111. raise HashConnectionError('Bad response from server %r' % r)
  112. elif new_mode == self.MODE_GET_STREAM and self.mode == self.MODE_NORMAL:
  113. r = self.send_message({'get-stream': None})
  114. if r != 'ok':
  115. raise HashConnectionError('Bad response from server %r' % r)
  116. elif new_mode != self.mode:
  117. raise Exception('Undefined mode transition %r -> %r' % (self.mode, new_mode))
  118. self.mode = new_mode
  119. def get_unihash(self, method, taskhash):
  120. self._set_mode(self.MODE_GET_STREAM)
  121. r = self.send_stream('%s %s' % (method, taskhash))
  122. if not r:
  123. return None
  124. return r
  125. def report_unihash(self, taskhash, method, outhash, unihash, extra={}):
  126. self._set_mode(self.MODE_NORMAL)
  127. m = extra.copy()
  128. m['taskhash'] = taskhash
  129. m['method'] = method
  130. m['outhash'] = outhash
  131. m['unihash'] = unihash
  132. return self.send_message({'report': m})
  133. def report_unihash_equiv(self, taskhash, method, unihash, extra={}):
  134. self._set_mode(self.MODE_NORMAL)
  135. m = extra.copy()
  136. m['taskhash'] = taskhash
  137. m['method'] = method
  138. m['unihash'] = unihash
  139. return self.send_message({'report-equiv': m})
  140. def get_taskhash(self, method, taskhash, all_properties=False):
  141. self._set_mode(self.MODE_NORMAL)
  142. return self.send_message({'get': {
  143. 'taskhash': taskhash,
  144. 'method': method,
  145. 'all': all_properties
  146. }})
  147. def get_stats(self):
  148. self._set_mode(self.MODE_NORMAL)
  149. return self.send_message({'get-stats': None})
  150. def reset_stats(self):
  151. self._set_mode(self.MODE_NORMAL)
  152. return self.send_message({'reset-stats': None})