SkSLMetalCodeGenerator.cpp 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688
  1. /*
  2. * Copyright 2016 Google Inc.
  3. *
  4. * Use of this source code is governed by a BSD-style license that can be
  5. * found in the LICENSE file.
  6. */
  7. #include "src/sksl/SkSLMetalCodeGenerator.h"
  8. #include "src/sksl/SkSLCompiler.h"
  9. #include "src/sksl/ir/SkSLExpressionStatement.h"
  10. #include "src/sksl/ir/SkSLExtension.h"
  11. #include "src/sksl/ir/SkSLIndexExpression.h"
  12. #include "src/sksl/ir/SkSLModifiersDeclaration.h"
  13. #include "src/sksl/ir/SkSLNop.h"
  14. #include "src/sksl/ir/SkSLVariableReference.h"
  15. #ifdef SK_MOLTENVK
  16. static const uint32_t MVKMagicNum = 0x19960412;
  17. #endif
  18. namespace SkSL {
  19. void MetalCodeGenerator::setupIntrinsics() {
  20. #define METAL(x) std::make_pair(kMetal_IntrinsicKind, k ## x ## _MetalIntrinsic)
  21. #define SPECIAL(x) std::make_pair(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic)
  22. fIntrinsicMap[String("texture")] = SPECIAL(Texture);
  23. fIntrinsicMap[String("mod")] = SPECIAL(Mod);
  24. fIntrinsicMap[String("equal")] = METAL(Equal);
  25. fIntrinsicMap[String("notEqual")] = METAL(NotEqual);
  26. fIntrinsicMap[String("lessThan")] = METAL(LessThan);
  27. fIntrinsicMap[String("lessThanEqual")] = METAL(LessThanEqual);
  28. fIntrinsicMap[String("greaterThan")] = METAL(GreaterThan);
  29. fIntrinsicMap[String("greaterThanEqual")] = METAL(GreaterThanEqual);
  30. }
  31. void MetalCodeGenerator::write(const char* s) {
  32. if (!s[0]) {
  33. return;
  34. }
  35. if (fAtLineStart) {
  36. for (int i = 0; i < fIndentation; i++) {
  37. fOut->writeText(" ");
  38. }
  39. }
  40. fOut->writeText(s);
  41. fAtLineStart = false;
  42. }
  43. void MetalCodeGenerator::writeLine(const char* s) {
  44. this->write(s);
  45. fOut->writeText(fLineEnding);
  46. fAtLineStart = true;
  47. }
  48. void MetalCodeGenerator::write(const String& s) {
  49. this->write(s.c_str());
  50. }
  51. void MetalCodeGenerator::writeLine(const String& s) {
  52. this->writeLine(s.c_str());
  53. }
  54. void MetalCodeGenerator::writeLine() {
  55. this->writeLine("");
  56. }
  57. void MetalCodeGenerator::writeExtension(const Extension& ext) {
  58. this->writeLine("#extension " + ext.fName + " : enable");
  59. }
  60. void MetalCodeGenerator::writeType(const Type& type) {
  61. switch (type.kind()) {
  62. case Type::kStruct_Kind:
  63. for (const Type* search : fWrittenStructs) {
  64. if (*search == type) {
  65. // already written
  66. this->write(type.name());
  67. return;
  68. }
  69. }
  70. fWrittenStructs.push_back(&type);
  71. this->writeLine("struct " + type.name() + " {");
  72. fIndentation++;
  73. this->writeFields(type.fields(), type.fOffset);
  74. fIndentation--;
  75. this->write("}");
  76. break;
  77. case Type::kVector_Kind:
  78. this->writeType(type.componentType());
  79. this->write(to_string(type.columns()));
  80. break;
  81. case Type::kMatrix_Kind:
  82. this->writeType(type.componentType());
  83. this->write(to_string(type.columns()));
  84. this->write("x");
  85. this->write(to_string(type.rows()));
  86. break;
  87. case Type::kSampler_Kind:
  88. this->write("texture2d<float> "); // FIXME - support other texture types;
  89. break;
  90. default:
  91. if (type == *fContext.fHalf_Type) {
  92. // FIXME - Currently only supporting floats in MSL to avoid type coercion issues.
  93. this->write(fContext.fFloat_Type->name());
  94. } else if (type == *fContext.fByte_Type) {
  95. this->write("char");
  96. } else if (type == *fContext.fUByte_Type) {
  97. this->write("uchar");
  98. } else {
  99. this->write(type.name());
  100. }
  101. }
  102. }
  103. void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
  104. switch (expr.fKind) {
  105. case Expression::kBinary_Kind:
  106. this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence);
  107. break;
  108. case Expression::kBoolLiteral_Kind:
  109. this->writeBoolLiteral((BoolLiteral&) expr);
  110. break;
  111. case Expression::kConstructor_Kind:
  112. this->writeConstructor((Constructor&) expr, parentPrecedence);
  113. break;
  114. case Expression::kIntLiteral_Kind:
  115. this->writeIntLiteral((IntLiteral&) expr);
  116. break;
  117. case Expression::kFieldAccess_Kind:
  118. this->writeFieldAccess(((FieldAccess&) expr));
  119. break;
  120. case Expression::kFloatLiteral_Kind:
  121. this->writeFloatLiteral(((FloatLiteral&) expr));
  122. break;
  123. case Expression::kFunctionCall_Kind:
  124. this->writeFunctionCall((FunctionCall&) expr);
  125. break;
  126. case Expression::kPrefix_Kind:
  127. this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence);
  128. break;
  129. case Expression::kPostfix_Kind:
  130. this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence);
  131. break;
  132. case Expression::kSetting_Kind:
  133. this->writeSetting((Setting&) expr);
  134. break;
  135. case Expression::kSwizzle_Kind:
  136. this->writeSwizzle((Swizzle&) expr);
  137. break;
  138. case Expression::kVariableReference_Kind:
  139. this->writeVariableReference((VariableReference&) expr);
  140. break;
  141. case Expression::kTernary_Kind:
  142. this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence);
  143. break;
  144. case Expression::kIndex_Kind:
  145. this->writeIndexExpression((IndexExpression&) expr);
  146. break;
  147. default:
  148. ABORT("unsupported expression: %s", expr.description().c_str());
  149. }
  150. }
  151. void MetalCodeGenerator::writeIntrinsicCall(const FunctionCall& c) {
  152. auto i = fIntrinsicMap.find(c.fFunction.fName);
  153. SkASSERT(i != fIntrinsicMap.end());
  154. Intrinsic intrinsic = i->second;
  155. int32_t intrinsicId = intrinsic.second;
  156. switch (intrinsic.first) {
  157. case kSpecial_IntrinsicKind:
  158. return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId);
  159. break;
  160. case kMetal_IntrinsicKind:
  161. this->writeExpression(*c.fArguments[0], kSequence_Precedence);
  162. switch ((MetalIntrinsic) intrinsicId) {
  163. case kEqual_MetalIntrinsic:
  164. this->write(" == ");
  165. break;
  166. case kNotEqual_MetalIntrinsic:
  167. this->write(" != ");
  168. break;
  169. case kLessThan_MetalIntrinsic:
  170. this->write(" < ");
  171. break;
  172. case kLessThanEqual_MetalIntrinsic:
  173. this->write(" <= ");
  174. break;
  175. case kGreaterThan_MetalIntrinsic:
  176. this->write(" > ");
  177. break;
  178. case kGreaterThanEqual_MetalIntrinsic:
  179. this->write(" >= ");
  180. break;
  181. default:
  182. ABORT("unsupported metal intrinsic kind");
  183. }
  184. this->writeExpression(*c.fArguments[1], kSequence_Precedence);
  185. break;
  186. default:
  187. ABORT("unsupported intrinsic kind");
  188. }
  189. }
  190. void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
  191. const auto& entry = fIntrinsicMap.find(c.fFunction.fName);
  192. if (entry != fIntrinsicMap.end()) {
  193. this->writeIntrinsicCall(c);
  194. return;
  195. }
  196. if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) {
  197. this->write("atan2");
  198. } else if (c.fFunction.fBuiltin && "inversesqrt" == c.fFunction.fName) {
  199. this->write("rsqrt");
  200. } else if (c.fFunction.fBuiltin && "inverse" == c.fFunction.fName) {
  201. SkASSERT(c.fArguments.size() == 1);
  202. this->writeInverseHack(*c.fArguments[0]);
  203. } else if (c.fFunction.fBuiltin && "dFdx" == c.fFunction.fName) {
  204. this->write("dfdx");
  205. } else if (c.fFunction.fBuiltin && "dFdy" == c.fFunction.fName) {
  206. // Flipping Y also negates the Y derivatives.
  207. this->write((fProgram.fSettings.fFlipY) ? "-dfdy" : "dfdy");
  208. } else {
  209. this->writeName(c.fFunction.fName);
  210. }
  211. this->write("(");
  212. const char* separator = "";
  213. if (this->requirements(c.fFunction) & kInputs_Requirement) {
  214. this->write("_in");
  215. separator = ", ";
  216. }
  217. if (this->requirements(c.fFunction) & kOutputs_Requirement) {
  218. this->write(separator);
  219. this->write("_out");
  220. separator = ", ";
  221. }
  222. if (this->requirements(c.fFunction) & kUniforms_Requirement) {
  223. this->write(separator);
  224. this->write("_uniforms");
  225. separator = ", ";
  226. }
  227. if (this->requirements(c.fFunction) & kGlobals_Requirement) {
  228. this->write(separator);
  229. this->write("_globals");
  230. separator = ", ";
  231. }
  232. for (size_t i = 0; i < c.fArguments.size(); ++i) {
  233. const Expression& arg = *c.fArguments[i];
  234. this->write(separator);
  235. separator = ", ";
  236. if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
  237. this->write("&");
  238. }
  239. this->writeExpression(arg, kSequence_Precedence);
  240. }
  241. this->write(")");
  242. }
  243. void MetalCodeGenerator::writeInverseHack(const Expression& mat) {
  244. String typeName = mat.fType.name();
  245. String name = typeName + "_inverse";
  246. if (mat.fType == *fContext.fFloat2x2_Type || mat.fType == *fContext.fHalf2x2_Type) {
  247. if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
  248. fWrittenIntrinsics.insert(name);
  249. fExtraFunctions.writeText((
  250. typeName + " " + name + "(" + typeName + " m) {"
  251. " return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));"
  252. "}"
  253. ).c_str());
  254. }
  255. }
  256. else if (mat.fType == *fContext.fFloat3x3_Type || mat.fType == *fContext.fHalf3x3_Type) {
  257. if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
  258. fWrittenIntrinsics.insert(name);
  259. fExtraFunctions.writeText((
  260. typeName + " " + name + "(" + typeName + " m) {"
  261. " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];"
  262. " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];"
  263. " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];"
  264. " float b01 = a22 * a11 - a12 * a21;"
  265. " float b11 = -a22 * a10 + a12 * a20;"
  266. " float b21 = a21 * a10 - a11 * a20;"
  267. " float det = a00 * b01 + a01 * b11 + a02 * b21;"
  268. " return " + typeName +
  269. " (b01, (-a22 * a01 + a02 * a21), (a12 * a01 - a02 * a11),"
  270. " b11, (a22 * a00 - a02 * a20), (-a12 * a00 + a02 * a10),"
  271. " b21, (-a21 * a00 + a01 * a20), (a11 * a00 - a01 * a10)) * "
  272. " (1/det);"
  273. "}"
  274. ).c_str());
  275. }
  276. }
  277. else if (mat.fType == *fContext.fFloat4x4_Type || mat.fType == *fContext.fHalf4x4_Type) {
  278. if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
  279. fWrittenIntrinsics.insert(name);
  280. fExtraFunctions.writeText((
  281. typeName + " " + name + "(" + typeName + " m) {"
  282. " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];"
  283. " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];"
  284. " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];"
  285. " float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];"
  286. " float b00 = a00 * a11 - a01 * a10;"
  287. " float b01 = a00 * a12 - a02 * a10;"
  288. " float b02 = a00 * a13 - a03 * a10;"
  289. " float b03 = a01 * a12 - a02 * a11;"
  290. " float b04 = a01 * a13 - a03 * a11;"
  291. " float b05 = a02 * a13 - a03 * a12;"
  292. " float b06 = a20 * a31 - a21 * a30;"
  293. " float b07 = a20 * a32 - a22 * a30;"
  294. " float b08 = a20 * a33 - a23 * a30;"
  295. " float b09 = a21 * a32 - a22 * a31;"
  296. " float b10 = a21 * a33 - a23 * a31;"
  297. " float b11 = a22 * a33 - a23 * a32;"
  298. " float det = b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - "
  299. " b04 * b07 + b05 * b06;"
  300. " return " + typeName + "(a11 * b11 - a12 * b10 + a13 * b09,"
  301. " a02 * b10 - a01 * b11 - a03 * b09,"
  302. " a31 * b05 - a32 * b04 + a33 * b03,"
  303. " a22 * b04 - a21 * b05 - a23 * b03,"
  304. " a12 * b08 - a10 * b11 - a13 * b07,"
  305. " a00 * b11 - a02 * b08 + a03 * b07,"
  306. " a32 * b02 - a30 * b05 - a33 * b01,"
  307. " a20 * b05 - a22 * b02 + a23 * b01,"
  308. " a10 * b10 - a11 * b08 + a13 * b06,"
  309. " a01 * b08 - a00 * b10 - a03 * b06,"
  310. " a30 * b04 - a31 * b02 + a33 * b00,"
  311. " a21 * b02 - a20 * b04 - a23 * b00,"
  312. " a11 * b07 - a10 * b09 - a12 * b06,"
  313. " a00 * b09 - a01 * b07 + a02 * b06,"
  314. " a31 * b01 - a30 * b03 - a32 * b00,"
  315. " a20 * b03 - a21 * b01 + a22 * b00) / det;"
  316. "}"
  317. ).c_str());
  318. }
  319. }
  320. this->write(name);
  321. }
  322. void MetalCodeGenerator::writeSpecialIntrinsic(const FunctionCall & c, SpecialIntrinsic kind) {
  323. switch (kind) {
  324. case kTexture_SpecialIntrinsic:
  325. this->writeExpression(*c.fArguments[0], kSequence_Precedence);
  326. this->write(".sample(");
  327. this->writeExpression(*c.fArguments[0], kSequence_Precedence);
  328. this->write(SAMPLER_SUFFIX);
  329. this->write(", ");
  330. this->writeExpression(*c.fArguments[1], kSequence_Precedence);
  331. if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
  332. this->write(".xy)"); // FIXME - add projection functionality
  333. } else {
  334. SkASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
  335. this->write(")");
  336. }
  337. break;
  338. case kMod_SpecialIntrinsic:
  339. // fmod(x, y) in metal calculates x - y * trunc(x / y) instead of x - y * floor(x / y)
  340. this->write("((");
  341. this->writeExpression(*c.fArguments[0], kSequence_Precedence);
  342. this->write(") - (");
  343. this->writeExpression(*c.fArguments[1], kSequence_Precedence);
  344. this->write(") * floor((");
  345. this->writeExpression(*c.fArguments[0], kSequence_Precedence);
  346. this->write(") / (");
  347. this->writeExpression(*c.fArguments[1], kSequence_Precedence);
  348. this->write(")))");
  349. break;
  350. default:
  351. ABORT("unsupported special intrinsic kind");
  352. }
  353. }
  354. // If it hasn't already been written, writes a constructor for 'matrix' which takes a single value
  355. // of type 'arg'.
  356. String MetalCodeGenerator::getMatrixConstructHelper(const Type& matrix, const Type& arg) {
  357. String key = matrix.name() + arg.name();
  358. auto found = fHelpers.find(key);
  359. if (found != fHelpers.end()) {
  360. return found->second;
  361. }
  362. String name;
  363. int columns = matrix.columns();
  364. int rows = matrix.rows();
  365. if (arg.isNumber()) {
  366. // creating a matrix from a single scalar value
  367. name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float";
  368. fExtraFunctions.printf("float%dx%d %s(float x) {\n",
  369. columns, rows, name.c_str());
  370. fExtraFunctions.printf(" return float%dx%d(", columns, rows);
  371. for (int i = 0; i < columns; ++i) {
  372. if (i > 0) {
  373. fExtraFunctions.writeText(", ");
  374. }
  375. fExtraFunctions.printf("float%d(", rows);
  376. for (int j = 0; j < rows; ++j) {
  377. if (j > 0) {
  378. fExtraFunctions.writeText(", ");
  379. }
  380. if (i == j) {
  381. fExtraFunctions.writeText("x");
  382. } else {
  383. fExtraFunctions.writeText("0");
  384. }
  385. }
  386. fExtraFunctions.writeText(")");
  387. }
  388. fExtraFunctions.writeText(");\n}\n");
  389. } else if (arg.kind() == Type::kMatrix_Kind) {
  390. // creating a matrix from another matrix
  391. int argColumns = arg.columns();
  392. int argRows = arg.rows();
  393. name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float" +
  394. to_string(argColumns) + "x" + to_string(argRows);
  395. fExtraFunctions.printf("float%dx%d %s(float%dx%d m) {\n",
  396. columns, rows, name.c_str(), argColumns, argRows);
  397. fExtraFunctions.printf(" return float%dx%d(", columns, rows);
  398. for (int i = 0; i < columns; ++i) {
  399. if (i > 0) {
  400. fExtraFunctions.writeText(", ");
  401. }
  402. fExtraFunctions.printf("float%d(", rows);
  403. for (int j = 0; j < rows; ++j) {
  404. if (j > 0) {
  405. fExtraFunctions.writeText(", ");
  406. }
  407. if (i < argColumns && j < argRows) {
  408. fExtraFunctions.printf("m[%d][%d]", i, j);
  409. } else {
  410. fExtraFunctions.writeText("0");
  411. }
  412. }
  413. fExtraFunctions.writeText(")");
  414. }
  415. fExtraFunctions.writeText(");\n}\n");
  416. } else if (matrix.rows() == 2 && matrix.columns() == 2 && arg == *fContext.fFloat4_Type) {
  417. // float2x2(float4) doesn't work, need to split it into float2x2(float2, float2)
  418. name = "float2x2_from_float4";
  419. fExtraFunctions.printf(
  420. "float2x2 %s(float4 v) {\n"
  421. " return float2x2(float2(v[0], v[1]), float2(v[2], v[3]));\n"
  422. "}\n",
  423. name.c_str()
  424. );
  425. } else {
  426. SkASSERT(false);
  427. name = "<error>";
  428. }
  429. fHelpers[key] = name;
  430. return name;
  431. }
  432. bool MetalCodeGenerator::canCoerce(const Type& t1, const Type& t2) {
  433. if (t1.columns() != t2.columns() || t1.rows() != t2.rows()) {
  434. return false;
  435. }
  436. if (t1.columns() > 1) {
  437. return this->canCoerce(t1.componentType(), t2.componentType());
  438. }
  439. return t1.isFloat() && t2.isFloat();
  440. }
  441. void MetalCodeGenerator::writeConstructor(const Constructor& c, Precedence parentPrecedence) {
  442. if (c.fArguments.size() == 1 && this->canCoerce(c.fType, c.fArguments[0]->fType)) {
  443. this->writeExpression(*c.fArguments[0], parentPrecedence);
  444. return;
  445. }
  446. if (c.fType.kind() == Type::kMatrix_Kind && c.fArguments.size() == 1) {
  447. const Expression& arg = *c.fArguments[0];
  448. String name = this->getMatrixConstructHelper(c.fType, arg.fType);
  449. this->write(name);
  450. this->write("(");
  451. this->writeExpression(arg, kSequence_Precedence);
  452. this->write(")");
  453. } else {
  454. this->writeType(c.fType);
  455. this->write("(");
  456. const char* separator = "";
  457. int scalarCount = 0;
  458. for (const auto& arg : c.fArguments) {
  459. this->write(separator);
  460. separator = ", ";
  461. if (Type::kMatrix_Kind == c.fType.kind() && arg->fType.columns() != c.fType.rows()) {
  462. // merge scalars and smaller vectors together
  463. if (!scalarCount) {
  464. this->writeType(c.fType.componentType());
  465. this->write(to_string(c.fType.rows()));
  466. this->write("(");
  467. }
  468. scalarCount += arg->fType.columns();
  469. }
  470. this->writeExpression(*arg, kSequence_Precedence);
  471. if (scalarCount && scalarCount == c.fType.rows()) {
  472. this->write(")");
  473. scalarCount = 0;
  474. }
  475. }
  476. this->write(")");
  477. }
  478. }
  479. void MetalCodeGenerator::writeFragCoord() {
  480. if (fProgram.fInputs.fRTHeight) {
  481. this->write("float4(_fragCoord.x, _anonInterface0.u_skRTHeight - _fragCoord.y, 0.0, "
  482. "_fragCoord.w)");
  483. } else {
  484. this->write("float4(_fragCoord.x, _fragCoord.y, 0.0, _fragCoord.w)");
  485. }
  486. }
  487. void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
  488. switch (ref.fVariable.fModifiers.fLayout.fBuiltin) {
  489. case SK_FRAGCOLOR_BUILTIN:
  490. this->write("_out->sk_FragColor");
  491. break;
  492. case SK_FRAGCOORD_BUILTIN:
  493. this->writeFragCoord();
  494. break;
  495. case SK_VERTEXID_BUILTIN:
  496. this->write("sk_VertexID");
  497. break;
  498. case SK_INSTANCEID_BUILTIN:
  499. this->write("sk_InstanceID");
  500. break;
  501. case SK_CLOCKWISE_BUILTIN:
  502. // We'd set the front facing winding in the MTLRenderCommandEncoder to be counter
  503. // clockwise to match Skia convention. This is also the default in MoltenVK.
  504. this->write(fProgram.fSettings.fFlipY ? "_frontFacing" : "(!_frontFacing)");
  505. break;
  506. default:
  507. if (Variable::kGlobal_Storage == ref.fVariable.fStorage) {
  508. if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
  509. this->write("_in.");
  510. } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
  511. this->write("_out->");
  512. } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag &&
  513. ref.fVariable.fType.kind() != Type::kSampler_Kind) {
  514. this->write("_uniforms.");
  515. } else {
  516. this->write("_globals->");
  517. }
  518. }
  519. this->writeName(ref.fVariable.fName);
  520. }
  521. }
  522. void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
  523. this->writeExpression(*expr.fBase, kPostfix_Precedence);
  524. this->write("[");
  525. this->writeExpression(*expr.fIndex, kTopLevel_Precedence);
  526. this->write("]");
  527. }
  528. void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
  529. const Type::Field* field = &f.fBase->fType.fields()[f.fFieldIndex];
  530. if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) {
  531. this->writeExpression(*f.fBase, kPostfix_Precedence);
  532. this->write(".");
  533. }
  534. switch (field->fModifiers.fLayout.fBuiltin) {
  535. case SK_CLIPDISTANCE_BUILTIN:
  536. this->write("gl_ClipDistance");
  537. break;
  538. case SK_POSITION_BUILTIN:
  539. this->write("_out->sk_Position");
  540. break;
  541. default:
  542. if (field->fName == "sk_PointSize") {
  543. this->write("_out->sk_PointSize");
  544. } else {
  545. if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
  546. this->write("_globals->");
  547. this->write(fInterfaceBlockNameMap[fInterfaceBlockMap[field]]);
  548. this->write("->");
  549. }
  550. this->writeName(field->fName);
  551. }
  552. }
  553. }
  554. void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
  555. int last = swizzle.fComponents.back();
  556. if (last == SKSL_SWIZZLE_0 || last == SKSL_SWIZZLE_1) {
  557. this->writeType(swizzle.fType);
  558. this->write("(");
  559. }
  560. this->writeExpression(*swizzle.fBase, kPostfix_Precedence);
  561. this->write(".");
  562. for (int c : swizzle.fComponents) {
  563. if (c >= 0) {
  564. this->write(&("x\0y\0z\0w\0"[c * 2]));
  565. }
  566. }
  567. if (last == SKSL_SWIZZLE_0) {
  568. this->write(", 0)");
  569. }
  570. else if (last == SKSL_SWIZZLE_1) {
  571. this->write(", 1)");
  572. }
  573. }
  574. MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) {
  575. switch (op) {
  576. case Token::STAR: // fall through
  577. case Token::SLASH: // fall through
  578. case Token::PERCENT: return MetalCodeGenerator::kMultiplicative_Precedence;
  579. case Token::PLUS: // fall through
  580. case Token::MINUS: return MetalCodeGenerator::kAdditive_Precedence;
  581. case Token::SHL: // fall through
  582. case Token::SHR: return MetalCodeGenerator::kShift_Precedence;
  583. case Token::LT: // fall through
  584. case Token::GT: // fall through
  585. case Token::LTEQ: // fall through
  586. case Token::GTEQ: return MetalCodeGenerator::kRelational_Precedence;
  587. case Token::EQEQ: // fall through
  588. case Token::NEQ: return MetalCodeGenerator::kEquality_Precedence;
  589. case Token::BITWISEAND: return MetalCodeGenerator::kBitwiseAnd_Precedence;
  590. case Token::BITWISEXOR: return MetalCodeGenerator::kBitwiseXor_Precedence;
  591. case Token::BITWISEOR: return MetalCodeGenerator::kBitwiseOr_Precedence;
  592. case Token::LOGICALAND: return MetalCodeGenerator::kLogicalAnd_Precedence;
  593. case Token::LOGICALXOR: return MetalCodeGenerator::kLogicalXor_Precedence;
  594. case Token::LOGICALOR: return MetalCodeGenerator::kLogicalOr_Precedence;
  595. case Token::EQ: // fall through
  596. case Token::PLUSEQ: // fall through
  597. case Token::MINUSEQ: // fall through
  598. case Token::STAREQ: // fall through
  599. case Token::SLASHEQ: // fall through
  600. case Token::PERCENTEQ: // fall through
  601. case Token::SHLEQ: // fall through
  602. case Token::SHREQ: // fall through
  603. case Token::LOGICALANDEQ: // fall through
  604. case Token::LOGICALXOREQ: // fall through
  605. case Token::LOGICALOREQ: // fall through
  606. case Token::BITWISEANDEQ: // fall through
  607. case Token::BITWISEXOREQ: // fall through
  608. case Token::BITWISEOREQ: return MetalCodeGenerator::kAssignment_Precedence;
  609. case Token::COMMA: return MetalCodeGenerator::kSequence_Precedence;
  610. default: ABORT("unsupported binary operator");
  611. }
  612. }
  613. void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
  614. const Type& result) {
  615. String key = "TimesEqual" + left.name() + right.name();
  616. if (fHelpers.find(key) == fHelpers.end()) {
  617. fExtraFunctions.printf("%s operator*=(thread %s& left, thread const %s& right) {\n"
  618. " left = left * right;\n"
  619. " return left;\n"
  620. "}", result.name().c_str(), left.name().c_str(),
  621. right.name().c_str());
  622. }
  623. }
  624. void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
  625. Precedence parentPrecedence) {
  626. Precedence precedence = GetBinaryPrecedence(b.fOperator);
  627. bool needParens = precedence >= parentPrecedence;
  628. switch (b.fOperator) {
  629. case Token::EQEQ:
  630. if (b.fLeft->fType.kind() == Type::kVector_Kind) {
  631. this->write("all");
  632. needParens = true;
  633. }
  634. break;
  635. case Token::NEQ:
  636. if (b.fLeft->fType.kind() == Type::kVector_Kind) {
  637. this->write("any");
  638. needParens = true;
  639. }
  640. break;
  641. default:
  642. break;
  643. }
  644. if (needParens) {
  645. this->write("(");
  646. }
  647. if (Compiler::IsAssignment(b.fOperator) &&
  648. Expression::kVariableReference_Kind == b.fLeft->fKind &&
  649. Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage &&
  650. (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
  651. // writing to an out parameter. Since we have to turn those into pointers, we have to
  652. // dereference it here.
  653. this->write("*");
  654. }
  655. if (b.fOperator == Token::STAREQ && b.fLeft->fType.kind() == Type::kMatrix_Kind &&
  656. b.fRight->fType.kind() == Type::kMatrix_Kind) {
  657. this->writeMatrixTimesEqualHelper(b.fLeft->fType, b.fRight->fType, b.fType);
  658. }
  659. this->writeExpression(*b.fLeft, precedence);
  660. if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
  661. Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
  662. // This doesn't compile in Metal:
  663. // float4 x = float4(1);
  664. // x.xy *= float2x2(...);
  665. // with the error message "non-const reference cannot bind to vector element",
  666. // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
  667. // as long as the LHS has no side effects, and hope for the best otherwise.
  668. this->write(" = ");
  669. this->writeExpression(*b.fLeft, kAssignment_Precedence);
  670. this->write(" ");
  671. String op = Compiler::OperatorName(b.fOperator);
  672. SkASSERT(op.endsWith("="));
  673. this->write(op.substr(0, op.size() - 1).c_str());
  674. this->write(" ");
  675. } else {
  676. this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
  677. }
  678. this->writeExpression(*b.fRight, precedence);
  679. if (needParens) {
  680. this->write(")");
  681. }
  682. }
  683. void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
  684. Precedence parentPrecedence) {
  685. if (kTernary_Precedence >= parentPrecedence) {
  686. this->write("(");
  687. }
  688. this->writeExpression(*t.fTest, kTernary_Precedence);
  689. this->write(" ? ");
  690. this->writeExpression(*t.fIfTrue, kTernary_Precedence);
  691. this->write(" : ");
  692. this->writeExpression(*t.fIfFalse, kTernary_Precedence);
  693. if (kTernary_Precedence >= parentPrecedence) {
  694. this->write(")");
  695. }
  696. }
  697. void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
  698. Precedence parentPrecedence) {
  699. if (kPrefix_Precedence >= parentPrecedence) {
  700. this->write("(");
  701. }
  702. this->write(Compiler::OperatorName(p.fOperator));
  703. this->writeExpression(*p.fOperand, kPrefix_Precedence);
  704. if (kPrefix_Precedence >= parentPrecedence) {
  705. this->write(")");
  706. }
  707. }
  708. void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
  709. Precedence parentPrecedence) {
  710. if (kPostfix_Precedence >= parentPrecedence) {
  711. this->write("(");
  712. }
  713. this->writeExpression(*p.fOperand, kPostfix_Precedence);
  714. this->write(Compiler::OperatorName(p.fOperator));
  715. if (kPostfix_Precedence >= parentPrecedence) {
  716. this->write(")");
  717. }
  718. }
  719. void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
  720. this->write(b.fValue ? "true" : "false");
  721. }
  722. void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
  723. if (i.fType == *fContext.fUInt_Type) {
  724. this->write(to_string(i.fValue & 0xffffffff) + "u");
  725. } else {
  726. this->write(to_string((int32_t) i.fValue));
  727. }
  728. }
  729. void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
  730. this->write(to_string(f.fValue));
  731. }
  732. void MetalCodeGenerator::writeSetting(const Setting& s) {
  733. ABORT("internal error; setting was not folded to a constant during compilation\n");
  734. }
  735. void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
  736. const char* separator = "";
  737. if ("main" == f.fDeclaration.fName) {
  738. switch (fProgram.fKind) {
  739. case Program::kFragment_Kind:
  740. #ifdef SK_MOLTENVK
  741. this->write("fragment Outputs main0");
  742. #else
  743. this->write("fragment Outputs fragmentMain");
  744. #endif
  745. break;
  746. case Program::kVertex_Kind:
  747. #ifdef SK_MOLTENVK
  748. this->write("vertex Outputs main0");
  749. #else
  750. this->write("vertex Outputs vertexMain");
  751. #endif
  752. break;
  753. default:
  754. SkASSERT(false);
  755. }
  756. this->write("(Inputs _in [[stage_in]]");
  757. if (-1 != fUniformBuffer) {
  758. this->write(", constant Uniforms& _uniforms [[buffer(" +
  759. to_string(fUniformBuffer) + ")]]");
  760. }
  761. for (const auto& e : fProgram) {
  762. if (ProgramElement::kVar_Kind == e.fKind) {
  763. VarDeclarations& decls = (VarDeclarations&) e;
  764. if (!decls.fVars.size()) {
  765. continue;
  766. }
  767. for (const auto& stmt: decls.fVars) {
  768. VarDeclaration& var = (VarDeclaration&) *stmt;
  769. if (var.fVar->fType.kind() == Type::kSampler_Kind) {
  770. this->write(", texture2d<float> "); // FIXME - support other texture types
  771. this->writeName(var.fVar->fName);
  772. this->write("[[texture(");
  773. this->write(to_string(var.fVar->fModifiers.fLayout.fBinding));
  774. this->write(")]]");
  775. this->write(", sampler ");
  776. this->writeName(var.fVar->fName);
  777. this->write(SAMPLER_SUFFIX);
  778. this->write("[[sampler(");
  779. this->write(to_string(var.fVar->fModifiers.fLayout.fBinding));
  780. this->write(")]]");
  781. }
  782. }
  783. } else if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
  784. InterfaceBlock& intf = (InterfaceBlock&) e;
  785. if ("sk_PerVertex" == intf.fTypeName) {
  786. continue;
  787. }
  788. this->write(", constant ");
  789. this->writeType(intf.fVariable.fType);
  790. this->write("& " );
  791. this->write(fInterfaceBlockNameMap[&intf]);
  792. this->write(" [[buffer(");
  793. #ifdef SK_MOLTENVK
  794. this->write(to_string(intf.fVariable.fModifiers.fLayout.fSet));
  795. #else
  796. this->write(to_string(intf.fVariable.fModifiers.fLayout.fBinding));
  797. #endif
  798. this->write(")]]");
  799. }
  800. }
  801. if (fProgram.fKind == Program::kFragment_Kind) {
  802. if (fProgram.fInputs.fRTHeight && fInterfaceBlockNameMap.empty()) {
  803. #ifdef SK_MOLTENVK
  804. this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(0)]]");
  805. #else
  806. this->write(", constant sksl_synthetic_uniforms& _anonInterface0 [[buffer(1)]]");
  807. #endif
  808. }
  809. this->write(", bool _frontFacing [[front_facing]]");
  810. this->write(", float4 _fragCoord [[position]]");
  811. } else if (fProgram.fKind == Program::kVertex_Kind) {
  812. this->write(", uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]");
  813. }
  814. separator = ", ";
  815. } else {
  816. this->writeType(f.fDeclaration.fReturnType);
  817. this->write(" ");
  818. this->writeName(f.fDeclaration.fName);
  819. this->write("(");
  820. if (this->requirements(f.fDeclaration) & kInputs_Requirement) {
  821. this->write("Inputs _in");
  822. separator = ", ";
  823. }
  824. if (this->requirements(f.fDeclaration) & kOutputs_Requirement) {
  825. this->write(separator);
  826. this->write("thread Outputs* _out");
  827. separator = ", ";
  828. }
  829. if (this->requirements(f.fDeclaration) & kUniforms_Requirement) {
  830. this->write(separator);
  831. this->write("Uniforms _uniforms");
  832. separator = ", ";
  833. }
  834. if (this->requirements(f.fDeclaration) & kGlobals_Requirement) {
  835. this->write(separator);
  836. this->write("thread Globals* _globals");
  837. separator = ", ";
  838. }
  839. }
  840. for (const auto& param : f.fDeclaration.fParameters) {
  841. this->write(separator);
  842. separator = ", ";
  843. this->writeModifiers(param->fModifiers, false);
  844. std::vector<int> sizes;
  845. const Type* type = &param->fType;
  846. while (Type::kArray_Kind == type->kind()) {
  847. sizes.push_back(type->columns());
  848. type = &type->componentType();
  849. }
  850. this->writeType(*type);
  851. if (param->fModifiers.fFlags & Modifiers::kOut_Flag) {
  852. this->write("*");
  853. }
  854. this->write(" ");
  855. this->writeName(param->fName);
  856. for (int s : sizes) {
  857. if (s <= 0) {
  858. this->write("[]");
  859. } else {
  860. this->write("[" + to_string(s) + "]");
  861. }
  862. }
  863. }
  864. this->writeLine(") {");
  865. SkASSERT(!fProgram.fSettings.fFragColorIsInOut);
  866. if ("main" == f.fDeclaration.fName) {
  867. if (fNeedsGlobalStructInit) {
  868. this->writeLine(" Globals globalStruct;");
  869. this->writeLine(" thread Globals* _globals = &globalStruct;");
  870. for (const auto& intf: fInterfaceBlockNameMap) {
  871. const auto& intfName = intf.second;
  872. this->write(" _globals->");
  873. this->writeName(intfName);
  874. this->write(" = &");
  875. this->writeName(intfName);
  876. this->write(";\n");
  877. }
  878. for (const auto& var: fInitNonConstGlobalVars) {
  879. this->write(" _globals->");
  880. this->writeName(var->fVar->fName);
  881. this->write(" = ");
  882. this->writeVarInitializer(*var->fVar, *var->fValue);
  883. this->writeLine(";");
  884. }
  885. for (const auto& texture: fTextures) {
  886. this->write(" _globals->");
  887. this->writeName(texture->fName);
  888. this->write(" = ");
  889. this->writeName(texture->fName);
  890. this->write(";\n");
  891. this->write(" _globals->");
  892. this->writeName(texture->fName);
  893. this->write(SAMPLER_SUFFIX);
  894. this->write(" = ");
  895. this->writeName(texture->fName);
  896. this->write(SAMPLER_SUFFIX);
  897. this->write(";\n");
  898. }
  899. }
  900. this->writeLine(" Outputs _outputStruct;");
  901. this->writeLine(" thread Outputs* _out = &_outputStruct;");
  902. }
  903. fFunctionHeader = "";
  904. OutputStream* oldOut = fOut;
  905. StringStream buffer;
  906. fOut = &buffer;
  907. fIndentation++;
  908. this->writeStatements(((Block&) *f.fBody).fStatements);
  909. if ("main" == f.fDeclaration.fName) {
  910. switch (fProgram.fKind) {
  911. case Program::kFragment_Kind:
  912. this->writeLine("return *_out;");
  913. break;
  914. case Program::kVertex_Kind:
  915. this->writeLine("_out->sk_Position.y = -_out->sk_Position.y;");
  916. this->writeLine("return *_out;"); // FIXME - detect if function already has return
  917. break;
  918. default:
  919. SkASSERT(false);
  920. }
  921. }
  922. fIndentation--;
  923. this->writeLine("}");
  924. fOut = oldOut;
  925. this->write(fFunctionHeader);
  926. this->write(buffer.str());
  927. }
  928. void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
  929. bool globalContext) {
  930. if (modifiers.fFlags & Modifiers::kOut_Flag) {
  931. this->write("thread ");
  932. }
  933. if (modifiers.fFlags & Modifiers::kConst_Flag) {
  934. this->write("constant ");
  935. }
  936. }
  937. void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
  938. if ("sk_PerVertex" == intf.fTypeName) {
  939. return;
  940. }
  941. this->writeModifiers(intf.fVariable.fModifiers, true);
  942. this->write("struct ");
  943. this->writeLine(intf.fTypeName + " {");
  944. const Type* structType = &intf.fVariable.fType;
  945. fWrittenStructs.push_back(structType);
  946. while (Type::kArray_Kind == structType->kind()) {
  947. structType = &structType->componentType();
  948. }
  949. fIndentation++;
  950. writeFields(structType->fields(), structType->fOffset, &intf);
  951. if (fProgram.fInputs.fRTHeight) {
  952. this->writeLine("float u_skRTHeight;");
  953. }
  954. fIndentation--;
  955. this->write("}");
  956. if (intf.fInstanceName.size()) {
  957. this->write(" ");
  958. this->write(intf.fInstanceName);
  959. for (const auto& size : intf.fSizes) {
  960. this->write("[");
  961. if (size) {
  962. this->writeExpression(*size, kTopLevel_Precedence);
  963. }
  964. this->write("]");
  965. }
  966. fInterfaceBlockNameMap[&intf] = intf.fInstanceName;
  967. } else {
  968. fInterfaceBlockNameMap[&intf] = "_anonInterface" + to_string(fAnonInterfaceCount++);
  969. }
  970. this->writeLine(";");
  971. }
  972. void MetalCodeGenerator::writeFields(const std::vector<Type::Field>& fields, int parentOffset,
  973. const InterfaceBlock* parentIntf) {
  974. #ifdef SK_MOLTENVK
  975. MemoryLayout memoryLayout(MemoryLayout::k140_Standard);
  976. #else
  977. MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
  978. #endif
  979. int currentOffset = 0;
  980. for (const auto& field: fields) {
  981. int fieldOffset = field.fModifiers.fLayout.fOffset;
  982. const Type* fieldType = field.fType;
  983. if (fieldOffset != -1) {
  984. if (currentOffset > fieldOffset) {
  985. fErrors.error(parentOffset,
  986. "offset of field '" + field.fName + "' must be at least " +
  987. to_string((int) currentOffset));
  988. } else if (currentOffset < fieldOffset) {
  989. this->write("char pad");
  990. this->write(to_string(fPaddingCount++));
  991. this->write("[");
  992. this->write(to_string(fieldOffset - currentOffset));
  993. this->writeLine("];");
  994. currentOffset = fieldOffset;
  995. }
  996. int alignment = memoryLayout.alignment(*fieldType);
  997. if (fieldOffset % alignment) {
  998. fErrors.error(parentOffset,
  999. "offset of field '" + field.fName + "' must be a multiple of " +
  1000. to_string((int) alignment));
  1001. }
  1002. }
  1003. #ifdef SK_MOLTENVK
  1004. if (fieldType->kind() == Type::kVector_Kind &&
  1005. fieldType->columns() == 3) {
  1006. SkASSERT(memoryLayout.size(*fieldType) == 3);
  1007. // Pack all vec3 types so that their size in bytes will match what was expected in the
  1008. // original SkSL code since MSL has vec3 sizes equal to 4 * component type, while SkSL
  1009. // has vec3 equal to 3 * component type.
  1010. // FIXME - Packed vectors can't be accessed by swizzles, but can be indexed into. A
  1011. // combination of this being a problem which only occurs when using MoltenVK and the
  1012. // fact that we haven't swizzled a vec3 yet means that this problem hasn't been
  1013. // addressed.
  1014. this->write(PACKED_PREFIX);
  1015. }
  1016. #endif
  1017. currentOffset += memoryLayout.size(*fieldType);
  1018. std::vector<int> sizes;
  1019. while (fieldType->kind() == Type::kArray_Kind) {
  1020. sizes.push_back(fieldType->columns());
  1021. fieldType = &fieldType->componentType();
  1022. }
  1023. this->writeModifiers(field.fModifiers, false);
  1024. this->writeType(*fieldType);
  1025. this->write(" ");
  1026. this->writeName(field.fName);
  1027. for (int s : sizes) {
  1028. if (s <= 0) {
  1029. this->write("[]");
  1030. } else {
  1031. this->write("[" + to_string(s) + "]");
  1032. }
  1033. }
  1034. this->writeLine(";");
  1035. if (parentIntf) {
  1036. fInterfaceBlockMap[&field] = parentIntf;
  1037. }
  1038. }
  1039. }
  1040. void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
  1041. this->writeExpression(value, kTopLevel_Precedence);
  1042. }
  1043. void MetalCodeGenerator::writeName(const String& name) {
  1044. if (fReservedWords.find(name) != fReservedWords.end()) {
  1045. this->write("_"); // adding underscore before name to avoid conflict with reserved words
  1046. }
  1047. this->write(name);
  1048. }
  1049. void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) {
  1050. SkASSERT(decl.fVars.size() > 0);
  1051. bool wroteType = false;
  1052. for (const auto& stmt : decl.fVars) {
  1053. VarDeclaration& var = (VarDeclaration&) *stmt;
  1054. if (global && !(var.fVar->fModifiers.fFlags & Modifiers::kConst_Flag)) {
  1055. continue;
  1056. }
  1057. if (wroteType) {
  1058. this->write(", ");
  1059. } else {
  1060. this->writeModifiers(var.fVar->fModifiers, global);
  1061. this->writeType(decl.fBaseType);
  1062. this->write(" ");
  1063. wroteType = true;
  1064. }
  1065. this->writeName(var.fVar->fName);
  1066. for (const auto& size : var.fSizes) {
  1067. this->write("[");
  1068. if (size) {
  1069. this->writeExpression(*size, kTopLevel_Precedence);
  1070. }
  1071. this->write("]");
  1072. }
  1073. if (var.fValue) {
  1074. this->write(" = ");
  1075. this->writeVarInitializer(*var.fVar, *var.fValue);
  1076. }
  1077. }
  1078. if (wroteType) {
  1079. this->write(";");
  1080. }
  1081. }
  1082. void MetalCodeGenerator::writeStatement(const Statement& s) {
  1083. switch (s.fKind) {
  1084. case Statement::kBlock_Kind:
  1085. this->writeBlock((Block&) s);
  1086. break;
  1087. case Statement::kExpression_Kind:
  1088. this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence);
  1089. this->write(";");
  1090. break;
  1091. case Statement::kReturn_Kind:
  1092. this->writeReturnStatement((ReturnStatement&) s);
  1093. break;
  1094. case Statement::kVarDeclarations_Kind:
  1095. this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false);
  1096. break;
  1097. case Statement::kIf_Kind:
  1098. this->writeIfStatement((IfStatement&) s);
  1099. break;
  1100. case Statement::kFor_Kind:
  1101. this->writeForStatement((ForStatement&) s);
  1102. break;
  1103. case Statement::kWhile_Kind:
  1104. this->writeWhileStatement((WhileStatement&) s);
  1105. break;
  1106. case Statement::kDo_Kind:
  1107. this->writeDoStatement((DoStatement&) s);
  1108. break;
  1109. case Statement::kSwitch_Kind:
  1110. this->writeSwitchStatement((SwitchStatement&) s);
  1111. break;
  1112. case Statement::kBreak_Kind:
  1113. this->write("break;");
  1114. break;
  1115. case Statement::kContinue_Kind:
  1116. this->write("continue;");
  1117. break;
  1118. case Statement::kDiscard_Kind:
  1119. this->write("discard_fragment();");
  1120. break;
  1121. case Statement::kNop_Kind:
  1122. this->write(";");
  1123. break;
  1124. default:
  1125. ABORT("unsupported statement: %s", s.description().c_str());
  1126. }
  1127. }
  1128. void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) {
  1129. for (const auto& s : statements) {
  1130. if (!s->isEmpty()) {
  1131. this->writeStatement(*s);
  1132. this->writeLine();
  1133. }
  1134. }
  1135. }
  1136. void MetalCodeGenerator::writeBlock(const Block& b) {
  1137. this->writeLine("{");
  1138. fIndentation++;
  1139. this->writeStatements(b.fStatements);
  1140. fIndentation--;
  1141. this->write("}");
  1142. }
  1143. void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
  1144. this->write("if (");
  1145. this->writeExpression(*stmt.fTest, kTopLevel_Precedence);
  1146. this->write(") ");
  1147. this->writeStatement(*stmt.fIfTrue);
  1148. if (stmt.fIfFalse) {
  1149. this->write(" else ");
  1150. this->writeStatement(*stmt.fIfFalse);
  1151. }
  1152. }
  1153. void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
  1154. this->write("for (");
  1155. if (f.fInitializer && !f.fInitializer->isEmpty()) {
  1156. this->writeStatement(*f.fInitializer);
  1157. } else {
  1158. this->write("; ");
  1159. }
  1160. if (f.fTest) {
  1161. this->writeExpression(*f.fTest, kTopLevel_Precedence);
  1162. }
  1163. this->write("; ");
  1164. if (f.fNext) {
  1165. this->writeExpression(*f.fNext, kTopLevel_Precedence);
  1166. }
  1167. this->write(") ");
  1168. this->writeStatement(*f.fStatement);
  1169. }
  1170. void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
  1171. this->write("while (");
  1172. this->writeExpression(*w.fTest, kTopLevel_Precedence);
  1173. this->write(") ");
  1174. this->writeStatement(*w.fStatement);
  1175. }
  1176. void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
  1177. this->write("do ");
  1178. this->writeStatement(*d.fStatement);
  1179. this->write(" while (");
  1180. this->writeExpression(*d.fTest, kTopLevel_Precedence);
  1181. this->write(");");
  1182. }
  1183. void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
  1184. this->write("switch (");
  1185. this->writeExpression(*s.fValue, kTopLevel_Precedence);
  1186. this->writeLine(") {");
  1187. fIndentation++;
  1188. for (const auto& c : s.fCases) {
  1189. if (c->fValue) {
  1190. this->write("case ");
  1191. this->writeExpression(*c->fValue, kTopLevel_Precedence);
  1192. this->writeLine(":");
  1193. } else {
  1194. this->writeLine("default:");
  1195. }
  1196. fIndentation++;
  1197. for (const auto& stmt : c->fStatements) {
  1198. this->writeStatement(*stmt);
  1199. this->writeLine();
  1200. }
  1201. fIndentation--;
  1202. }
  1203. fIndentation--;
  1204. this->write("}");
  1205. }
  1206. void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
  1207. this->write("return");
  1208. if (r.fExpression) {
  1209. this->write(" ");
  1210. this->writeExpression(*r.fExpression, kTopLevel_Precedence);
  1211. }
  1212. this->write(";");
  1213. }
  1214. void MetalCodeGenerator::writeHeader() {
  1215. this->write("#include <metal_stdlib>\n");
  1216. this->write("#include <simd/simd.h>\n");
  1217. this->write("using namespace metal;\n");
  1218. }
  1219. void MetalCodeGenerator::writeUniformStruct() {
  1220. for (const auto& e : fProgram) {
  1221. if (ProgramElement::kVar_Kind == e.fKind) {
  1222. VarDeclarations& decls = (VarDeclarations&) e;
  1223. if (!decls.fVars.size()) {
  1224. continue;
  1225. }
  1226. const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
  1227. if (first.fModifiers.fFlags & Modifiers::kUniform_Flag &&
  1228. first.fType.kind() != Type::kSampler_Kind) {
  1229. if (-1 == fUniformBuffer) {
  1230. this->write("struct Uniforms {\n");
  1231. fUniformBuffer = first.fModifiers.fLayout.fSet;
  1232. if (-1 == fUniformBuffer) {
  1233. fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'");
  1234. }
  1235. } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) {
  1236. if (-1 == fUniformBuffer) {
  1237. fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
  1238. "the same 'layout(set=...)'");
  1239. }
  1240. }
  1241. this->write(" ");
  1242. this->writeType(first.fType);
  1243. this->write(" ");
  1244. for (const auto& stmt : decls.fVars) {
  1245. VarDeclaration& var = (VarDeclaration&) *stmt;
  1246. this->writeName(var.fVar->fName);
  1247. }
  1248. this->write(";\n");
  1249. }
  1250. }
  1251. }
  1252. if (-1 != fUniformBuffer) {
  1253. this->write("};\n");
  1254. }
  1255. }
  1256. void MetalCodeGenerator::writeInputStruct() {
  1257. this->write("struct Inputs {\n");
  1258. for (const auto& e : fProgram) {
  1259. if (ProgramElement::kVar_Kind == e.fKind) {
  1260. VarDeclarations& decls = (VarDeclarations&) e;
  1261. if (!decls.fVars.size()) {
  1262. continue;
  1263. }
  1264. const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
  1265. if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
  1266. -1 == first.fModifiers.fLayout.fBuiltin) {
  1267. this->write(" ");
  1268. this->writeType(first.fType);
  1269. this->write(" ");
  1270. for (const auto& stmt : decls.fVars) {
  1271. VarDeclaration& var = (VarDeclaration&) *stmt;
  1272. this->writeName(var.fVar->fName);
  1273. if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
  1274. if (fProgram.fKind == Program::kVertex_Kind) {
  1275. this->write(" [[attribute(" +
  1276. to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
  1277. } else if (fProgram.fKind == Program::kFragment_Kind) {
  1278. this->write(" [[user(locn" +
  1279. to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
  1280. }
  1281. }
  1282. }
  1283. this->write(";\n");
  1284. }
  1285. }
  1286. }
  1287. this->write("};\n");
  1288. }
  1289. void MetalCodeGenerator::writeOutputStruct() {
  1290. this->write("struct Outputs {\n");
  1291. if (fProgram.fKind == Program::kVertex_Kind) {
  1292. this->write(" float4 sk_Position [[position]];\n");
  1293. } else if (fProgram.fKind == Program::kFragment_Kind) {
  1294. this->write(" float4 sk_FragColor [[color(0)]];\n");
  1295. }
  1296. for (const auto& e : fProgram) {
  1297. if (ProgramElement::kVar_Kind == e.fKind) {
  1298. VarDeclarations& decls = (VarDeclarations&) e;
  1299. if (!decls.fVars.size()) {
  1300. continue;
  1301. }
  1302. const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
  1303. if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
  1304. -1 == first.fModifiers.fLayout.fBuiltin) {
  1305. this->write(" ");
  1306. this->writeType(first.fType);
  1307. this->write(" ");
  1308. for (const auto& stmt : decls.fVars) {
  1309. VarDeclaration& var = (VarDeclaration&) *stmt;
  1310. this->writeName(var.fVar->fName);
  1311. if (fProgram.fKind == Program::kVertex_Kind) {
  1312. this->write(" [[user(locn" +
  1313. to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
  1314. } else if (fProgram.fKind == Program::kFragment_Kind) {
  1315. this->write(" [[color(" +
  1316. to_string(var.fVar->fModifiers.fLayout.fLocation) +")");
  1317. int colorIndex = var.fVar->fModifiers.fLayout.fIndex;
  1318. if (colorIndex) {
  1319. this->write(", index(" + to_string(colorIndex) + ")");
  1320. }
  1321. this->write("]]");
  1322. }
  1323. }
  1324. this->write(";\n");
  1325. }
  1326. }
  1327. }
  1328. if (fProgram.fKind == Program::kVertex_Kind) {
  1329. this->write(" float sk_PointSize;\n");
  1330. }
  1331. this->write("};\n");
  1332. }
  1333. void MetalCodeGenerator::writeInterfaceBlocks() {
  1334. bool wroteInterfaceBlock = false;
  1335. for (const auto& e : fProgram) {
  1336. if (ProgramElement::kInterfaceBlock_Kind == e.fKind) {
  1337. this->writeInterfaceBlock((InterfaceBlock&) e);
  1338. wroteInterfaceBlock = true;
  1339. }
  1340. }
  1341. if (!wroteInterfaceBlock && fProgram.fInputs.fRTHeight) {
  1342. this->writeLine("struct sksl_synthetic_uniforms {");
  1343. this->writeLine(" float u_skRTHeight;");
  1344. this->writeLine("};");
  1345. }
  1346. }
  1347. void MetalCodeGenerator::writeGlobalStruct() {
  1348. bool wroteStructDecl = false;
  1349. for (const auto& intf : fInterfaceBlockNameMap) {
  1350. if (!wroteStructDecl) {
  1351. this->write("struct Globals {\n");
  1352. wroteStructDecl = true;
  1353. }
  1354. fNeedsGlobalStructInit = true;
  1355. const auto& intfType = intf.first;
  1356. const auto& intfName = intf.second;
  1357. this->write(" constant ");
  1358. this->write(intfType->fTypeName);
  1359. this->write("* ");
  1360. this->writeName(intfName);
  1361. this->write(";\n");
  1362. }
  1363. for (const auto& e : fProgram) {
  1364. if (ProgramElement::kVar_Kind == e.fKind) {
  1365. VarDeclarations& decls = (VarDeclarations&) e;
  1366. if (!decls.fVars.size()) {
  1367. continue;
  1368. }
  1369. const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
  1370. if ((!first.fModifiers.fFlags && -1 == first.fModifiers.fLayout.fBuiltin) ||
  1371. first.fType.kind() == Type::kSampler_Kind) {
  1372. if (!wroteStructDecl) {
  1373. this->write("struct Globals {\n");
  1374. wroteStructDecl = true;
  1375. }
  1376. fNeedsGlobalStructInit = true;
  1377. this->write(" ");
  1378. this->writeType(first.fType);
  1379. this->write(" ");
  1380. for (const auto& stmt : decls.fVars) {
  1381. VarDeclaration& var = (VarDeclaration&) *stmt;
  1382. this->writeName(var.fVar->fName);
  1383. if (var.fVar->fType.kind() == Type::kSampler_Kind) {
  1384. fTextures.push_back(var.fVar);
  1385. this->write(";\n");
  1386. this->write(" sampler ");
  1387. this->writeName(var.fVar->fName);
  1388. this->write(SAMPLER_SUFFIX);
  1389. }
  1390. if (var.fValue) {
  1391. fInitNonConstGlobalVars.push_back(&var);
  1392. }
  1393. }
  1394. this->write(";\n");
  1395. }
  1396. }
  1397. }
  1398. if (wroteStructDecl) {
  1399. this->write("};\n");
  1400. }
  1401. }
  1402. void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
  1403. switch (e.fKind) {
  1404. case ProgramElement::kExtension_Kind:
  1405. break;
  1406. case ProgramElement::kVar_Kind: {
  1407. VarDeclarations& decl = (VarDeclarations&) e;
  1408. if (decl.fVars.size() > 0) {
  1409. int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
  1410. if (-1 == builtin) {
  1411. // normal var
  1412. this->writeVarDeclarations(decl, true);
  1413. this->writeLine();
  1414. } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
  1415. // ignore
  1416. }
  1417. }
  1418. break;
  1419. }
  1420. case ProgramElement::kInterfaceBlock_Kind:
  1421. // handled in writeInterfaceBlocks, do nothing
  1422. break;
  1423. case ProgramElement::kFunction_Kind:
  1424. this->writeFunction((FunctionDefinition&) e);
  1425. break;
  1426. case ProgramElement::kModifiers_Kind:
  1427. this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
  1428. this->writeLine(";");
  1429. break;
  1430. default:
  1431. printf("%s\n", e.description().c_str());
  1432. ABORT("unsupported program element");
  1433. }
  1434. }
  1435. MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) {
  1436. switch (e.fKind) {
  1437. case Expression::kFunctionCall_Kind: {
  1438. const FunctionCall& f = (const FunctionCall&) e;
  1439. Requirements result = this->requirements(f.fFunction);
  1440. for (const auto& e : f.fArguments) {
  1441. result |= this->requirements(*e);
  1442. }
  1443. return result;
  1444. }
  1445. case Expression::kConstructor_Kind: {
  1446. const Constructor& c = (const Constructor&) e;
  1447. Requirements result = kNo_Requirements;
  1448. for (const auto& e : c.fArguments) {
  1449. result |= this->requirements(*e);
  1450. }
  1451. return result;
  1452. }
  1453. case Expression::kFieldAccess_Kind: {
  1454. const FieldAccess& f = (const FieldAccess&) e;
  1455. if (FieldAccess::kAnonymousInterfaceBlock_OwnerKind == f.fOwnerKind) {
  1456. return kGlobals_Requirement;
  1457. }
  1458. return this->requirements(*((const FieldAccess&) e).fBase);
  1459. }
  1460. case Expression::kSwizzle_Kind:
  1461. return this->requirements(*((const Swizzle&) e).fBase);
  1462. case Expression::kBinary_Kind: {
  1463. const BinaryExpression& b = (const BinaryExpression&) e;
  1464. return this->requirements(*b.fLeft) | this->requirements(*b.fRight);
  1465. }
  1466. case Expression::kIndex_Kind: {
  1467. const IndexExpression& idx = (const IndexExpression&) e;
  1468. return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex);
  1469. }
  1470. case Expression::kPrefix_Kind:
  1471. return this->requirements(*((const PrefixExpression&) e).fOperand);
  1472. case Expression::kPostfix_Kind:
  1473. return this->requirements(*((const PostfixExpression&) e).fOperand);
  1474. case Expression::kTernary_Kind: {
  1475. const TernaryExpression& t = (const TernaryExpression&) e;
  1476. return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) |
  1477. this->requirements(*t.fIfFalse);
  1478. }
  1479. case Expression::kVariableReference_Kind: {
  1480. const VariableReference& v = (const VariableReference&) e;
  1481. Requirements result = kNo_Requirements;
  1482. if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
  1483. result = kInputs_Requirement;
  1484. } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) {
  1485. if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
  1486. result = kInputs_Requirement;
  1487. } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
  1488. result = kOutputs_Requirement;
  1489. } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag &&
  1490. v.fVariable.fType.kind() != Type::kSampler_Kind) {
  1491. result = kUniforms_Requirement;
  1492. } else {
  1493. result = kGlobals_Requirement;
  1494. }
  1495. }
  1496. return result;
  1497. }
  1498. default:
  1499. return kNo_Requirements;
  1500. }
  1501. }
  1502. MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) {
  1503. switch (s.fKind) {
  1504. case Statement::kBlock_Kind: {
  1505. Requirements result = kNo_Requirements;
  1506. for (const auto& child : ((const Block&) s).fStatements) {
  1507. result |= this->requirements(*child);
  1508. }
  1509. return result;
  1510. }
  1511. case Statement::kVarDeclaration_Kind: {
  1512. Requirements result = kNo_Requirements;
  1513. const VarDeclaration& var = (const VarDeclaration&) s;
  1514. if (var.fValue) {
  1515. result = this->requirements(*var.fValue);
  1516. }
  1517. return result;
  1518. }
  1519. case Statement::kVarDeclarations_Kind: {
  1520. Requirements result = kNo_Requirements;
  1521. const VarDeclarations& decls = *((const VarDeclarationsStatement&) s).fDeclaration;
  1522. for (const auto& stmt : decls.fVars) {
  1523. result |= this->requirements(*stmt);
  1524. }
  1525. return result;
  1526. }
  1527. case Statement::kExpression_Kind:
  1528. return this->requirements(*((const ExpressionStatement&) s).fExpression);
  1529. case Statement::kReturn_Kind: {
  1530. const ReturnStatement& r = (const ReturnStatement&) s;
  1531. if (r.fExpression) {
  1532. return this->requirements(*r.fExpression);
  1533. }
  1534. return kNo_Requirements;
  1535. }
  1536. case Statement::kIf_Kind: {
  1537. const IfStatement& i = (const IfStatement&) s;
  1538. return this->requirements(*i.fTest) |
  1539. this->requirements(*i.fIfTrue) |
  1540. (i.fIfFalse && this->requirements(*i.fIfFalse));
  1541. }
  1542. case Statement::kFor_Kind: {
  1543. const ForStatement& f = (const ForStatement&) s;
  1544. return this->requirements(*f.fInitializer) |
  1545. this->requirements(*f.fTest) |
  1546. this->requirements(*f.fNext) |
  1547. this->requirements(*f.fStatement);
  1548. }
  1549. case Statement::kWhile_Kind: {
  1550. const WhileStatement& w = (const WhileStatement&) s;
  1551. return this->requirements(*w.fTest) |
  1552. this->requirements(*w.fStatement);
  1553. }
  1554. case Statement::kDo_Kind: {
  1555. const DoStatement& d = (const DoStatement&) s;
  1556. return this->requirements(*d.fTest) |
  1557. this->requirements(*d.fStatement);
  1558. }
  1559. case Statement::kSwitch_Kind: {
  1560. const SwitchStatement& sw = (const SwitchStatement&) s;
  1561. Requirements result = this->requirements(*sw.fValue);
  1562. for (const auto& c : sw.fCases) {
  1563. for (const auto& st : c->fStatements) {
  1564. result |= this->requirements(*st);
  1565. }
  1566. }
  1567. return result;
  1568. }
  1569. default:
  1570. return kNo_Requirements;
  1571. }
  1572. }
  1573. MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
  1574. if (f.fBuiltin) {
  1575. return kNo_Requirements;
  1576. }
  1577. auto found = fRequirements.find(&f);
  1578. if (found == fRequirements.end()) {
  1579. fRequirements[&f] = kNo_Requirements;
  1580. for (const auto& e : fProgram) {
  1581. if (ProgramElement::kFunction_Kind == e.fKind) {
  1582. const FunctionDefinition& def = (const FunctionDefinition&) e;
  1583. if (&def.fDeclaration == &f) {
  1584. Requirements reqs = this->requirements(*def.fBody);
  1585. fRequirements[&f] = reqs;
  1586. return reqs;
  1587. }
  1588. }
  1589. }
  1590. }
  1591. return found->second;
  1592. }
  1593. bool MetalCodeGenerator::generateCode() {
  1594. OutputStream* rawOut = fOut;
  1595. fOut = &fHeader;
  1596. #ifdef SK_MOLTENVK
  1597. fOut->write((const char*) &MVKMagicNum, sizeof(MVKMagicNum));
  1598. #endif
  1599. fProgramKind = fProgram.fKind;
  1600. this->writeHeader();
  1601. this->writeUniformStruct();
  1602. this->writeInputStruct();
  1603. this->writeOutputStruct();
  1604. this->writeInterfaceBlocks();
  1605. this->writeGlobalStruct();
  1606. StringStream body;
  1607. fOut = &body;
  1608. for (const auto& e : fProgram) {
  1609. this->writeProgramElement(e);
  1610. }
  1611. fOut = rawOut;
  1612. write_stringstream(fHeader, *rawOut);
  1613. write_stringstream(fExtraFunctions, *rawOut);
  1614. write_stringstream(body, *rawOut);
  1615. #ifdef SK_MOLTENVK
  1616. this->write("\0");
  1617. #endif
  1618. return true;
  1619. }
  1620. }