SkSLJIT.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. /*
  2. * Copyright 2018 Google Inc.
  3. *
  4. * Use of this source code is governed by a BSD-style license that can be
  5. * found in the LICENSE file.
  6. */
  7. #ifndef SKSL_JIT
  8. #define SKSL_JIT
  9. #ifdef SK_LLVM_AVAILABLE
  10. #include "src/sksl/ir/SkSLBinaryExpression.h"
  11. #include "src/sksl/ir/SkSLBreakStatement.h"
  12. #include "src/sksl/ir/SkSLContinueStatement.h"
  13. #include "src/sksl/ir/SkSLDoStatement.h"
  14. #include "src/sksl/ir/SkSLExpression.h"
  15. #include "src/sksl/ir/SkSLForStatement.h"
  16. #include "src/sksl/ir/SkSLFunctionCall.h"
  17. #include "src/sksl/ir/SkSLFunctionDefinition.h"
  18. #include "src/sksl/ir/SkSLIfStatement.h"
  19. #include "src/sksl/ir/SkSLIndexExpression.h"
  20. #include "src/sksl/ir/SkSLPostfixExpression.h"
  21. #include "src/sksl/ir/SkSLPrefixExpression.h"
  22. #include "src/sksl/ir/SkSLProgram.h"
  23. #include "src/sksl/ir/SkSLReturnStatement.h"
  24. #include "src/sksl/ir/SkSLStatement.h"
  25. #include "src/sksl/ir/SkSLSwizzle.h"
  26. #include "src/sksl/ir/SkSLTernaryExpression.h"
  27. #include "src/sksl/ir/SkSLVarDeclarationsStatement.h"
  28. #include "src/sksl/ir/SkSLVariableReference.h"
  29. #include "src/sksl/ir/SkSLWhileStatement.h"
  30. #include "llvm-c/Analysis.h"
  31. #include "llvm-c/Core.h"
  32. #include "llvm-c/OrcBindings.h"
  33. #include "llvm-c/Support.h"
  34. #include "llvm-c/Target.h"
  35. #include "llvm-c/Transforms/PassManagerBuilder.h"
  36. #include "llvm-c/Types.h"
  37. #include <stack>
  38. class SkRasterPipeline;
  39. namespace SkSL {
  40. struct AppendStage;
  41. /**
  42. * A just-in-time compiler for SkSL code which uses an LLVM backend. Only available when the
  43. * skia_llvm_path gn arg is set.
  44. *
  45. * Example of using SkSLJIT to set up an SkJumper pipeline stage:
  46. *
  47. * #ifdef SK_LLVM_AVAILABLE
  48. * SkSL::Compiler compiler;
  49. * SkSL::Program::Settings settings;
  50. * std::unique_ptr<SkSL::Program> program = compiler.convertProgram(
  51. SkSL::Program::kPipelineStage_Kind,
  52. * "void swap(int x, int y, inout float4 color) {"
  53. * " color.rb = color.br;"
  54. * "}",
  55. * settings);
  56. * if (!program) {
  57. * printf("%s\n", compiler.errorText().c_str());
  58. * abort();
  59. * }
  60. * SkSL::JIT& jit = *scratch->make<SkSL::JIT>(&compiler);
  61. * std::unique_ptr<SkSL::JIT::Module> module = jit.compile(std::move(program));
  62. * void* func = module->getJumperStage("swap");
  63. * p->append(func, nullptr);
  64. * #endif
  65. */
  66. class JIT {
  67. typedef int StackIndex;
  68. public:
  69. class Module {
  70. public:
  71. /**
  72. * Returns the address of a symbol in the module.
  73. */
  74. void* getSymbol(const char* name);
  75. /**
  76. * Returns the address of a function as an SkJumper pipeline stage. The function must have
  77. * the signature void <name>(int x, int y, inout float4 color). The returned function will
  78. * have the correct signature to function as an SkJumper stage (meaning it will actually
  79. * have a different signature at runtime, accepting vector parameters and operating on
  80. * multiple pixels simultaneously as is normal for SkJumper stages).
  81. */
  82. void* getJumperStage(const char* name);
  83. ~Module() {
  84. LLVMOrcDisposeSharedModuleRef(fSharedModule);
  85. }
  86. private:
  87. Module(std::unique_ptr<Program> program,
  88. LLVMSharedModuleRef sharedModule,
  89. LLVMOrcJITStackRef jitStack)
  90. : fProgram(std::move(program))
  91. , fSharedModule(sharedModule)
  92. , fJITStack(jitStack) {}
  93. std::unique_ptr<Program> fProgram;
  94. LLVMSharedModuleRef fSharedModule;
  95. LLVMOrcJITStackRef fJITStack;
  96. friend class JIT;
  97. };
  98. JIT(Compiler* compiler);
  99. ~JIT();
  100. /**
  101. * Just-in-time compiles an SkSL program and returns the resulting Module. The JIT must not be
  102. * destroyed before all of its Modules are destroyed.
  103. */
  104. std::unique_ptr<Module> compile(std::unique_ptr<Program> program);
  105. private:
  106. static constexpr int CHANNELS = 4;
  107. enum TypeKind {
  108. kFloat_TypeKind,
  109. kInt_TypeKind,
  110. kUInt_TypeKind,
  111. kBool_TypeKind
  112. };
  113. class LValue {
  114. public:
  115. virtual ~LValue() {}
  116. virtual LLVMValueRef load(LLVMBuilderRef builder) = 0;
  117. virtual void store(LLVMBuilderRef builder, LLVMValueRef value) = 0;
  118. };
  119. void addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
  120. std::vector<LLVMTypeRef> parameters);
  121. void loadBuiltinFunctions();
  122. void setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block);
  123. LLVMTypeRef getType(const Type& type);
  124. TypeKind typeKind(const Type& type);
  125. std::unique_ptr<LValue> getLValue(LLVMBuilderRef builder, const Expression& expr);
  126. void vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns);
  127. void vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
  128. LLVMValueRef* right);
  129. LLVMValueRef compileBinary(LLVMBuilderRef builder, const BinaryExpression& b);
  130. LLVMValueRef compileConstructor(LLVMBuilderRef builder, const Constructor& c);
  131. LLVMValueRef compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc);
  132. LLVMValueRef compileIndex(LLVMBuilderRef builder, const IndexExpression& v);
  133. LLVMValueRef compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p);
  134. LLVMValueRef compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p);
  135. LLVMValueRef compileSwizzle(LLVMBuilderRef builder, const Swizzle& s);
  136. LLVMValueRef compileVariableReference(LLVMBuilderRef builder, const VariableReference& v);
  137. LLVMValueRef compileTernary(LLVMBuilderRef builder, const TernaryExpression& t);
  138. LLVMValueRef compileExpression(LLVMBuilderRef builder, const Expression& expr);
  139. void appendStage(LLVMBuilderRef builder, const AppendStage& a);
  140. void compileBlock(LLVMBuilderRef builder, const Block& block);
  141. void compileBreak(LLVMBuilderRef builder, const BreakStatement& b);
  142. void compileContinue(LLVMBuilderRef builder, const ContinueStatement& c);
  143. void compileDo(LLVMBuilderRef builder, const DoStatement& d);
  144. void compileFor(LLVMBuilderRef builder, const ForStatement& f);
  145. void compileIf(LLVMBuilderRef builder, const IfStatement& i);
  146. void compileReturn(LLVMBuilderRef builder, const ReturnStatement& r);
  147. void compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls);
  148. void compileWhile(LLVMBuilderRef builder, const WhileStatement& w);
  149. void compileStatement(LLVMBuilderRef builder, const Statement& stmt);
  150. // The "Vector" variants of functions attempt to compile a given expression or statement as part
  151. // of a vectorized SkJumper stage function - that is, with r, g, b, and a each being vectors of
  152. // fVectorCount floats. So a statement like "color.r = 0;" looks like it modifies a single
  153. // channel of a single pixel, but the compiled code will actually modify the red channel of
  154. // fVectorCount pixels at once.
  155. //
  156. // As not everything can be vectorized, these calls return a bool to indicate whether they were
  157. // successful. If anything anywhere in the function cannot be vectorized, the JIT will fall back
  158. // to looping over the pixels instead.
  159. //
  160. // Since we process multiple pixels at once, and each pixel consists of multiple color channels,
  161. // expressions may effectively result in a vector-of-vectors. We produce zero to four outputs
  162. // when compiling expression, each of which is a vector, so that e.g. float2(1, 0) actually
  163. // produces two vectors, one containing all 1s, the other all 0s. The out parameter always
  164. // allows for 4 channels, but the functions produce 0 to 4 channels depending on the type they
  165. // are operating on. Thus evaluating "color.rgb" actually fills in out[0] through out[2],
  166. // leaving out[3] uninitialized.
  167. // As the number of outputs can be inferred from the type of the expression, it is not
  168. // explicitly signalled anywhere.
  169. bool compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
  170. LLVMValueRef out[CHANNELS]);
  171. bool compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
  172. LLVMValueRef out[CHANNELS]);
  173. bool compileVectorFloatLiteral(LLVMBuilderRef builder, const FloatLiteral& f,
  174. LLVMValueRef out[CHANNELS]);
  175. bool compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
  176. LLVMValueRef out[CHANNELS]);
  177. bool compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
  178. LLVMValueRef out[CHANNELS]);
  179. bool compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
  180. LLVMValueRef out[CHANNELS]);
  181. bool getVectorLValue(LLVMBuilderRef builder, const Expression& e, LLVMValueRef out[CHANNELS]);
  182. /**
  183. * Evaluates the left and right operands of a binary operation, promoting one of them to a
  184. * vector if necessary to make the types match.
  185. */
  186. bool getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
  187. LLVMValueRef outLeft[CHANNELS], const Expression& right,
  188. LLVMValueRef outRight[CHANNELS]);
  189. bool compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt);
  190. /**
  191. * Returns true if this function has the signature void(int, int, inout float4) and thus can be
  192. * used as an SkJumper stage.
  193. */
  194. bool hasStageSignature(const FunctionDeclaration& f);
  195. /**
  196. * Attempts to compile a vectorized stage function, returning true on success. A stage function
  197. * of e.g. "color.r = 0;" will produce code which sets the entire red vector to zeros in a
  198. * single instruction, thus calculating several pixels at once.
  199. */
  200. bool compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc);
  201. /**
  202. * Fallback function which loops over the pixels, for when vectorization fails. A stage function
  203. * of e.g. "color.r = 0;" will produce a loop which iterates over the entries in the red vector,
  204. * setting each one to zero individually.
  205. */
  206. void compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc);
  207. /**
  208. * Called when compiling a function which has the signature of an SkJumper stage. Produces a
  209. * version of the function which can be plugged into SkJumper (thus having a signature which
  210. * accepts four vectors, one for each color channel, containing the color data of multiple
  211. * pixels at once). To go from SkSL code which operates on a single pixel at a time to CPU code
  212. * which operates on multiple pixels at once, the code is either vectorized using
  213. * compileStageFunctionVector or wrapped in a loop using compileStageFunctionLoop.
  214. */
  215. LLVMValueRef compileStageFunction(const FunctionDefinition& f);
  216. /**
  217. * Compiles an SkSL function to an LLVM function. If the function has the signature of an
  218. * SkJumper stage, it will *also* be compiled by compileStageFunction, resulting in both a stage
  219. * and non-stage version of the function.
  220. */
  221. LLVMValueRef compileFunction(const FunctionDefinition& f);
  222. void createModule();
  223. void optimize();
  224. bool isColorRef(const Expression& expr);
  225. static uint64_t resolveSymbol(const char* name, JIT* jit);
  226. const char* fCPU;
  227. int fVectorCount;
  228. Compiler& fCompiler;
  229. std::unique_ptr<Program> fProgram;
  230. LLVMContextRef fContext;
  231. LLVMModuleRef fModule;
  232. LLVMSharedModuleRef fSharedModule;
  233. LLVMOrcJITStackRef fJITStack;
  234. LLVMValueRef fCurrentFunction;
  235. LLVMBasicBlockRef fAllocaBlock;
  236. LLVMBasicBlockRef fCurrentBlock;
  237. LLVMTypeRef fVoidType;
  238. LLVMTypeRef fInt1Type;
  239. LLVMTypeRef fInt1VectorType;
  240. LLVMTypeRef fInt1Vector2Type;
  241. LLVMTypeRef fInt1Vector3Type;
  242. LLVMTypeRef fInt1Vector4Type;
  243. LLVMTypeRef fInt8Type;
  244. LLVMTypeRef fInt8PtrType;
  245. LLVMTypeRef fInt32Type;
  246. LLVMTypeRef fInt32VectorType;
  247. LLVMTypeRef fInt32Vector2Type;
  248. LLVMTypeRef fInt32Vector3Type;
  249. LLVMTypeRef fInt32Vector4Type;
  250. LLVMTypeRef fInt64Type;
  251. LLVMTypeRef fSizeTType;
  252. LLVMTypeRef fFloat32Type;
  253. LLVMTypeRef fFloat32VectorType;
  254. LLVMTypeRef fFloat32Vector2Type;
  255. LLVMTypeRef fFloat32Vector3Type;
  256. LLVMTypeRef fFloat32Vector4Type;
  257. // Our SkSL stage functions have a single float4 for color, but the actual SkJumper stage
  258. // function has four separate vectors, one for each channel. These four values are references to
  259. // the red, green, blue, and alpha vectors respectively.
  260. LLVMValueRef fChannels[CHANNELS];
  261. // when processing a stage function, this points to the SkSL color parameter (an inout float4)
  262. const Variable* fColorParam;
  263. std::unordered_map<const FunctionDeclaration*, LLVMValueRef> fFunctions;
  264. std::unordered_map<const Variable*, LLVMValueRef> fVariables;
  265. // LLVM function parameters are read-only, so when modifying function parameters we need to
  266. // first promote them to variables. This keeps track of which parameters have been promoted.
  267. std::set<const Variable*> fPromotedParameters;
  268. std::vector<LLVMBasicBlockRef> fBreakTarget;
  269. std::vector<LLVMBasicBlockRef> fContinueTarget;
  270. LLVMValueRef fFoldAnd2Func;
  271. LLVMValueRef fFoldOr2Func;
  272. LLVMValueRef fFoldAnd3Func;
  273. LLVMValueRef fFoldOr3Func;
  274. LLVMValueRef fFoldAnd4Func;
  275. LLVMValueRef fFoldOr4Func;
  276. LLVMValueRef fAppendFunc;
  277. LLVMValueRef fAppendCallbackFunc;
  278. LLVMValueRef fDebugFunc;
  279. };
  280. } // namespace
  281. #endif // SK_LLVM_AVAILABLE
  282. #endif // SKSL_JIT