Browse Source

Merge pull request #10 from lrozema/master

Speed and stability improvements
Peter Magnusson 9 years ago
parent
commit
5c4dd0ed36
1 changed files with 103 additions and 75 deletions
  1. 103 75
      nodemcu-uploader.py

+ 103 - 75
nodemcu-uploader.py

@@ -57,7 +57,41 @@ CHUNK_REPLY = '\v'
 class Uploader:
     BAUD = 9600
     PORT = '/dev/ttyUSB0'
-    TIMEOUT = 1
+    TIMEOUT = 5
+
+    def expect(self, exp='> ', timeout=TIMEOUT):
+        t = self._port.timeout
+
+        # Checking for new data every 100us is fast enough
+        lt = 0.0001
+        if self._port.timeout != lt:
+            self._port.timeout = lt
+
+        end = time.time() + timeout
+
+        # Finish as soon as either exp matches or we run out of time (work like dump, but faster on success)
+        data = ''
+        while not data.endswith(exp) and time.time() <= end:
+            data += self._port.read()
+
+        self._port.timeout = t
+        log.debug('expect return: %s', data)
+        return data
+
+    def write(self, output, binary=False):
+        if not binary:
+            log.debug('write: %s', output)
+        else:
+            log.debug('write binary: %s' % ':'.join(x.encode('hex') for x in output))
+        self._port.write(output)
+        self._port.flush()
+
+    def writeln(self, output):
+        self.write(output + '\n')
+
+    def exchange(self, output):
+        self.writeln(output)
+        return self.expect()
 
     def __init__(self, port = 0, baud = BAUD):
         self._port = serial.Serial(port, Uploader.BAUD, timeout=Uploader.TIMEOUT)
@@ -67,50 +101,46 @@ class Uploader:
         ## DTR = GPIO0
         self._port.setRTS(False)
         self._port.setDTR(False)
-        time.sleep(0.5)
-        self.dump()
+
+        # Get in sync with LUA (this assumes that NodeMCU gets reset by the previous two lines)
+        self.expect('NodeMCU ')
+        self.expect()
+        self.exchange('')
 
         if baud != Uploader.BAUD:
             log.info('Changing communication to %s baud', baud)
-            self._port.write('uart.setup(0,%s,8,0,1,1)\r\n' % baud)
-            log.info(self.dump())
-            self._port.close()
-            self._port = serial.Serial(port, baud, timeout=Uploader.TIMEOUT)
+            self.writeln('uart.setup(0,%s,8,0,1,1)' % baud)
+
+            # Wait for the string to be sent before switching baud
+            time.sleep(0.1)
+            self._port.setBaudrate(baud)
+
+            # Get in sync again
+            self.exchange('')
+            self.exchange('')
 
         self.line_number = 0
 
     def close(self):
-        self._port.write('uart.setup(0,%s,8,0,1,1)\r\n' % Uploader.BAUD)
+        self.writeln('uart.setup(0,%s,8,0,1,1)' % Uploader.BAUD)
         self._port.close()
 
-    def dump(self, timeout=TIMEOUT):
-        t = self._port.timeout
-        if self._port.timeout != timeout:
-            self._port.timeout = timeout
-        n = self._port.read()
-        data = ''
-        while n != '':
-            data += n
-            n = self._port.read()
-
-        self._port.timeout = t
-        return data
-
-
     def prepare(self):
         log.info('Preparing esp for transfer.')
-        self.write_lines(save_lua.replace('9600', '%d' % self._port.baudrate))
-        self._port.write('\r\n')
 
-        d = self.dump(0.1)
-        if 'unexpected' in d or len(d) > len(save_lua)+10:
-            log.error('error in save_lua "%s"' % d)
-            return
+        data = save_lua.replace('9600', '%d' % self._port.baudrate)
+        lines = data.replace('\r', '').split('\n')
+
+        for line in lines:
+            d = self.exchange(line)
+
+            if 'unexpected' in d or len(d) > len(save_lua)+10:
+                log.error('error in save_lua "%s"' % d)
+                return
 
     def download_file(self, filename):
-        self.dump()
-        self._port.write(r"file.open('" + filename + r"') print(file.seek('end', 0)) file.seek('set', 0) uart.write(0, file.read()) file.close()" + '\n')
-        cmd, size, data = self.dump().split('\n', 2)
+        d = self.exchange(r"file.open('" + filename + r"') print(file.seek('end', 0)) file.seek('set', 0) uart.write(0, file.read()) file.close()")
+        cmd, size, data = d.split('\n', 2)
         data = data[0:int(size)]
         return data
 
@@ -120,28 +150,23 @@ class Uploader:
         log.info('Transfering %s to %s' %(filename, destination))
         data = self.download_file(filename)
         with open(destination, 'w') as f:
-          f.write(data)
+            f.write(data)
 
     def write_file(self, path, destination = '', verify = False):
         filename = os.path.basename(path)
         if not destination:
             destination = filename
         log.info('Transfering %s as %s' %(filename, destination))
-        self.dump()
-        self._port.write(r"recv()" + '\n')
-
-        count = 0
-        while not 'C' in self.dump(0.2):
-            time.sleep(1)
-            count += 1
-            if count > 5:
-                log.error('Error waiting for esp "%s"' % self.dump())
-                return
-        self.dump(0.5)
+        self.writeln("recv()")
+
+        r = self.expect('C> ')
+        if not r.endswith('C> '):
+            log.error('Error waiting for esp "%s"' % r)
+            return
         log.debug('sending destination filename "%s"', destination)
-        self._port.write(destination + '\x00')
+        self.write(destination + '\x00', True)
         if not self.got_ack():
-            log.error('did not ack destination filename: "%s"' % self.dump())
+            log.error('did not ack destination filename')
             return
 
         f = open( path, 'rt' ); content = f.read(); f.close()
@@ -156,17 +181,15 @@ class Uploader:
 
             data = content[pos:pos+rest]
             if not self.write_chunk(data):
-                error = True
-                d = self.dump()
+                d = self.expect()
                 log.error('Bad chunk response "%s" %s' % (d, ':'.join(x.encode('hex') for x in d)))
-                break
+                return
 
             pos += chunk_size
 
         log.debug('sending zero block')
-        if not error:
-            #zero size block
-            self.write_chunk('')
+        #zero size block
+        self.write_chunk('')
 
         if verify:
             log.info('Verifying...')
@@ -177,6 +200,7 @@ class Uploader:
     def got_ack(self):
         log.debug('waiting for ack')
         r = self._port.read(1)
+        log.debug('ack read %s', r.encode('hex'))
         return r == '\x06' #ACK
 
 
@@ -184,9 +208,8 @@ class Uploader:
         lines = data.replace('\r', '').split('\n')
 
         for line in lines:
-            self._port.write(line + '\r\n')
-            d = self.dump(0.1)
-            log.debug(d)
+            self.exchange(line)
+
         return
 
 
@@ -198,55 +221,55 @@ class Uploader:
             log.debug('pad with %d characters' % padding)
             data = data + (' ' * padding)
         log.debug("packet size %d" % len(data))
-        self._port.write(data)
+        self.write(data)
 
         return self.got_ack()
 
 
     def file_list(self):
         log.info('Listing files')
-        self._port.write('for key,value in pairs(file.list()) do print(key,value) end' + '\r\n')
-        r = self.dump()
+        r = self.exchange('for key,value in pairs(file.list()) do print(key,value) end')
+        log.info(r)
+        return r
+
+    def file_do(self, f):
+        log.info('Executing '+f)
+        r = self.exchange('dofile("'+f+'")')
         log.info(r)
         return r
 
     def file_format(self):
         log.info('Formating...')
-        self._port.write('file.format()' + '\r\n')
-        r = self.dump()
-        while(r == '') or not ('format done' in r):
-            r = self.dump()
-            if r != '':
-                log.info(r)
+        r = self.exchange('file.format()')
+        if 'format done' not in r:
+            log.error(r)
+        else:
+            log.info(r)
         return r
 
     def node_heap(self):
         log.info('Heap')
-        self._port.write('print(node.heap())\r\n')
-        r = self.dump()
+        r = self.exchange('print(node.heap())')
         log.info(r)
         return r
 
     def node_restart(self):
         log.info('Restart')
-        self._port.write('node.restart()' +'\r\n')
-        r = self.dump()
+        r = self.exchange('node.restart()')
         log.info(r)
         return r
     
     def file_compile(self, path):
         log.info('Compile '+path)
         cmd = 'node.compile("%s")' % path
-        self._port.write(cmd + '\r\n')
-        r = self.dump()
+        r = self.exchange(cmd)
         log.info(r)
         return r
     
     def file_remove(self, path):
         log.info('Remove '+path)
         cmd = 'file.remove("%s")' % path
-        self._port.write(cmd + '\r\n')
-        r = self.dump()
+        r = self.exchange(cmd)
         log.info(r)
         return r
 
@@ -336,7 +359,8 @@ if __name__ == '__main__':
         'file',
         help = 'File functions')
 
-    file_parser.add_argument('cmd', choices=('list', 'format'))
+    file_parser.add_argument('cmd', choices=('list', 'do', 'format'))
+    file_parser.add_argument('filename', nargs='*', help = 'Lua file to run.')
 
     node_parse = subparsers.add_parser(
         'node', 
@@ -350,10 +374,11 @@ if __name__ == '__main__':
     formatter = logging.Formatter('%(message)s')
     logging.basicConfig(level=logging.INFO, format='%(message)s')
 
-    uploader = Uploader(args.port, args.baud)
     if args.verbose:
         log.setLevel(logging.DEBUG)
 
+    uploader = Uploader(args.port, args.baud)
+
     if args.operation == 'upload' or args.operation == 'download':
         sources = args.filename
         destinations = []
@@ -378,7 +403,7 @@ if __name__ == '__main__':
 
             if args.restart:
                 uploader.node_restart()
-            print 'All done!'
+            log.info('All done!')
 
         if args.operation == 'download':
             if len(destinations) == len(sources):
@@ -386,11 +411,14 @@ if __name__ == '__main__':
                     uploader.read_file(f, d)
             else:
                 raise Exception('You must specify a destination filename for each file you want to download.')
-            print 'All done!'
+            log.info('All done!')
 
     elif args.operation == 'file':
         if args.cmd == 'list':
             uploader.file_list()
+        if args.cmd == 'do':
+            for f in args.filename:
+                uploader.file_do(f)
         elif args.cmd == 'format':
             uploader.file_format()