NTest.lua 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. local function TERMINAL_HANDLER(e, test, msg, errormsg)
  2. if errormsg then
  3. errormsg = ": "..errormsg
  4. else
  5. errormsg = ""
  6. end
  7. if e == 'start' then
  8. print("######## "..e.."ed "..test.." tests")
  9. elseif e == 'pass' then
  10. print(" "..e.." "..test..': '..msg)
  11. elseif e == 'fail' then
  12. print(" ==> "..e.." "..test..': '..msg..errormsg)
  13. elseif e == 'except' then
  14. print(" ==> "..e.." "..test..': '..msg..errormsg)
  15. elseif e == 'finish' then
  16. print("######## "..e.."ed "..test.." tests")
  17. else
  18. print(e.." "..test)
  19. end
  20. end
  21. -- implement pseudo task handling for on host testing
  22. local drain_post_queue = function() end
  23. if not node then -- assume we run on host, not on MCU
  24. local post_queue = {{},{},{}}
  25. drain_post_queue = function()
  26. while #post_queue[1] + #post_queue[2] + #post_queue[3] > 0 do
  27. for i = 3, 1, -1 do
  28. if #post_queue[i] > 0 then
  29. local f = table.remove(post_queue[i], 1)
  30. if f then
  31. f()
  32. end
  33. break
  34. end
  35. end
  36. end
  37. end
  38. -- luacheck: push ignore 121 122 (setting read-only global variable)
  39. node = {}
  40. node.task = {LOW_PRIORITY = 1, MEDIUM_PRIORITY = 2, HIGH_PRIORITY = 3}
  41. node.task.post = function (p, f)
  42. table.insert(post_queue[p], f)
  43. end
  44. node.setonerror = function(fn) node.Host_Error_Func = fn end -- luacheck: ignore 142
  45. -- luacheck: pop
  46. end
  47. --[[
  48. if equal returns true
  49. if different returns {msg = "<reason>"}
  50. this will be handled spechially by ok and nok
  51. --]]
  52. local function deepeq(a, b)
  53. local function notEqual(m)
  54. return { msg=m }
  55. end
  56. -- Different types: false
  57. if type(a) ~= type(b) then return notEqual("type 1 is "..type(a)..", type 2 is "..type(b)) end
  58. -- Functions
  59. if type(a) == 'function' then
  60. if string.dump(a) == string.dump(b) then
  61. return true
  62. else
  63. return notEqual("functions differ")
  64. end
  65. end
  66. -- Primitives and equal pointers
  67. if a == b then return true end
  68. -- Only equal tables could have passed previous tests
  69. if type(a) ~= 'table' then return notEqual("different "..type(a).."s expected "..a.." vs. "..b) end
  70. -- Compare tables field by field
  71. for k,v in pairs(a) do
  72. if b[k] == nil then return notEqual("key "..k.."only contained in left part") end
  73. local result = deepeq(v, b[k])
  74. if type(result) == 'table' then return result end
  75. end
  76. for k,v in pairs(b) do
  77. if a[k] == nil then return notEqual("key "..k.."only contained in right part") end
  78. local result = deepeq(a[k], v)
  79. if type(result) == 'table' then return result end
  80. end
  81. return true
  82. end
  83. -- Compatibility for Lua 5.1 and Lua 5.2
  84. local function args(...)
  85. return {n=select('#', ...), ...}
  86. end
  87. local function spy(f)
  88. local mt = {}
  89. setmetatable(mt, {__call = function(s, ...)
  90. s.called = s.called or {}
  91. local a = args(...)
  92. table.insert(s.called, {...})
  93. if f then
  94. local r
  95. r = args(pcall(f, unpack(a, 1, a.n)))
  96. if not r[1] then
  97. s.errors = s.errors or {}
  98. s.errors[#s.called] = r[2]
  99. else
  100. return unpack(r, 2, r.n)
  101. end
  102. end
  103. end})
  104. return mt
  105. end
  106. local function getstackframe()
  107. -- debug.getinfo() does not exist in NodeMCU Lua 5.1
  108. if debug.getinfo then
  109. return debug.getinfo(5, 'S').short_src:match("([^\\/]*)$")..":"..debug.getinfo(5, 'l').currentline
  110. end
  111. local msg
  112. msg = debug.traceback()
  113. msg = msg:match("\t[^\t]*\t[^\t]*\t[^\t]*\t[^\t]*\t([^\t]*): in") -- Get 5th stack frame
  114. msg = msg:match(".-([^\\/]*)$") -- cut off path of filename
  115. return msg
  116. end
  117. local function assertok(handler, name, invert, cond, msg)
  118. local errormsg
  119. -- check if cond is return object of 'eq' call
  120. if type(cond) == 'table' and cond.msg then
  121. errormsg = cond.msg
  122. cond = false
  123. end
  124. if not msg then
  125. msg = getstackframe()
  126. end
  127. if invert then
  128. cond = not cond
  129. end
  130. if cond then
  131. handler('pass', name, msg)
  132. else
  133. handler('fail', name, msg, errormsg)
  134. error('_*_TestAbort_*_')
  135. end
  136. end
  137. local function fail(handler, name, func, expected, msg)
  138. local status, err = pcall(func)
  139. if not msg then
  140. msg = getstackframe()
  141. end
  142. if status then
  143. local messageParts = {"Expected to fail with Error"}
  144. if expected then
  145. messageParts[2] = " containing \"" .. expected .. "\""
  146. end
  147. handler('fail', name, msg, table.concat(messageParts, ""))
  148. error('_*_TestAbort_*_')
  149. end
  150. if (expected and not string.find(err, expected)) then
  151. err = err:match(".-([^\\/]*)$") -- cut off path of filename
  152. handler('fail', name, msg, "expected errormessage \"" .. err .. "\" to contain \"" .. expected .. "\"")
  153. error('_*_TestAbort_*_')
  154. end
  155. handler('pass', name, msg)
  156. end
  157. local nmt = {
  158. env = _G,
  159. outputhandler = TERMINAL_HANDLER
  160. }
  161. nmt.__index = nmt
  162. return function(testrunname)
  163. local pendingtests = {}
  164. local started
  165. local N = setmetatable({}, nmt)
  166. local function runpending()
  167. if pendingtests[1] ~= nil then
  168. node.task.post(node.task.LOW_PRIORITY, function()
  169. pendingtests[1](runpending)
  170. end)
  171. else
  172. N.outputhandler('finish', testrunname)
  173. end
  174. end
  175. local function copyenv(dest, src)
  176. dest.eq = src.eq
  177. dest.spy = src.spy
  178. dest.ok = src.ok
  179. dest.nok = src.nok
  180. dest.fail = src.fail
  181. end
  182. local function testimpl(name, f, async)
  183. local testfn = function(next)
  184. local prev = {}
  185. copyenv(prev, N.env)
  186. local handler = N.outputhandler
  187. local restore = function(err)
  188. if err then
  189. err = err:match(".-([^\\/]*)$") -- cut off path of filename
  190. if not err:match('_*_TestAbort_*_') then
  191. handler('except', name, err)
  192. end
  193. end
  194. if node then node.setonerror() end
  195. copyenv(N.env, prev)
  196. handler('end', name)
  197. table.remove(pendingtests, 1)
  198. collectgarbage()
  199. if next then next() end
  200. end
  201. local function wrap(method, ...)
  202. method(handler, name, ...)
  203. end
  204. local function cbError(err)
  205. err = err:match(".-([^\\/]*)$") -- cut off path of filename
  206. if not err:match('_*_TestAbort_*_') then
  207. handler('except', name, err)
  208. end
  209. restore()
  210. end
  211. local env = N.env
  212. env.eq = deepeq
  213. env.spy = spy
  214. env.ok = function (cond, msg) wrap(assertok, false, cond, msg) end
  215. env.nok = function(cond, msg) wrap(assertok, true, cond, msg) end
  216. env.fail = function (func, expected, msg) wrap(fail, func, expected, msg) end
  217. handler('begin', name)
  218. node.setonerror(cbError)
  219. local ok, err = pcall(f, async and restore)
  220. if not ok then
  221. err = err:match(".-([^\\/]*)$") -- cut off path of filename
  222. if not err:match('_*_TestAbort_*_') then
  223. handler('except', name, err)
  224. end
  225. if async then
  226. restore()
  227. end
  228. end
  229. if not async then
  230. restore()
  231. end
  232. end
  233. if not started then
  234. N.outputhandler('start', testrunname)
  235. started = true
  236. end
  237. table.insert(pendingtests, testfn)
  238. if #pendingtests == 1 then
  239. runpending()
  240. drain_post_queue()
  241. end
  242. end
  243. function N.test(name, f)
  244. testimpl(name, f)
  245. end
  246. function N.testasync(name, f)
  247. testimpl(name, f, true)
  248. end
  249. local currentCoName
  250. function N.testco(name, func)
  251. -- local t = tmr.create();
  252. local co
  253. N.testasync(name, function(Next)
  254. currentCoName = name
  255. local function getCB(cbName)
  256. return function(...) -- upval: co, cbName
  257. local result, err = coroutine.resume(co, cbName, ...)
  258. if (not result) then
  259. if (name == currentCoName) then
  260. currentCoName = nil
  261. Next(err)
  262. else
  263. N.outputhandler('fail', name, "Found stray Callback '"..cbName.."' from test '"..name.."'")
  264. end
  265. elseif coroutine.status(co) == "dead" then
  266. currentCoName = nil
  267. Next()
  268. end
  269. end
  270. end
  271. local function waitCb()
  272. return coroutine.yield()
  273. end
  274. co = coroutine.create(function(wr, wa)
  275. func(wr, wa)
  276. end)
  277. local result, err = coroutine.resume(co, getCB, waitCb)
  278. if (not result) then
  279. currentCoName = nil
  280. Next(err)
  281. elseif coroutine.status(co) == "dead" then
  282. currentCoName = nil
  283. Next()
  284. end
  285. end)
  286. end
  287. return N
  288. end