|
@@ -0,0 +1,414 @@
|
|
|
|
+# Copyright (C) 2019 Garmin Ltd.
|
|
|
|
+#
|
|
|
|
+# SPDX-License-Identifier: GPL-2.0-only
|
|
|
|
+#
|
|
|
|
+
|
|
|
|
+from contextlib import closing
|
|
|
|
+from datetime import datetime
|
|
|
|
+import asyncio
|
|
|
|
+import json
|
|
|
|
+import logging
|
|
|
|
+import math
|
|
|
|
+import os
|
|
|
|
+import signal
|
|
|
|
+import socket
|
|
|
|
+import time
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger('hashserv.server')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Measurement(object):
|
|
|
|
+ def __init__(self, sample):
|
|
|
|
+ self.sample = sample
|
|
|
|
+
|
|
|
|
+ def start(self):
|
|
|
|
+ self.start_time = time.perf_counter()
|
|
|
|
+
|
|
|
|
+ def end(self):
|
|
|
|
+ self.sample.add(time.perf_counter() - self.start_time)
|
|
|
|
+
|
|
|
|
+ def __enter__(self):
|
|
|
|
+ self.start()
|
|
|
|
+ return self
|
|
|
|
+
|
|
|
|
+ def __exit__(self, *args, **kwargs):
|
|
|
|
+ self.end()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Sample(object):
|
|
|
|
+ def __init__(self, stats):
|
|
|
|
+ self.stats = stats
|
|
|
|
+ self.num_samples = 0
|
|
|
|
+ self.elapsed = 0
|
|
|
|
+
|
|
|
|
+ def measure(self):
|
|
|
|
+ return Measurement(self)
|
|
|
|
+
|
|
|
|
+ def __enter__(self):
|
|
|
|
+ return self
|
|
|
|
+
|
|
|
|
+ def __exit__(self, *args, **kwargs):
|
|
|
|
+ self.end()
|
|
|
|
+
|
|
|
|
+ def add(self, elapsed):
|
|
|
|
+ self.num_samples += 1
|
|
|
|
+ self.elapsed += elapsed
|
|
|
|
+
|
|
|
|
+ def end(self):
|
|
|
|
+ if self.num_samples:
|
|
|
|
+ self.stats.add(self.elapsed)
|
|
|
|
+ self.num_samples = 0
|
|
|
|
+ self.elapsed = 0
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Stats(object):
|
|
|
|
+ def __init__(self):
|
|
|
|
+ self.reset()
|
|
|
|
+
|
|
|
|
+ def reset(self):
|
|
|
|
+ self.num = 0
|
|
|
|
+ self.total_time = 0
|
|
|
|
+ self.max_time = 0
|
|
|
|
+ self.m = 0
|
|
|
|
+ self.s = 0
|
|
|
|
+ self.current_elapsed = None
|
|
|
|
+
|
|
|
|
+ def add(self, elapsed):
|
|
|
|
+ self.num += 1
|
|
|
|
+ if self.num == 1:
|
|
|
|
+ self.m = elapsed
|
|
|
|
+ self.s = 0
|
|
|
|
+ else:
|
|
|
|
+ last_m = self.m
|
|
|
|
+ self.m = last_m + (elapsed - last_m) / self.num
|
|
|
|
+ self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
|
|
|
|
+
|
|
|
|
+ self.total_time += elapsed
|
|
|
|
+
|
|
|
|
+ if self.max_time < elapsed:
|
|
|
|
+ self.max_time = elapsed
|
|
|
|
+
|
|
|
|
+ def start_sample(self):
|
|
|
|
+ return Sample(self)
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def average(self):
|
|
|
|
+ if self.num == 0:
|
|
|
|
+ return 0
|
|
|
|
+ return self.total_time / self.num
|
|
|
|
+
|
|
|
|
+ @property
|
|
|
|
+ def stdev(self):
|
|
|
|
+ if self.num <= 1:
|
|
|
|
+ return 0
|
|
|
|
+ return math.sqrt(self.s / (self.num - 1))
|
|
|
|
+
|
|
|
|
+ def todict(self):
|
|
|
|
+ return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class ServerClient(object):
|
|
|
|
+ def __init__(self, reader, writer, db, request_stats):
|
|
|
|
+ self.reader = reader
|
|
|
|
+ self.writer = writer
|
|
|
|
+ self.db = db
|
|
|
|
+ self.request_stats = request_stats
|
|
|
|
+
|
|
|
|
+ async def process_requests(self):
|
|
|
|
+ try:
|
|
|
|
+ self.addr = self.writer.get_extra_info('peername')
|
|
|
|
+ logger.debug('Client %r connected' % (self.addr,))
|
|
|
|
+
|
|
|
|
+ # Read protocol and version
|
|
|
|
+ protocol = await self.reader.readline()
|
|
|
|
+ if protocol is None:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
|
|
|
|
+ if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ # Read headers. Currently, no headers are implemented, so look for
|
|
|
|
+ # an empty line to signal the end of the headers
|
|
|
|
+ while True:
|
|
|
|
+ line = await self.reader.readline()
|
|
|
|
+ if line is None:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ line = line.decode('utf-8').rstrip()
|
|
|
|
+ if not line:
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ # Handle messages
|
|
|
|
+ handlers = {
|
|
|
|
+ 'get': self.handle_get,
|
|
|
|
+ 'report': self.handle_report,
|
|
|
|
+ 'get-stream': self.handle_get_stream,
|
|
|
|
+ 'get-stats': self.handle_get_stats,
|
|
|
|
+ 'reset-stats': self.handle_reset_stats,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ while True:
|
|
|
|
+ d = await self.read_message()
|
|
|
|
+ if d is None:
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ for k in handlers.keys():
|
|
|
|
+ if k in d:
|
|
|
|
+ logger.debug('Handling %s' % k)
|
|
|
|
+ if 'stream' in k:
|
|
|
|
+ await handlers[k](d[k])
|
|
|
|
+ else:
|
|
|
|
+ with self.request_stats.start_sample() as self.request_sample, \
|
|
|
|
+ self.request_sample.measure():
|
|
|
|
+ await handlers[k](d[k])
|
|
|
|
+ break
|
|
|
|
+ else:
|
|
|
|
+ logger.warning("Unrecognized command %r" % d)
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ await self.writer.drain()
|
|
|
|
+ finally:
|
|
|
|
+ self.writer.close()
|
|
|
|
+
|
|
|
|
+ def write_message(self, msg):
|
|
|
|
+ self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
|
|
|
|
+
|
|
|
|
+ async def read_message(self):
|
|
|
|
+ l = await self.reader.readline()
|
|
|
|
+ if not l:
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ message = l.decode('utf-8')
|
|
|
|
+
|
|
|
|
+ if not message.endswith('\n'):
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ return json.loads(message)
|
|
|
|
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
|
|
+ logger.error('Bad message from client: %r' % message)
|
|
|
|
+ raise e
|
|
|
|
+
|
|
|
|
+ async def handle_get(self, request):
|
|
|
|
+ method = request['method']
|
|
|
|
+ taskhash = request['taskhash']
|
|
|
|
+
|
|
|
|
+ row = self.query_equivalent(method, taskhash)
|
|
|
|
+ if row is not None:
|
|
|
|
+ logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
|
|
|
|
+ d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
|
|
|
|
+
|
|
|
|
+ self.write_message(d)
|
|
|
|
+ else:
|
|
|
|
+ self.write_message(None)
|
|
|
|
+
|
|
|
|
+ async def handle_get_stream(self, request):
|
|
|
|
+ self.write_message('ok')
|
|
|
|
+
|
|
|
|
+ while True:
|
|
|
|
+ l = await self.reader.readline()
|
|
|
|
+ if not l:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # This inner loop is very sensitive and must be as fast as
|
|
|
|
+ # possible (which is why the request sample is handled manually
|
|
|
|
+ # instead of using 'with', and also why logging statements are
|
|
|
|
+ # commented out.
|
|
|
|
+ self.request_sample = self.request_stats.start_sample()
|
|
|
|
+ request_measure = self.request_sample.measure()
|
|
|
|
+ request_measure.start()
|
|
|
|
+
|
|
|
|
+ l = l.decode('utf-8').rstrip()
|
|
|
|
+ if l == 'END':
|
|
|
|
+ self.writer.write('ok\n'.encode('utf-8'))
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ (method, taskhash) = l.split()
|
|
|
|
+ #logger.debug('Looking up %s %s' % (method, taskhash))
|
|
|
|
+ row = self.query_equivalent(method, taskhash)
|
|
|
|
+ if row is not None:
|
|
|
|
+ msg = ('%s\n' % row['unihash']).encode('utf-8')
|
|
|
|
+ #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
|
|
|
|
+ else:
|
|
|
|
+ msg = '\n'.encode('utf-8')
|
|
|
|
+
|
|
|
|
+ self.writer.write(msg)
|
|
|
|
+ finally:
|
|
|
|
+ request_measure.end()
|
|
|
|
+ self.request_sample.end()
|
|
|
|
+
|
|
|
|
+ await self.writer.drain()
|
|
|
|
+
|
|
|
|
+ async def handle_report(self, data):
|
|
|
|
+ with closing(self.db.cursor()) as cursor:
|
|
|
|
+ cursor.execute('''
|
|
|
|
+ -- Find tasks with a matching outhash (that is, tasks that
|
|
|
|
+ -- are equivalent)
|
|
|
|
+ SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
|
|
|
|
+
|
|
|
|
+ -- If there is an exact match on the taskhash, return it.
|
|
|
|
+ -- Otherwise return the oldest matching outhash of any
|
|
|
|
+ -- taskhash
|
|
|
|
+ ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
|
|
|
|
+ created ASC
|
|
|
|
+
|
|
|
|
+ -- Only return one row
|
|
|
|
+ LIMIT 1
|
|
|
|
+ ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
|
|
|
|
+
|
|
|
|
+ row = cursor.fetchone()
|
|
|
|
+
|
|
|
|
+ # If no matching outhash was found, or one *was* found but it
|
|
|
|
+ # wasn't an exact match on the taskhash, a new entry for this
|
|
|
|
+ # taskhash should be added
|
|
|
|
+ if row is None or row['taskhash'] != data['taskhash']:
|
|
|
|
+ # If a row matching the outhash was found, the unihash for
|
|
|
|
+ # the new taskhash should be the same as that one.
|
|
|
|
+ # Otherwise the caller provided unihash is used.
|
|
|
|
+ unihash = data['unihash']
|
|
|
|
+ if row is not None:
|
|
|
|
+ unihash = row['unihash']
|
|
|
|
+
|
|
|
|
+ insert_data = {
|
|
|
|
+ 'method': data['method'],
|
|
|
|
+ 'outhash': data['outhash'],
|
|
|
|
+ 'taskhash': data['taskhash'],
|
|
|
|
+ 'unihash': unihash,
|
|
|
|
+ 'created': datetime.now()
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
|
|
|
|
+ if k in data:
|
|
|
|
+ insert_data[k] = data[k]
|
|
|
|
+
|
|
|
|
+ cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
|
|
|
|
+ ', '.join(sorted(insert_data.keys())),
|
|
|
|
+ ', '.join(':' + k for k in sorted(insert_data.keys()))),
|
|
|
|
+ insert_data)
|
|
|
|
+
|
|
|
|
+ self.db.commit()
|
|
|
|
+
|
|
|
|
+ logger.info('Adding taskhash %s with unihash %s',
|
|
|
|
+ data['taskhash'], unihash)
|
|
|
|
+
|
|
|
|
+ d = {
|
|
|
|
+ 'taskhash': data['taskhash'],
|
|
|
|
+ 'method': data['method'],
|
|
|
|
+ 'unihash': unihash
|
|
|
|
+ }
|
|
|
|
+ else:
|
|
|
|
+ d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
|
|
|
|
+
|
|
|
|
+ self.write_message(d)
|
|
|
|
+
|
|
|
|
+ async def handle_get_stats(self, request):
|
|
|
|
+ d = {
|
|
|
|
+ 'requests': self.request_stats.todict(),
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ self.write_message(d)
|
|
|
|
+
|
|
|
|
+ async def handle_reset_stats(self, request):
|
|
|
|
+ d = {
|
|
|
|
+ 'requests': self.request_stats.todict(),
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ self.request_stats.reset()
|
|
|
|
+ self.write_message(d)
|
|
|
|
+
|
|
|
|
+ def query_equivalent(self, method, taskhash):
|
|
|
|
+ # This is part of the inner loop and must be as fast as possible
|
|
|
|
+ try:
|
|
|
|
+ cursor = self.db.cursor()
|
|
|
|
+ cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
|
|
|
|
+ {'method': method, 'taskhash': taskhash})
|
|
|
|
+ return cursor.fetchone()
|
|
|
|
+ except:
|
|
|
|
+ cursor.close()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Server(object):
|
|
|
|
+ def __init__(self, db, loop=None):
|
|
|
|
+ self.request_stats = Stats()
|
|
|
|
+ self.db = db
|
|
|
|
+
|
|
|
|
+ if loop is None:
|
|
|
|
+ self.loop = asyncio.new_event_loop()
|
|
|
|
+ self.close_loop = True
|
|
|
|
+ else:
|
|
|
|
+ self.loop = loop
|
|
|
|
+ self.close_loop = False
|
|
|
|
+
|
|
|
|
+ self._cleanup_socket = None
|
|
|
|
+
|
|
|
|
+ def start_tcp_server(self, host, port):
|
|
|
|
+ self.server = self.loop.run_until_complete(
|
|
|
|
+ asyncio.start_server(self.handle_client, host, port, loop=self.loop)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ for s in self.server.sockets:
|
|
|
|
+ logger.info('Listening on %r' % (s.getsockname(),))
|
|
|
|
+ # Newer python does this automatically. Do it manually here for
|
|
|
|
+ # maximum compatibility
|
|
|
|
+ s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
|
|
|
|
+ s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
|
|
|
|
+
|
|
|
|
+ name = self.server.sockets[0].getsockname()
|
|
|
|
+ if self.server.sockets[0].family == socket.AF_INET6:
|
|
|
|
+ self.address = "[%s]:%d" % (name[0], name[1])
|
|
|
|
+ else:
|
|
|
|
+ self.address = "%s:%d" % (name[0], name[1])
|
|
|
|
+
|
|
|
|
+ def start_unix_server(self, path):
|
|
|
|
+ def cleanup():
|
|
|
|
+ os.unlink(path)
|
|
|
|
+
|
|
|
|
+ cwd = os.getcwd()
|
|
|
|
+ try:
|
|
|
|
+ # Work around path length limits in AF_UNIX
|
|
|
|
+ os.chdir(os.path.dirname(path))
|
|
|
|
+ self.server = self.loop.run_until_complete(
|
|
|
|
+ asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
|
|
|
|
+ )
|
|
|
|
+ finally:
|
|
|
|
+ os.chdir(cwd)
|
|
|
|
+
|
|
|
|
+ logger.info('Listening on %r' % path)
|
|
|
|
+
|
|
|
|
+ self._cleanup_socket = cleanup
|
|
|
|
+ self.address = "unix://%s" % os.path.abspath(path)
|
|
|
|
+
|
|
|
|
+ async def handle_client(self, reader, writer):
|
|
|
|
+ # writer.transport.set_write_buffer_limits(0)
|
|
|
|
+ try:
|
|
|
|
+ client = ServerClient(reader, writer, self.db, self.request_stats)
|
|
|
|
+ await client.process_requests()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ import traceback
|
|
|
|
+ logger.error('Error from client: %s' % str(e), exc_info=True)
|
|
|
|
+ traceback.print_exc()
|
|
|
|
+ writer.close()
|
|
|
|
+ logger.info('Client disconnected')
|
|
|
|
+
|
|
|
|
+ def serve_forever(self):
|
|
|
|
+ def signal_handler():
|
|
|
|
+ self.loop.stop()
|
|
|
|
+
|
|
|
|
+ self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ self.loop.run_forever()
|
|
|
|
+ except KeyboardInterrupt:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ self.server.close()
|
|
|
|
+ self.loop.run_until_complete(self.server.wait_closed())
|
|
|
|
+ logger.info('Server shutting down')
|
|
|
|
+
|
|
|
|
+ if self.close_loop:
|
|
|
|
+ self.loop.close()
|
|
|
|
+
|
|
|
|
+ if self._cleanup_socket is not None:
|
|
|
|
+ self._cleanup_socket()
|