SkSLSPIRVCodeGenerator.cpp 146 KB


  1. /*
  2. * Copyright 2016 Google Inc.
  3. *
  4. * Use of this source code is governed by a BSD-style license that can be
  5. * found in the LICENSE file.
  6. */
  7. #include "src/sksl/SkSLSPIRVCodeGenerator.h"
  8. #include "src/sksl/GLSL.std.450.h"
  9. #include "src/sksl/SkSLCompiler.h"
  10. #include "src/sksl/ir/SkSLExpressionStatement.h"
  11. #include "src/sksl/ir/SkSLExtension.h"
  12. #include "src/sksl/ir/SkSLIndexExpression.h"
  13. #include "src/sksl/ir/SkSLVariableReference.h"
  14. namespace SkSL {
  15. static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number
  16. void SPIRVCodeGenerator::setupIntrinsics() {
  17. #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
  18. GLSLstd450 ## x, GLSLstd450 ## x)
  19. #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
  20. GLSLstd450 ## ifFloat, \
  21. GLSLstd450 ## ifInt, \
  22. GLSLstd450 ## ifUInt, \
  23. SpvOpUndef)
  24. #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
  25. SpvOp ## x)
  26. #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
  27. k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
  28. k ## x ## _SpecialIntrinsic)
  29. fIntrinsicMap[String("round")] = ALL_GLSL(Round);
  30. fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven);
  31. fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc);
  32. fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
  33. fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign);
  34. fIntrinsicMap[String("floor")] = ALL_GLSL(Floor);
  35. fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil);
  36. fIntrinsicMap[String("fract")] = ALL_GLSL(Fract);
  37. fIntrinsicMap[String("radians")] = ALL_GLSL(Radians);
  38. fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees);
  39. fIntrinsicMap[String("sin")] = ALL_GLSL(Sin);
  40. fIntrinsicMap[String("cos")] = ALL_GLSL(Cos);
  41. fIntrinsicMap[String("tan")] = ALL_GLSL(Tan);
  42. fIntrinsicMap[String("asin")] = ALL_GLSL(Asin);
  43. fIntrinsicMap[String("acos")] = ALL_GLSL(Acos);
  44. fIntrinsicMap[String("atan")] = SPECIAL(Atan);
  45. fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh);
  46. fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh);
  47. fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh);
  48. fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh);
  49. fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh);
  50. fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh);
  51. fIntrinsicMap[String("pow")] = ALL_GLSL(Pow);
  52. fIntrinsicMap[String("exp")] = ALL_GLSL(Exp);
  53. fIntrinsicMap[String("log")] = ALL_GLSL(Log);
  54. fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2);
  55. fIntrinsicMap[String("log2")] = ALL_GLSL(Log2);
  56. fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt);
  57. fIntrinsicMap[String("inverse")] = ALL_GLSL(MatrixInverse);
  58. fIntrinsicMap[String("transpose")] = ALL_SPIRV(Transpose);
  59. fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt);
  60. fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant);
  61. fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
  62. fIntrinsicMap[String("mod")] = SPECIAL(Mod);
  63. fIntrinsicMap[String("min")] = SPECIAL(Min);
  64. fIntrinsicMap[String("max")] = SPECIAL(Max);
  65. fIntrinsicMap[String("clamp")] = SPECIAL(Clamp);
  66. fIntrinsicMap[String("saturate")] = SPECIAL(Saturate);
  67. fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
  68. SpvOpUndef, SpvOpUndef, SpvOpUndef);
  69. fIntrinsicMap[String("mix")] = SPECIAL(Mix);
  70. fIntrinsicMap[String("step")] = ALL_GLSL(Step);
  71. fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep);
  72. fIntrinsicMap[String("fma")] = ALL_GLSL(Fma);
  73. fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp);
  74. fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp);
  75. #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
  76. fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
  77. PACK(Snorm4x8);
  78. PACK(Unorm4x8);
  79. PACK(Snorm2x16);
  80. PACK(Unorm2x16);
  81. PACK(Half2x16);
  82. PACK(Double2x32);
  83. fIntrinsicMap[String("length")] = ALL_GLSL(Length);
  84. fIntrinsicMap[String("distance")] = ALL_GLSL(Distance);
  85. fIntrinsicMap[String("cross")] = ALL_GLSL(Cross);
  86. fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize);
  87. fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
  88. fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect);
  89. fIntrinsicMap[String("refract")] = ALL_GLSL(Refract);
  90. fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb);
  91. fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
  92. fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
  93. SpvOpUndef, SpvOpUndef, SpvOpUndef);
  94. fIntrinsicMap[String("dFdy")] = SPECIAL(DFdy);
  95. fIntrinsicMap[String("fwidth")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFwidth,
  96. SpvOpUndef, SpvOpUndef, SpvOpUndef);
  97. fIntrinsicMap[String("texture")] = SPECIAL(Texture);
  98. fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
  99. fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
  100. SpvOpUndef, SpvOpUndef, SpvOpAny);
  101. fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
  102. SpvOpUndef, SpvOpUndef, SpvOpAll);
  103. fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind,
  104. SpvOpFOrdEqual, SpvOpIEqual,
  105. SpvOpIEqual, SpvOpLogicalEqual);
  106. fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
  107. SpvOpFOrdNotEqual, SpvOpINotEqual,
  108. SpvOpINotEqual,
  109. SpvOpLogicalNotEqual);
  110. fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
  111. SpvOpFOrdLessThan, SpvOpSLessThan,
  112. SpvOpULessThan, SpvOpUndef);
  113. fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
  114. SpvOpFOrdLessThanEqual,
  115. SpvOpSLessThanEqual,
  116. SpvOpULessThanEqual,
  117. SpvOpUndef);
  118. fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind,
  119. SpvOpFOrdGreaterThan,
  120. SpvOpSGreaterThan,
  121. SpvOpUGreaterThan,
  122. SpvOpUndef);
  123. fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
  124. SpvOpFOrdGreaterThanEqual,
  125. SpvOpSGreaterThanEqual,
  126. SpvOpUGreaterThanEqual,
  127. SpvOpUndef);
  128. fIntrinsicMap[String("EmitVertex")] = ALL_SPIRV(EmitVertex);
  129. fIntrinsicMap[String("EndPrimitive")] = ALL_SPIRV(EndPrimitive);
  130. // interpolateAt* not yet supported...
  131. }
  132. void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
  133. out.write((const char*) &word, sizeof(word));
  134. }
  135. static bool is_float(const Context& context, const Type& type) {
  136. if (type.columns() > 1) {
  137. return is_float(context, type.componentType());
  138. }
  139. return type == *context.fFloat_Type || type == *context.fHalf_Type ||
  140. type == *context.fDouble_Type;
  141. }
  142. static bool is_signed(const Context& context, const Type& type) {
  143. if (type.kind() == Type::kVector_Kind) {
  144. return is_signed(context, type.componentType());
  145. }
  146. return type == *context.fInt_Type || type == *context.fShort_Type ||
  147. type == *context.fByte_Type;
  148. }
  149. static bool is_unsigned(const Context& context, const Type& type) {
  150. if (type.kind() == Type::kVector_Kind) {
  151. return is_unsigned(context, type.componentType());
  152. }
  153. return type == *context.fUInt_Type || type == *context.fUShort_Type ||
  154. type == *context.fUByte_Type;
  155. }
  156. static bool is_bool(const Context& context, const Type& type) {
  157. if (type.kind() == Type::kVector_Kind) {
  158. return is_bool(context, type.componentType());
  159. }
  160. return type == *context.fBool_Type;
  161. }
  162. static bool is_out(const Variable& var) {
  163. return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
  164. }
  165. void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
  166. SkASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
  167. SkASSERT(opCode != SpvOpUndef);
  168. switch (opCode) {
  169. case SpvOpReturn: // fall through
  170. case SpvOpReturnValue: // fall through
  171. case SpvOpKill: // fall through
  172. case SpvOpBranch: // fall through
  173. case SpvOpBranchConditional:
  174. SkASSERT(fCurrentBlock);
  175. fCurrentBlock = 0;
  176. break;
  177. case SpvOpConstant: // fall through
  178. case SpvOpConstantTrue: // fall through
  179. case SpvOpConstantFalse: // fall through
  180. case SpvOpConstantComposite: // fall through
  181. case SpvOpTypeVoid: // fall through
  182. case SpvOpTypeInt: // fall through
  183. case SpvOpTypeFloat: // fall through
  184. case SpvOpTypeBool: // fall through
  185. case SpvOpTypeVector: // fall through
  186. case SpvOpTypeMatrix: // fall through
  187. case SpvOpTypeArray: // fall through
  188. case SpvOpTypePointer: // fall through
  189. case SpvOpTypeFunction: // fall through
  190. case SpvOpTypeRuntimeArray: // fall through
  191. case SpvOpTypeStruct: // fall through
  192. case SpvOpTypeImage: // fall through
  193. case SpvOpTypeSampledImage: // fall through
  194. case SpvOpVariable: // fall through
  195. case SpvOpFunction: // fall through
  196. case SpvOpFunctionParameter: // fall through
  197. case SpvOpFunctionEnd: // fall through
  198. case SpvOpExecutionMode: // fall through
  199. case SpvOpMemoryModel: // fall through
  200. case SpvOpCapability: // fall through
  201. case SpvOpExtInstImport: // fall through
  202. case SpvOpEntryPoint: // fall through
  203. case SpvOpSource: // fall through
  204. case SpvOpSourceExtension: // fall through
  205. case SpvOpName: // fall through
  206. case SpvOpMemberName: // fall through
  207. case SpvOpDecorate: // fall through
  208. case SpvOpMemberDecorate:
  209. break;
  210. default:
  211. SkASSERT(fCurrentBlock);
  212. }
  213. this->writeWord((length << 16) | opCode, out);
  214. }
  215. void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
  216. fCurrentBlock = label;
  217. this->writeInstruction(SpvOpLabel, label, out);
  218. }
  219. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
  220. this->writeOpCode(opCode, 1, out);
  221. }
  222. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
  223. this->writeOpCode(opCode, 2, out);
  224. this->writeWord(word1, out);
  225. }
  226. void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
  227. out.write(string, length);
  228. switch (length % 4) {
  229. case 1:
  230. out.write8(0);
  231. // fall through
  232. case 2:
  233. out.write8(0);
  234. // fall through
  235. case 3:
  236. out.write8(0);
  237. break;
  238. default:
  239. this->writeWord(0, out);
  240. }
  241. }
  242. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
  243. this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
  244. this->writeString(string.fChars, string.fLength, out);
  245. }
  246. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
  247. OutputStream& out) {
  248. this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
  249. this->writeWord(word1, out);
  250. this->writeString(string.fChars, string.fLength, out);
  251. }
  252. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  253. StringFragment string, OutputStream& out) {
  254. this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
  255. this->writeWord(word1, out);
  256. this->writeWord(word2, out);
  257. this->writeString(string.fChars, string.fLength, out);
  258. }
  259. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  260. OutputStream& out) {
  261. this->writeOpCode(opCode, 3, out);
  262. this->writeWord(word1, out);
  263. this->writeWord(word2, out);
  264. }
  265. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  266. int32_t word3, OutputStream& out) {
  267. this->writeOpCode(opCode, 4, out);
  268. this->writeWord(word1, out);
  269. this->writeWord(word2, out);
  270. this->writeWord(word3, out);
  271. }
  272. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  273. int32_t word3, int32_t word4, OutputStream& out) {
  274. this->writeOpCode(opCode, 5, out);
  275. this->writeWord(word1, out);
  276. this->writeWord(word2, out);
  277. this->writeWord(word3, out);
  278. this->writeWord(word4, out);
  279. }
  280. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  281. int32_t word3, int32_t word4, int32_t word5,
  282. OutputStream& out) {
  283. this->writeOpCode(opCode, 6, out);
  284. this->writeWord(word1, out);
  285. this->writeWord(word2, out);
  286. this->writeWord(word3, out);
  287. this->writeWord(word4, out);
  288. this->writeWord(word5, out);
  289. }
  290. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  291. int32_t word3, int32_t word4, int32_t word5,
  292. int32_t word6, OutputStream& out) {
  293. this->writeOpCode(opCode, 7, out);
  294. this->writeWord(word1, out);
  295. this->writeWord(word2, out);
  296. this->writeWord(word3, out);
  297. this->writeWord(word4, out);
  298. this->writeWord(word5, out);
  299. this->writeWord(word6, out);
  300. }
  301. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  302. int32_t word3, int32_t word4, int32_t word5,
  303. int32_t word6, int32_t word7, OutputStream& out) {
  304. this->writeOpCode(opCode, 8, out);
  305. this->writeWord(word1, out);
  306. this->writeWord(word2, out);
  307. this->writeWord(word3, out);
  308. this->writeWord(word4, out);
  309. this->writeWord(word5, out);
  310. this->writeWord(word6, out);
  311. this->writeWord(word7, out);
  312. }
  313. void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
  314. int32_t word3, int32_t word4, int32_t word5,
  315. int32_t word6, int32_t word7, int32_t word8,
  316. OutputStream& out) {
  317. this->writeOpCode(opCode, 9, out);
  318. this->writeWord(word1, out);
  319. this->writeWord(word2, out);
  320. this->writeWord(word3, out);
  321. this->writeWord(word4, out);
  322. this->writeWord(word5, out);
  323. this->writeWord(word6, out);
  324. this->writeWord(word7, out);
  325. this->writeWord(word8, out);
  326. }
  327. void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
  328. for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
  329. if (fCapabilities & bit) {
  330. this->writeInstruction(SpvOpCapability, (SpvId) i, out);
  331. }
  332. }
  333. if (fProgram.fKind == Program::kGeometry_Kind) {
  334. this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
  335. }
  336. else {
  337. this->writeInstruction(SpvOpCapability, SpvCapabilityShader, out);
  338. }
  339. }
  340. SpvId SPIRVCodeGenerator::nextId() {
  341. return fIdCount++;
  342. }
  343. void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
  344. SpvId resultId) {
  345. this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
  346. // go ahead and write all of the field types, so we don't inadvertently write them while we're
  347. // in the middle of writing the struct instruction
  348. std::vector<SpvId> types;
  349. for (const auto& f : type.fields()) {
  350. types.push_back(this->getType(*f.fType, memoryLayout));
  351. }
  352. this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
  353. this->writeWord(resultId, fConstantBuffer);
  354. for (SpvId id : types) {
  355. this->writeWord(id, fConstantBuffer);
  356. }
  357. size_t offset = 0;
  358. for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
  359. const Type::Field& field = type.fields()[i];
  360. size_t size = memoryLayout.size(*field.fType);
  361. size_t alignment = memoryLayout.alignment(*field.fType);
  362. const Layout& fieldLayout = field.fModifiers.fLayout;
  363. if (fieldLayout.fOffset >= 0) {
  364. if (fieldLayout.fOffset < (int) offset) {
  365. fErrors.error(type.fOffset,
  366. "offset of field '" + field.fName + "' must be at "
  367. "least " + to_string((int) offset));
  368. }
  369. if (fieldLayout.fOffset % alignment) {
  370. fErrors.error(type.fOffset,
  371. "offset of field '" + field.fName + "' must be a multiple"
  372. " of " + to_string((int) alignment));
  373. }
  374. offset = fieldLayout.fOffset;
  375. } else {
  376. size_t mod = offset % alignment;
  377. if (mod) {
  378. offset += alignment - mod;
  379. }
  380. }
  381. this->writeInstruction(SpvOpMemberName, resultId, i, field.fName, fNameBuffer);
  382. this->writeLayout(fieldLayout, resultId, i);
  383. if (field.fModifiers.fLayout.fBuiltin < 0) {
  384. this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
  385. (SpvId) offset, fDecorationBuffer);
  386. }
  387. if (field.fType->kind() == Type::kMatrix_Kind) {
  388. this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
  389. fDecorationBuffer);
  390. this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
  391. (SpvId) memoryLayout.stride(*field.fType),
  392. fDecorationBuffer);
  393. }
  394. if (!field.fType->highPrecision()) {
  395. this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i,
  396. SpvDecorationRelaxedPrecision, fDecorationBuffer);
  397. }
  398. offset += size;
  399. Type::Kind kind = field.fType->kind();
  400. if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
  401. offset += alignment - offset % alignment;
  402. }
  403. }
  404. }
  405. Type SPIRVCodeGenerator::getActualType(const Type& type) {
  406. if (type.isFloat()) {
  407. return *fContext.fFloat_Type;
  408. }
  409. if (type.isSigned()) {
  410. return *fContext.fInt_Type;
  411. }
  412. if (type.isUnsigned()) {
  413. return *fContext.fUInt_Type;
  414. }
  415. if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
  416. if (type.componentType() == *fContext.fHalf_Type) {
  417. return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
  418. }
  419. if (type.componentType() == *fContext.fShort_Type ||
  420. type.componentType() == *fContext.fByte_Type) {
  421. return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
  422. }
  423. if (type.componentType() == *fContext.fUShort_Type ||
  424. type.componentType() == *fContext.fUByte_Type) {
  425. return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
  426. }
  427. }
  428. return type;
  429. }
  430. SpvId SPIRVCodeGenerator::getType(const Type& type) {
  431. return this->getType(type, fDefaultLayout);
  432. }
  433. SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
  434. Type type = this->getActualType(rawType);
  435. String key = type.name() + to_string((int) layout.fStd);
  436. auto entry = fTypeMap.find(key);
  437. if (entry == fTypeMap.end()) {
  438. SpvId result = this->nextId();
  439. switch (type.kind()) {
  440. case Type::kScalar_Kind:
  441. if (type == *fContext.fBool_Type) {
  442. this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
  443. } else if (type == *fContext.fInt_Type || type == *fContext.fShort_Type ||
  444. type == *fContext.fIntLiteral_Type) {
  445. this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
  446. } else if (type == *fContext.fUInt_Type || type == *fContext.fUShort_Type) {
  447. this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
  448. } else if (type == *fContext.fFloat_Type || type == *fContext.fHalf_Type ||
  449. type == *fContext.fFloatLiteral_Type) {
  450. this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
  451. } else if (type == *fContext.fDouble_Type) {
  452. this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
  453. } else {
  454. SkASSERT(false);
  455. }
  456. break;
  457. case Type::kVector_Kind:
  458. this->writeInstruction(SpvOpTypeVector, result,
  459. this->getType(type.componentType(), layout),
  460. type.columns(), fConstantBuffer);
  461. break;
  462. case Type::kMatrix_Kind:
  463. this->writeInstruction(SpvOpTypeMatrix, result,
  464. this->getType(index_type(fContext, type), layout),
  465. type.columns(), fConstantBuffer);
  466. break;
  467. case Type::kStruct_Kind:
  468. this->writeStruct(type, layout, result);
  469. break;
  470. case Type::kArray_Kind: {
  471. if (type.columns() > 0) {
  472. IntLiteral count(fContext, -1, type.columns());
  473. this->writeInstruction(SpvOpTypeArray, result,
  474. this->getType(type.componentType(), layout),
  475. this->writeIntLiteral(count), fConstantBuffer);
  476. this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
  477. (int32_t) layout.stride(type),
  478. fDecorationBuffer);
  479. } else {
  480. SkASSERT(false); // we shouldn't have any runtime-sized arrays right now
  481. this->writeInstruction(SpvOpTypeRuntimeArray, result,
  482. this->getType(type.componentType(), layout),
  483. fConstantBuffer);
  484. this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
  485. (int32_t) layout.stride(type),
  486. fDecorationBuffer);
  487. }
  488. break;
  489. }
  490. case Type::kSampler_Kind: {
  491. SpvId image = result;
  492. if (SpvDimSubpassData != type.dimensions()) {
  493. image = this->nextId();
  494. }
  495. if (SpvDimBuffer == type.dimensions()) {
  496. fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
  497. }
  498. this->writeInstruction(SpvOpTypeImage, image,
  499. this->getType(*fContext.fFloat_Type, layout),
  500. type.dimensions(), type.isDepth(), type.isArrayed(),
  501. type.isMultisampled(), type.isSampled() ? 1 : 2,
  502. SpvImageFormatUnknown, fConstantBuffer);
  503. fImageTypeMap[key] = image;
  504. if (SpvDimSubpassData != type.dimensions()) {
  505. this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
  506. }
  507. break;
  508. }
  509. default:
  510. if (type == *fContext.fVoid_Type) {
  511. this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
  512. } else {
  513. ABORT("invalid type: %s", type.description().c_str());
  514. }
  515. }
  516. fTypeMap[key] = result;
  517. return result;
  518. }
  519. return entry->second;
  520. }
  521. SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
  522. SkASSERT(type.kind() == Type::kSampler_Kind);
  523. this->getType(type);
  524. String key = type.name() + to_string((int) fDefaultLayout.fStd);
  525. SkASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
  526. return fImageTypeMap[key];
  527. }
  528. SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
  529. String key = function.fReturnType.description() + "(";
  530. String separator;
  531. for (size_t i = 0; i < function.fParameters.size(); i++) {
  532. key += separator;
  533. separator = ", ";
  534. key += function.fParameters[i]->fType.description();
  535. }
  536. key += ")";
  537. auto entry = fTypeMap.find(key);
  538. if (entry == fTypeMap.end()) {
  539. SpvId result = this->nextId();
  540. int32_t length = 3 + (int32_t) function.fParameters.size();
  541. SpvId returnType = this->getType(function.fReturnType);
  542. std::vector<SpvId> parameterTypes;
  543. for (size_t i = 0; i < function.fParameters.size(); i++) {
  544. // glslang seems to treat all function arguments as pointers whether they need to be or
  545. // not. I was initially puzzled by this until I ran bizarre failures with certain
  546. // patterns of function calls and control constructs, as exemplified by this minimal
  547. // failure case:
  548. //
  549. // void sphere(float x) {
  550. // }
  551. //
  552. // void map() {
  553. // sphere(1.0);
  554. // }
  555. //
  556. // void main() {
  557. // for (int i = 0; i < 1; i++) {
  558. // map();
  559. // }
  560. // }
  561. //
  562. // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
  563. // crashes. Making it take a float* and storing the argument in a temporary variable,
  564. // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
  565. // the spec makes this make sense.
  566. // if (is_out(function->fParameters[i])) {
  567. parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
  568. SpvStorageClassFunction));
  569. // } else {
  570. // parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
  571. // }
  572. }
  573. this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
  574. this->writeWord(result, fConstantBuffer);
  575. this->writeWord(returnType, fConstantBuffer);
  576. for (SpvId id : parameterTypes) {
  577. this->writeWord(id, fConstantBuffer);
  578. }
  579. fTypeMap[key] = result;
  580. return result;
  581. }
  582. return entry->second;
  583. }
  584. SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
  585. return this->getPointerType(type, fDefaultLayout, storageClass);
  586. }
  587. SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
  588. SpvStorageClass_ storageClass) {
  589. Type type = this->getActualType(rawType);
  590. String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
  591. auto entry = fTypeMap.find(key);
  592. if (entry == fTypeMap.end()) {
  593. SpvId result = this->nextId();
  594. this->writeInstruction(SpvOpTypePointer, result, storageClass,
  595. this->getType(type), fConstantBuffer);
  596. fTypeMap[key] = result;
  597. return result;
  598. }
  599. return entry->second;
  600. }
  601. SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
  602. switch (expr.fKind) {
  603. case Expression::kBinary_Kind:
  604. return this->writeBinaryExpression((BinaryExpression&) expr, out);
  605. case Expression::kBoolLiteral_Kind:
  606. return this->writeBoolLiteral((BoolLiteral&) expr);
  607. case Expression::kConstructor_Kind:
  608. return this->writeConstructor((Constructor&) expr, out);
  609. case Expression::kIntLiteral_Kind:
  610. return this->writeIntLiteral((IntLiteral&) expr);
  611. case Expression::kFieldAccess_Kind:
  612. return this->writeFieldAccess(((FieldAccess&) expr), out);
  613. case Expression::kFloatLiteral_Kind:
  614. return this->writeFloatLiteral(((FloatLiteral&) expr));
  615. case Expression::kFunctionCall_Kind:
  616. return this->writeFunctionCall((FunctionCall&) expr, out);
  617. case Expression::kPrefix_Kind:
  618. return this->writePrefixExpression((PrefixExpression&) expr, out);
  619. case Expression::kPostfix_Kind:
  620. return this->writePostfixExpression((PostfixExpression&) expr, out);
  621. case Expression::kSwizzle_Kind:
  622. return this->writeSwizzle((Swizzle&) expr, out);
  623. case Expression::kVariableReference_Kind:
  624. return this->writeVariableReference((VariableReference&) expr, out);
  625. case Expression::kTernary_Kind:
  626. return this->writeTernaryExpression((TernaryExpression&) expr, out);
  627. case Expression::kIndex_Kind:
  628. return this->writeIndexExpression((IndexExpression&) expr, out);
  629. default:
  630. ABORT("unsupported expression: %s", expr.description().c_str());
  631. }
  632. return -1;
  633. }
  634. SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
  635. auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
  636. SkASSERT(intrinsic != fIntrinsicMap.end());
  637. int32_t intrinsicId;
  638. if (c.fArguments.size() > 0) {
  639. const Type& type = c.fArguments[0]->fType;
  640. if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
  641. intrinsicId = std::get<1>(intrinsic->second);
  642. } else if (is_signed(fContext, type)) {
  643. intrinsicId = std::get<2>(intrinsic->second);
  644. } else if (is_unsigned(fContext, type)) {
  645. intrinsicId = std::get<3>(intrinsic->second);
  646. } else if (is_bool(fContext, type)) {
  647. intrinsicId = std::get<4>(intrinsic->second);
  648. } else {
  649. intrinsicId = std::get<1>(intrinsic->second);
  650. }
  651. } else {
  652. intrinsicId = std::get<1>(intrinsic->second);
  653. }
  654. switch (std::get<0>(intrinsic->second)) {
  655. case kGLSL_STD_450_IntrinsicKind: {
  656. SpvId result = this->nextId();
  657. std::vector<SpvId> arguments;
  658. for (size_t i = 0; i < c.fArguments.size(); i++) {
  659. if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
  660. arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
  661. } else {
  662. arguments.push_back(this->writeExpression(*c.fArguments[i], out));
  663. }
  664. }
  665. this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
  666. this->writeWord(this->getType(c.fType), out);
  667. this->writeWord(result, out);
  668. this->writeWord(fGLSLExtendedInstructions, out);
  669. this->writeWord(intrinsicId, out);
  670. for (SpvId id : arguments) {
  671. this->writeWord(id, out);
  672. }
  673. return result;
  674. }
  675. case kSPIRV_IntrinsicKind: {
  676. SpvId result = this->nextId();
  677. std::vector<SpvId> arguments;
  678. for (size_t i = 0; i < c.fArguments.size(); i++) {
  679. if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
  680. arguments.push_back(this->getLValue(*c.fArguments[i], out)->getPointer());
  681. } else {
  682. arguments.push_back(this->writeExpression(*c.fArguments[i], out));
  683. }
  684. }
  685. if (c.fType != *fContext.fVoid_Type) {
  686. this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
  687. this->writeWord(this->getType(c.fType), out);
  688. this->writeWord(result, out);
  689. } else {
  690. this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
  691. }
  692. for (SpvId id : arguments) {
  693. this->writeWord(id, out);
  694. }
  695. return result;
  696. }
  697. case kSpecial_IntrinsicKind:
  698. return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
  699. default:
  700. ABORT("unsupported intrinsic kind");
  701. }
  702. }
  703. std::vector<SpvId> SPIRVCodeGenerator::vectorize(
  704. const std::vector<std::unique_ptr<Expression>>& args,
  705. OutputStream& out) {
  706. int vectorSize = 0;
  707. for (const auto& a : args) {
  708. if (a->fType.kind() == Type::kVector_Kind) {
  709. if (vectorSize) {
  710. SkASSERT(a->fType.columns() == vectorSize);
  711. }
  712. else {
  713. vectorSize = a->fType.columns();
  714. }
  715. }
  716. }
  717. std::vector<SpvId> result;
  718. for (const auto& a : args) {
  719. SpvId raw = this->writeExpression(*a, out);
  720. if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
  721. SpvId vector = this->nextId();
  722. this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
  723. this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
  724. this->writeWord(vector, out);
  725. for (int i = 0; i < vectorSize; i++) {
  726. this->writeWord(raw, out);
  727. }
  728. this->writePrecisionModifier(a->fType, vector);
  729. result.push_back(vector);
  730. } else {
  731. result.push_back(raw);
  732. }
  733. }
  734. return result;
  735. }
  736. void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
  737. SpvId signedInst, SpvId unsignedInst,
  738. const std::vector<SpvId>& args,
  739. OutputStream& out) {
  740. this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
  741. this->writeWord(this->getType(type), out);
  742. this->writeWord(id, out);
  743. this->writeWord(fGLSLExtendedInstructions, out);
  744. if (is_float(fContext, type)) {
  745. this->writeWord(floatInst, out);
  746. } else if (is_signed(fContext, type)) {
  747. this->writeWord(signedInst, out);
  748. } else if (is_unsigned(fContext, type)) {
  749. this->writeWord(unsignedInst, out);
  750. } else {
  751. SkASSERT(false);
  752. }
  753. for (SpvId a : args) {
  754. this->writeWord(a, out);
  755. }
  756. }
  757. SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
  758. OutputStream& out) {
  759. SpvId result = this->nextId();
  760. switch (kind) {
  761. case kAtan_SpecialIntrinsic: {
  762. std::vector<SpvId> arguments;
  763. for (size_t i = 0; i < c.fArguments.size(); i++) {
  764. arguments.push_back(this->writeExpression(*c.fArguments[i], out));
  765. }
  766. this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
  767. this->writeWord(this->getType(c.fType), out);
  768. this->writeWord(result, out);
  769. this->writeWord(fGLSLExtendedInstructions, out);
  770. this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
  771. for (SpvId id : arguments) {
  772. this->writeWord(id, out);
  773. }
  774. break;
  775. }
  776. case kSubpassLoad_SpecialIntrinsic: {
  777. SpvId img = this->writeExpression(*c.fArguments[0], out);
  778. std::vector<std::unique_ptr<Expression>> args;
  779. args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
  780. args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
  781. Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
  782. SpvId coords = this->writeConstantVector(ctor);
  783. if (1 == c.fArguments.size()) {
  784. this->writeInstruction(SpvOpImageRead,
  785. this->getType(c.fType),
  786. result,
  787. img,
  788. coords,
  789. out);
  790. } else {
  791. SkASSERT(2 == c.fArguments.size());
  792. SpvId sample = this->writeExpression(*c.fArguments[1], out);
  793. this->writeInstruction(SpvOpImageRead,
  794. this->getType(c.fType),
  795. result,
  796. img,
  797. coords,
  798. SpvImageOperandsSampleMask,
  799. sample,
  800. out);
  801. }
  802. break;
  803. }
  804. case kTexture_SpecialIntrinsic: {
  805. SpvOp_ op = SpvOpImageSampleImplicitLod;
  806. switch (c.fArguments[0]->fType.dimensions()) {
  807. case SpvDim1D:
  808. if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
  809. op = SpvOpImageSampleProjImplicitLod;
  810. } else {
  811. SkASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
  812. }
  813. break;
  814. case SpvDim2D:
  815. if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
  816. op = SpvOpImageSampleProjImplicitLod;
  817. } else {
  818. SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
  819. }
  820. break;
  821. case SpvDim3D:
  822. if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
  823. op = SpvOpImageSampleProjImplicitLod;
  824. } else {
  825. SkASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
  826. }
  827. break;
  828. case SpvDimCube: // fall through
  829. case SpvDimRect: // fall through
  830. case SpvDimBuffer: // fall through
  831. case SpvDimSubpassData:
  832. break;
  833. }
  834. SpvId type = this->getType(c.fType);
  835. SpvId sampler = this->writeExpression(*c.fArguments[0], out);
  836. SpvId uv = this->writeExpression(*c.fArguments[1], out);
  837. if (c.fArguments.size() == 3) {
  838. this->writeInstruction(op, type, result, sampler, uv,
  839. SpvImageOperandsBiasMask,
  840. this->writeExpression(*c.fArguments[2], out),
  841. out);
  842. } else {
  843. SkASSERT(c.fArguments.size() == 2);
  844. if (fProgram.fSettings.fSharpenTextures) {
  845. FloatLiteral lodBias(fContext, -1, -0.5);
  846. this->writeInstruction(op, type, result, sampler, uv,
  847. SpvImageOperandsBiasMask,
  848. this->writeFloatLiteral(lodBias),
  849. out);
  850. } else {
  851. this->writeInstruction(op, type, result, sampler, uv,
  852. out);
  853. }
  854. }
  855. break;
  856. }
  857. case kMod_SpecialIntrinsic: {
  858. std::vector<SpvId> args = this->vectorize(c.fArguments, out);
  859. SkASSERT(args.size() == 2);
  860. const Type& operandType = c.fArguments[0]->fType;
  861. SpvOp_ op;
  862. if (is_float(fContext, operandType)) {
  863. op = SpvOpFMod;
  864. } else if (is_signed(fContext, operandType)) {
  865. op = SpvOpSMod;
  866. } else if (is_unsigned(fContext, operandType)) {
  867. op = SpvOpUMod;
  868. } else {
  869. SkASSERT(false);
  870. return 0;
  871. }
  872. this->writeOpCode(op, 5, out);
  873. this->writeWord(this->getType(operandType), out);
  874. this->writeWord(result, out);
  875. this->writeWord(args[0], out);
  876. this->writeWord(args[1], out);
  877. break;
  878. }
  879. case kDFdy_SpecialIntrinsic: {
  880. SpvId fn = this->writeExpression(*c.fArguments[0], out);
  881. this->writeOpCode(SpvOpDPdy, 4, out);
  882. this->writeWord(this->getType(c.fType), out);
  883. this->writeWord(result, out);
  884. this->writeWord(fn, out);
  885. if (fProgram.fSettings.fFlipY) {
  886. // Flipping Y also negates the Y derivatives.
  887. SpvId flipped = this->nextId();
  888. this->writeInstruction(SpvOpFNegate, this->getType(c.fType), flipped, result, out);
  889. this->writePrecisionModifier(c.fType, flipped);
  890. return flipped;
  891. }
  892. break;
  893. }
  894. case kClamp_SpecialIntrinsic: {
  895. std::vector<SpvId> args = this->vectorize(c.fArguments, out);
  896. SkASSERT(args.size() == 3);
  897. this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
  898. GLSLstd450UClamp, args, out);
  899. break;
  900. }
  901. case kMax_SpecialIntrinsic: {
  902. std::vector<SpvId> args = this->vectorize(c.fArguments, out);
  903. SkASSERT(args.size() == 2);
  904. this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
  905. GLSLstd450UMax, args, out);
  906. break;
  907. }
  908. case kMin_SpecialIntrinsic: {
  909. std::vector<SpvId> args = this->vectorize(c.fArguments, out);
  910. SkASSERT(args.size() == 2);
  911. this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
  912. GLSLstd450UMin, args, out);
  913. break;
  914. }
  915. case kMix_SpecialIntrinsic: {
  916. std::vector<SpvId> args = this->vectorize(c.fArguments, out);
  917. SkASSERT(args.size() == 3);
  918. this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
  919. SpvOpUndef, args, out);
  920. break;
  921. }
  922. case kSaturate_SpecialIntrinsic: {
  923. SkASSERT(c.fArguments.size() == 1);
  924. std::vector<std::unique_ptr<Expression>> finalArgs;
  925. finalArgs.push_back(c.fArguments[0]->clone());
  926. finalArgs.emplace_back(new FloatLiteral(fContext, -1, 0));
  927. finalArgs.emplace_back(new FloatLiteral(fContext, -1, 1));
  928. std::vector<SpvId> spvArgs = this->vectorize(finalArgs, out);
  929. this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
  930. GLSLstd450UClamp, spvArgs, out);
  931. break;
  932. }
  933. }
  934. return result;
  935. }
  936. SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
  937. const auto& entry = fFunctionMap.find(&c.fFunction);
  938. if (entry == fFunctionMap.end()) {
  939. return this->writeIntrinsicCall(c, out);
  940. }
  941. // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
  942. std::vector<std::tuple<SpvId, const Type*, std::unique_ptr<LValue>>> lvalues;
  943. std::vector<SpvId> arguments;
  944. for (size_t i = 0; i < c.fArguments.size(); i++) {
  945. // id of temporary variable that we will use to hold this argument, or 0 if it is being
  946. // passed directly
  947. SpvId tmpVar;
  948. // if we need a temporary var to store this argument, this is the value to store in the var
  949. SpvId tmpValueId;
  950. if (is_out(*c.fFunction.fParameters[i])) {
  951. std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
  952. SpvId ptr = lv->getPointer();
  953. if (ptr) {
  954. arguments.push_back(ptr);
  955. continue;
  956. } else {
  957. // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
  958. // copy it into a temp, call the function, read the value out of the temp, and then
  959. // update the lvalue.
  960. tmpValueId = lv->load(out);
  961. tmpVar = this->nextId();
  962. lvalues.push_back(std::make_tuple(tmpVar, &c.fArguments[i]->fType, std::move(lv)));
  963. }
  964. } else {
  965. // see getFunctionType for an explanation of why we're always using pointer parameters
  966. tmpValueId = this->writeExpression(*c.fArguments[i], out);
  967. tmpVar = this->nextId();
  968. }
  969. this->writeInstruction(SpvOpVariable,
  970. this->getPointerType(c.fArguments[i]->fType,
  971. SpvStorageClassFunction),
  972. tmpVar,
  973. SpvStorageClassFunction,
  974. fVariableBuffer);
  975. this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
  976. arguments.push_back(tmpVar);
  977. }
  978. SpvId result = this->nextId();
  979. this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
  980. this->writeWord(this->getType(c.fType), out);
  981. this->writeWord(result, out);
  982. this->writeWord(entry->second, out);
  983. for (SpvId id : arguments) {
  984. this->writeWord(id, out);
  985. }
  986. // now that the call is complete, we may need to update some lvalues with the new values of out
  987. // arguments
  988. for (const auto& tuple : lvalues) {
  989. SpvId load = this->nextId();
  990. this->writeInstruction(SpvOpLoad, getType(*std::get<1>(tuple)), load, std::get<0>(tuple),
  991. out);
  992. this->writePrecisionModifier(*std::get<1>(tuple), load);
  993. std::get<2>(tuple)->store(load, out);
  994. }
  995. return result;
  996. }
  997. SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
  998. SkASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
  999. SpvId result = this->nextId();
  1000. std::vector<SpvId> arguments;
  1001. for (size_t i = 0; i < c.fArguments.size(); i++) {
  1002. arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
  1003. }
  1004. SpvId type = this->getType(c.fType);
  1005. if (c.fArguments.size() == 1) {
  1006. // with a single argument, a vector will have all of its entries equal to the argument
  1007. this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
  1008. this->writeWord(type, fConstantBuffer);
  1009. this->writeWord(result, fConstantBuffer);
  1010. for (int i = 0; i < c.fType.columns(); i++) {
  1011. this->writeWord(arguments[0], fConstantBuffer);
  1012. }
  1013. } else {
  1014. this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
  1015. fConstantBuffer);
  1016. this->writeWord(type, fConstantBuffer);
  1017. this->writeWord(result, fConstantBuffer);
  1018. for (SpvId id : arguments) {
  1019. this->writeWord(id, fConstantBuffer);
  1020. }
  1021. }
  1022. return result;
  1023. }
  1024. SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
  1025. SkASSERT(c.fType.isFloat());
  1026. SkASSERT(c.fArguments.size() == 1);
  1027. SkASSERT(c.fArguments[0]->fType.isNumber());
  1028. SpvId result = this->nextId();
  1029. SpvId parameter = this->writeExpression(*c.fArguments[0], out);
  1030. if (c.fArguments[0]->fType.isSigned()) {
  1031. this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
  1032. out);
  1033. } else {
  1034. SkASSERT(c.fArguments[0]->fType.isUnsigned());
  1035. this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
  1036. out);
  1037. }
  1038. return result;
  1039. }
  1040. SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
  1041. SkASSERT(c.fType.isSigned());
  1042. SkASSERT(c.fArguments.size() == 1);
  1043. SkASSERT(c.fArguments[0]->fType.isNumber());
  1044. SpvId result = this->nextId();
  1045. SpvId parameter = this->writeExpression(*c.fArguments[0], out);
  1046. if (c.fArguments[0]->fType.isFloat()) {
  1047. this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
  1048. out);
  1049. }
  1050. else {
  1051. SkASSERT(c.fArguments[0]->fType.isUnsigned());
  1052. this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
  1053. out);
  1054. }
  1055. return result;
  1056. }
  1057. SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
  1058. SkASSERT(c.fType.isUnsigned());
  1059. SkASSERT(c.fArguments.size() == 1);
  1060. SkASSERT(c.fArguments[0]->fType.isNumber());
  1061. SpvId result = this->nextId();
  1062. SpvId parameter = this->writeExpression(*c.fArguments[0], out);
  1063. if (c.fArguments[0]->fType.isFloat()) {
  1064. this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
  1065. out);
  1066. } else {
  1067. SkASSERT(c.fArguments[0]->fType.isSigned());
  1068. this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
  1069. out);
  1070. }
  1071. return result;
  1072. }
  1073. void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
  1074. OutputStream& out) {
  1075. FloatLiteral zero(fContext, -1, 0);
  1076. SpvId zeroId = this->writeFloatLiteral(zero);
  1077. std::vector<SpvId> columnIds;
  1078. for (int column = 0; column < type.columns(); column++) {
  1079. this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
  1080. out);
  1081. this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
  1082. out);
  1083. SpvId columnId = this->nextId();
  1084. this->writeWord(columnId, out);
  1085. columnIds.push_back(columnId);
  1086. for (int row = 0; row < type.columns(); row++) {
  1087. this->writeWord(row == column ? diagonal : zeroId, out);
  1088. }
  1089. this->writePrecisionModifier(type, columnId);
  1090. }
  1091. this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
  1092. out);
  1093. this->writeWord(this->getType(type), out);
  1094. this->writeWord(id, out);
  1095. for (SpvId id : columnIds) {
  1096. this->writeWord(id, out);
  1097. }
  1098. this->writePrecisionModifier(type, id);
  1099. }
  1100. void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
  1101. const Type& dstType, OutputStream& out) {
  1102. SkASSERT(srcType.kind() == Type::kMatrix_Kind);
  1103. SkASSERT(dstType.kind() == Type::kMatrix_Kind);
  1104. SkASSERT(srcType.componentType() == dstType.componentType());
  1105. SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
  1106. srcType.rows(),
  1107. 1));
  1108. SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
  1109. dstType.rows(),
  1110. 1));
  1111. SpvId zeroId;
  1112. if (dstType.componentType() == *fContext.fFloat_Type) {
  1113. FloatLiteral zero(fContext, -1, 0.0);
  1114. zeroId = this->writeFloatLiteral(zero);
  1115. } else if (dstType.componentType() == *fContext.fInt_Type) {
  1116. IntLiteral zero(fContext, -1, 0);
  1117. zeroId = this->writeIntLiteral(zero);
  1118. } else {
  1119. ABORT("unsupported matrix component type");
  1120. }
  1121. SpvId zeroColumn = 0;
  1122. SpvId columns[4];
  1123. for (int i = 0; i < dstType.columns(); i++) {
  1124. if (i < srcType.columns()) {
  1125. // we're still inside the src matrix, copy the column
  1126. SpvId srcColumn = this->nextId();
  1127. this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
  1128. this->writePrecisionModifier(dstType, srcColumn);
  1129. SpvId dstColumn;
  1130. if (srcType.rows() == dstType.rows()) {
  1131. // columns are equal size, don't need to do anything
  1132. dstColumn = srcColumn;
  1133. }
  1134. else if (dstType.rows() > srcType.rows()) {
  1135. // dst column is bigger, need to zero-pad it
  1136. dstColumn = this->nextId();
  1137. int delta = dstType.rows() - srcType.rows();
  1138. this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
  1139. this->writeWord(dstColumnType, out);
  1140. this->writeWord(dstColumn, out);
  1141. this->writeWord(srcColumn, out);
  1142. for (int i = 0; i < delta; ++i) {
  1143. this->writeWord(zeroId, out);
  1144. }
  1145. this->writePrecisionModifier(dstType, dstColumn);
  1146. }
  1147. else {
  1148. // dst column is smaller, need to swizzle the src column
  1149. dstColumn = this->nextId();
  1150. int count = dstType.rows();
  1151. this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
  1152. this->writeWord(dstColumnType, out);
  1153. this->writeWord(dstColumn, out);
  1154. this->writeWord(srcColumn, out);
  1155. this->writeWord(srcColumn, out);
  1156. for (int i = 0; i < count; i++) {
  1157. this->writeWord(i, out);
  1158. }
  1159. this->writePrecisionModifier(dstType, dstColumn);
  1160. }
  1161. columns[i] = dstColumn;
  1162. } else {
  1163. // we're past the end of the src matrix, need a vector of zeroes
  1164. if (!zeroColumn) {
  1165. zeroColumn = this->nextId();
  1166. this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
  1167. this->writeWord(dstColumnType, out);
  1168. this->writeWord(zeroColumn, out);
  1169. for (int i = 0; i < dstType.rows(); ++i) {
  1170. this->writeWord(zeroId, out);
  1171. }
  1172. this->writePrecisionModifier(dstType, zeroColumn);
  1173. }
  1174. columns[i] = zeroColumn;
  1175. }
  1176. }
  1177. this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
  1178. this->writeWord(this->getType(dstType), out);
  1179. this->writeWord(id, out);
  1180. for (int i = 0; i < dstType.columns(); i++) {
  1181. this->writeWord(columns[i], out);
  1182. }
  1183. this->writePrecisionModifier(dstType, id);
  1184. }
  1185. void SPIRVCodeGenerator::addColumnEntry(SpvId columnType, Precision precision,
  1186. std::vector<SpvId>* currentColumn,
  1187. std::vector<SpvId>* columnIds,
  1188. int* currentCount, int rows, SpvId entry,
  1189. OutputStream& out) {
  1190. SkASSERT(*currentCount < rows);
  1191. ++(*currentCount);
  1192. currentColumn->push_back(entry);
  1193. if (*currentCount == rows) {
  1194. *currentCount = 0;
  1195. this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn->size(), out);
  1196. this->writeWord(columnType, out);
  1197. SpvId columnId = this->nextId();
  1198. this->writeWord(columnId, out);
  1199. columnIds->push_back(columnId);
  1200. for (SpvId id : *currentColumn) {
  1201. this->writeWord(id, out);
  1202. }
  1203. currentColumn->clear();
  1204. this->writePrecisionModifier(precision, columnId);
  1205. }
  1206. }
  1207. SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
  1208. SkASSERT(c.fType.kind() == Type::kMatrix_Kind);
  1209. // go ahead and write the arguments so we don't try to write new instructions in the middle of
  1210. // an instruction
  1211. std::vector<SpvId> arguments;
  1212. for (size_t i = 0; i < c.fArguments.size(); i++) {
  1213. arguments.push_back(this->writeExpression(*c.fArguments[i], out));
  1214. }
  1215. SpvId result = this->nextId();
  1216. int rows = c.fType.rows();
  1217. int columns = c.fType.columns();
  1218. if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
  1219. this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
  1220. } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
  1221. this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
  1222. } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
  1223. SkASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
  1224. SkASSERT(c.fArguments[0]->fType.columns() == 4);
  1225. SpvId componentType = this->getType(c.fType.componentType());
  1226. SpvId v[4];
  1227. for (int i = 0; i < 4; ++i) {
  1228. v[i] = this->nextId();
  1229. this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
  1230. }
  1231. SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
  1232. SpvId column1 = this->nextId();
  1233. this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
  1234. SpvId column2 = this->nextId();
  1235. this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
  1236. this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
  1237. column2, out);
  1238. } else {
  1239. SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, rows, 1));
  1240. std::vector<SpvId> columnIds;
  1241. // ids of vectors and scalars we have written to the current column so far
  1242. std::vector<SpvId> currentColumn;
  1243. // the total number of scalars represented by currentColumn's entries
  1244. int currentCount = 0;
  1245. Precision precision = c.fType.highPrecision() ? Precision::kHigh : Precision::kLow;
  1246. for (size_t i = 0; i < arguments.size(); i++) {
  1247. if (currentCount == 0 && c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
  1248. c.fArguments[i]->fType.columns() == c.fType.rows()) {
  1249. // this is a complete column by itself
  1250. columnIds.push_back(arguments[i]);
  1251. } else {
  1252. if (c.fArguments[i]->fType.columns() == 1) {
  1253. this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
  1254. &currentCount, rows, arguments[i], out);
  1255. } else {
  1256. SpvId componentType = this->getType(c.fArguments[i]->fType.componentType());
  1257. for (int j = 0; j < c.fArguments[i]->fType.columns(); ++j) {
  1258. SpvId swizzle = this->nextId();
  1259. this->writeInstruction(SpvOpCompositeExtract, componentType, swizzle,
  1260. arguments[i], j, out);
  1261. this->addColumnEntry(columnType, precision, &currentColumn, &columnIds,
  1262. &currentCount, rows, swizzle, out);
  1263. }
  1264. }
  1265. }
  1266. }
  1267. SkASSERT(columnIds.size() == (size_t) columns);
  1268. this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
  1269. this->writeWord(this->getType(c.fType), out);
  1270. this->writeWord(result, out);
  1271. for (SpvId id : columnIds) {
  1272. this->writeWord(id, out);
  1273. }
  1274. }
  1275. this->writePrecisionModifier(c.fType, result);
  1276. return result;
  1277. }
  1278. SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
  1279. SkASSERT(c.fType.kind() == Type::kVector_Kind);
  1280. if (c.isConstant()) {
  1281. return this->writeConstantVector(c);
  1282. }
  1283. // go ahead and write the arguments so we don't try to write new instructions in the middle of
  1284. // an instruction
  1285. std::vector<SpvId> arguments;
  1286. for (size_t i = 0; i < c.fArguments.size(); i++) {
  1287. if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
  1288. // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
  1289. // extract the components and convert them in that case manually. On top of that,
  1290. // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
  1291. // doesn't handle vector arguments at all, so we always extract vector components and
  1292. // pass them into OpCreateComposite individually.
  1293. SpvId vec = this->writeExpression(*c.fArguments[i], out);
  1294. SpvOp_ op = SpvOpUndef;
  1295. const Type& src = c.fArguments[i]->fType.componentType();
  1296. const Type& dst = c.fType.componentType();
  1297. if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
  1298. if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
  1299. if (c.fArguments.size() == 1) {
  1300. return vec;
  1301. }
  1302. } else if (src == *fContext.fInt_Type ||
  1303. src == *fContext.fShort_Type ||
  1304. src == *fContext.fByte_Type) {
  1305. op = SpvOpConvertSToF;
  1306. } else if (src == *fContext.fUInt_Type ||
  1307. src == *fContext.fUShort_Type ||
  1308. src == *fContext.fUByte_Type) {
  1309. op = SpvOpConvertUToF;
  1310. } else {
  1311. SkASSERT(false);
  1312. }
  1313. } else if (dst == *fContext.fInt_Type ||
  1314. dst == *fContext.fShort_Type ||
  1315. dst == *fContext.fByte_Type) {
  1316. if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
  1317. op = SpvOpConvertFToS;
  1318. } else if (src == *fContext.fInt_Type ||
  1319. src == *fContext.fShort_Type ||
  1320. src == *fContext.fByte_Type) {
  1321. if (c.fArguments.size() == 1) {
  1322. return vec;
  1323. }
  1324. } else if (src == *fContext.fUInt_Type ||
  1325. src == *fContext.fUShort_Type ||
  1326. src == *fContext.fUByte_Type) {
  1327. op = SpvOpBitcast;
  1328. } else {
  1329. SkASSERT(false);
  1330. }
  1331. } else if (dst == *fContext.fUInt_Type ||
  1332. dst == *fContext.fUShort_Type ||
  1333. dst == *fContext.fUByte_Type) {
  1334. if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
  1335. op = SpvOpConvertFToS;
  1336. } else if (src == *fContext.fInt_Type ||
  1337. src == *fContext.fShort_Type ||
  1338. src == *fContext.fByte_Type) {
  1339. op = SpvOpBitcast;
  1340. } else if (src == *fContext.fUInt_Type ||
  1341. src == *fContext.fUShort_Type ||
  1342. src == *fContext.fUByte_Type) {
  1343. if (c.fArguments.size() == 1) {
  1344. return vec;
  1345. }
  1346. } else {
  1347. SkASSERT(false);
  1348. }
  1349. }
  1350. for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
  1351. SpvId swizzle = this->nextId();
  1352. this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
  1353. out);
  1354. if (op != SpvOpUndef) {
  1355. SpvId cast = this->nextId();
  1356. this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
  1357. arguments.push_back(cast);
  1358. } else {
  1359. arguments.push_back(swizzle);
  1360. }
  1361. }
  1362. } else {
  1363. arguments.push_back(this->writeExpression(*c.fArguments[i], out));
  1364. }
  1365. }
  1366. SpvId result = this->nextId();
  1367. if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
  1368. this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
  1369. this->writeWord(this->getType(c.fType), out);
  1370. this->writeWord(result, out);
  1371. for (int i = 0; i < c.fType.columns(); i++) {
  1372. this->writeWord(arguments[0], out);
  1373. }
  1374. } else {
  1375. SkASSERT(arguments.size() > 1);
  1376. this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
  1377. this->writeWord(this->getType(c.fType), out);
  1378. this->writeWord(result, out);
  1379. for (SpvId id : arguments) {
  1380. this->writeWord(id, out);
  1381. }
  1382. }
  1383. return result;
  1384. }
  1385. SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
  1386. SkASSERT(c.fType.kind() == Type::kArray_Kind);
  1387. // go ahead and write the arguments so we don't try to write new instructions in the middle of
  1388. // an instruction
  1389. std::vector<SpvId> arguments;
  1390. for (size_t i = 0; i < c.fArguments.size(); i++) {
  1391. arguments.push_back(this->writeExpression(*c.fArguments[i], out));
  1392. }
  1393. SpvId result = this->nextId();
  1394. this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
  1395. this->writeWord(this->getType(c.fType), out);
  1396. this->writeWord(result, out);
  1397. for (SpvId id : arguments) {
  1398. this->writeWord(id, out);
  1399. }
  1400. return result;
  1401. }
  1402. SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
  1403. if (c.fArguments.size() == 1 &&
  1404. this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
  1405. return this->writeExpression(*c.fArguments[0], out);
  1406. }
  1407. if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
  1408. return this->writeFloatConstructor(c, out);
  1409. } else if (c.fType == *fContext.fInt_Type ||
  1410. c.fType == *fContext.fShort_Type ||
  1411. c.fType == *fContext.fByte_Type) {
  1412. return this->writeIntConstructor(c, out);
  1413. } else if (c.fType == *fContext.fUInt_Type ||
  1414. c.fType == *fContext.fUShort_Type ||
  1415. c.fType == *fContext.fUByte_Type) {
  1416. return this->writeUIntConstructor(c, out);
  1417. }
  1418. switch (c.fType.kind()) {
  1419. case Type::kVector_Kind:
  1420. return this->writeVectorConstructor(c, out);
  1421. case Type::kMatrix_Kind:
  1422. return this->writeMatrixConstructor(c, out);
  1423. case Type::kArray_Kind:
  1424. return this->writeArrayConstructor(c, out);
  1425. default:
  1426. ABORT("unsupported constructor: %s", c.description().c_str());
  1427. }
  1428. }
  1429. SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
  1430. if (modifiers.fFlags & Modifiers::kIn_Flag) {
  1431. SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
  1432. return SpvStorageClassInput;
  1433. } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
  1434. SkASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
  1435. return SpvStorageClassOutput;
  1436. } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
  1437. if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
  1438. return SpvStorageClassPushConstant;
  1439. }
  1440. return SpvStorageClassUniform;
  1441. } else {
  1442. return SpvStorageClassFunction;
  1443. }
  1444. }
  1445. SpvStorageClass_ get_storage_class(const Expression& expr) {
  1446. switch (expr.fKind) {
  1447. case Expression::kVariableReference_Kind: {
  1448. const Variable& var = ((VariableReference&) expr).fVariable;
  1449. if (var.fStorage != Variable::kGlobal_Storage) {
  1450. return SpvStorageClassFunction;
  1451. }
  1452. SpvStorageClass_ result = get_storage_class(var.fModifiers);
  1453. if (result == SpvStorageClassFunction) {
  1454. result = SpvStorageClassPrivate;
  1455. }
  1456. return result;
  1457. }
  1458. case Expression::kFieldAccess_Kind:
  1459. return get_storage_class(*((FieldAccess&) expr).fBase);
  1460. case Expression::kIndex_Kind:
  1461. return get_storage_class(*((IndexExpression&) expr).fBase);
  1462. default:
  1463. return SpvStorageClassFunction;
  1464. }
  1465. }
  1466. std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
  1467. std::vector<SpvId> chain;
  1468. switch (expr.fKind) {
  1469. case Expression::kIndex_Kind: {
  1470. IndexExpression& indexExpr = (IndexExpression&) expr;
  1471. chain = this->getAccessChain(*indexExpr.fBase, out);
  1472. chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
  1473. break;
  1474. }
  1475. case Expression::kFieldAccess_Kind: {
  1476. FieldAccess& fieldExpr = (FieldAccess&) expr;
  1477. chain = this->getAccessChain(*fieldExpr.fBase, out);
  1478. IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
  1479. chain.push_back(this->writeIntLiteral(index));
  1480. break;
  1481. }
  1482. default: {
  1483. SpvId id = this->getLValue(expr, out)->getPointer();
  1484. SkASSERT(id != 0);
  1485. chain.push_back(id);
  1486. }
  1487. }
  1488. return chain;
  1489. }
  1490. class PointerLValue : public SPIRVCodeGenerator::LValue {
  1491. public:
  1492. PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type,
  1493. SPIRVCodeGenerator::Precision precision)
  1494. : fGen(gen)
  1495. , fPointer(pointer)
  1496. , fType(type)
  1497. , fPrecision(precision) {}
  1498. virtual SpvId getPointer() override {
  1499. return fPointer;
  1500. }
  1501. virtual SpvId load(OutputStream& out) override {
  1502. SpvId result = fGen.nextId();
  1503. fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
  1504. fGen.writePrecisionModifier(fPrecision, result);
  1505. return result;
  1506. }
  1507. virtual void store(SpvId value, OutputStream& out) override {
  1508. fGen.writeInstruction(SpvOpStore, fPointer, value, out);
  1509. }
  1510. private:
  1511. SPIRVCodeGenerator& fGen;
  1512. const SpvId fPointer;
  1513. const SpvId fType;
  1514. const SPIRVCodeGenerator::Precision fPrecision;
  1515. };
  1516. class SwizzleLValue : public SPIRVCodeGenerator::LValue {
  1517. public:
  1518. SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
  1519. const Type& baseType, const Type& swizzleType,
  1520. SPIRVCodeGenerator::Precision precision)
  1521. : fGen(gen)
  1522. , fVecPointer(vecPointer)
  1523. , fComponents(components)
  1524. , fBaseType(baseType)
  1525. , fSwizzleType(swizzleType)
  1526. , fPrecision(precision) {}
  1527. virtual SpvId getPointer() override {
  1528. return 0;
  1529. }
  1530. virtual SpvId load(OutputStream& out) override {
  1531. SpvId base = fGen.nextId();
  1532. fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
  1533. fGen.writePrecisionModifier(fPrecision, base);
  1534. SpvId result = fGen.nextId();
  1535. fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
  1536. fGen.writeWord(fGen.getType(fSwizzleType), out);
  1537. fGen.writeWord(result, out);
  1538. fGen.writeWord(base, out);
  1539. fGen.writeWord(base, out);
  1540. for (int component : fComponents) {
  1541. fGen.writeWord(component, out);
  1542. }
  1543. fGen.writePrecisionModifier(fPrecision, result);
  1544. return result;
  1545. }
  1546. virtual void store(SpvId value, OutputStream& out) override {
  1547. // use OpVectorShuffle to mix and match the vector components. We effectively create
  1548. // a virtual vector out of the concatenation of the left and right vectors, and then
  1549. // select components from this virtual vector to make the result vector. For
  1550. // instance, given:
  1551. // float3L = ...;
  1552. // float3R = ...;
  1553. // L.xz = R.xy;
  1554. // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
  1555. // our result vector to look like (R.x, L.y, R.y), so we need to select indices
  1556. // (3, 1, 4).
  1557. SpvId base = fGen.nextId();
  1558. fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
  1559. SpvId shuffle = fGen.nextId();
  1560. fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
  1561. fGen.writeWord(fGen.getType(fBaseType), out);
  1562. fGen.writeWord(shuffle, out);
  1563. fGen.writeWord(base, out);
  1564. fGen.writeWord(value, out);
  1565. for (int i = 0; i < fBaseType.columns(); i++) {
  1566. // current offset into the virtual vector, defaults to pulling the unmodified
  1567. // value from the left side
  1568. int offset = i;
  1569. // check to see if we are writing this component
  1570. for (size_t j = 0; j < fComponents.size(); j++) {
  1571. if (fComponents[j] == i) {
  1572. // we're writing to this component, so adjust the offset to pull from
  1573. // the correct component of the right side instead of preserving the
  1574. // value from the left
  1575. offset = (int) (j + fBaseType.columns());
  1576. break;
  1577. }
  1578. }
  1579. fGen.writeWord(offset, out);
  1580. }
  1581. fGen.writePrecisionModifier(fPrecision, shuffle);
  1582. fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
  1583. }
  1584. private:
  1585. SPIRVCodeGenerator& fGen;
  1586. const SpvId fVecPointer;
  1587. const std::vector<int>& fComponents;
  1588. const Type& fBaseType;
  1589. const Type& fSwizzleType;
  1590. const SPIRVCodeGenerator::Precision fPrecision;
  1591. };
  1592. std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
  1593. OutputStream& out) {
  1594. Precision precision = expr.fType.highPrecision() ? Precision::kHigh : Precision::kLow;
  1595. switch (expr.fKind) {
  1596. case Expression::kVariableReference_Kind: {
  1597. SpvId type;
  1598. const Variable& var = ((VariableReference&) expr).fVariable;
  1599. if (var.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
  1600. type = this->getType(Type("sk_in", Type::kArray_Kind, var.fType.componentType(),
  1601. fSkInCount));
  1602. } else {
  1603. type = this->getType(expr.fType);
  1604. }
  1605. auto entry = fVariableMap.find(&var);
  1606. SkASSERT(entry != fVariableMap.end());
  1607. return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(*this,
  1608. entry->second,
  1609. type,
  1610. precision));
  1611. }
  1612. case Expression::kIndex_Kind: // fall through
  1613. case Expression::kFieldAccess_Kind: {
  1614. std::vector<SpvId> chain = this->getAccessChain(expr, out);
  1615. SpvId member = this->nextId();
  1616. this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
  1617. this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
  1618. this->writeWord(member, out);
  1619. for (SpvId idx : chain) {
  1620. this->writeWord(idx, out);
  1621. }
  1622. return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
  1623. *this,
  1624. member,
  1625. this->getType(expr.fType),
  1626. precision));
  1627. }
  1628. case Expression::kSwizzle_Kind: {
  1629. Swizzle& swizzle = (Swizzle&) expr;
  1630. size_t count = swizzle.fComponents.size();
  1631. SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
  1632. SkASSERT(base);
  1633. if (count == 1) {
  1634. IntLiteral index(fContext, -1, swizzle.fComponents[0]);
  1635. SpvId member = this->nextId();
  1636. this->writeInstruction(SpvOpAccessChain,
  1637. this->getPointerType(swizzle.fType,
  1638. get_storage_class(*swizzle.fBase)),
  1639. member,
  1640. base,
  1641. this->writeIntLiteral(index),
  1642. out);
  1643. return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
  1644. *this,
  1645. member,
  1646. this->getType(expr.fType),
  1647. precision));
  1648. } else {
  1649. return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
  1650. *this,
  1651. base,
  1652. swizzle.fComponents,
  1653. swizzle.fBase->fType,
  1654. expr.fType,
  1655. precision));
  1656. }
  1657. }
  1658. case Expression::kTernary_Kind: {
  1659. TernaryExpression& t = (TernaryExpression&) expr;
  1660. SpvId test = this->writeExpression(*t.fTest, out);
  1661. SpvId end = this->nextId();
  1662. SpvId ifTrueLabel = this->nextId();
  1663. SpvId ifFalseLabel = this->nextId();
  1664. this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
  1665. this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
  1666. this->writeLabel(ifTrueLabel, out);
  1667. SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
  1668. SkASSERT(ifTrue);
  1669. this->writeInstruction(SpvOpBranch, end, out);
  1670. ifTrueLabel = fCurrentBlock;
  1671. SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
  1672. SkASSERT(ifFalse);
  1673. ifFalseLabel = fCurrentBlock;
  1674. this->writeInstruction(SpvOpBranch, end, out);
  1675. SpvId result = this->nextId();
  1676. this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
  1677. ifTrueLabel, ifFalse, ifFalseLabel, out);
  1678. return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
  1679. *this,
  1680. result,
  1681. this->getType(expr.fType),
  1682. precision));
  1683. }
  1684. default:
  1685. // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
  1686. // to the need to store values in temporary variables during function calls (see
  1687. // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
  1688. // caught by IRGenerator
  1689. SpvId result = this->nextId();
  1690. SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
  1691. this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
  1692. fVariableBuffer);
  1693. this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
  1694. return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
  1695. *this,
  1696. result,
  1697. this->getType(expr.fType),
  1698. precision));
  1699. }
  1700. }
  1701. SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
  1702. SpvId result = this->nextId();
  1703. auto entry = fVariableMap.find(&ref.fVariable);
  1704. SkASSERT(entry != fVariableMap.end());
  1705. SpvId var = entry->second;
  1706. this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
  1707. this->writePrecisionModifier(ref.fVariable.fType, result);
  1708. if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
  1709. fProgram.fSettings.fFlipY) {
  1710. // need to remap to a top-left coordinate system
  1711. if (fRTHeightStructId == (SpvId) -1) {
  1712. // height variable hasn't been written yet
  1713. std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
  1714. SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
  1715. std::vector<Type::Field> fields;
  1716. fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
  1717. StringFragment name("sksl_synthetic_uniforms");
  1718. Type intfStruct(-1, name, fields);
  1719. Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
  1720. Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
  1721. Layout::CType::kDefault);
  1722. Variable* intfVar = (Variable*) fSynthetics.takeOwnership(std::unique_ptr<Symbol>(
  1723. new Variable(-1,
  1724. Modifiers(layout, Modifiers::kUniform_Flag),
  1725. name,
  1726. intfStruct,
  1727. Variable::kGlobal_Storage)));
  1728. InterfaceBlock intf(-1, intfVar, name, String(""),
  1729. std::vector<std::unique_ptr<Expression>>(), st);
  1730. fRTHeightStructId = this->writeInterfaceBlock(intf);
  1731. fRTHeightFieldIndex = 0;
  1732. }
  1733. SkASSERT(fRTHeightFieldIndex != (SpvId) -1);
  1734. // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, gl_FragCoord.w)
  1735. SpvId xId = this->nextId();
  1736. this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
  1737. result, 0, out);
  1738. IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
  1739. SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
  1740. SpvId heightPtr = this->nextId();
  1741. this->writeOpCode(SpvOpAccessChain, 5, out);
  1742. this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
  1743. this->writeWord(heightPtr, out);
  1744. this->writeWord(fRTHeightStructId, out);
  1745. this->writeWord(fieldIndexId, out);
  1746. SpvId heightRead = this->nextId();
  1747. this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
  1748. heightPtr, out);
  1749. SpvId rawYId = this->nextId();
  1750. this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
  1751. result, 1, out);
  1752. SpvId flippedYId = this->nextId();
  1753. this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
  1754. heightRead, rawYId, out);
  1755. FloatLiteral zero(fContext, -1, 0.0);
  1756. SpvId zeroId = writeFloatLiteral(zero);
  1757. FloatLiteral one(fContext, -1, 1.0);
  1758. SpvId wId = this->nextId();
  1759. this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), wId,
  1760. result, 3, out);
  1761. SpvId flipped = this->nextId();
  1762. this->writeOpCode(SpvOpCompositeConstruct, 7, out);
  1763. this->writeWord(this->getType(*fContext.fFloat4_Type), out);
  1764. this->writeWord(flipped, out);
  1765. this->writeWord(xId, out);
  1766. this->writeWord(flippedYId, out);
  1767. this->writeWord(zeroId, out);
  1768. this->writeWord(wId, out);
  1769. return flipped;
  1770. }
  1771. if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_CLOCKWISE_BUILTIN &&
  1772. !fProgram.fSettings.fFlipY) {
  1773. // FrontFacing in Vulkan is defined in terms of a top-down render target. In skia, we use
  1774. // the default convention of "counter-clockwise face is front".
  1775. SpvId inverse = this->nextId();
  1776. this->writeInstruction(SpvOpLogicalNot, this->getType(*fContext.fBool_Type), inverse,
  1777. result, out);
  1778. return inverse;
  1779. }
  1780. return result;
  1781. }
  1782. SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
  1783. if (expr.fBase->fType.kind() == Type::Kind::kVector_Kind) {
  1784. SpvId base = this->writeExpression(*expr.fBase, out);
  1785. SpvId index = this->writeExpression(*expr.fIndex, out);
  1786. SpvId result = this->nextId();
  1787. this->writeInstruction(SpvOpVectorExtractDynamic, this->getType(expr.fType), result, base,
  1788. index, out);
  1789. return result;
  1790. }
  1791. return getLValue(expr, out)->load(out);
  1792. }
  1793. SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
  1794. return getLValue(f, out)->load(out);
  1795. }
  1796. SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
  1797. SpvId base = this->writeExpression(*swizzle.fBase, out);
  1798. SpvId result = this->nextId();
  1799. size_t count = swizzle.fComponents.size();
  1800. if (count == 1) {
  1801. this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
  1802. swizzle.fComponents[0], out);
  1803. } else {
  1804. this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
  1805. this->writeWord(this->getType(swizzle.fType), out);
  1806. this->writeWord(result, out);
  1807. this->writeWord(base, out);
  1808. SpvId other;
  1809. int last = swizzle.fComponents.back();
  1810. if (last < 0) {
  1811. if (!fConstantZeroOneVector) {
  1812. FloatLiteral zero(fContext, -1, 0);
  1813. SpvId zeroId = this->writeFloatLiteral(zero);
  1814. FloatLiteral one(fContext, -1, 1);
  1815. SpvId oneId = this->writeFloatLiteral(one);
  1816. SpvId type = this->getType(*fContext.fFloat2_Type);
  1817. fConstantZeroOneVector = this->nextId();
  1818. this->writeOpCode(SpvOpConstantComposite, 5, fConstantBuffer);
  1819. this->writeWord(type, fConstantBuffer);
  1820. this->writeWord(fConstantZeroOneVector, fConstantBuffer);
  1821. this->writeWord(zeroId, fConstantBuffer);
  1822. this->writeWord(oneId, fConstantBuffer);
  1823. }
  1824. other = fConstantZeroOneVector;
  1825. } else {
  1826. other = base;
  1827. }
  1828. this->writeWord(other, out);
  1829. for (int component : swizzle.fComponents) {
  1830. if (component == SKSL_SWIZZLE_0) {
  1831. this->writeWord(swizzle.fBase->fType.columns(), out);
  1832. } else if (component == SKSL_SWIZZLE_1) {
  1833. this->writeWord(swizzle.fBase->fType.columns() + 1, out);
  1834. } else {
  1835. this->writeWord(component, out);
  1836. }
  1837. }
  1838. }
  1839. return result;
  1840. }
  1841. SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
  1842. const Type& operandType, SpvId lhs,
  1843. SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
  1844. SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
  1845. SpvId result = this->nextId();
  1846. if (is_float(fContext, operandType)) {
  1847. this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
  1848. } else if (is_signed(fContext, operandType)) {
  1849. this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
  1850. } else if (is_unsigned(fContext, operandType)) {
  1851. this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
  1852. } else if (operandType == *fContext.fBool_Type) {
  1853. this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
  1854. return result; // skip RelaxedPrecision check
  1855. } else {
  1856. ABORT("invalid operandType: %s", operandType.description().c_str());
  1857. }
  1858. if (getActualType(resultType) == operandType && !resultType.highPrecision()) {
  1859. this->writeInstruction(SpvOpDecorate, result, SpvDecorationRelaxedPrecision,
  1860. fDecorationBuffer);
  1861. }
  1862. return result;
  1863. }
  1864. SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SpvOp op,
  1865. OutputStream& out) {
  1866. if (operandType.kind() == Type::kVector_Kind) {
  1867. SpvId result = this->nextId();
  1868. this->writeInstruction(op, this->getType(*fContext.fBool_Type), result, id, out);
  1869. return result;
  1870. }
  1871. return id;
  1872. }
  1873. SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
  1874. SpvOp_ floatOperator, SpvOp_ intOperator,
  1875. SpvOp_ vectorMergeOperator, SpvOp_ mergeOperator,
  1876. OutputStream& out) {
  1877. SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
  1878. SkASSERT(operandType.kind() == Type::kMatrix_Kind);
  1879. SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
  1880. operandType.rows(),
  1881. 1));
  1882. SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
  1883. operandType.rows(),
  1884. 1));
  1885. SpvId boolType = this->getType(*fContext.fBool_Type);
  1886. SpvId result = 0;
  1887. for (int i = 0; i < operandType.columns(); i++) {
  1888. SpvId columnL = this->nextId();
  1889. this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
  1890. SpvId columnR = this->nextId();
  1891. this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
  1892. SpvId compare = this->nextId();
  1893. this->writeInstruction(compareOp, bvecType, compare, columnL, columnR, out);
  1894. SpvId merge = this->nextId();
  1895. this->writeInstruction(vectorMergeOperator, boolType, merge, compare, out);
  1896. if (result != 0) {
  1897. SpvId next = this->nextId();
  1898. this->writeInstruction(mergeOperator, boolType, next, result, merge, out);
  1899. result = next;
  1900. }
  1901. else {
  1902. result = merge;
  1903. }
  1904. }
  1905. return result;
  1906. }
  1907. SpvId SPIRVCodeGenerator::writeComponentwiseMatrixBinary(const Type& operandType, SpvId lhs,
  1908. SpvId rhs, SpvOp_ floatOperator,
  1909. SpvOp_ intOperator,
  1910. OutputStream& out) {
  1911. SpvOp_ op = is_float(fContext, operandType) ? floatOperator : intOperator;
  1912. SkASSERT(operandType.kind() == Type::kMatrix_Kind);
  1913. SpvId columnType = this->getType(operandType.componentType().toCompound(fContext,
  1914. operandType.rows(),
  1915. 1));
  1916. SpvId columns[4];
  1917. for (int i = 0; i < operandType.columns(); i++) {
  1918. SpvId columnL = this->nextId();
  1919. this->writeInstruction(SpvOpCompositeExtract, columnType, columnL, lhs, i, out);
  1920. SpvId columnR = this->nextId();
  1921. this->writeInstruction(SpvOpCompositeExtract, columnType, columnR, rhs, i, out);
  1922. columns[i] = this->nextId();
  1923. this->writeInstruction(op, columnType, columns[i], columnL, columnR, out);
  1924. }
  1925. SpvId result = this->nextId();
  1926. this->writeOpCode(SpvOpCompositeConstruct, 3 + operandType.columns(), out);
  1927. this->writeWord(this->getType(operandType), out);
  1928. this->writeWord(result, out);
  1929. for (int i = 0; i < operandType.columns(); i++) {
  1930. this->writeWord(columns[i], out);
  1931. }
  1932. return result;
  1933. }
  1934. std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
  1935. if (type.isInteger()) {
  1936. return std::unique_ptr<Expression>(new IntLiteral(-1, 1, &type));
  1937. }
  1938. else if (type.isFloat()) {
  1939. return std::unique_ptr<Expression>(new FloatLiteral(-1, 1.0, &type));
  1940. } else {
  1941. ABORT("math is unsupported on type '%s'", type.name().c_str());
  1942. }
  1943. }
  1944. SpvId SPIRVCodeGenerator::writeBinaryExpression(const Type& leftType, SpvId lhs, Token::Kind op,
  1945. const Type& rightType, SpvId rhs,
  1946. const Type& resultType, OutputStream& out) {
  1947. Type tmp("<invalid>");
  1948. // overall type we are operating on: float2, int, uint4...
  1949. const Type* operandType;
  1950. // IR allows mismatched types in expressions (e.g. float2 * float), but they need special
  1951. // handling in SPIR-V
  1952. if (this->getActualType(leftType) != this->getActualType(rightType)) {
  1953. if (leftType.kind() == Type::kVector_Kind && rightType.isNumber()) {
  1954. if (op == Token::SLASH) {
  1955. SpvId one = this->writeExpression(*create_literal_1(fContext, rightType), out);
  1956. SpvId inverse = this->nextId();
  1957. this->writeInstruction(SpvOpFDiv, this->getType(rightType), inverse, one, rhs, out);
  1958. rhs = inverse;
  1959. op = Token::STAR;
  1960. }
  1961. if (op == Token::STAR) {
  1962. SpvId result = this->nextId();
  1963. this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
  1964. result, lhs, rhs, out);
  1965. return result;
  1966. }
  1967. // promote number to vector
  1968. SpvId vec = this->nextId();
  1969. const Type& vecType = leftType;
  1970. this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
  1971. this->writeWord(this->getType(vecType), out);
  1972. this->writeWord(vec, out);
  1973. for (int i = 0; i < vecType.columns(); i++) {
  1974. this->writeWord(rhs, out);
  1975. }
  1976. rhs = vec;
  1977. operandType = &leftType;
  1978. } else if (rightType.kind() == Type::kVector_Kind && leftType.isNumber()) {
  1979. if (op == Token::STAR) {
  1980. SpvId result = this->nextId();
  1981. this->writeInstruction(SpvOpVectorTimesScalar, this->getType(resultType),
  1982. result, rhs, lhs, out);
  1983. return result;
  1984. }
  1985. // promote number to vector
  1986. SpvId vec = this->nextId();
  1987. const Type& vecType = rightType;
  1988. this->writeOpCode(SpvOpCompositeConstruct, 3 + vecType.columns(), out);
  1989. this->writeWord(this->getType(vecType), out);
  1990. this->writeWord(vec, out);
  1991. for (int i = 0; i < vecType.columns(); i++) {
  1992. this->writeWord(lhs, out);
  1993. }
  1994. lhs = vec;
  1995. operandType = &rightType;
  1996. } else if (leftType.kind() == Type::kMatrix_Kind) {
  1997. SpvOp_ spvop;
  1998. if (rightType.kind() == Type::kMatrix_Kind) {
  1999. spvop = SpvOpMatrixTimesMatrix;
  2000. } else if (rightType.kind() == Type::kVector_Kind) {
  2001. spvop = SpvOpMatrixTimesVector;
  2002. } else {
  2003. SkASSERT(rightType.kind() == Type::kScalar_Kind);
  2004. spvop = SpvOpMatrixTimesScalar;
  2005. }
  2006. SpvId result = this->nextId();
  2007. this->writeInstruction(spvop, this->getType(resultType), result, lhs, rhs, out);
  2008. return result;
  2009. } else if (rightType.kind() == Type::kMatrix_Kind) {
  2010. SpvId result = this->nextId();
  2011. if (leftType.kind() == Type::kVector_Kind) {
  2012. this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(resultType), result,
  2013. lhs, rhs, out);
  2014. } else {
  2015. SkASSERT(leftType.kind() == Type::kScalar_Kind);
  2016. this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(resultType), result,
  2017. rhs, lhs, out);
  2018. }
  2019. return result;
  2020. } else {
  2021. SkASSERT(false);
  2022. return -1;
  2023. }
  2024. } else {
  2025. tmp = this->getActualType(leftType);
  2026. operandType = &tmp;
  2027. SkASSERT(*operandType == this->getActualType(rightType));
  2028. }
  2029. switch (op) {
  2030. case Token::EQEQ: {
  2031. if (operandType->kind() == Type::kMatrix_Kind) {
  2032. return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
  2033. SpvOpIEqual, SpvOpAll, SpvOpLogicalAnd, out);
  2034. }
  2035. SkASSERT(resultType == *fContext.fBool_Type);
  2036. const Type* tmpType;
  2037. if (operandType->kind() == Type::kVector_Kind) {
  2038. tmpType = &fContext.fBool_Type->toCompound(fContext,
  2039. operandType->columns(),
  2040. operandType->rows());
  2041. } else {
  2042. tmpType = &resultType;
  2043. }
  2044. return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
  2045. SpvOpFOrdEqual, SpvOpIEqual,
  2046. SpvOpIEqual, SpvOpLogicalEqual, out),
  2047. *operandType, SpvOpAll, out);
  2048. }
  2049. case Token::NEQ:
  2050. if (operandType->kind() == Type::kMatrix_Kind) {
  2051. return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
  2052. SpvOpINotEqual, SpvOpAny, SpvOpLogicalOr, out);
  2053. }
  2054. SkASSERT(resultType == *fContext.fBool_Type);
  2055. const Type* tmpType;
  2056. if (operandType->kind() == Type::kVector_Kind) {
  2057. tmpType = &fContext.fBool_Type->toCompound(fContext,
  2058. operandType->columns(),
  2059. operandType->rows());
  2060. } else {
  2061. tmpType = &resultType;
  2062. }
  2063. return this->foldToBool(this->writeBinaryOperation(*tmpType, *operandType, lhs, rhs,
  2064. SpvOpFOrdNotEqual, SpvOpINotEqual,
  2065. SpvOpINotEqual, SpvOpLogicalNotEqual,
  2066. out),
  2067. *operandType, SpvOpAny, out);
  2068. case Token::GT:
  2069. SkASSERT(resultType == *fContext.fBool_Type);
  2070. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
  2071. SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
  2072. SpvOpUGreaterThan, SpvOpUndef, out);
  2073. case Token::LT:
  2074. SkASSERT(resultType == *fContext.fBool_Type);
  2075. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
  2076. SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
  2077. case Token::GTEQ:
  2078. SkASSERT(resultType == *fContext.fBool_Type);
  2079. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
  2080. SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
  2081. SpvOpUGreaterThanEqual, SpvOpUndef, out);
  2082. case Token::LTEQ:
  2083. SkASSERT(resultType == *fContext.fBool_Type);
  2084. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
  2085. SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
  2086. SpvOpULessThanEqual, SpvOpUndef, out);
  2087. case Token::PLUS:
  2088. if (leftType.kind() == Type::kMatrix_Kind &&
  2089. rightType.kind() == Type::kMatrix_Kind) {
  2090. SkASSERT(leftType == rightType);
  2091. return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
  2092. SpvOpFAdd, SpvOpIAdd, out);
  2093. }
  2094. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
  2095. SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
  2096. case Token::MINUS:
  2097. if (leftType.kind() == Type::kMatrix_Kind &&
  2098. rightType.kind() == Type::kMatrix_Kind) {
  2099. SkASSERT(leftType == rightType);
  2100. return this->writeComponentwiseMatrixBinary(leftType, lhs, rhs,
  2101. SpvOpFSub, SpvOpISub, out);
  2102. }
  2103. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
  2104. SpvOpISub, SpvOpISub, SpvOpUndef, out);
  2105. case Token::STAR:
  2106. if (leftType.kind() == Type::kMatrix_Kind &&
  2107. rightType.kind() == Type::kMatrix_Kind) {
  2108. // matrix multiply
  2109. SpvId result = this->nextId();
  2110. this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
  2111. lhs, rhs, out);
  2112. return result;
  2113. }
  2114. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
  2115. SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
  2116. case Token::SLASH:
  2117. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
  2118. SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
  2119. case Token::PERCENT:
  2120. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
  2121. SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
  2122. case Token::SHL:
  2123. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
  2124. SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
  2125. SpvOpUndef, out);
  2126. case Token::SHR:
  2127. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
  2128. SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
  2129. SpvOpUndef, out);
  2130. case Token::BITWISEAND:
  2131. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
  2132. SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
  2133. case Token::BITWISEOR:
  2134. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
  2135. SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
  2136. case Token::BITWISEXOR:
  2137. return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
  2138. SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
  2139. case Token::COMMA:
  2140. return rhs;
  2141. default:
  2142. SkASSERT(false);
  2143. return -1;
  2144. }
  2145. }
  2146. SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
  2147. // handle cases where we don't necessarily evaluate both LHS and RHS
  2148. switch (b.fOperator) {
  2149. case Token::EQ: {
  2150. SpvId rhs = this->writeExpression(*b.fRight, out);
  2151. this->getLValue(*b.fLeft, out)->store(rhs, out);
  2152. return rhs;
  2153. }
  2154. case Token::LOGICALAND:
  2155. return this->writeLogicalAnd(b, out);
  2156. case Token::LOGICALOR:
  2157. return this->writeLogicalOr(b, out);
  2158. default:
  2159. break;
  2160. }
  2161. std::unique_ptr<LValue> lvalue;
  2162. SpvId lhs;
  2163. if (is_assignment(b.fOperator)) {
  2164. lvalue = this->getLValue(*b.fLeft, out);
  2165. lhs = lvalue->load(out);
  2166. } else {
  2167. lvalue = nullptr;
  2168. lhs = this->writeExpression(*b.fLeft, out);
  2169. }
  2170. SpvId rhs = this->writeExpression(*b.fRight, out);
  2171. SpvId result = this->writeBinaryExpression(b.fLeft->fType, lhs, remove_assignment(b.fOperator),
  2172. b.fRight->fType, rhs, b.fType, out);
  2173. if (lvalue) {
  2174. lvalue->store(result, out);
  2175. }
  2176. return result;
  2177. }
  2178. SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
  2179. SkASSERT(a.fOperator == Token::LOGICALAND);
  2180. BoolLiteral falseLiteral(fContext, -1, false);
  2181. SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
  2182. SpvId lhs = this->writeExpression(*a.fLeft, out);
  2183. SpvId rhsLabel = this->nextId();
  2184. SpvId end = this->nextId();
  2185. SpvId lhsBlock = fCurrentBlock;
  2186. this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
  2187. this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
  2188. this->writeLabel(rhsLabel, out);
  2189. SpvId rhs = this->writeExpression(*a.fRight, out);
  2190. SpvId rhsBlock = fCurrentBlock;
  2191. this->writeInstruction(SpvOpBranch, end, out);
  2192. this->writeLabel(end, out);
  2193. SpvId result = this->nextId();
  2194. this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
  2195. lhsBlock, rhs, rhsBlock, out);
  2196. return result;
  2197. }
  2198. SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
  2199. SkASSERT(o.fOperator == Token::LOGICALOR);
  2200. BoolLiteral trueLiteral(fContext, -1, true);
  2201. SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
  2202. SpvId lhs = this->writeExpression(*o.fLeft, out);
  2203. SpvId rhsLabel = this->nextId();
  2204. SpvId end = this->nextId();
  2205. SpvId lhsBlock = fCurrentBlock;
  2206. this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
  2207. this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
  2208. this->writeLabel(rhsLabel, out);
  2209. SpvId rhs = this->writeExpression(*o.fRight, out);
  2210. SpvId rhsBlock = fCurrentBlock;
  2211. this->writeInstruction(SpvOpBranch, end, out);
  2212. this->writeLabel(end, out);
  2213. SpvId result = this->nextId();
  2214. this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
  2215. lhsBlock, rhs, rhsBlock, out);
  2216. return result;
  2217. }
  2218. SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
  2219. SpvId test = this->writeExpression(*t.fTest, out);
  2220. if (t.fIfTrue->fType.columns() == 1 && t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
  2221. // both true and false are constants, can just use OpSelect
  2222. SpvId result = this->nextId();
  2223. SpvId trueId = this->writeExpression(*t.fIfTrue, out);
  2224. SpvId falseId = this->writeExpression(*t.fIfFalse, out);
  2225. this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
  2226. out);
  2227. return result;
  2228. }
  2229. // was originally using OpPhi to choose the result, but for some reason that is crashing on
  2230. // Adreno. Switched to storing the result in a temp variable as glslang does.
  2231. SpvId var = this->nextId();
  2232. this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
  2233. var, SpvStorageClassFunction, fVariableBuffer);
  2234. SpvId trueLabel = this->nextId();
  2235. SpvId falseLabel = this->nextId();
  2236. SpvId end = this->nextId();
  2237. this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
  2238. this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
  2239. this->writeLabel(trueLabel, out);
  2240. this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
  2241. this->writeInstruction(SpvOpBranch, end, out);
  2242. this->writeLabel(falseLabel, out);
  2243. this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
  2244. this->writeInstruction(SpvOpBranch, end, out);
  2245. this->writeLabel(end, out);
  2246. SpvId result = this->nextId();
  2247. this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
  2248. this->writePrecisionModifier(t.fType, result);
  2249. return result;
  2250. }
  2251. SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
  2252. if (p.fOperator == Token::MINUS) {
  2253. SpvId result = this->nextId();
  2254. SpvId typeId = this->getType(p.fType);
  2255. SpvId expr = this->writeExpression(*p.fOperand, out);
  2256. if (is_float(fContext, p.fType)) {
  2257. this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
  2258. } else if (is_signed(fContext, p.fType)) {
  2259. this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
  2260. } else {
  2261. ABORT("unsupported prefix expression %s", p.description().c_str());
  2262. }
  2263. this->writePrecisionModifier(p.fType, result);
  2264. return result;
  2265. }
  2266. switch (p.fOperator) {
  2267. case Token::PLUS:
  2268. return this->writeExpression(*p.fOperand, out);
  2269. case Token::PLUSPLUS: {
  2270. std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
  2271. SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
  2272. SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
  2273. SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
  2274. out);
  2275. lv->store(result, out);
  2276. return result;
  2277. }
  2278. case Token::MINUSMINUS: {
  2279. std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
  2280. SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
  2281. SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
  2282. SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
  2283. out);
  2284. lv->store(result, out);
  2285. return result;
  2286. }
  2287. case Token::LOGICALNOT: {
  2288. SkASSERT(p.fOperand->fType == *fContext.fBool_Type);
  2289. SpvId result = this->nextId();
  2290. this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
  2291. this->writeExpression(*p.fOperand, out), out);
  2292. return result;
  2293. }
  2294. case Token::BITWISENOT: {
  2295. SpvId result = this->nextId();
  2296. this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
  2297. this->writeExpression(*p.fOperand, out), out);
  2298. return result;
  2299. }
  2300. default:
  2301. ABORT("unsupported prefix expression: %s", p.description().c_str());
  2302. }
  2303. }
  2304. SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
  2305. std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
  2306. SpvId result = lv->load(out);
  2307. SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
  2308. switch (p.fOperator) {
  2309. case Token::PLUSPLUS: {
  2310. SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
  2311. SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
  2312. lv->store(temp, out);
  2313. return result;
  2314. }
  2315. case Token::MINUSMINUS: {
  2316. SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
  2317. SpvOpISub, SpvOpISub, SpvOpUndef, out);
  2318. lv->store(temp, out);
  2319. return result;
  2320. }
  2321. default:
  2322. ABORT("unsupported postfix expression %s", p.description().c_str());
  2323. }
  2324. }
  2325. SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
  2326. if (b.fValue) {
  2327. if (fBoolTrue == 0) {
  2328. fBoolTrue = this->nextId();
  2329. this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
  2330. fConstantBuffer);
  2331. }
  2332. return fBoolTrue;
  2333. } else {
  2334. if (fBoolFalse == 0) {
  2335. fBoolFalse = this->nextId();
  2336. this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
  2337. fConstantBuffer);
  2338. }
  2339. return fBoolFalse;
  2340. }
  2341. }
  2342. SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
  2343. ConstantType type;
  2344. if (i.fType == *fContext.fInt_Type) {
  2345. type = ConstantType::kInt;
  2346. } else if (i.fType == *fContext.fUInt_Type) {
  2347. type = ConstantType::kUInt;
  2348. } else if (i.fType == *fContext.fShort_Type) {
  2349. type = ConstantType::kShort;
  2350. } else if (i.fType == *fContext.fUShort_Type) {
  2351. type = ConstantType::kUShort;
  2352. }
  2353. std::pair<ConstantValue, ConstantType> key(i.fValue, type);
  2354. auto entry = fNumberConstants.find(key);
  2355. if (entry == fNumberConstants.end()) {
  2356. SpvId result = this->nextId();
  2357. this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
  2358. fConstantBuffer);
  2359. fNumberConstants[key] = result;
  2360. return result;
  2361. }
  2362. return entry->second;
  2363. }
  2364. SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
  2365. if (f.fType != *fContext.fDouble_Type) {
  2366. ConstantType type;
  2367. if (f.fType == *fContext.fHalf_Type) {
  2368. type = ConstantType::kHalf;
  2369. } else {
  2370. type = ConstantType::kFloat;
  2371. }
  2372. float value = (float) f.fValue;
  2373. std::pair<ConstantValue, ConstantType> key(f.fValue, type);
  2374. auto entry = fNumberConstants.find(key);
  2375. if (entry == fNumberConstants.end()) {
  2376. SpvId result = this->nextId();
  2377. uint32_t bits;
  2378. SkASSERT(sizeof(bits) == sizeof(value));
  2379. memcpy(&bits, &value, sizeof(bits));
  2380. this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
  2381. fConstantBuffer);
  2382. fNumberConstants[key] = result;
  2383. return result;
  2384. }
  2385. return entry->second;
  2386. } else {
  2387. std::pair<ConstantValue, ConstantType> key(f.fValue, ConstantType::kDouble);
  2388. auto entry = fNumberConstants.find(key);
  2389. if (entry == fNumberConstants.end()) {
  2390. SpvId result = this->nextId();
  2391. uint64_t bits;
  2392. SkASSERT(sizeof(bits) == sizeof(f.fValue));
  2393. memcpy(&bits, &f.fValue, sizeof(bits));
  2394. this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
  2395. bits & 0xffffffff, bits >> 32, fConstantBuffer);
  2396. fNumberConstants[key] = result;
  2397. return result;
  2398. }
  2399. return entry->second;
  2400. }
  2401. }
  2402. SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
  2403. SpvId result = fFunctionMap[&f];
  2404. this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
  2405. SpvFunctionControlMaskNone, this->getFunctionType(f), out);
  2406. this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
  2407. for (size_t i = 0; i < f.fParameters.size(); i++) {
  2408. SpvId id = this->nextId();
  2409. fVariableMap[f.fParameters[i]] = id;
  2410. SpvId type;
  2411. type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
  2412. this->writeInstruction(SpvOpFunctionParameter, type, id, out);
  2413. }
  2414. return result;
  2415. }
  2416. SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
  2417. fVariableBuffer.reset();
  2418. SpvId result = this->writeFunctionStart(f.fDeclaration, out);
  2419. this->writeLabel(this->nextId(), out);
  2420. StringStream bodyBuffer;
  2421. this->writeBlock((Block&) *f.fBody, bodyBuffer);
  2422. write_stringstream(fVariableBuffer, out);
  2423. if (f.fDeclaration.fName == "main") {
  2424. write_stringstream(fGlobalInitializersBuffer, out);
  2425. }
  2426. write_stringstream(bodyBuffer, out);
  2427. if (fCurrentBlock) {
  2428. if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
  2429. this->writeInstruction(SpvOpReturn, out);
  2430. } else {
  2431. this->writeInstruction(SpvOpUnreachable, out);
  2432. }
  2433. }
  2434. this->writeInstruction(SpvOpFunctionEnd, out);
  2435. return result;
  2436. }
  2437. void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
  2438. if (layout.fLocation >= 0) {
  2439. this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
  2440. fDecorationBuffer);
  2441. }
  2442. if (layout.fBinding >= 0) {
  2443. this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
  2444. fDecorationBuffer);
  2445. }
  2446. if (layout.fIndex >= 0) {
  2447. this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
  2448. fDecorationBuffer);
  2449. }
  2450. if (layout.fSet >= 0) {
  2451. this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
  2452. fDecorationBuffer);
  2453. }
  2454. if (layout.fInputAttachmentIndex >= 0) {
  2455. this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
  2456. layout.fInputAttachmentIndex, fDecorationBuffer);
  2457. fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
  2458. }
  2459. if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
  2460. layout.fBuiltin != SK_IN_BUILTIN && layout.fBuiltin != SK_OUT_BUILTIN) {
  2461. this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
  2462. fDecorationBuffer);
  2463. }
  2464. }
  2465. void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
  2466. if (layout.fLocation >= 0) {
  2467. this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
  2468. layout.fLocation, fDecorationBuffer);
  2469. }
  2470. if (layout.fBinding >= 0) {
  2471. this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
  2472. layout.fBinding, fDecorationBuffer);
  2473. }
  2474. if (layout.fIndex >= 0) {
  2475. this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
  2476. layout.fIndex, fDecorationBuffer);
  2477. }
  2478. if (layout.fSet >= 0) {
  2479. this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
  2480. layout.fSet, fDecorationBuffer);
  2481. }
  2482. if (layout.fInputAttachmentIndex >= 0) {
  2483. this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
  2484. layout.fInputAttachmentIndex, fDecorationBuffer);
  2485. }
  2486. if (layout.fBuiltin >= 0) {
  2487. this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
  2488. layout.fBuiltin, fDecorationBuffer);
  2489. }
  2490. }
  2491. static void update_sk_in_count(const Modifiers& m, int* outSkInCount) {
  2492. switch (m.fLayout.fPrimitive) {
  2493. case Layout::kPoints_Primitive:
  2494. *outSkInCount = 1;
  2495. break;
  2496. case Layout::kLines_Primitive:
  2497. *outSkInCount = 2;
  2498. break;
  2499. case Layout::kLinesAdjacency_Primitive:
  2500. *outSkInCount = 4;
  2501. break;
  2502. case Layout::kTriangles_Primitive:
  2503. *outSkInCount = 3;
  2504. break;
  2505. case Layout::kTrianglesAdjacency_Primitive:
  2506. *outSkInCount = 6;
  2507. break;
  2508. default:
  2509. return;
  2510. }
  2511. }
  2512. SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
  2513. bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
  2514. bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
  2515. Layout::kPushConstant_Flag));
  2516. MemoryLayout memoryLayout = (pushConstant || isBuffer) ?
  2517. MemoryLayout(MemoryLayout::k430_Standard) :
  2518. fDefaultLayout;
  2519. SpvId result = this->nextId();
  2520. const Type* type = &intf.fVariable.fType;
  2521. if (fProgram.fInputs.fRTHeight) {
  2522. SkASSERT(fRTHeightStructId == (SpvId) -1);
  2523. SkASSERT(fRTHeightFieldIndex == (SpvId) -1);
  2524. std::vector<Type::Field> fields = type->fields();
  2525. fRTHeightStructId = result;
  2526. fRTHeightFieldIndex = fields.size();
  2527. fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
  2528. type = new Type(type->fOffset, type->name(), fields);
  2529. }
  2530. SpvId typeId;
  2531. if (intf.fVariable.fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
  2532. for (const auto& e : fProgram) {
  2533. if (e.fKind == ProgramElement::kModifiers_Kind) {
  2534. const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
  2535. update_sk_in_count(m, &fSkInCount);
  2536. }
  2537. }
  2538. typeId = this->getType(Type("sk_in", Type::kArray_Kind, intf.fVariable.fType.componentType(),
  2539. fSkInCount), memoryLayout);
  2540. } else {
  2541. typeId = this->getType(*type, memoryLayout);
  2542. }
  2543. if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
  2544. this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
  2545. } else if (intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
  2546. this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
  2547. }
  2548. SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
  2549. SpvId ptrType = this->nextId();
  2550. this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
  2551. this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
  2552. Layout layout = intf.fVariable.fModifiers.fLayout;
  2553. if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) {
  2554. layout.fSet = 0;
  2555. }
  2556. this->writeLayout(layout, result);
  2557. fVariableMap[&intf.fVariable] = result;
  2558. if (fProgram.fInputs.fRTHeight) {
  2559. delete type;
  2560. }
  2561. return result;
  2562. }
  2563. void SPIRVCodeGenerator::writePrecisionModifier(const Type& type, SpvId id) {
  2564. this->writePrecisionModifier(type.highPrecision() ? Precision::kHigh : Precision::kLow, id);
  2565. }
  2566. void SPIRVCodeGenerator::writePrecisionModifier(Precision precision, SpvId id) {
  2567. if (precision == Precision::kLow) {
  2568. this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
  2569. }
  2570. }
  2571. #define BUILTIN_IGNORE 9999
  2572. void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
  2573. OutputStream& out) {
  2574. for (size_t i = 0; i < decl.fVars.size(); i++) {
  2575. if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
  2576. continue;
  2577. }
  2578. const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
  2579. const Variable* var = varDecl.fVar;
  2580. // These haven't been implemented in our SPIR-V generator yet and we only currently use them
  2581. // in the OpenGL backend.
  2582. SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
  2583. Modifiers::kWriteOnly_Flag |
  2584. Modifiers::kCoherent_Flag |
  2585. Modifiers::kVolatile_Flag |
  2586. Modifiers::kRestrict_Flag)));
  2587. if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
  2588. continue;
  2589. }
  2590. if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
  2591. kind != Program::kFragment_Kind) {
  2592. SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
  2593. continue;
  2594. }
  2595. if (!var->fReadCount && !var->fWriteCount &&
  2596. !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
  2597. Modifiers::kOut_Flag |
  2598. Modifiers::kUniform_Flag |
  2599. Modifiers::kBuffer_Flag))) {
  2600. // variable is dead and not an input / output var (the Vulkan debug layers complain if
  2601. // we elide an interface var, even if it's dead)
  2602. continue;
  2603. }
  2604. SpvStorageClass_ storageClass;
  2605. if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
  2606. storageClass = SpvStorageClassInput;
  2607. } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
  2608. storageClass = SpvStorageClassOutput;
  2609. } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
  2610. if (var->fType.kind() == Type::kSampler_Kind) {
  2611. storageClass = SpvStorageClassUniformConstant;
  2612. } else {
  2613. storageClass = SpvStorageClassUniform;
  2614. }
  2615. } else {
  2616. storageClass = SpvStorageClassPrivate;
  2617. }
  2618. SpvId id = this->nextId();
  2619. fVariableMap[var] = id;
  2620. SpvId type;
  2621. if (var->fModifiers.fLayout.fBuiltin == SK_IN_BUILTIN) {
  2622. type = this->getPointerType(Type("sk_in", Type::kArray_Kind,
  2623. var->fType.componentType(), fSkInCount),
  2624. storageClass);
  2625. } else {
  2626. type = this->getPointerType(var->fType, storageClass);
  2627. }
  2628. this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
  2629. this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
  2630. this->writePrecisionModifier(var->fType, id);
  2631. if (varDecl.fValue) {
  2632. SkASSERT(!fCurrentBlock);
  2633. fCurrentBlock = -1;
  2634. SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
  2635. this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
  2636. fCurrentBlock = 0;
  2637. }
  2638. this->writeLayout(var->fModifiers.fLayout, id);
  2639. if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
  2640. this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
  2641. }
  2642. if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
  2643. this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
  2644. fDecorationBuffer);
  2645. }
  2646. }
  2647. }
  2648. void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
  2649. for (const auto& stmt : decl.fVars) {
  2650. SkASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
  2651. VarDeclaration& varDecl = (VarDeclaration&) *stmt;
  2652. const Variable* var = varDecl.fVar;
  2653. // These haven't been implemented in our SPIR-V generator yet and we only currently use them
  2654. // in the OpenGL backend.
  2655. SkASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
  2656. Modifiers::kWriteOnly_Flag |
  2657. Modifiers::kCoherent_Flag |
  2658. Modifiers::kVolatile_Flag |
  2659. Modifiers::kRestrict_Flag)));
  2660. SpvId id = this->nextId();
  2661. fVariableMap[var] = id;
  2662. SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
  2663. this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
  2664. this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
  2665. if (varDecl.fValue) {
  2666. SpvId value = this->writeExpression(*varDecl.fValue, out);
  2667. this->writeInstruction(SpvOpStore, id, value, out);
  2668. }
  2669. }
  2670. }
  2671. void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
  2672. switch (s.fKind) {
  2673. case Statement::kNop_Kind:
  2674. break;
  2675. case Statement::kBlock_Kind:
  2676. this->writeBlock((Block&) s, out);
  2677. break;
  2678. case Statement::kExpression_Kind:
  2679. this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
  2680. break;
  2681. case Statement::kReturn_Kind:
  2682. this->writeReturnStatement((ReturnStatement&) s, out);
  2683. break;
  2684. case Statement::kVarDeclarations_Kind:
  2685. this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
  2686. break;
  2687. case Statement::kIf_Kind:
  2688. this->writeIfStatement((IfStatement&) s, out);
  2689. break;
  2690. case Statement::kFor_Kind:
  2691. this->writeForStatement((ForStatement&) s, out);
  2692. break;
  2693. case Statement::kWhile_Kind:
  2694. this->writeWhileStatement((WhileStatement&) s, out);
  2695. break;
  2696. case Statement::kDo_Kind:
  2697. this->writeDoStatement((DoStatement&) s, out);
  2698. break;
  2699. case Statement::kSwitch_Kind:
  2700. this->writeSwitchStatement((SwitchStatement&) s, out);
  2701. break;
  2702. case Statement::kBreak_Kind:
  2703. this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
  2704. break;
  2705. case Statement::kContinue_Kind:
  2706. this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
  2707. break;
  2708. case Statement::kDiscard_Kind:
  2709. this->writeInstruction(SpvOpKill, out);
  2710. break;
  2711. default:
  2712. ABORT("unsupported statement: %s", s.description().c_str());
  2713. }
  2714. }
  2715. void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
  2716. for (size_t i = 0; i < b.fStatements.size(); i++) {
  2717. this->writeStatement(*b.fStatements[i], out);
  2718. }
  2719. }
  2720. void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
  2721. SpvId test = this->writeExpression(*stmt.fTest, out);
  2722. SpvId ifTrue = this->nextId();
  2723. SpvId ifFalse = this->nextId();
  2724. if (stmt.fIfFalse) {
  2725. SpvId end = this->nextId();
  2726. this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
  2727. this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
  2728. this->writeLabel(ifTrue, out);
  2729. this->writeStatement(*stmt.fIfTrue, out);
  2730. if (fCurrentBlock) {
  2731. this->writeInstruction(SpvOpBranch, end, out);
  2732. }
  2733. this->writeLabel(ifFalse, out);
  2734. this->writeStatement(*stmt.fIfFalse, out);
  2735. if (fCurrentBlock) {
  2736. this->writeInstruction(SpvOpBranch, end, out);
  2737. }
  2738. this->writeLabel(end, out);
  2739. } else {
  2740. this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
  2741. this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
  2742. this->writeLabel(ifTrue, out);
  2743. this->writeStatement(*stmt.fIfTrue, out);
  2744. if (fCurrentBlock) {
  2745. this->writeInstruction(SpvOpBranch, ifFalse, out);
  2746. }
  2747. this->writeLabel(ifFalse, out);
  2748. }
  2749. }
  2750. void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
  2751. if (f.fInitializer) {
  2752. this->writeStatement(*f.fInitializer, out);
  2753. }
  2754. SpvId header = this->nextId();
  2755. SpvId start = this->nextId();
  2756. SpvId body = this->nextId();
  2757. SpvId next = this->nextId();
  2758. fContinueTarget.push(next);
  2759. SpvId end = this->nextId();
  2760. fBreakTarget.push(end);
  2761. this->writeInstruction(SpvOpBranch, header, out);
  2762. this->writeLabel(header, out);
  2763. this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
  2764. this->writeInstruction(SpvOpBranch, start, out);
  2765. this->writeLabel(start, out);
  2766. if (f.fTest) {
  2767. SpvId test = this->writeExpression(*f.fTest, out);
  2768. this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
  2769. }
  2770. this->writeLabel(body, out);
  2771. this->writeStatement(*f.fStatement, out);
  2772. if (fCurrentBlock) {
  2773. this->writeInstruction(SpvOpBranch, next, out);
  2774. }
  2775. this->writeLabel(next, out);
  2776. if (f.fNext) {
  2777. this->writeExpression(*f.fNext, out);
  2778. }
  2779. this->writeInstruction(SpvOpBranch, header, out);
  2780. this->writeLabel(end, out);
  2781. fBreakTarget.pop();
  2782. fContinueTarget.pop();
  2783. }
  2784. void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
  2785. SpvId header = this->nextId();
  2786. SpvId start = this->nextId();
  2787. SpvId body = this->nextId();
  2788. SpvId continueTarget = this->nextId();
  2789. fContinueTarget.push(continueTarget);
  2790. SpvId end = this->nextId();
  2791. fBreakTarget.push(end);
  2792. this->writeInstruction(SpvOpBranch, header, out);
  2793. this->writeLabel(header, out);
  2794. this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
  2795. this->writeInstruction(SpvOpBranch, start, out);
  2796. this->writeLabel(start, out);
  2797. SpvId test = this->writeExpression(*w.fTest, out);
  2798. this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
  2799. this->writeLabel(body, out);
  2800. this->writeStatement(*w.fStatement, out);
  2801. if (fCurrentBlock) {
  2802. this->writeInstruction(SpvOpBranch, continueTarget, out);
  2803. }
  2804. this->writeLabel(continueTarget, out);
  2805. this->writeInstruction(SpvOpBranch, header, out);
  2806. this->writeLabel(end, out);
  2807. fBreakTarget.pop();
  2808. fContinueTarget.pop();
  2809. }
  2810. void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
  2811. // We believe the do loop code below will work, but Skia doesn't actually use them and
  2812. // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
  2813. // the time being, we just fail with an error due to the lack of testing. If you encounter this
  2814. // message, simply remove the error call below to see whether our do loop support actually
  2815. // works.
  2816. fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
  2817. "SkSLSPIRVCodeGenerator.cpp for details");
  2818. SpvId header = this->nextId();
  2819. SpvId start = this->nextId();
  2820. SpvId next = this->nextId();
  2821. SpvId continueTarget = this->nextId();
  2822. fContinueTarget.push(continueTarget);
  2823. SpvId end = this->nextId();
  2824. fBreakTarget.push(end);
  2825. this->writeInstruction(SpvOpBranch, header, out);
  2826. this->writeLabel(header, out);
  2827. this->writeInstruction(SpvOpLoopMerge, end, continueTarget, SpvLoopControlMaskNone, out);
  2828. this->writeInstruction(SpvOpBranch, start, out);
  2829. this->writeLabel(start, out);
  2830. this->writeStatement(*d.fStatement, out);
  2831. if (fCurrentBlock) {
  2832. this->writeInstruction(SpvOpBranch, next, out);
  2833. }
  2834. this->writeLabel(next, out);
  2835. SpvId test = this->writeExpression(*d.fTest, out);
  2836. this->writeInstruction(SpvOpBranchConditional, test, continueTarget, end, out);
  2837. this->writeLabel(continueTarget, out);
  2838. this->writeInstruction(SpvOpBranch, header, out);
  2839. this->writeLabel(end, out);
  2840. fBreakTarget.pop();
  2841. fContinueTarget.pop();
  2842. }
  2843. void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
  2844. SpvId value = this->writeExpression(*s.fValue, out);
  2845. std::vector<SpvId> labels;
  2846. SpvId end = this->nextId();
  2847. SpvId defaultLabel = end;
  2848. fBreakTarget.push(end);
  2849. int size = 3;
  2850. for (const auto& c : s.fCases) {
  2851. SpvId label = this->nextId();
  2852. labels.push_back(label);
  2853. if (c->fValue) {
  2854. size += 2;
  2855. } else {
  2856. defaultLabel = label;
  2857. }
  2858. }
  2859. labels.push_back(end);
  2860. this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
  2861. this->writeOpCode(SpvOpSwitch, size, out);
  2862. this->writeWord(value, out);
  2863. this->writeWord(defaultLabel, out);
  2864. for (size_t i = 0; i < s.fCases.size(); ++i) {
  2865. if (!s.fCases[i]->fValue) {
  2866. continue;
  2867. }
  2868. SkASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
  2869. this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
  2870. this->writeWord(labels[i], out);
  2871. }
  2872. for (size_t i = 0; i < s.fCases.size(); ++i) {
  2873. this->writeLabel(labels[i], out);
  2874. for (const auto& stmt : s.fCases[i]->fStatements) {
  2875. this->writeStatement(*stmt, out);
  2876. }
  2877. if (fCurrentBlock) {
  2878. this->writeInstruction(SpvOpBranch, labels[i + 1], out);
  2879. }
  2880. }
  2881. this->writeLabel(end, out);
  2882. fBreakTarget.pop();
  2883. }
  2884. void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
  2885. if (r.fExpression) {
  2886. this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
  2887. out);
  2888. } else {
  2889. this->writeInstruction(SpvOpReturn, out);
  2890. }
  2891. }
  2892. void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
  2893. SkASSERT(fProgram.fKind == Program::kGeometry_Kind);
  2894. int invocations = 1;
  2895. for (const auto& e : fProgram) {
  2896. if (e.fKind == ProgramElement::kModifiers_Kind) {
  2897. const Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
  2898. if (m.fFlags & Modifiers::kIn_Flag) {
  2899. if (m.fLayout.fInvocations != -1) {
  2900. invocations = m.fLayout.fInvocations;
  2901. }
  2902. SpvId input;
  2903. switch (m.fLayout.fPrimitive) {
  2904. case Layout::kPoints_Primitive:
  2905. input = SpvExecutionModeInputPoints;
  2906. break;
  2907. case Layout::kLines_Primitive:
  2908. input = SpvExecutionModeInputLines;
  2909. break;
  2910. case Layout::kLinesAdjacency_Primitive:
  2911. input = SpvExecutionModeInputLinesAdjacency;
  2912. break;
  2913. case Layout::kTriangles_Primitive:
  2914. input = SpvExecutionModeTriangles;
  2915. break;
  2916. case Layout::kTrianglesAdjacency_Primitive:
  2917. input = SpvExecutionModeInputTrianglesAdjacency;
  2918. break;
  2919. default:
  2920. input = 0;
  2921. break;
  2922. }
  2923. update_sk_in_count(m, &fSkInCount);
  2924. if (input) {
  2925. this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
  2926. }
  2927. } else if (m.fFlags & Modifiers::kOut_Flag) {
  2928. SpvId output;
  2929. switch (m.fLayout.fPrimitive) {
  2930. case Layout::kPoints_Primitive:
  2931. output = SpvExecutionModeOutputPoints;
  2932. break;
  2933. case Layout::kLineStrip_Primitive:
  2934. output = SpvExecutionModeOutputLineStrip;
  2935. break;
  2936. case Layout::kTriangleStrip_Primitive:
  2937. output = SpvExecutionModeOutputTriangleStrip;
  2938. break;
  2939. default:
  2940. output = 0;
  2941. break;
  2942. }
  2943. if (output) {
  2944. this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
  2945. }
  2946. if (m.fLayout.fMaxVertices != -1) {
  2947. this->writeInstruction(SpvOpExecutionMode, entryPoint,
  2948. SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
  2949. out);
  2950. }
  2951. }
  2952. }
  2953. }
  2954. this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
  2955. invocations, out);
  2956. }
  2957. void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
  2958. fGLSLExtendedInstructions = this->nextId();
  2959. StringStream body;
  2960. std::set<SpvId> interfaceVars;
  2961. // assign IDs to functions, determine sk_in size
  2962. int skInSize = -1;
  2963. for (const auto& e : program) {
  2964. switch (e.fKind) {
  2965. case ProgramElement::kFunction_Kind: {
  2966. FunctionDefinition& f = (FunctionDefinition&) e;
  2967. fFunctionMap[&f.fDeclaration] = this->nextId();
  2968. break;
  2969. }
  2970. case ProgramElement::kModifiers_Kind: {
  2971. Modifiers& m = ((ModifiersDeclaration&) e).fModifiers;
  2972. if (m.fFlags & Modifiers::kIn_Flag) {
  2973. switch (m.fLayout.fPrimitive) {
  2974. case Layout::kPoints_Primitive: // break
  2975. case Layout::kLines_Primitive:
  2976. skInSize = 1;
  2977. break;
  2978. case Layout::kLinesAdjacency_Primitive: // break
  2979. skInSize = 2;
  2980. break;
  2981. case Layout::kTriangles_Primitive: // break
  2982. case Layout::kTrianglesAdjacency_Primitive:
  2983. skInSize = 3;
  2984. break;
  2985. default:
  2986. break;
  2987. }
  2988. }
  2989. break;
  2990. }
  2991. default:
  2992. break;
  2993. }
  2994. }
  2995. for (const auto& e : program) {
  2996. if (e.fKind == ProgramElement::kInterfaceBlock_Kind) {
  2997. InterfaceBlock& intf = (InterfaceBlock&) e;
  2998. if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
  2999. SkASSERT(skInSize != -1);
  3000. intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
  3001. }
  3002. SpvId id = this->writeInterfaceBlock(intf);
  3003. if (((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
  3004. (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) &&
  3005. intf.fVariable.fModifiers.fLayout.fBuiltin == -1) {
  3006. interfaceVars.insert(id);
  3007. }
  3008. }
  3009. }
  3010. for (const auto& e : program) {
  3011. if (e.fKind == ProgramElement::kVar_Kind) {
  3012. this->writeGlobalVars(program.fKind, ((VarDeclarations&) e), body);
  3013. }
  3014. }
  3015. for (const auto& e : program) {
  3016. if (e.fKind == ProgramElement::kFunction_Kind) {
  3017. this->writeFunction(((FunctionDefinition&) e), body);
  3018. }
  3019. }
  3020. const FunctionDeclaration* main = nullptr;
  3021. for (auto entry : fFunctionMap) {
  3022. if (entry.first->fName == "main") {
  3023. main = entry.first;
  3024. }
  3025. }
  3026. if (!main) {
  3027. fErrors.error(0, "program does not contain a main() function");
  3028. return;
  3029. }
  3030. for (auto entry : fVariableMap) {
  3031. const Variable* var = entry.first;
  3032. if (var->fStorage == Variable::kGlobal_Storage &&
  3033. ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
  3034. (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
  3035. interfaceVars.insert(entry.second);
  3036. }
  3037. }
  3038. this->writeCapabilities(out);
  3039. this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
  3040. this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
  3041. this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
  3042. (int32_t) interfaceVars.size(), out);
  3043. switch (program.fKind) {
  3044. case Program::kVertex_Kind:
  3045. this->writeWord(SpvExecutionModelVertex, out);
  3046. break;
  3047. case Program::kFragment_Kind:
  3048. this->writeWord(SpvExecutionModelFragment, out);
  3049. break;
  3050. case Program::kGeometry_Kind:
  3051. this->writeWord(SpvExecutionModelGeometry, out);
  3052. break;
  3053. default:
  3054. ABORT("cannot write this kind of program to SPIR-V\n");
  3055. }
  3056. SpvId entryPoint = fFunctionMap[main];
  3057. this->writeWord(entryPoint, out);
  3058. this->writeString(main->fName.fChars, main->fName.fLength, out);
  3059. for (int var : interfaceVars) {
  3060. this->writeWord(var, out);
  3061. }
  3062. if (program.fKind == Program::kGeometry_Kind) {
  3063. this->writeGeometryShaderExecutionMode(entryPoint, out);
  3064. }
  3065. if (program.fKind == Program::kFragment_Kind) {
  3066. this->writeInstruction(SpvOpExecutionMode,
  3067. fFunctionMap[main],
  3068. SpvExecutionModeOriginUpperLeft,
  3069. out);
  3070. }
  3071. for (const auto& e : program) {
  3072. if (e.fKind == ProgramElement::kExtension_Kind) {
  3073. this->writeInstruction(SpvOpSourceExtension, ((Extension&) e).fName.c_str(), out);
  3074. }
  3075. }
  3076. write_stringstream(fExtraGlobalsBuffer, out);
  3077. write_stringstream(fNameBuffer, out);
  3078. write_stringstream(fDecorationBuffer, out);
  3079. write_stringstream(fConstantBuffer, out);
  3080. write_stringstream(fExternalFunctionsBuffer, out);
  3081. write_stringstream(body, out);
  3082. }
  3083. bool SPIRVCodeGenerator::generateCode() {
  3084. SkASSERT(!fErrors.errorCount());
  3085. this->writeWord(SpvMagicNumber, *fOut);
  3086. this->writeWord(SpvVersion, *fOut);
  3087. this->writeWord(SKSL_MAGIC, *fOut);
  3088. StringStream buffer;
  3089. this->writeInstructions(fProgram, buffer);
  3090. this->writeWord(fIdCount, *fOut);
  3091. this->writeWord(0, *fOut); // reserved, always zero
  3092. write_stringstream(buffer, *fOut);
  3093. return 0 == fErrors.errorCount();
  3094. }
  3095. }