123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414 |
- # Copyright (C) 2019 Garmin Ltd.
- #
- # SPDX-License-Identifier: GPL-2.0-only
- #
- from contextlib import closing
- from datetime import datetime
- import asyncio
- import json
- import logging
- import math
- import os
- import signal
- import socket
- import time
- logger = logging.getLogger('hashserv.server')
- class Measurement(object):
- def __init__(self, sample):
- self.sample = sample
- def start(self):
- self.start_time = time.perf_counter()
- def end(self):
- self.sample.add(time.perf_counter() - self.start_time)
- def __enter__(self):
- self.start()
- return self
- def __exit__(self, *args, **kwargs):
- self.end()
- class Sample(object):
- def __init__(self, stats):
- self.stats = stats
- self.num_samples = 0
- self.elapsed = 0
- def measure(self):
- return Measurement(self)
- def __enter__(self):
- return self
- def __exit__(self, *args, **kwargs):
- self.end()
- def add(self, elapsed):
- self.num_samples += 1
- self.elapsed += elapsed
- def end(self):
- if self.num_samples:
- self.stats.add(self.elapsed)
- self.num_samples = 0
- self.elapsed = 0
- class Stats(object):
- def __init__(self):
- self.reset()
- def reset(self):
- self.num = 0
- self.total_time = 0
- self.max_time = 0
- self.m = 0
- self.s = 0
- self.current_elapsed = None
- def add(self, elapsed):
- self.num += 1
- if self.num == 1:
- self.m = elapsed
- self.s = 0
- else:
- last_m = self.m
- self.m = last_m + (elapsed - last_m) / self.num
- self.s = self.s + (elapsed - last_m) * (elapsed - self.m)
- self.total_time += elapsed
- if self.max_time < elapsed:
- self.max_time = elapsed
- def start_sample(self):
- return Sample(self)
- @property
- def average(self):
- if self.num == 0:
- return 0
- return self.total_time / self.num
- @property
- def stdev(self):
- if self.num <= 1:
- return 0
- return math.sqrt(self.s / (self.num - 1))
- def todict(self):
- return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
- class ServerClient(object):
- def __init__(self, reader, writer, db, request_stats):
- self.reader = reader
- self.writer = writer
- self.db = db
- self.request_stats = request_stats
- async def process_requests(self):
- try:
- self.addr = self.writer.get_extra_info('peername')
- logger.debug('Client %r connected' % (self.addr,))
- # Read protocol and version
- protocol = await self.reader.readline()
- if protocol is None:
- return
- (proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
- if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
- return
- # Read headers. Currently, no headers are implemented, so look for
- # an empty line to signal the end of the headers
- while True:
- line = await self.reader.readline()
- if line is None:
- return
- line = line.decode('utf-8').rstrip()
- if not line:
- break
- # Handle messages
- handlers = {
- 'get': self.handle_get,
- 'report': self.handle_report,
- 'get-stream': self.handle_get_stream,
- 'get-stats': self.handle_get_stats,
- 'reset-stats': self.handle_reset_stats,
- }
- while True:
- d = await self.read_message()
- if d is None:
- break
- for k in handlers.keys():
- if k in d:
- logger.debug('Handling %s' % k)
- if 'stream' in k:
- await handlers[k](d[k])
- else:
- with self.request_stats.start_sample() as self.request_sample, \
- self.request_sample.measure():
- await handlers[k](d[k])
- break
- else:
- logger.warning("Unrecognized command %r" % d)
- break
- await self.writer.drain()
- finally:
- self.writer.close()
- def write_message(self, msg):
- self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
- async def read_message(self):
- l = await self.reader.readline()
- if not l:
- return None
- try:
- message = l.decode('utf-8')
- if not message.endswith('\n'):
- return None
- return json.loads(message)
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- logger.error('Bad message from client: %r' % message)
- raise e
- async def handle_get(self, request):
- method = request['method']
- taskhash = request['taskhash']
- row = self.query_equivalent(method, taskhash)
- if row is not None:
- logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
- else:
- self.write_message(None)
- async def handle_get_stream(self, request):
- self.write_message('ok')
- while True:
- l = await self.reader.readline()
- if not l:
- return
- try:
- # This inner loop is very sensitive and must be as fast as
- # possible (which is why the request sample is handled manually
- # instead of using 'with', and also why logging statements are
- # commented out.
- self.request_sample = self.request_stats.start_sample()
- request_measure = self.request_sample.measure()
- request_measure.start()
- l = l.decode('utf-8').rstrip()
- if l == 'END':
- self.writer.write('ok\n'.encode('utf-8'))
- return
- (method, taskhash) = l.split()
- #logger.debug('Looking up %s %s' % (method, taskhash))
- row = self.query_equivalent(method, taskhash)
- if row is not None:
- msg = ('%s\n' % row['unihash']).encode('utf-8')
- #logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
- else:
- msg = '\n'.encode('utf-8')
- self.writer.write(msg)
- finally:
- request_measure.end()
- self.request_sample.end()
- await self.writer.drain()
- async def handle_report(self, data):
- with closing(self.db.cursor()) as cursor:
- cursor.execute('''
- -- Find tasks with a matching outhash (that is, tasks that
- -- are equivalent)
- SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND outhash=:outhash
- -- If there is an exact match on the taskhash, return it.
- -- Otherwise return the oldest matching outhash of any
- -- taskhash
- ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
- created ASC
- -- Only return one row
- LIMIT 1
- ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
- row = cursor.fetchone()
- # If no matching outhash was found, or one *was* found but it
- # wasn't an exact match on the taskhash, a new entry for this
- # taskhash should be added
- if row is None or row['taskhash'] != data['taskhash']:
- # If a row matching the outhash was found, the unihash for
- # the new taskhash should be the same as that one.
- # Otherwise the caller provided unihash is used.
- unihash = data['unihash']
- if row is not None:
- unihash = row['unihash']
- insert_data = {
- 'method': data['method'],
- 'outhash': data['outhash'],
- 'taskhash': data['taskhash'],
- 'unihash': unihash,
- 'created': datetime.now()
- }
- for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
- if k in data:
- insert_data[k] = data[k]
- cursor.execute('''INSERT INTO tasks_v2 (%s) VALUES (%s)''' % (
- ', '.join(sorted(insert_data.keys())),
- ', '.join(':' + k for k in sorted(insert_data.keys()))),
- insert_data)
- self.db.commit()
- logger.info('Adding taskhash %s with unihash %s',
- data['taskhash'], unihash)
- d = {
- 'taskhash': data['taskhash'],
- 'method': data['method'],
- 'unihash': unihash
- }
- else:
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
- self.write_message(d)
- async def handle_get_stats(self, request):
- d = {
- 'requests': self.request_stats.todict(),
- }
- self.write_message(d)
- async def handle_reset_stats(self, request):
- d = {
- 'requests': self.request_stats.todict(),
- }
- self.request_stats.reset()
- self.write_message(d)
- def query_equivalent(self, method, taskhash):
- # This is part of the inner loop and must be as fast as possible
- try:
- cursor = self.db.cursor()
- cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
- {'method': method, 'taskhash': taskhash})
- return cursor.fetchone()
- except:
- cursor.close()
- class Server(object):
- def __init__(self, db, loop=None):
- self.request_stats = Stats()
- self.db = db
- if loop is None:
- self.loop = asyncio.new_event_loop()
- self.close_loop = True
- else:
- self.loop = loop
- self.close_loop = False
- self._cleanup_socket = None
- def start_tcp_server(self, host, port):
- self.server = self.loop.run_until_complete(
- asyncio.start_server(self.handle_client, host, port, loop=self.loop)
- )
- for s in self.server.sockets:
- logger.info('Listening on %r' % (s.getsockname(),))
- # Newer python does this automatically. Do it manually here for
- # maximum compatibility
- s.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
- s.setsockopt(socket.SOL_TCP, socket.TCP_QUICKACK, 1)
- name = self.server.sockets[0].getsockname()
- if self.server.sockets[0].family == socket.AF_INET6:
- self.address = "[%s]:%d" % (name[0], name[1])
- else:
- self.address = "%s:%d" % (name[0], name[1])
- def start_unix_server(self, path):
- def cleanup():
- os.unlink(path)
- cwd = os.getcwd()
- try:
- # Work around path length limits in AF_UNIX
- os.chdir(os.path.dirname(path))
- self.server = self.loop.run_until_complete(
- asyncio.start_unix_server(self.handle_client, os.path.basename(path), loop=self.loop)
- )
- finally:
- os.chdir(cwd)
- logger.info('Listening on %r' % path)
- self._cleanup_socket = cleanup
- self.address = "unix://%s" % os.path.abspath(path)
- async def handle_client(self, reader, writer):
- # writer.transport.set_write_buffer_limits(0)
- try:
- client = ServerClient(reader, writer, self.db, self.request_stats)
- await client.process_requests()
- except Exception as e:
- import traceback
- logger.error('Error from client: %s' % str(e), exc_info=True)
- traceback.print_exc()
- writer.close()
- logger.info('Client disconnected')
- def serve_forever(self):
- def signal_handler():
- self.loop.stop()
- self.loop.add_signal_handler(signal.SIGTERM, signal_handler)
- try:
- self.loop.run_forever()
- except KeyboardInterrupt:
- pass
- self.server.close()
- self.loop.run_until_complete(self.server.wait_closed())
- logger.info('Server shutting down')
- if self.close_loop:
- self.loop.close()
- if self._cleanup_socket is not None:
- self._cleanup_socket()
|