__init__.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (C) 2018 Garmin Ltd.
  2. #
  3. # SPDX-License-Identifier: GPL-2.0-only
  4. #
  5. from http.server import BaseHTTPRequestHandler, HTTPServer
  6. import contextlib
  7. import urllib.parse
  8. import sqlite3
  9. import json
  10. import traceback
  11. import logging
  12. from datetime import datetime
  13. logger = logging.getLogger('hashserv')
  14. class HashEquivalenceServer(BaseHTTPRequestHandler):
  15. def log_message(self, f, *args):
  16. logger.debug(f, *args)
  17. def do_GET(self):
  18. try:
  19. p = urllib.parse.urlparse(self.path)
  20. if p.path != self.prefix + '/v1/equivalent':
  21. self.send_error(404)
  22. return
  23. query = urllib.parse.parse_qs(p.query, strict_parsing=True)
  24. method = query['method'][0]
  25. taskhash = query['taskhash'][0]
  26. d = None
  27. with contextlib.closing(self.db.cursor()) as cursor:
  28. cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND taskhash=:taskhash ORDER BY created ASC LIMIT 1',
  29. {'method': method, 'taskhash': taskhash})
  30. row = cursor.fetchone()
  31. if row is not None:
  32. logger.debug('Found equivalent task %s', row['taskhash'])
  33. d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
  34. self.send_response(200)
  35. self.send_header('Content-Type', 'application/json; charset=utf-8')
  36. self.end_headers()
  37. self.wfile.write(json.dumps(d).encode('utf-8'))
  38. except:
  39. logger.exception('Error in GET')
  40. self.send_error(400, explain=traceback.format_exc())
  41. return
  42. def do_POST(self):
  43. try:
  44. p = urllib.parse.urlparse(self.path)
  45. if p.path != self.prefix + '/v1/equivalent':
  46. self.send_error(404)
  47. return
  48. length = int(self.headers['content-length'])
  49. data = json.loads(self.rfile.read(length).decode('utf-8'))
  50. with contextlib.closing(self.db.cursor()) as cursor:
  51. cursor.execute('''
  52. SELECT taskhash, method, unihash FROM tasks_v1 WHERE method=:method AND outhash=:outhash
  53. ORDER BY CASE WHEN taskhash=:taskhash THEN 1 ELSE 2 END,
  54. created ASC
  55. LIMIT 1
  56. ''', {k: data[k] for k in ('method', 'outhash', 'taskhash')})
  57. row = cursor.fetchone()
  58. if row is None or row['taskhash'] != data['taskhash']:
  59. unihash = data['unihash']
  60. if row is not None:
  61. unihash = row['unihash']
  62. insert_data = {
  63. 'method': data['method'],
  64. 'outhash': data['outhash'],
  65. 'taskhash': data['taskhash'],
  66. 'unihash': unihash,
  67. 'created': datetime.now()
  68. }
  69. for k in ('owner', 'PN', 'PV', 'PR', 'task', 'outhash_siginfo'):
  70. if k in data:
  71. insert_data[k] = data[k]
  72. cursor.execute('''INSERT INTO tasks_v1 (%s) VALUES (%s)''' % (
  73. ', '.join(sorted(insert_data.keys())),
  74. ', '.join(':' + k for k in sorted(insert_data.keys()))),
  75. insert_data)
  76. logger.info('Adding taskhash %s with unihash %s', data['taskhash'], unihash)
  77. cursor.execute('SELECT taskhash, method, unihash FROM tasks_v1 WHERE id=:id', {'id': cursor.lastrowid})
  78. row = cursor.fetchone()
  79. self.db.commit()
  80. d = {k: row[k] for k in ('taskhash', 'method', 'unihash')}
  81. self.send_response(200)
  82. self.send_header('Content-Type', 'application/json; charset=utf-8')
  83. self.end_headers()
  84. self.wfile.write(json.dumps(d).encode('utf-8'))
  85. except:
  86. logger.exception('Error in POST')
  87. self.send_error(400, explain=traceback.format_exc())
  88. return
  89. def create_server(addr, db, prefix=''):
  90. class Handler(HashEquivalenceServer):
  91. pass
  92. Handler.prefix = prefix
  93. Handler.db = db
  94. db.row_factory = sqlite3.Row
  95. with contextlib.closing(db.cursor()) as cursor:
  96. cursor.execute('''
  97. CREATE TABLE IF NOT EXISTS tasks_v1 (
  98. id INTEGER PRIMARY KEY AUTOINCREMENT,
  99. method TEXT NOT NULL,
  100. outhash TEXT NOT NULL,
  101. taskhash TEXT NOT NULL,
  102. unihash TEXT NOT NULL,
  103. created DATETIME,
  104. -- Optional fields
  105. owner TEXT,
  106. PN TEXT,
  107. PV TEXT,
  108. PR TEXT,
  109. task TEXT,
  110. outhash_siginfo TEXT
  111. )
  112. ''')
  113. logger.info('Starting server on %s', addr)
  114. return HTTPServer(addr, Handler)