uploader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright (C) 2015-2016 Peter Magnusson <peter@birchroad.net>
  4. import time
  5. import logging
  6. import hashlib
  7. import os
  8. import serial
  9. from .utils import default_port, system
  10. from .luacode import DOWNLOAD_FILE, SAVE_LUA, LUA_FUNCTIONS, LIST_FILES, UART_SETUP, PRINT_FILE
  11. log = logging.getLogger(__name__)
  12. __all__ = ['Uploader', 'default_port']
  13. SYSTEM = system()
  14. class Uploader(object):
  15. """Uploader is the class for communicating with the nodemcu and
  16. that will allow various tasks like uploading files, formating the filesystem etc.
  17. """
  18. BAUD = 9600
  19. TIMEOUT = 5
  20. PORT = default_port()
  21. def __init__(self, port=PORT, baud=BAUD):
  22. log.info('opening port %s with %s baud', port, baud)
  23. if port == 'loop://':
  24. self._port = serial.serial_for_url(port, baud, timeout=Uploader.TIMEOUT)
  25. else:
  26. self._port = serial.Serial(port, baud, timeout=Uploader.TIMEOUT)
  27. # Keeps things working, if following conections are made:
  28. ## RTS = CH_PD (i.e reset)
  29. ## DTR = GPIO0
  30. self._port.setRTS(False)
  31. self._port.setDTR(False)
  32. def sync():
  33. # Get in sync with LUA (this assumes that NodeMCU gets reset by the previous two lines)
  34. log.debug('getting in sync with LUA');
  35. self.clear_buffers()
  36. self.exchange(';') # Get a defined state
  37. self.writeln('print("%sync%");')
  38. self.expect('%sync%\r\n> ')
  39. sync()
  40. if baud != Uploader.BAUD:
  41. log.info('Changing communication to %s baud', baud)
  42. self.writeln(UART_SETUP.format(baud=baud))
  43. # Wait for the string to be sent before switching baud
  44. time.sleep(0.1)
  45. self.set_baudrate(baud)
  46. # Get in sync again
  47. sync()
  48. self.line_number = 0
  49. def set_baudrate(self, baud):
  50. try:
  51. self._port.setBaudrate(baud)
  52. except AttributeError:
  53. #pySerial 2.7
  54. self._port.baudrate = baud
  55. def clear_buffers(self):
  56. try:
  57. self._port.reset_input_buffer()
  58. self._port.reset_output_buffer()
  59. except AttributeError:
  60. #pySerial 2.7
  61. self._port.flushInput()
  62. self._port.flushOutput()
  63. def expect(self, exp='> ', timeout=TIMEOUT):
  64. """will wait for exp to be returned from nodemcu or timeout"""
  65. if SYSTEM != 'Windows':
  66. timer = self._port.timeout
  67. # Checking for new data every 100us is fast enough
  68. lt = 0.0001
  69. if self._port.timeout != lt:
  70. self._port.timeout = lt
  71. end = time.time() + timeout
  72. # Finish as soon as either exp matches or we run out of time (work like dump, but faster on success)
  73. data = ''
  74. while not data.endswith(exp) and time.time() <= end:
  75. data += self._port.read()
  76. if time.time() > end and not data.endswith(exp) and len(exp) > 0:
  77. raise Exception('Timeout expecting ' + exp)
  78. if SYSTEM != 'Windows':
  79. self._port.timeout = timer
  80. log.debug('expect returned: `{0}`'.format(data))
  81. return data
  82. def write(self, output, binary=False):
  83. """write data on the nodemcu port. If 'binary' is True the debug log
  84. will show the intended output as hex, otherwise as string"""
  85. if not binary:
  86. log.debug('write: %s', output)
  87. else:
  88. log.debug('write binary: %s', ':'.join(x.encode('hex') for x in output))
  89. self._port.write(output)
  90. self._port.flush()
  91. def writeln(self, output):
  92. """write, with linefeed"""
  93. self.write(output + '\n')
  94. def exchange(self, output):
  95. self.writeln(output)
  96. self._port.flush()
  97. return self.expect()
  98. def close(self):
  99. """restores the nodemcu to default baudrate and then closes the port"""
  100. try:
  101. self.writeln(UART_SETUP.format(baud=Uploader.BAUD))
  102. self._port.flush()
  103. self.clear_buffers()
  104. except serial.serialutil.SerialException:
  105. pass
  106. log.debug('closing port')
  107. self._port.close()
  108. def prepare(self):
  109. """
  110. This uploads the protocol functions nessecary to do binary
  111. chunked transfer
  112. """
  113. log.info('Preparing esp for transfer.')
  114. for fn in LUA_FUNCTIONS:
  115. d = self.exchange('print({0})'.format(fn))
  116. if d.find('function:') == -1:
  117. break
  118. else:
  119. log.debug('Found all required lua functions, no need to upload them')
  120. return True
  121. data = SAVE_LUA.format(baud=self._port.baudrate)
  122. ##change any \r\n to just \n and split on that
  123. lines = data.replace('\r', '').split('\n')
  124. #remove some unneccesary spaces to conserve some bytes
  125. for line in lines:
  126. line = line.strip().replace(', ', ',').replace(' = ', '=')
  127. if len(line) == 0:
  128. continue
  129. d = self.exchange(line)
  130. #do some basic test of the result
  131. if ('unexpected' in d) or ('stdin' in d) or len(d) > len(SAVE_LUA)+10:
  132. log.error('error in save_lua "%s"', d)
  133. return False
  134. return True
  135. def download_file(self, filename):
  136. chunk_size = 256
  137. bytes_read = 0
  138. data = ""
  139. while True:
  140. d = self.exchange(DOWNLOAD_FILE.format(filename=filename, bytes_read=bytes_read, chunk_size=chunk_size))
  141. cmd, size, tmp_data = d.split('\n', 2)
  142. data = data + tmp_data[0:chunk_size]
  143. bytes_read = bytes_read + chunk_size
  144. if bytes_read > int(size):
  145. break
  146. data = data[0:int(size)]
  147. return data
  148. def read_file(self, filename, destination=''):
  149. if not destination:
  150. destination = filename
  151. log.info('Transfering %s to %s', filename, destination)
  152. data = self.download_file(filename)
  153. with open(destination, 'w') as f:
  154. f.write(data)
  155. def write_file(self, path, destination='', verify='none'):
  156. filename = os.path.basename(path)
  157. if not destination:
  158. destination = filename
  159. log.info('Transfering %s as %s', path, destination)
  160. self.writeln("recv()")
  161. res = self.expect('C> ')
  162. if not res.endswith('C> '):
  163. log.error('Error waiting for esp "%s"', res)
  164. return
  165. log.debug('sending destination filename "%s"', destination)
  166. self.write(destination + '\x00', True)
  167. if not self.got_ack():
  168. log.error('did not ack destination filename')
  169. return
  170. f = open(path, 'rb')
  171. content = f.read()
  172. f.close()
  173. log.debug('sending %d bytes in %s', len(content), filename)
  174. pos = 0
  175. chunk_size = 128
  176. while pos < len(content):
  177. rest = len(content) - pos
  178. if rest > chunk_size:
  179. rest = chunk_size
  180. data = content[pos:pos+rest]
  181. if not self.write_chunk(data):
  182. d = self.expect()
  183. log.error('Bad chunk response "%s" %s', d, ':'.join(x.encode('hex') for x in d))
  184. return
  185. pos += chunk_size
  186. log.debug('sending zero block')
  187. #zero size block
  188. self.write_chunk('')
  189. if verify == 'text':
  190. log.info('Verifying...')
  191. data = self.download_file(destination)
  192. if content != data:
  193. log.error('Verification failed.')
  194. elif verify == 'sha1':
  195. #Calculate SHA1 on remote file. Extract just hash from result
  196. data = self.exchange('shafile("'+destination+'")').splitlines()[1]
  197. log.info('Remote SHA1: %s', data)
  198. #Calculate hash of local data
  199. filehashhex = hashlib.sha1(content.encode()).hexdigest()
  200. log.info('Local SHA1: %s', filehashhex)
  201. if data != filehashhex:
  202. log.error('Verification failed.')
  203. def exec_file(self, path):
  204. filename = os.path.basename(path)
  205. log.info('Execute %s', filename)
  206. f = open(path, 'rt')
  207. res = '> '
  208. for line in f:
  209. line = line.rstrip('\r\n')
  210. retlines = (res + self.exchange(line)).splitlines()
  211. # Log all but the last line
  212. res = retlines.pop()
  213. for lin in retlines:
  214. log.info(lin)
  215. # last line
  216. log.info(res)
  217. f.close()
  218. def got_ack(self):
  219. log.debug('waiting for ack')
  220. res = self._port.read(1)
  221. log.debug('ack read %s', res.encode('hex'))
  222. return res == '\x06' #ACK
  223. def write_lines(self, data):
  224. lines = data.replace('\r', '').split('\n')
  225. for line in lines:
  226. self.exchange(line)
  227. return
  228. def write_chunk(self, chunk):
  229. log.debug('writing %d bytes chunk', len(chunk))
  230. data = '\x01' + chr(len(chunk)) + chunk
  231. if len(chunk) < 128:
  232. padding = 128 - len(chunk)
  233. log.debug('pad with %d characters', padding)
  234. data = data + (' ' * padding)
  235. log.debug("packet size %d", len(data))
  236. self.write(data)
  237. self._port.flush()
  238. return self.got_ack()
  239. def file_list(self):
  240. log.info('Listing files')
  241. res = self.exchange(LIST_FILES)
  242. log.info(res)
  243. return res
  244. def file_do(self, f):
  245. log.info('Executing '+f)
  246. res = self.exchange('dofile("'+f+'")')
  247. log.info(res)
  248. return res
  249. def file_format(self):
  250. log.info('Formating...')
  251. res = self.exchange('file.format()')
  252. if 'format done' not in res:
  253. log.error(res)
  254. else:
  255. log.info(res)
  256. return res
  257. def file_print(self, f):
  258. log.info('Printing ' + f)
  259. res = self.exchange(PRINT_FILE.format(filename=f))
  260. log.info(res)
  261. return res
  262. def node_heap(self):
  263. log.info('Heap')
  264. res = self.exchange('print(node.heap())')
  265. log.info(res)
  266. return res
  267. def node_restart(self):
  268. log.info('Restart')
  269. res = self.exchange('node.restart()')
  270. log.info(res)
  271. return res
  272. def file_compile(self, path):
  273. log.info('Compile '+path)
  274. cmd = 'node.compile("%s")' % path
  275. res = self.exchange(cmd)
  276. log.info(res)
  277. return res
  278. def file_remove(self, path):
  279. log.info('Remove '+path)
  280. cmd = 'file.remove("%s")' % path
  281. res = self.exchange(cmd)
  282. log.info(res)
  283. return res