SkVM.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. /*
  2. * Copyright 2019 Google LLC
  3. *
  4. * Use of this source code is governed by a BSD-style license that can be
  5. * found in the LICENSE file.
  6. */
  7. #ifndef SkVM_DEFINED
  8. #define SkVM_DEFINED
  9. #include "include/core/SkTypes.h"
  10. #include "include/private/SkTHash.h"
  11. #include <vector>
  12. namespace skvm {
  13. class Assembler {
  14. public:
  15. explicit Assembler(void* buf);
  16. size_t size() const;
  17. // Order matters... GP64, Xmm, Ymm values match 4-bit register encoding for each.
  18. enum GP64 {
  19. rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi,
  20. r8 , r9 , r10, r11, r12, r13, r14, r15,
  21. };
  22. enum Xmm {
  23. xmm0, xmm1, xmm2 , xmm3 , xmm4 , xmm5 , xmm6 , xmm7 ,
  24. xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15,
  25. };
  26. enum Ymm {
  27. ymm0, ymm1, ymm2 , ymm3 , ymm4 , ymm5 , ymm6 , ymm7 ,
  28. ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15,
  29. };
  30. // X and V values match 5-bit encoding for each (nothing tricky).
  31. enum X {
  32. x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ,
  33. x8 , x9 , x10, x11, x12, x13, x14, x15,
  34. x16, x17, x18, x19, x20, x21, x22, x23,
  35. x24, x25, x26, x27, x28, x29, x30, xzr,
  36. };
  37. enum V {
  38. v0 , v1 , v2 , v3 , v4 , v5 , v6 , v7 ,
  39. v8 , v9 , v10, v11, v12, v13, v14, v15,
  40. v16, v17, v18, v19, v20, v21, v22, v23,
  41. v24, v25, v26, v27, v28, v29, v30, v31,
  42. };
  43. void bytes(const void*, int);
  44. void byte(uint8_t);
  45. void word(uint32_t);
  46. // x86-64
  47. void align(int mod);
  48. void vzeroupper();
  49. void ret();
  50. void add(GP64, int imm);
  51. void sub(GP64, int imm);
  52. // All dst = x op y.
  53. using DstEqXOpY = void(Ymm dst, Ymm x, Ymm y);
  54. DstEqXOpY vpand, vpor, vpxor, vpandn,
  55. vpaddd, vpsubd, vpmulld,
  56. vpsubw, vpmullw,
  57. vaddps, vsubps, vmulps, vdivps,
  58. vfmadd132ps, vfmadd213ps, vfmadd231ps,
  59. vpackusdw, vpackuswb;
  60. using DstEqXOpImm = void(Ymm dst, Ymm x, int imm);
  61. DstEqXOpImm vpslld, vpsrld, vpsrad,
  62. vpsrlw,
  63. vpermq;
  64. using DstEqOpX = void(Ymm dst, Ymm x);
  65. DstEqOpX vcvtdq2ps, vcvttps2dq;
  66. struct Label {
  67. int offset = 0;
  68. enum { None, ARMDisp19, X86Disp32 } kind = None;
  69. std::vector<int> references;
  70. };
  71. Label here();
  72. void label(Label*);
  73. void jmp(Label*);
  74. void je (Label*);
  75. void jne(Label*);
  76. void jl (Label*);
  77. void cmp(GP64, int imm);
  78. void vbroadcastss(Ymm dst, Label*);
  79. void vpshufb(Ymm dst, Ymm x, Label*);
  80. void vmovups (Ymm dst, GP64 ptr); // dst = *ptr, 256-bit
  81. void vpmovzxbd(Ymm dst, GP64 ptr); // dst = *ptr, 64-bit, each uint8_t expanded to int
  82. void vmovd (Xmm dst, GP64 ptr); // dst = *ptr, 32-bit
  83. void vmovups(GP64 ptr, Ymm src); // *ptr = src, 256-bit
  84. void vmovq (GP64 ptr, Xmm src); // *ptr = src, 64-bit
  85. void vmovd (GP64 ptr, Xmm src); // *ptr = src, 32-bit
  86. void movzbl(GP64 dst, GP64 ptr); // dst = *ptr, 8-bit, uint8_t expanded to int
  87. void movb (GP64 ptr, GP64 src); // *ptr = src, 8-bit
  88. void vmovd_direct(GP64 dst, Xmm src); // dst = src, 32-bit
  89. void vmovd_direct(Xmm dst, GP64 src); // dst = src, 32-bit
  90. void vpinsrb(Xmm dst, Xmm src, GP64 ptr, int imm); // dst = src; dst[imm] = *ptr, 8-bit
  91. void vpextrb(GP64 ptr, Xmm src, int imm); // *dst = src[imm] , 8-bit
  92. // aarch64
  93. // d = op(n,m)
  94. using DOpNM = void(V d, V n, V m);
  95. DOpNM and16b, orr16b, eor16b, bic16b,
  96. add4s, sub4s, mul4s,
  97. sub8h, mul8h,
  98. fadd4s, fsub4s, fmul4s, fdiv4s,
  99. tbl;
  100. // d += n*m
  101. void fmla4s(V d, V n, V m);
  102. // d = op(n,imm)
  103. using DOpNImm = void(V d, V n, int imm);
  104. DOpNImm sli4s,
  105. shl4s, sshr4s, ushr4s,
  106. ushr8h;
  107. // d = op(n)
  108. using DOpN = void(V d, V n);
  109. DOpN scvtf4s, // int -> float
  110. fcvtzs4s, // truncate float -> int
  111. xtns2h, // u32 -> u16
  112. xtnh2b, // u16 -> u8
  113. uxtlb2h, // u8 -> u16
  114. uxtlh2s; // u16 -> u32
  115. // TODO: both these platforms support rounding float->int (vcvtps2dq, fcvtns.4s)... use?
  116. void ret (X);
  117. void add (X d, X n, int imm12);
  118. void sub (X d, X n, int imm12);
  119. void subs(X d, X n, int imm12); // subtract setting condition flags
  120. // There's another encoding for unconditional branches that can jump further,
  121. // but this one encoded as b.al is simple to implement and should be fine.
  122. void b (Label* l) { this->b(Condition::al, l); }
  123. void bne(Label* l) { this->b(Condition::ne, l); }
  124. void blt(Label* l) { this->b(Condition::lt, l); }
  125. // "cmp ..." is just an assembler mnemonic for "subs xzr, ..."!
  126. void cmp(X n, int imm12) { this->subs(xzr, n, imm12); }
  127. // Compare and branch if zero/non-zero, as if
  128. // cmp(t,0)
  129. // beq/bne(l)
  130. // but without setting condition flags.
  131. void cbz (X t, Label* l);
  132. void cbnz(X t, Label* l);
  133. void ldrq(V dst, Label*); // 128-bit PC-relative load
  134. void ldrq(V dst, X src); // 128-bit dst = *src
  135. void ldrs(V dst, X src); // 32-bit dst = *src
  136. void ldrb(V dst, X src); // 8-bit dst = *src
  137. void strq(V src, X dst); // 128-bit *dst = src
  138. void strs(V src, X dst); // 32-bit *dst = src
  139. void strb(V src, X dst); // 8-bit *dst = src
  140. private:
  141. // dst = op(dst, imm)
  142. void op(int opcode, int opcode_ext, GP64 dst, int imm);
  143. // dst = op(x,y) or op(x)
  144. void op(int prefix, int map, int opcode, Ymm dst, Ymm x, Ymm y, bool W=false);
  145. void op(int prefix, int map, int opcode, Ymm dst, Ymm x, bool W=false) {
  146. // Two arguments ops seem to pass them in dst and y, forcing x to 0 so VEX.vvvv == 1111.
  147. this->op(prefix, map, opcode, dst,(Ymm)0,x, W);
  148. }
  149. // dst = op(x,imm)
  150. void op(int prefix, int map, int opcode, int opcode_ext, Ymm dst, Ymm x, int imm);
  151. // dst = op(x,label) or op(label)
  152. void op(int prefix, int map, int opcode, Ymm dst, Ymm x, Label* l);
  153. void op(int prefix, int map, int opcode, Ymm dst, Label* l) {
  154. this->op(prefix, map, opcode, dst, (Ymm)0, l);
  155. }
  156. // *ptr = ymm or ymm = *ptr, depending on opcode.
  157. void load_store(int prefix, int map, int opcode, Ymm ymm, GP64 ptr);
  158. // Opcode for 3-arguments ops is split between hi and lo:
  159. // [11 bits hi] [5 bits m] [6 bits lo] [5 bits n] [5 bits d]
  160. void op(uint32_t hi, V m, uint32_t lo, V n, V d);
  161. // 2-argument ops, with or without an immediate.
  162. void op(uint32_t op22, int imm, V n, V d);
  163. void op(uint32_t op22, V n, V d) { this->op(op22,0,n,d); }
  164. void op(uint32_t op22, X x, V v) { this->op(op22,0,(V)x,v); }
  165. // Order matters... value is 4-bit encoding for condition code.
  166. enum class Condition { eq,ne,cs,cc,mi,pl,vs,vc,hi,ls,ge,lt,gt,le,al };
  167. void b(Condition, Label*);
  168. void jump(uint8_t condition, Label*);
  169. int disp19(Label*);
  170. int disp32(Label*);
  171. uint8_t* fCode;
  172. uint8_t* fCurr;
  173. size_t fSize;
  174. };
  175. enum class Op : uint8_t {
  176. store8, store32,
  177. load8, load32,
  178. splat,
  179. add_f32, sub_f32, mul_f32, div_f32, mad_f32,
  180. add_i32, sub_i32, mul_i32,
  181. sub_i16x2, mul_i16x2, shr_i16x2,
  182. bit_and, bit_or, bit_xor, bit_clear,
  183. shl, shr, sra,
  184. extract,
  185. pack,
  186. bytes,
  187. to_f32, to_i32,
  188. };
  189. using Val = int;
  190. // We reserve the last Val ID as a sentinel meaning none, n/a, null, nil, etc.
  191. static const Val NA = ~0;
  192. struct Arg { int ix; };
  193. struct I32 { Val id; };
  194. struct F32 { Val id; };
  195. class Program;
  196. class Builder {
  197. public:
  198. struct Instruction {
  199. Op op; // v* = op(x,y,z,imm), where * == index of this Instruction.
  200. Val x,y,z; // Enough arguments for mad().
  201. int imm; // Immediate bit pattern, shift count, argument index, etc.
  202. // Not populated until done() has been called.
  203. int death; // Index of last live instruction taking this input; live if != 0.
  204. bool hoist; // Value independent of all loop variables?
  205. };
  206. Program done(const char* debug_name = nullptr);
  207. // Declare a varying argument with given stride.
  208. Arg arg(int stride);
  209. // Convenience arg() wrapper for most common stride, sizeof(T).
  210. template <typename T>
  211. Arg arg() { return this->arg(sizeof(T)); }
  212. void store8 (Arg ptr, I32 val);
  213. void store32(Arg ptr, I32 val);
  214. I32 load8 (Arg ptr);
  215. I32 load32(Arg ptr);
  216. I32 splat(int n);
  217. I32 splat(unsigned u) { return this->splat((int)u); }
  218. F32 splat(float f);
  219. F32 add(F32 x, F32 y);
  220. F32 sub(F32 x, F32 y);
  221. F32 mul(F32 x, F32 y);
  222. F32 div(F32 x, F32 y);
  223. F32 mad(F32 x, F32 y, F32 z);
  224. I32 add(I32 x, I32 y);
  225. I32 sub(I32 x, I32 y);
  226. I32 mul(I32 x, I32 y);
  227. I32 sub_16x2(I32 x, I32 y);
  228. I32 mul_16x2(I32 x, I32 y);
  229. I32 shr_16x2(I32 x, int bits);
  230. I32 bit_and (I32 x, I32 y);
  231. I32 bit_or (I32 x, I32 y);
  232. I32 bit_xor (I32 x, I32 y);
  233. I32 bit_clear(I32 x, I32 y); // x & ~y
  234. I32 shl(I32 x, int bits);
  235. I32 shr(I32 x, int bits);
  236. I32 sra(I32 x, int bits);
  237. I32 extract(I32 x, int bits, I32 y); // (x >> bits) & y
  238. I32 pack (I32 x, I32 y, int bits); // x | (y << bits), assuming (x & (y << bits)) == 0
  239. // Shuffle the bytes in x according to each nibble of control, as if
  240. //
  241. // uint8_t bytes[] = {
  242. // 0,
  243. // ((uint32_t)x ) & 0xff,
  244. // ((uint32_t)x >> 8) & 0xff,
  245. // ((uint32_t)x >> 16) & 0xff,
  246. // ((uint32_t)x >> 24) & 0xff,
  247. // };
  248. // return (uint32_t)bytes[(control >> 0) & 0xf] << 0
  249. // | (uint32_t)bytes[(control >> 4) & 0xf] << 8
  250. // | (uint32_t)bytes[(control >> 8) & 0xf] << 16
  251. // | (uint32_t)bytes[(control >> 12) & 0xf] << 24;
  252. //
  253. // So, e.g.,
  254. // - bytes(x, 0x1111) splats the low byte of x to all four bytes
  255. // - bytes(x, 0x4321) is x, an identity
  256. // - bytes(x, 0x0000) is 0
  257. // - bytes(x, 0x0404) transforms an RGBA pixel into an A0A0 bit pattern.
  258. //
  259. I32 bytes(I32 x, int control);
  260. F32 to_f32(I32 x);
  261. I32 to_i32(F32 x);
  262. std::vector<Instruction> program() const { return fProgram; }
  263. private:
  264. struct InstructionHash {
  265. template <typename T>
  266. static size_t Hash(T val) {
  267. return std::hash<T>{}(val);
  268. }
  269. size_t operator()(const Instruction& inst) const {
  270. return Hash((uint8_t)inst.op)
  271. ^ Hash(inst.x)
  272. ^ Hash(inst.y)
  273. ^ Hash(inst.z)
  274. ^ Hash(inst.imm)
  275. ^ Hash(inst.death)
  276. ^ Hash(inst.hoist);
  277. }
  278. };
  279. Val push(Op, Val x, Val y=NA, Val z=NA, int imm=0);
  280. bool isZero(Val) const;
  281. SkTHashMap<Instruction, Val, InstructionHash> fIndex;
  282. std::vector<Instruction> fProgram;
  283. std::vector<int> fStrides;
  284. };
  285. using Reg = int;
  286. class Program {
  287. public:
  288. struct Instruction { // d = op(x, y, z/imm)
  289. Op op;
  290. Reg d,x,y;
  291. union { Reg z; int imm; };
  292. };
  293. Program(const std::vector<Builder::Instruction>& instructions,
  294. const std::vector<int> & strides,
  295. const char* debug_name);
  296. Program() : Program({}, {}, nullptr) {}
  297. ~Program();
  298. Program(Program&&);
  299. Program& operator=(Program&&);
  300. Program(const Program&) = delete;
  301. Program& operator=(const Program&) = delete;
  302. template <typename... T>
  303. void eval(int n, T*... arg) const {
  304. void* args[] = { (void*)arg..., nullptr };
  305. this->eval(n, args);
  306. }
  307. std::vector<Instruction> instructions() const { return fInstructions; }
  308. int nregs() const { return fRegs; }
  309. int loop() const { return fLoop; }
  310. // If this Program has been JITted, drop it, forcing interpreter fallback.
  311. void dropJIT();
  312. private:
  313. void eval(int n, void* args[]) const;
  314. void setupInterpreter(const std::vector<Builder::Instruction>&);
  315. void setupJIT (const std::vector<Builder::Instruction>&, const char* debug_name);
  316. bool jit(const std::vector<Builder::Instruction>&,
  317. bool hoist,
  318. Assembler*) const;
  319. // Dump jit-*.dump files for perf inject.
  320. void dumpJIT(const char* debug_name, size_t size) const;
  321. std::vector<Instruction> fInstructions;
  322. int fRegs;
  323. int fLoop;
  324. std::vector<int> fStrides;
  325. void* fJITBuf = nullptr;
  326. size_t fJITSize = 0;
  327. };
  328. // TODO: comparison operations, if_then_else
  329. // TODO: learn how to do control flow
  330. // TODO: gather, load_uniform
  331. // TODO: 16- and 64-bit loads and stores
  332. // TODO: 16- and 64-bit values?
  333. // TODO: x86-64 SSE2 / SSE4.1 / AVX2 / AVX-512F JIT
  334. // TODO: ARMv8 JIT
  335. // TODO: ARMv8.2+FP16 JIT
  336. // TODO: ARMv7 NEON JIT?
  337. // TODO: lower to LLVM?
  338. // TODO: lower to WebASM?
  339. }
  340. #endif//SkVM_DEFINED