server.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # Copyright (C) 2019 Garmin Ltd.
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. #
  5. from contextlib import closing
  6. from datetime import datetime
  7. import asyncio
  8. import json
  9. import logging
  10. import math
  11. import os
  12. import signal
  13. import socket
  14. import time
  15. logger = logging.getLogger('hashserv.server')
  16. class Measurement(object):
  17. def __init__(self, sample):
  18. self.sample = sample
  19. def start(self):
  20. self.start_time = time.perf_counter()
  21. def end(self):
  22. self.sample.add(time.perf_counter() - self.start_time)
  23. def __enter__(self):
  24. self.start()
  25. return self
  26. def __exit__(self, *args, **kwargs):
  27. self.end()
  28. class Sample(object):
  29. def __init__(self, stats):
  30. self.stats = stats
  31. self.num_samples = 0
  32. self.elapsed = 0
  33. def measure(self):
  34. return Measurement(self)
  35. def __enter__(self):
  36. return self
  37. def __exit__(self, *args, **kwargs):
  38. self.end()
  39. def add(self, elapsed):
  40. self.num_samples += 1
  41. self.elapsed += elapsed
  42. def end(self):
  43. if self.num_samples:
  44. self.stats.add(self.elapsed)
  45. self.num_samples = 0
  46. self.elapsed = 0
  47. class Stats(object):
  48. def __init__(self):
  49. self.reset()
  50. def reset(self):
  51. self.num = 0
  52. self.total_time = 0
  53. self.max_time = 0
  54. self.m = 0
  55. self.s = 0
  56. self.current_elapsed = None
  57. def add(self, elapsed):
  58. self.num += 1
  59. if self.num == 1:
  60. self.m = elapsed
  61. self.s = 0
  62. else:
  63. last_m = self.m
  64. self.m = last_m + (elapsed - last_m) / self.num
  65. self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
  66. self.total_time += elapsed
  67. if self.max_time < elapsed:
  68. self.max_time = elapsed
  69. def start_sample(self):
  70. return Sample(self)
  71. @property
  72. def average(self):
  73. if self.num == 0:
  74. return 0
  75. return self.total_time / self.num
  76. @property
  77. def stdev(self):
  78. if self.num <= 1:
  79. return 0
  80. return math.sqrt(self.s / (self.num - 1))
  81. def todict(self):
  82. return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
  83. class ServerClient(object):
  84. def __init__(self, reader, writer, db, request_stats):
  85. self.reader = reader
  86. self.writer = writer
  87. self.db = db
  88. self.request_stats = request_stats
  89. async def process_requests(self):
  90. try:
  91. self.addr = self.writer.get_extra_info('peername')
  92. logger.debug('Client %r connected' % (self.addr,))
  93. # Read protocol and version
  94. protocol = await self.reader.readline()
  95. if protocol is None:
  96. return
  97. (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
  98. if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
  99. return
  100. # Read headers. Currently, no headers are implemented, so look for
  101. # an empty line to signal the end of the headers
  102. while True:
  103. line = await self.reader.readline()
  104. if line is None:
  105. return
  106. line = line.decode('utf-8').rstrip()
  107. if not line:
  108. break
  109. # Handle messages
  110. handlers = {
  111. 'get': self.handle_get,
  112. 'report': self.handle_report,
  113. 'get-stream': self.handle_get_stream,
  114. 'get-stats': self.handle_get_stats,
  115. 'reset-stats': self.handle_reset_stats,
  116. }
  117. while True:
  118. d = await self.read_message()
  119. if d is None:
  120. break
  121. for k in handlers.keys():
  122. if k in d:
  123. logger.debug('Handling %s' % k)
  124. if 'stream' in k:
  125. await handlers[k](d[k])
  126. else:
  127. with self.request_stats.start_sample() as self.request_sample, \
  128. self.request_sample.measure():
  129. await handlers[k](d[k])
  130. break
  131. else:
  132. logger.warning("Unrecognized command %r" % d)
  133. break
  134. await self.writer.drain()
  135. finally:
  136. self.writer.close()
  137. def write_message(self, msg):
  138. self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
  139. async def read_message(self):
  140. l = await self.reader.readline()
  141. if not l:
  142. return None
  143. try:
  144. message = l.decode('utf-8')
  145. if not message.endswith('\n'):
  146. return None
  147. return json.loads(message)
  148. except (json.JSONDecodeError, UnicodeDecodeError) as e:
  149. logger.error('Bad message from client: %r' % message)
  150. raise e
  151. async def handle_get(self, request):
  152. method = request['method']
  153. taskhash = request['taskhash']
  154. row = self.query_equivalent(method, taskhash)
  155. if row is not None:
  156. logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
  157. d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
  158. self.write_message(d)
  159. else:
  160. self.write_message(None)
  161. async def handle_get_stream(self, request):
  162. self.write_message('ok')
  163. while True:
  164. l = await self.reader.readline()
  165. if not l:
  166. return
  167. try:
  168. # This inner loop is very sensitive and must be as fast as
  169. # possible (which is why the request sample is handled manually
  170. # instead of using 'with', and also why logging statements are
  171. # commented out.
  172. self.request_sample = self.request_stats.start_sample()
  173. request_measure = self.request_sample.measure()
  174. request_measure.start()
  175. l = l.decode('utf-8').rstrip()
  176. if l == 'END':
  177. self.writer.write('ok\n'.encode('utf-8'))
  178. return
  179. (method, taskhash) = l.split()
  180. #logger.debug('Looking up %s %s' % (method, taskhash))
  181. row = self.query_equivalent(method, taskhash)
  182. if row is not None:
  183. msg = ('%s\n' % row['unihash']).encode('utf-8')
  184. #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
  185. else:
  186. msg = '\n'.encode('utf-8')
  187. self.writer.write(msg)
  188. finally:
  189. request_measure.end()
  190. self.request_sample.end()
  191. await self.writer.drain()
  192. async def handle_report(self, data):
  193. with closing(self.db.cursor()) as cursor:
  194. cursor.execute('''
  195. -- Find tasks with a matching outhash (that is, tasks that
  196. -- are equivalent)
  197. SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
  198. -- If there is an exact match on the taskhash, return it.
  199. -- Otherwise return the oldest matching outhash of any
  200. -- taskhash
  201. ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
  202. created ASC
  203. -- Only return one row
  204. LIMIT 1
  205. ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
  206. row = cursor.fetchone()
  207. # If no matching outhash was found, or one *was* found but it
  208. # wasn't an exact match on the taskhash, a new entry for this
  209. # taskhash should be added
  210. if row is None or row['taskhash'] != data['taskhash']:
  211. # If a row matching the outhash was found, the unihash for
  212. # the new taskhash should be the same as that one.
  213. # Otherwise the caller provided unihash is used.
  214. unihash = data['unihash']
  215. if row is not None:
  216. unihash = row['unihash']
  217. insert_data = {
  218. 'method': data['method'],
  219. 'outhash': data['outhash'],
  220. 'taskhash': data['taskhash'],
  221. 'unihash': unihash,
  222. 'created': datetime.now()
  223. }
  224. for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
  225. if k in data:
  226. insert_data[k] = data[k]
  227. cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
  228. ', '.join(sorted(insert_data.keys())),
  229. ', '.join(':' + k for k in sorted(insert_data.keys()))),
  230. insert_data)
  231. self.db.commit()
  232. logger.info('Adding taskhash %s with unihash %s',
  233. data['taskhash'], unihash)
  234. d = {
  235. 'taskhash': data['taskhash'],
  236. 'method': data['method'],
  237. 'unihash': unihash
  238. }
  239. else:
  240. d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
  241. self.write_message(d)
  242. async def handle_get_stats(self, request):
  243. d = {
  244. 'requests': self.request_stats.todict(),
  245. }
  246. self.write_message(d)
  247. async def handle_reset_stats(self, request):
  248. d = {
  249. 'requests': self.request_stats.todict(),
  250. }
  251. self.request_stats.reset()
  252. self.write_message(d)
  253. def query_equivalent(self, method, taskhash):
  254. # This is part of the inner loop and must be as fast as possible
  255. try:
  256. cursor = self.db.cursor()
  257. cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
  258. {'method': method, 'taskhash': taskhash})
  259. return cursor.fetchone()
  260. except:
  261. cursor.close()
  262. class Server(object):
  263. def __init__(self, db, loop=None):
  264. self.request_stats = Stats()
  265. self.db = db
  266. if loop is None:
  267. self.loop = asyncio.new_event_loop()
  268. self.close_loop = True
  269. else:
  270. self.loop = loop
  271. self.close_loop = False
  272. self._cleanup_socket = None
  273. def start_tcp_server(self, host, port):
  274. self.server = self.loop.run_until_complete(
  275. asyncio.start_server(self.handle_client, host, port, loop=self.loop)
  276. )
  277. for s in self.server.sockets:
  278. logger.info('Listening on %r' % (s.getsockname(),))
  279. # Newer python does this automatically. Do it manually here for
  280. # maximum compatibility
  281. s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
  282. s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
  283. name = self.server.sockets[0].getsockname()
  284. if self.server.sockets[0].family == socket.AF_INET6:
  285. self.address = "[%s]:%d" % (name[0], name[1])
  286. else:
  287. self.address = "%s:%d" % (name[0], name[1])
  288. def start_unix_server(self, path):
  289. def cleanup():
  290. os.unlink(path)
  291. cwd = os.getcwd()
  292. try:
  293. # Work around path length limits in AF_UNIX
  294. os.chdir(os.path.dirname(path))
  295. self.server = self.loop.run_until_complete(
  296. asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
  297. )
  298. finally:
  299. os.chdir(cwd)
  300. logger.info('Listening on %r' % path)
  301. self._cleanup_socket = cleanup
  302. self.address = "unix://%s" % os.path.abspath(path)
  303. async def handle_client(self, reader, writer):
  304. # writer.transport.set_write_buffer_limits(0)
  305. try:
  306. client = ServerClient(reader, writer, self.db, self.request_stats)
  307. await client.process_requests()
  308. except Exception as e:
  309. import traceback
  310. logger.error('Error from client: %s' % str(e), exc_info=True)
  311. traceback.print_exc()
  312. writer.close()
  313. logger.info('Client disconnected')
  314. def serve_forever(self):
  315. def signal_handler():
  316. self.loop.stop()
  317. self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
  318. try:
  319. self.loop.run_forever()
  320. except KeyboardInterrupt:
  321. pass
  322. self.server.close()
  323. self.loop.run_until_complete(self.server.wait_closed())
  324. logger.info('Server shutting down')
  325. if self.close_loop:
  326. self.loop.close()
  327. if self._cleanup_socket is not None:
  328. self._cleanup_socket()