|
@@ -13,6 +13,7 @@ import os
|
|
|
import signal
|
|
|
import socket
|
|
|
import time
|
|
|
+from . import chunkify, DEFAULT_MAX_CHUNK
|
|
|
|
|
|
logger = logging.getLogger('hashserv.server')
|
|
|
|
|
@@ -107,12 +108,29 @@ class Stats(object):
|
|
|
return {k: getattr(self, k) for k in ('num', 'total_time', 'max_time', 'average', 'stdev')}
|
|
|
|
|
|
|
|
|
+class ClientError(Exception):
|
|
|
+ pass
|
|
|
+
|
|
|
class ServerClient(object):
|
|
|
+ FAST_QUERY = 'SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
|
|
+ ALL_QUERY = 'SELECT * FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1'
|
|
|
+
|
|
|
def __init__(self, reader, writer, db, request_stats):
|
|
|
self.reader = reader
|
|
|
self.writer = writer
|
|
|
self.db = db
|
|
|
self.request_stats = request_stats
|
|
|
+ self.max_chunk = DEFAULT_MAX_CHUNK
|
|
|
+
|
|
|
+ self.handlers = {
|
|
|
+ 'get': self.handle_get,
|
|
|
+ 'report': self.handle_report,
|
|
|
+ 'report-equiv': self.handle_equivreport,
|
|
|
+ 'get-stream': self.handle_get_stream,
|
|
|
+ 'get-stats': self.handle_get_stats,
|
|
|
+ 'reset-stats': self.handle_reset_stats,
|
|
|
+ 'chunk-stream': self.handle_chunk,
|
|
|
+ }
|
|
|
|
|
|
async def process_requests(self):
|
|
|
try:
|
|
@@ -125,7 +143,11 @@ class ServerClient(object):
|
|
|
return
|
|
|
|
|
|
(proto_name, proto_version) = protocol.decode('utf-8').rstrip().split()
|
|
|
- if proto_name != 'OEHASHEQUIV' or proto_version != '1.0':
|
|
|
+ if proto_name != 'OEHASHEQUIV':
|
|
|
+ return
|
|
|
+
|
|
|
+ proto_version = tuple(int(v) for v in proto_version.split('.'))
|
|
|
+ if proto_version < (1, 0) or proto_version > (1, 1):
|
|
|
return
|
|
|
|
|
|
# Read headers. Currently, no headers are implemented, so look for
|
|
@@ -140,40 +162,34 @@ class ServerClient(object):
|
|
|
break
|
|
|
|
|
|
# Handle messages
|
|
|
- handlers = {
|
|
|
- 'get': self.handle_get,
|
|
|
- 'report': self.handle_report,
|
|
|
- 'report-equiv': self.handle_equivreport,
|
|
|
- 'get-stream': self.handle_get_stream,
|
|
|
- 'get-stats': self.handle_get_stats,
|
|
|
- 'reset-stats': self.handle_reset_stats,
|
|
|
- }
|
|
|
-
|
|
|
while True:
|
|
|
d = await self.read_message()
|
|
|
if d is None:
|
|
|
break
|
|
|
-
|
|
|
- for k in handlers.keys():
|
|
|
- if k in d:
|
|
|
- logger.debug('Handling %s' % k)
|
|
|
- if 'stream' in k:
|
|
|
- await handlers[k](d[k])
|
|
|
- else:
|
|
|
- with self.request_stats.start_sample() as self.request_sample, \
|
|
|
- self.request_sample.measure():
|
|
|
- await handlers[k](d[k])
|
|
|
- break
|
|
|
- else:
|
|
|
- logger.warning("Unrecognized command %r" % d)
|
|
|
- break
|
|
|
-
|
|
|
+ await self.dispatch_message(d)
|
|
|
await self.writer.drain()
|
|
|
+ except ClientError as e:
|
|
|
+ logger.error(str(e))
|
|
|
finally:
|
|
|
self.writer.close()
|
|
|
|
|
|
+ async def dispatch_message(self, msg):
|
|
|
+ for k in self.handlers.keys():
|
|
|
+ if k in msg:
|
|
|
+ logger.debug('Handling %s' % k)
|
|
|
+ if 'stream' in k:
|
|
|
+ await self.handlers[k](msg[k])
|
|
|
+ else:
|
|
|
+ with self.request_stats.start_sample() as self.request_sample, \
|
|
|
+ self.request_sample.measure():
|
|
|
+ await self.handlers[k](msg[k])
|
|
|
+ return
|
|
|
+
|
|
|
+ raise ClientError("Unrecognized command %r" % msg)
|
|
|
+
|
|
|
def write_message(self, msg):
|
|
|
- self.writer.write(('%s\n' % json.dumps(msg)).encode('utf-8'))
|
|
|
+ for c in chunkify(json.dumps(msg), self.max_chunk):
|
|
|
+ self.writer.write(c.encode('utf-8'))
|
|
|
|
|
|
async def read_message(self):
|
|
|
l = await self.reader.readline()
|
|
@@ -191,14 +207,38 @@ class ServerClient(object):
|
|
|
logger.error('Bad message from client: %r' % message)
|
|
|
raise e
|
|
|
|
|
|
+ async def handle_chunk(self, request):
|
|
|
+ lines = []
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ l = await self.reader.readline()
|
|
|
+ l = l.rstrip(b"\n").decode("utf-8")
|
|
|
+ if not l:
|
|
|
+ break
|
|
|
+ lines.append(l)
|
|
|
+
|
|
|
+ msg = json.loads(''.join(lines))
|
|
|
+ except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
|
+ logger.error('Bad message from client: %r' % message)
|
|
|
+ raise e
|
|
|
+
|
|
|
+ if 'chunk-stream' in msg:
|
|
|
+ raise ClientError("Nested chunks are not allowed")
|
|
|
+
|
|
|
+ await self.dispatch_message(msg)
|
|
|
+
|
|
|
async def handle_get(self, request):
|
|
|
method = request['method']
|
|
|
taskhash = request['taskhash']
|
|
|
|
|
|
- row = self.query_equivalent(method, taskhash)
|
|
|
+ if request.get('all', False):
|
|
|
+ row = self.query_equivalent(method, taskhash, self.ALL_QUERY)
|
|
|
+ else:
|
|
|
+ row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
|
|
|
+
|
|
|
if row is not None:
|
|
|
logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
|
|
|
- d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
|
|
|
+ d = {k: row[k] for k in row.keys()}
|
|
|
|
|
|
self.write_message(d)
|
|
|
else:
|
|
@@ -228,7 +268,7 @@ class ServerClient(object):
|
|
|
|
|
|
(method, taskhash) = l.split()
|
|
|
#logger.debug('Looking up %s %s' % (method, taskhash))
|
|
|
- row = self.query_equivalent(method, taskhash)
|
|
|
+ row = self.query_equivalent(method, taskhash, self.FAST_QUERY)
|
|
|
if row is not None:
|
|
|
msg = ('%s\n' % row['unihash']).encode('utf-8')
|
|
|
#logger.debug('Found equivalent task %s -> %s', (row['taskhash'], row['unihash']))
|
|
@@ -328,7 +368,7 @@ class ServerClient(object):
|
|
|
# Fetch the unihash that will be reported for the taskhash. If the
|
|
|
# unihash matches, it means this row was inserted (or the mapping
|
|
|
# was already valid)
|
|
|
- row = self.query_equivalent(data['method'], data['taskhash'])
|
|
|
+ row = self.query_equivalent(data['method'], data['taskhash'], self.FAST_QUERY)
|
|
|
|
|
|
if row['unihash'] == data['unihash']:
|
|
|
logger.info('Adding taskhash equivalence for %s with unihash %s',
|
|
@@ -354,12 +394,11 @@ class ServerClient(object):
|
|
|
self.request_stats.reset()
|
|
|
self.write_message(d)
|
|
|
|
|
|
- def query_equivalent(self, method, taskhash):
|
|
|
+ def query_equivalent(self, method, taskhash, query):
|
|
|
# This is part of the inner loop and must be as fast as possible
|
|
|
try:
|
|
|
cursor = self.db.cursor()
|
|
|
- cursor.execute('SELECT taskhash, method, unihash FROM tasks_v2 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
|
|
|
- {'method': method, 'taskhash': taskhash})
|
|
|
+ cursor.execute(query, {'method': method, 'taskhash': taskhash})
|
|
|
return cursor.fetchone()
|
|
|
except:
|
|
|
cursor.close()
|