12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931 |
- /*
- * Copyright 2018 Google Inc.
- *
- * Use of this source code is governed by a BSD-style license that can be
- * found in the LICENSE file.
- */
- #ifndef SKSL_STANDALONE
- #ifdef SK_LLVM_AVAILABLE
- #include "src/sksl/SkSLJIT.h"
- #include "src/core/SkCpu.h"
- #include "src/core/SkRasterPipeline.h"
- #include "src/sksl/ir/SkSLAppendStage.h"
- #include "src/sksl/ir/SkSLExpressionStatement.h"
- #include "src/sksl/ir/SkSLFunctionCall.h"
- #include "src/sksl/ir/SkSLFunctionReference.h"
- #include "src/sksl/ir/SkSLIndexExpression.h"
- #include "src/sksl/ir/SkSLProgram.h"
- #include "src/sksl/ir/SkSLUnresolvedFunction.h"
- #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
- static constexpr int MAX_VECTOR_COUNT = 16;
- extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
- p->append((SkRasterPipeline::StockStage) stage, ctx);
- }
- #define PTR_SIZE sizeof(void*)
- extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
- p->append(fn, nullptr);
- }
- extern "C" void sksl_debug_print(float f) {
- printf("Debug: %f\n", f);
- }
- extern "C" float sksl_clamp1(float f, float min, float max) {
- return SkTPin(f, min, max);
- }
- using float2 = __attribute__((vector_size(8))) float;
- using float3 = __attribute__((vector_size(16))) float;
- using float4 = __attribute__((vector_size(16))) float;
- extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
- return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
- }
- extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
- return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
- }
- extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
- return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
- SkTPin(f[3], min, max) };
- }
- namespace SkSL {
- static constexpr int STAGE_PARAM_COUNT = 12;
- static bool ends_with_branch(const Statement& stmt) {
- switch (stmt.fKind) {
- case Statement::kBlock_Kind: {
- const Block& b = (const Block&) stmt;
- if (b.fStatements.size()) {
- return ends_with_branch(*b.fStatements.back());
- }
- return false;
- }
- case Statement::kBreak_Kind: // fall through
- case Statement::kContinue_Kind: // fall through
- case Statement::kReturn_Kind: // fall through
- return true;
- default:
- return false;
- }
- }
- JIT::JIT(Compiler* compiler)
- : fCompiler(*compiler) {
- LLVMInitializeNativeTarget();
- LLVMInitializeNativeAsmPrinter();
- LLVMLinkInMCJIT();
- SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
- if (SkCpu::Supports(SkCpu::HSW)) {
- fVectorCount = 8;
- fCPU = "haswell";
- } else if (SkCpu::Supports(SkCpu::AVX)) {
- fVectorCount = 8;
- fCPU = "ivybridge";
- } else {
- fVectorCount = 4;
- fCPU = nullptr;
- }
- fContext = LLVMContextCreate();
- fVoidType = LLVMVoidTypeInContext(fContext);
- fInt1Type = LLVMInt1TypeInContext(fContext);
- fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
- fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
- fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
- fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
- fInt8Type = LLVMInt8TypeInContext(fContext);
- fInt8PtrType = LLVMPointerType(fInt8Type, 0);
- fInt32Type = LLVMInt32TypeInContext(fContext);
- fInt64Type = LLVMInt64TypeInContext(fContext);
- fSizeTType = LLVMInt64TypeInContext(fContext);
- fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
- fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
- fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
- fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
- fFloat32Type = LLVMFloatTypeInContext(fContext);
- fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
- fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
- fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
- fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
- }
- JIT::~JIT() {
- LLVMOrcDisposeInstance(fJITStack);
- LLVMContextDispose(fContext);
- }
- void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
- std::vector<LLVMTypeRef> parameters) {
- bool found = false;
- for (const auto& pair : *fProgram->fSymbols) {
- if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
- const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
- if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
- parameters.size() != f.fParameters.size()) {
- continue;
- }
- for (size_t i = 0; i < parameters.size(); ++i) {
- if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
- goto next;
- }
- }
- fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
- parameters.data(),
- parameters.size(),
- false));
- found = true;
- }
- if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
- // FIXME consolidate this with the code above
- for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
- if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
- parameters.size() != f->fParameters.size()) {
- continue;
- }
- for (size_t i = 0; i < parameters.size(); ++i) {
- if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
- goto next;
- }
- }
- fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
- returnType,
- parameters.data(),
- parameters.size(),
- false));
- found = true;
- }
- }
- next:;
- }
- SkASSERT(found);
- }
- void JIT::loadBuiltinFunctions() {
- this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
- this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
- this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
- this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
- this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
- fFloat32Type,
- fFloat32Type });
- this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
- }
- uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
- LLVMOrcTargetAddress result;
- if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
- if (!strcmp(name, "_sksl_pipeline_append")) {
- result = (uint64_t) &sksl_pipeline_append;
- } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
- result = (uint64_t) &sksl_pipeline_append_callback;
- } else if (!strcmp(name, "_sksl_clamp1")) {
- result = (uint64_t) &sksl_clamp1;
- } else if (!strcmp(name, "_sksl_clamp2")) {
- result = (uint64_t) &sksl_clamp2;
- } else if (!strcmp(name, "_sksl_clamp3")) {
- result = (uint64_t) &sksl_clamp3;
- } else if (!strcmp(name, "_sksl_clamp4")) {
- result = (uint64_t) &sksl_clamp4;
- } else if (!strcmp(name, "_sksl_debug_print")) {
- result = (uint64_t) &sksl_debug_print;
- } else {
- result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
- }
- }
- SkASSERT(result);
- return result;
- }
- LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
- LLVMValueRef func = fFunctions[&fc.fFunction];
- SkASSERT(func);
- std::vector<LLVMValueRef> parameters;
- for (const auto& a : fc.fArguments) {
- parameters.push_back(this->compileExpression(builder, *a));
- }
- return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
- }
- LLVMTypeRef JIT::getType(const Type& type) {
- switch (type.kind()) {
- case Type::kOther_Kind:
- if (type.name() == "void") {
- return fVoidType;
- }
- SkASSERT(type.name() == "SkRasterPipeline");
- return fInt8PtrType;
- case Type::kScalar_Kind:
- if (type.isSigned() || type.isUnsigned()) {
- return fInt32Type;
- }
- if (type.isUnsigned()) {
- return fInt32Type;
- }
- if (type.isFloat()) {
- return fFloat32Type;
- }
- SkASSERT(type.name() == "bool");
- return fInt1Type;
- case Type::kArray_Kind:
- return LLVMPointerType(this->getType(type.componentType()), 0);
- case Type::kVector_Kind:
- if (type.name() == "float2" || type.name() == "half2") {
- return fFloat32Vector2Type;
- }
- if (type.name() == "float3" || type.name() == "half3") {
- return fFloat32Vector3Type;
- }
- if (type.name() == "float4" || type.name() == "half4") {
- return fFloat32Vector4Type;
- }
- if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") {
- return fInt32Vector2Type;
- }
- if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") {
- return fInt32Vector3Type;
- }
- if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") {
- return fInt32Vector4Type;
- }
- // fall through
- default:
- ABORT("unsupported type");
- }
- }
- void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
- fCurrentBlock = block;
- LLVMPositionBuilderAtEnd(builder, block);
- }
- std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
- switch (expr.fKind) {
- case Expression::kVariableReference_Kind: {
- class PointerLValue : public LValue {
- public:
- PointerLValue(LLVMValueRef ptr)
- : fPointer(ptr) {}
- LLVMValueRef load(LLVMBuilderRef builder) override {
- return LLVMBuildLoad(builder, fPointer, "lvalue load");
- }
- void store(LLVMBuilderRef builder, LLVMValueRef value) override {
- LLVMBuildStore(builder, value, fPointer);
- }
- private:
- LLVMValueRef fPointer;
- };
- const Variable* var = &((VariableReference&) expr).fVariable;
- if (var->fStorage == Variable::kParameter_Storage &&
- !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
- fPromotedParameters.find(var) == fPromotedParameters.end()) {
- // promote parameter to variable
- fPromotedParameters.insert(var);
- LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
- LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
- String(var->fName).c_str());
- LLVMBuildStore(builder, fVariables[var], alloca);
- LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
- fVariables[var] = alloca;
- }
- LLVMValueRef ptr = fVariables[var];
- return std::unique_ptr<LValue>(new PointerLValue(ptr));
- }
- case Expression::kTernary_Kind: {
- class TernaryLValue : public LValue {
- public:
- TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
- std::unique_ptr<LValue> ifFalse)
- : fJIT(*jit)
- , fTest(test)
- , fIfTrue(std::move(ifTrue))
- , fIfFalse(std::move(ifFalse)) {}
- LLVMValueRef load(LLVMBuilderRef builder) override {
- LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
- fJIT.fContext,
- fJIT.fCurrentFunction,
- "true ? ...");
- LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
- fJIT.fContext,
- fJIT.fCurrentFunction,
- "false ? ...");
- LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
- fJIT.fCurrentFunction,
- "ternary merge");
- LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
- fJIT.setBlock(builder, trueBlock);
- LLVMValueRef ifTrue = fIfTrue->load(builder);
- LLVMBuildBr(builder, merge);
- fJIT.setBlock(builder, falseBlock);
- LLVMValueRef ifFalse = fIfTrue->load(builder);
- LLVMBuildBr(builder, merge);
- fJIT.setBlock(builder, merge);
- LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
- LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
- LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
- LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
- LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
- return phi;
- }
- void store(LLVMBuilderRef builder, LLVMValueRef value) override {
- LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
- fJIT.fContext,
- fJIT.fCurrentFunction,
- "true ? ...");
- LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
- fJIT.fContext,
- fJIT.fCurrentFunction,
- "false ? ...");
- LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
- fJIT.fCurrentFunction,
- "ternary merge");
- LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
- fJIT.setBlock(builder, trueBlock);
- fIfTrue->store(builder, value);
- LLVMBuildBr(builder, merge);
- fJIT.setBlock(builder, falseBlock);
- fIfTrue->store(builder, value);
- LLVMBuildBr(builder, merge);
- fJIT.setBlock(builder, merge);
- }
- private:
- JIT& fJIT;
- LLVMValueRef fTest;
- std::unique_ptr<LValue> fIfTrue;
- std::unique_ptr<LValue> fIfFalse;
- };
- const TernaryExpression& t = (const TernaryExpression&) expr;
- LLVMValueRef test = this->compileExpression(builder, *t.fTest);
- return std::unique_ptr<LValue>(new TernaryLValue(this,
- test,
- this->getLValue(builder,
- *t.fIfTrue),
- this->getLValue(builder,
- *t.fIfFalse)));
- }
- case Expression::kSwizzle_Kind: {
- class SwizzleLValue : public LValue {
- public:
- SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
- std::vector<int> components)
- : fJIT(*jit)
- , fType(type)
- , fBase(std::move(base))
- , fComponents(components) {}
- LLVMValueRef load(LLVMBuilderRef builder) override {
- LLVMValueRef base = fBase->load(builder);
- if (fComponents.size() > 1) {
- LLVMValueRef result = LLVMGetUndef(fType);
- for (size_t i = 0; i < fComponents.size(); ++i) {
- LLVMValueRef element = LLVMBuildExtractElement(
- builder,
- base,
- LLVMConstInt(fJIT.fInt32Type,
- fComponents[i],
- false),
- "swizzle extract");
- result = LLVMBuildInsertElement(builder, result, element,
- LLVMConstInt(fJIT.fInt32Type, i, false),
- "swizzle insert");
- }
- return result;
- }
- SkASSERT(fComponents.size() == 1);
- return LLVMBuildExtractElement(builder, base,
- LLVMConstInt(fJIT.fInt32Type,
- fComponents[0],
- false),
- "swizzle extract");
- }
- void store(LLVMBuilderRef builder, LLVMValueRef value) override {
- LLVMValueRef result = fBase->load(builder);
- if (fComponents.size() > 1) {
- for (size_t i = 0; i < fComponents.size(); ++i) {
- LLVMValueRef element = LLVMBuildExtractElement(builder, value,
- LLVMConstInt(
- fJIT.fInt32Type,
- i,
- false),
- "swizzle extract");
- result = LLVMBuildInsertElement(builder, result, element,
- LLVMConstInt(fJIT.fInt32Type,
- fComponents[i],
- false),
- "swizzle insert");
- }
- } else {
- result = LLVMBuildInsertElement(builder, result, value,
- LLVMConstInt(fJIT.fInt32Type,
- fComponents[0],
- false),
- "swizzle insert");
- }
- fBase->store(builder, result);
- }
- private:
- JIT& fJIT;
- LLVMTypeRef fType;
- std::unique_ptr<LValue> fBase;
- std::vector<int> fComponents;
- };
- const Swizzle& s = (const Swizzle&) expr;
- return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
- this->getLValue(builder, *s.fBase),
- s.fComponents));
- }
- default:
- ABORT("unsupported lvalue");
- }
- }
- JIT::TypeKind JIT::typeKind(const Type& type) {
- if (type.kind() == Type::kVector_Kind) {
- return this->typeKind(type.componentType());
- }
- if (type.fName == "int" || type.fName == "short" || type.fName == "byte") {
- return JIT::kInt_TypeKind;
- } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
- return JIT::kUInt_TypeKind;
- } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
- return JIT::kFloat_TypeKind;
- }
- ABORT("unsupported type: %s\n", type.description().c_str());
- }
- void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
- LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
- for (int i = 0; i < columns; ++i) {
- result = LLVMBuildInsertElement(builder,
- result,
- *value,
- LLVMConstInt(fInt32Type, i, false),
- "vectorize");
- }
- *value = result;
- }
- void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
- LLVMValueRef* right) {
- if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
- b.fRight->fType.kind() == Type::kVector_Kind) {
- this->vectorize(builder, left, b.fRight->fType.columns());
- } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
- b.fRight->fType.kind() == Type::kScalar_Kind) {
- this->vectorize(builder, right, b.fLeft->fType.columns());
- }
- }
- LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
- #define BINARY(SFunc, UFunc, FFunc) { \
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
- LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
- this->vectorize(builder, b, &left, &right); \
- switch (this->typeKind(b.fLeft->fType)) { \
- case kInt_TypeKind: \
- return SFunc(builder, left, right, "binary"); \
- case kUInt_TypeKind: \
- return UFunc(builder, left, right, "binary"); \
- case kFloat_TypeKind: \
- return FFunc(builder, left, right, "binary"); \
- default: \
- ABORT("unsupported typeKind"); \
- } \
- }
- #define COMPOUND(SFunc, UFunc, FFunc) { \
- std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
- LLVMValueRef left = lvalue->load(builder); \
- LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
- this->vectorize(builder, b, &left, &right); \
- LLVMValueRef result; \
- switch (this->typeKind(b.fLeft->fType)) { \
- case kInt_TypeKind: \
- result = SFunc(builder, left, right, "binary"); \
- break; \
- case kUInt_TypeKind: \
- result = UFunc(builder, left, right, "binary"); \
- break; \
- case kFloat_TypeKind: \
- result = FFunc(builder, left, right, "binary"); \
- break; \
- default: \
- ABORT("unsupported typeKind"); \
- } \
- lvalue->store(builder, result); \
- return result; \
- }
- #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
- LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
- this->vectorize(builder, b, &left, &right); \
- switch (this->typeKind(b.fLeft->fType)) { \
- case kInt_TypeKind: \
- return SFunc(builder, SOp, left, right, "binary"); \
- case kUInt_TypeKind: \
- return UFunc(builder, UOp, left, right, "binary"); \
- case kFloat_TypeKind: \
- return FFunc(builder, FOp, left, right, "binary"); \
- default: \
- ABORT("unsupported typeKind"); \
- } \
- }
- switch (b.fOperator) {
- case Token::EQ: {
- std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
- LLVMValueRef result = this->compileExpression(builder, *b.fRight);
- lvalue->store(builder, result);
- return result;
- }
- case Token::PLUS:
- BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
- case Token::MINUS:
- BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
- case Token::STAR:
- BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
- case Token::SLASH:
- BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
- case Token::PERCENT:
- BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
- case Token::BITWISEAND:
- BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
- case Token::BITWISEOR:
- BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
- case Token::SHL:
- BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
- case Token::SHR:
- BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
- case Token::PLUSEQ:
- COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
- case Token::MINUSEQ:
- COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
- case Token::STAREQ:
- COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
- case Token::SLASHEQ:
- COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
- case Token::BITWISEANDEQ:
- COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
- case Token::BITWISEOREQ:
- COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
- case Token::EQEQ:
- switch (b.fLeft->fType.kind()) {
- case Type::kScalar_Kind:
- COMPARE(LLVMBuildICmp, LLVMIntEQ,
- LLVMBuildICmp, LLVMIntEQ,
- LLVMBuildFCmp, LLVMRealOEQ);
- case Type::kVector_Kind: {
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
- LLVMValueRef right = this->compileExpression(builder, *b.fRight);
- this->vectorize(builder, b, &left, &right);
- LLVMValueRef value;
- switch (this->typeKind(b.fLeft->fType)) {
- case kInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
- break;
- case kUInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
- break;
- case kFloat_TypeKind:
- value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- LLVMValueRef args[1] = { value };
- LLVMValueRef func;
- switch (b.fLeft->fType.columns()) {
- case 2: func = fFoldAnd2Func; break;
- case 3: func = fFoldAnd3Func; break;
- case 4: func = fFoldAnd4Func; break;
- default:
- SkASSERT(false);
- func = fFoldAnd2Func;
- }
- return LLVMBuildCall(builder, func, args, 1, "all");
- }
- default:
- SkASSERT(false);
- }
- case Token::NEQ:
- switch (b.fLeft->fType.kind()) {
- case Type::kScalar_Kind:
- COMPARE(LLVMBuildICmp, LLVMIntNE,
- LLVMBuildICmp, LLVMIntNE,
- LLVMBuildFCmp, LLVMRealONE);
- case Type::kVector_Kind: {
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
- LLVMValueRef right = this->compileExpression(builder, *b.fRight);
- this->vectorize(builder, b, &left, &right);
- LLVMValueRef value;
- switch (this->typeKind(b.fLeft->fType)) {
- case kInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
- break;
- case kUInt_TypeKind:
- value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
- break;
- case kFloat_TypeKind:
- value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- LLVMValueRef args[1] = { value };
- LLVMValueRef func;
- switch (b.fLeft->fType.columns()) {
- case 2: func = fFoldOr2Func; break;
- case 3: func = fFoldOr3Func; break;
- case 4: func = fFoldOr4Func; break;
- default:
- SkASSERT(false);
- func = fFoldOr2Func;
- }
- return LLVMBuildCall(builder, func, args, 1, "all");
- }
- default:
- SkASSERT(false);
- }
- case Token::LT:
- COMPARE(LLVMBuildICmp, LLVMIntSLT,
- LLVMBuildICmp, LLVMIntULT,
- LLVMBuildFCmp, LLVMRealOLT);
- case Token::LTEQ:
- COMPARE(LLVMBuildICmp, LLVMIntSLE,
- LLVMBuildICmp, LLVMIntULE,
- LLVMBuildFCmp, LLVMRealOLE);
- case Token::GT:
- COMPARE(LLVMBuildICmp, LLVMIntSGT,
- LLVMBuildICmp, LLVMIntUGT,
- LLVMBuildFCmp, LLVMRealOGT);
- case Token::GTEQ:
- COMPARE(LLVMBuildICmp, LLVMIntSGE,
- LLVMBuildICmp, LLVMIntUGE,
- LLVMBuildFCmp, LLVMRealOGE);
- case Token::LOGICALAND: {
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
- LLVMBasicBlockRef ifFalse = fCurrentBlock;
- LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "true && ...");
- LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "&& merge");
- LLVMBuildCondBr(builder, left, ifTrue, merge);
- this->setBlock(builder, ifTrue);
- LLVMValueRef right = this->compileExpression(builder, *b.fRight);
- LLVMBuildBr(builder, merge);
- this->setBlock(builder, merge);
- LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
- LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
- LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
- LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
- return phi;
- }
- case Token::LOGICALOR: {
- LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
- LLVMBasicBlockRef ifTrue = fCurrentBlock;
- LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "false || ...");
- LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "|| merge");
- LLVMBuildCondBr(builder, left, merge, ifFalse);
- this->setBlock(builder, ifFalse);
- LLVMValueRef right = this->compileExpression(builder, *b.fRight);
- LLVMBuildBr(builder, merge);
- this->setBlock(builder, merge);
- LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
- LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
- LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
- LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
- return phi;
- }
- default:
- printf("%s\n", b.description().c_str());
- ABORT("unsupported binary operator");
- }
- }
- LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
- LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
- LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
- LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
- return LLVMBuildLoad(builder, ptr, "index load");
- }
- LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
- std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
- LLVMValueRef result = lvalue->load(builder);
- LLVMValueRef mod;
- LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
- switch (p.fOperator) {
- case Token::PLUSPLUS:
- switch (this->typeKind(p.fType)) {
- case kInt_TypeKind: // fall through
- case kUInt_TypeKind:
- mod = LLVMBuildAdd(builder, result, one, "++");
- break;
- case kFloat_TypeKind:
- mod = LLVMBuildFAdd(builder, result, one, "++");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- break;
- case Token::MINUSMINUS:
- switch (this->typeKind(p.fType)) {
- case kInt_TypeKind: // fall through
- case kUInt_TypeKind:
- mod = LLVMBuildSub(builder, result, one, "--");
- break;
- case kFloat_TypeKind:
- mod = LLVMBuildFSub(builder, result, one, "--");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- break;
- default:
- ABORT("unsupported postfix op");
- }
- lvalue->store(builder, mod);
- return result;
- }
- LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
- LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
- if (Token::LOGICALNOT == p.fOperator) {
- LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
- return LLVMBuildXor(builder, base, one, "!");
- }
- if (Token::MINUS == p.fOperator) {
- LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
- return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
- }
- std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
- LLVMValueRef raw = lvalue->load(builder);
- LLVMValueRef result;
- switch (p.fOperator) {
- case Token::PLUSPLUS:
- switch (this->typeKind(p.fType)) {
- case kInt_TypeKind: // fall through
- case kUInt_TypeKind:
- result = LLVMBuildAdd(builder, raw, one, "++");
- break;
- case kFloat_TypeKind:
- result = LLVMBuildFAdd(builder, raw, one, "++");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- break;
- case Token::MINUSMINUS:
- switch (this->typeKind(p.fType)) {
- case kInt_TypeKind: // fall through
- case kUInt_TypeKind:
- result = LLVMBuildSub(builder, raw, one, "--");
- break;
- case kFloat_TypeKind:
- result = LLVMBuildFSub(builder, raw, one, "--");
- break;
- default:
- ABORT("unsupported typeKind");
- }
- break;
- default:
- ABORT("unsupported prefix op");
- }
- lvalue->store(builder, result);
- return result;
- }
- LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
- const Variable& var = v.fVariable;
- if (Variable::kParameter_Storage == var.fStorage &&
- !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
- fPromotedParameters.find(&var) == fPromotedParameters.end()) {
- return fVariables[&var];
- }
- return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
- }
- void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
- SkASSERT(a.fArguments.size() >= 1);
- SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
- LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
- LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
- switch (a.fStage) {
- case SkRasterPipeline::callback: {
- SkASSERT(a.fArguments.size() == 2);
- SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
- const FunctionDeclaration& functionDecl =
- *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
- bool found = false;
- for (const auto& pe : *fProgram) {
- if (ProgramElement::kFunction_Kind == pe.fKind) {
- const FunctionDefinition& def = (const FunctionDefinition&) pe;
- if (&def.fDeclaration == &functionDecl) {
- LLVMValueRef fn = this->compileStageFunction(def);
- LLVMValueRef args[2] = {
- pipeline,
- LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
- };
- LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
- found = true;
- break;
- }
- }
- }
- SkASSERT(found);
- break;
- }
- default: {
- LLVMValueRef ctx;
- if (a.fArguments.size() == 2) {
- ctx = this->compileExpression(builder, *a.fArguments[1]);
- ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
- } else {
- SkASSERT(a.fArguments.size() == 1);
- ctx = LLVMConstNull(fInt8PtrType);
- }
- LLVMValueRef args[3] = {
- pipeline,
- stage,
- ctx
- };
- LLVMBuildCall(builder, fAppendFunc, args, 3, "");
- break;
- }
- }
- }
- LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
- switch (c.fType.kind()) {
- case Type::kScalar_Kind: {
- SkASSERT(c.fArguments.size() == 1);
- TypeKind from = this->typeKind(c.fArguments[0]->fType);
- TypeKind to = this->typeKind(c.fType);
- LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
- switch (to) {
- case kFloat_TypeKind:
- switch (from) {
- case kInt_TypeKind:
- return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
- case kUInt_TypeKind:
- return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
- case kFloat_TypeKind:
- return base;
- case kBool_TypeKind:
- SkASSERT(false);
- }
- case kInt_TypeKind:
- switch (from) {
- case kInt_TypeKind:
- return base;
- case kUInt_TypeKind:
- return base;
- case kFloat_TypeKind:
- return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
- case kBool_TypeKind:
- SkASSERT(false);
- }
- case kUInt_TypeKind:
- switch (from) {
- case kInt_TypeKind:
- return base;
- case kUInt_TypeKind:
- return base;
- case kFloat_TypeKind:
- return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
- case kBool_TypeKind:
- SkASSERT(false);
- }
- case kBool_TypeKind:
- SkASSERT(false);
- }
- }
- case Type::kVector_Kind: {
- LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
- if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
- LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
- for (int i = 0; i < c.fType.columns(); ++i) {
- vec = LLVMBuildInsertElement(builder, vec, value,
- LLVMConstInt(fInt32Type, i, false),
- "vec build 1");
- }
- } else {
- int index = 0;
- for (const auto& arg : c.fArguments) {
- LLVMValueRef value = this->compileExpression(builder, *arg);
- if (arg->fType.kind() == Type::kVector_Kind) {
- for (int i = 0; i < arg->fType.columns(); ++i) {
- LLVMValueRef column = LLVMBuildExtractElement(builder,
- vec,
- LLVMConstInt(fInt32Type,
- i,
- false),
- "construct extract");
- vec = LLVMBuildInsertElement(builder, vec, column,
- LLVMConstInt(fInt32Type, index++, false),
- "vec build 2");
- }
- } else {
- vec = LLVMBuildInsertElement(builder, vec, value,
- LLVMConstInt(fInt32Type, index++, false),
- "vec build 3");
- }
- }
- }
- return vec;
- }
- default:
- break;
- }
- ABORT("unsupported constructor");
- }
- LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
- LLVMValueRef base = this->compileExpression(builder, *s.fBase);
- if (s.fComponents.size() > 1) {
- LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
- for (size_t i = 0; i < s.fComponents.size(); ++i) {
- LLVMValueRef element = LLVMBuildExtractElement(
- builder,
- base,
- LLVMConstInt(fInt32Type,
- s.fComponents[i],
- false),
- "swizzle extract");
- result = LLVMBuildInsertElement(builder, result, element,
- LLVMConstInt(fInt32Type, i, false),
- "swizzle insert");
- }
- return result;
- }
- SkASSERT(s.fComponents.size() == 1);
- return LLVMBuildExtractElement(builder, base,
- LLVMConstInt(fInt32Type,
- s.fComponents[0],
- false),
- "swizzle extract");
- }
- LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
- LLVMValueRef test = this->compileExpression(builder, *t.fTest);
- LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "if true");
- LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "if merge");
- LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "if false");
- LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
- this->setBlock(builder, trueBlock);
- LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
- trueBlock = fCurrentBlock;
- LLVMBuildBr(builder, merge);
- this->setBlock(builder, falseBlock);
- LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
- falseBlock = fCurrentBlock;
- LLVMBuildBr(builder, merge);
- this->setBlock(builder, merge);
- LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
- LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
- LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
- LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
- return phi;
- }
- LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
- switch (expr.fKind) {
- case Expression::kAppendStage_Kind: {
- this->appendStage(builder, (const AppendStage&) expr);
- return LLVMValueRef();
- }
- case Expression::kBinary_Kind:
- return this->compileBinary(builder, (BinaryExpression&) expr);
- case Expression::kBoolLiteral_Kind:
- return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
- case Expression::kConstructor_Kind:
- return this->compileConstructor(builder, (Constructor&) expr);
- case Expression::kIntLiteral_Kind:
- return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
- case Expression::kFieldAccess_Kind:
- abort();
- case Expression::kFloatLiteral_Kind:
- return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
- case Expression::kFunctionCall_Kind:
- return this->compileFunctionCall(builder, (FunctionCall&) expr);
- case Expression::kIndex_Kind:
- return this->compileIndex(builder, (IndexExpression&) expr);
- case Expression::kPrefix_Kind:
- return this->compilePrefix(builder, (PrefixExpression&) expr);
- case Expression::kPostfix_Kind:
- return this->compilePostfix(builder, (PostfixExpression&) expr);
- case Expression::kSetting_Kind:
- abort();
- case Expression::kSwizzle_Kind:
- return this->compileSwizzle(builder, (Swizzle&) expr);
- case Expression::kVariableReference_Kind:
- return this->compileVariableReference(builder, (VariableReference&) expr);
- case Expression::kTernary_Kind:
- return this->compileTernary(builder, (TernaryExpression&) expr);
- case Expression::kTypeReference_Kind:
- abort();
- default:
- abort();
- }
- ABORT("unsupported expression: %s\n", expr.description().c_str());
- }
- void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
- for (const auto& stmt : block.fStatements) {
- this->compileStatement(builder, *stmt);
- }
- }
- void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
- for (const auto& declStatement : decls.fDeclaration->fVars) {
- const VarDeclaration& decl = (VarDeclaration&) *declStatement;
- LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
- LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
- String(decl.fVar->fName).c_str());
- fVariables[decl.fVar] = alloca;
- LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
- if (decl.fValue) {
- LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
- LLVMBuildStore(builder, result, alloca);
- }
- }
- }
- void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
- LLVMValueRef test = this->compileExpression(builder, *i.fTest);
- LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
- LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "if merge");
- LLVMBasicBlockRef ifFalse;
- if (i.fIfFalse) {
- ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
- } else {
- ifFalse = merge;
- }
- LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
- this->setBlock(builder, ifTrue);
- this->compileStatement(builder, *i.fIfTrue);
- if (!ends_with_branch(*i.fIfTrue)) {
- LLVMBuildBr(builder, merge);
- }
- if (i.fIfFalse) {
- this->setBlock(builder, ifFalse);
- this->compileStatement(builder, *i.fIfFalse);
- if (!ends_with_branch(*i.fIfFalse)) {
- LLVMBuildBr(builder, merge);
- }
- }
- this->setBlock(builder, merge);
- }
- void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
- if (f.fInitializer) {
- this->compileStatement(builder, *f.fInitializer);
- }
- LLVMBasicBlockRef start;
- LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
- LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
- LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
- if (f.fTest) {
- start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
- LLVMBuildBr(builder, start);
- this->setBlock(builder, start);
- LLVMValueRef test = this->compileExpression(builder, *f.fTest);
- LLVMBuildCondBr(builder, test, body, end);
- } else {
- start = body;
- LLVMBuildBr(builder, body);
- }
- this->setBlock(builder, body);
- fBreakTarget.push_back(end);
- fContinueTarget.push_back(next);
- this->compileStatement(builder, *f.fStatement);
- fBreakTarget.pop_back();
- fContinueTarget.pop_back();
- if (!ends_with_branch(*f.fStatement)) {
- LLVMBuildBr(builder, next);
- }
- this->setBlock(builder, next);
- if (f.fNext) {
- this->compileExpression(builder, *f.fNext);
- }
- LLVMBuildBr(builder, start);
- this->setBlock(builder, end);
- }
- void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
- LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "do test");
- LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "do body");
- LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "do end");
- LLVMBuildBr(builder, body);
- this->setBlock(builder, testBlock);
- LLVMValueRef test = this->compileExpression(builder, *d.fTest);
- LLVMBuildCondBr(builder, test, body, end);
- this->setBlock(builder, body);
- fBreakTarget.push_back(end);
- fContinueTarget.push_back(body);
- this->compileStatement(builder, *d.fStatement);
- fBreakTarget.pop_back();
- fContinueTarget.pop_back();
- if (!ends_with_branch(*d.fStatement)) {
- LLVMBuildBr(builder, testBlock);
- }
- this->setBlock(builder, end);
- }
- void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
- LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "while test");
- LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "while body");
- LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
- "while end");
- LLVMBuildBr(builder, testBlock);
- this->setBlock(builder, testBlock);
- LLVMValueRef test = this->compileExpression(builder, *w.fTest);
- LLVMBuildCondBr(builder, test, body, end);
- this->setBlock(builder, body);
- fBreakTarget.push_back(end);
- fContinueTarget.push_back(testBlock);
- this->compileStatement(builder, *w.fStatement);
- fBreakTarget.pop_back();
- fContinueTarget.pop_back();
- if (!ends_with_branch(*w.fStatement)) {
- LLVMBuildBr(builder, testBlock);
- }
- this->setBlock(builder, end);
- }
- void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
- LLVMBuildBr(builder, fBreakTarget.back());
- }
- void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
- LLVMBuildBr(builder, fContinueTarget.back());
- }
- void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
- if (r.fExpression) {
- LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
- } else {
- LLVMBuildRetVoid(builder);
- }
- }
- void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
- switch (stmt.fKind) {
- case Statement::kBlock_Kind:
- this->compileBlock(builder, (Block&) stmt);
- break;
- case Statement::kBreak_Kind:
- this->compileBreak(builder, (BreakStatement&) stmt);
- break;
- case Statement::kContinue_Kind:
- this->compileContinue(builder, (ContinueStatement&) stmt);
- break;
- case Statement::kDiscard_Kind:
- abort();
- case Statement::kDo_Kind:
- this->compileDo(builder, (DoStatement&) stmt);
- break;
- case Statement::kExpression_Kind:
- this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
- break;
- case Statement::kFor_Kind:
- this->compileFor(builder, (ForStatement&) stmt);
- break;
- case Statement::kGroup_Kind:
- abort();
- case Statement::kIf_Kind:
- this->compileIf(builder, (IfStatement&) stmt);
- break;
- case Statement::kNop_Kind:
- break;
- case Statement::kReturn_Kind:
- this->compileReturn(builder, (ReturnStatement&) stmt);
- break;
- case Statement::kSwitch_Kind:
- abort();
- case Statement::kVarDeclarations_Kind:
- this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
- break;
- case Statement::kWhile_Kind:
- this->compileWhile(builder, (WhileStatement&) stmt);
- break;
- default:
- abort();
- }
- }
- void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
- // loop over fVectorCount pixels, running the body of the stage function for each of them
- LLVMValueRef oldFunction = fCurrentFunction;
- fCurrentFunction = newFunc;
- std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
- LLVMGetParams(fCurrentFunction, params.get());
- LLVMValueRef programParam = params.get()[1];
- LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
- LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
- LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
- fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
- this->setBlock(builder, fAllocaBlock);
- // temporaries to store the color channel vectors
- LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
- LLVMBuildStore(builder, params.get()[4], rVec);
- LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
- LLVMBuildStore(builder, params.get()[5], gVec);
- LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
- LLVMBuildStore(builder, params.get()[6], bVec);
- LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
- LLVMBuildStore(builder, params.get()[7], aVec);
- LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
- fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
- "y->Int32");
- fVariables[f.fDeclaration.fParameters[2]] = color;
- LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
- LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
- LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
- this->setBlock(builder, start);
- LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
- fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
- LLVMBuildTrunc(builder,
- params.get()[2],
- fInt32Type,
- "x->Int32"),
- iload,
- "x");
- LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
- LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
- LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
- LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
- LLVMBuildCondBr(builder, test, loopBody, loopEnd);
- this->setBlock(builder, loopBody);
- LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
- // extract the r, g, b, and a values from the color channel vectors and store them into "color"
- for (int i = 0; i < 4; ++i) {
- vec = LLVMBuildInsertElement(builder, vec,
- LLVMBuildExtractElement(builder,
- params.get()[4 + i],
- iload, "initial"),
- LLVMConstInt(fInt32Type, i, false),
- "vec build");
- }
- LLVMBuildStore(builder, vec, color);
- // write actual loop body
- this->compileStatement(builder, *f.fBody);
- // extract the r, g, b, and a values from "color" and stick them back into the color channel
- // vectors
- LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
- LLVMBuildStore(builder,
- LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
- LLVMBuildExtractElement(builder, colorLoad,
- LLVMConstInt(fInt32Type, 0,
- false),
- "rExtract"),
- iload, "rInsert"),
- rVec);
- LLVMBuildStore(builder,
- LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
- LLVMBuildExtractElement(builder, colorLoad,
- LLVMConstInt(fInt32Type, 1,
- false),
- "gExtract"),
- iload, "gInsert"),
- gVec);
- LLVMBuildStore(builder,
- LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
- LLVMBuildExtractElement(builder, colorLoad,
- LLVMConstInt(fInt32Type, 2,
- false),
- "bExtract"),
- iload, "bInsert"),
- bVec);
- LLVMBuildStore(builder,
- LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
- LLVMBuildExtractElement(builder, colorLoad,
- LLVMConstInt(fInt32Type, 3,
- false),
- "aExtract"),
- iload, "aInsert"),
- aVec);
- LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
- LLVMBuildStore(builder, inc, ivar);
- LLVMBuildBr(builder, start);
- this->setBlock(builder, loopEnd);
- // increment program pointer, call the next stage
- LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
- LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
- LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
- LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
- LLVMBuildAdd(builder,
- LLVMBuildPtrToInt(builder,
- programParam,
- fInt64Type,
- "cast 1"),
- LLVMConstInt(fInt64Type, PTR_SIZE, false),
- "add"),
- LLVMPointerType(fInt8PtrType, 0), "cast 2");
- LLVMValueRef args[STAGE_PARAM_COUNT] = {
- params.get()[0],
- nextInc,
- params.get()[2],
- params.get()[3],
- LLVMBuildLoad(builder, rVec, "rVec"),
- LLVMBuildLoad(builder, gVec, "gVec"),
- LLVMBuildLoad(builder, bVec, "bVec"),
- LLVMBuildLoad(builder, aVec, "aVec"),
- params.get()[8],
- params.get()[9],
- params.get()[10],
- params.get()[11]
- };
- LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
- LLVMBuildRetVoid(builder);
- // finish
- LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
- LLVMBuildBr(builder, start);
- LLVMDisposeBuilder(builder);
- if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
- ABORT("verify failed\n");
- }
- fAllocaBlock = oldAllocaBlock;
- fCurrentBlock = oldCurrentBlock;
- fCurrentFunction = oldFunction;
- }
- // FIXME maybe pluggable code generators? Need to do something to separate all
- // of the normal codegen from the vector codegen and break this up into multiple
- // classes.
- bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
- LLVMValueRef out[CHANNELS]) {
- switch (e.fKind) {
- case Expression::kVariableReference_Kind:
- if (fColorParam == &((VariableReference&) e).fVariable) {
- memcpy(out, fChannels, sizeof(fChannels));
- return true;
- }
- return false;
- case Expression::kSwizzle_Kind: {
- const Swizzle& s = (const Swizzle&) e;
- LLVMValueRef base[CHANNELS];
- if (!this->getVectorLValue(builder, *s.fBase, base)) {
- return false;
- }
- for (size_t i = 0; i < s.fComponents.size(); ++i) {
- out[i] = base[s.fComponents[i]];
- }
- return true;
- }
- default:
- return false;
- }
- }
- bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
- LLVMValueRef outLeft[CHANNELS], const Expression& right,
- LLVMValueRef outRight[CHANNELS]) {
- if (!this->compileVectorExpression(builder, left, outLeft)) {
- return false;
- }
- int leftColumns = left.fType.columns();
- int rightColumns = right.fType.columns();
- if (leftColumns == 1 && rightColumns > 1) {
- for (int i = 1; i < rightColumns; ++i) {
- outLeft[i] = outLeft[0];
- }
- }
- if (!this->compileVectorExpression(builder, right, outRight)) {
- return false;
- }
- if (rightColumns == 1 && leftColumns > 1) {
- for (int i = 1; i < leftColumns; ++i) {
- outRight[i] = outRight[0];
- }
- }
- return true;
- }
- bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
- LLVMValueRef out[CHANNELS]) {
- LLVMValueRef left[CHANNELS];
- LLVMValueRef right[CHANNELS];
- #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
- if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
- return false; \
- } \
- for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
- switch (this->typeKind(b.fLeft->fType)) { \
- case kInt_TypeKind: \
- out[i] = signedOp(builder, left[i], right[i], "binary"); \
- break; \
- case kUInt_TypeKind: \
- out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
- break; \
- case kFloat_TypeKind: \
- out[i] = floatOp(builder, left[i], right[i], "binary"); \
- break; \
- case kBool_TypeKind: \
- SkASSERT(false); \
- break; \
- } \
- } \
- return true; \
- }
- switch (b.fOperator) {
- case Token::EQ: {
- if (!this->getVectorLValue(builder, *b.fLeft, left)) {
- return false;
- }
- if (!this->compileVectorExpression(builder, *b.fRight, right)) {
- return false;
- }
- int columns = b.fRight->fType.columns();
- for (int i = 0; i < columns; ++i) {
- LLVMBuildStore(builder, right[i], left[i]);
- }
- return true;
- }
- case Token::PLUS:
- VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
- case Token::MINUS:
- VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
- case Token::STAR:
- VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
- case Token::SLASH:
- VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
- case Token::PERCENT:
- VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
- case Token::BITWISEAND:
- VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
- case Token::BITWISEOR:
- VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
- default:
- printf("unsupported operator: %s\n", b.description().c_str());
- return false;
- }
- }
- bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
- LLVMValueRef out[CHANNELS]) {
- switch (c.fType.kind()) {
- case Type::kScalar_Kind: {
- SkASSERT(c.fArguments.size() == 1);
- TypeKind from = this->typeKind(c.fArguments[0]->fType);
- TypeKind to = this->typeKind(c.fType);
- LLVMValueRef base[CHANNELS];
- if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
- return false;
- }
- #define CONSTRUCT(fn) \
- out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
- for (int i = 0; i < fVectorCount; ++i) { \
- LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
- LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
- "construct extract"); \
- out[0] = LLVMBuildInsertElement(builder, out[0], \
- fn(builder, baseVal, this->getType(c.fType), \
- "cast"), \
- index, "construct insert"); \
- } \
- return true;
- if (kFloat_TypeKind == to) {
- if (kInt_TypeKind == from) {
- CONSTRUCT(LLVMBuildSIToFP);
- }
- if (kUInt_TypeKind == from) {
- CONSTRUCT(LLVMBuildUIToFP);
- }
- }
- if (kInt_TypeKind == to) {
- if (kFloat_TypeKind == from) {
- CONSTRUCT(LLVMBuildFPToSI);
- }
- if (kUInt_TypeKind == from) {
- return true;
- }
- }
- if (kUInt_TypeKind == to) {
- if (kFloat_TypeKind == from) {
- CONSTRUCT(LLVMBuildFPToUI);
- }
- if (kInt_TypeKind == from) {
- return base;
- }
- }
- printf("%s\n", c.description().c_str());
- ABORT("unsupported constructor");
- }
- case Type::kVector_Kind: {
- if (c.fArguments.size() == 1) {
- LLVMValueRef base[CHANNELS];
- if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
- return false;
- }
- for (int i = 0; i < c.fType.columns(); ++i) {
- out[i] = base[0];
- }
- } else {
- SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
- for (int i = 0; i < c.fType.columns(); ++i) {
- LLVMValueRef base[CHANNELS];
- if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
- return false;
- }
- out[i] = base[0];
- }
- }
- return true;
- }
- default:
- break;
- }
- ABORT("unsupported constructor");
- }
- bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
- const FloatLiteral& f,
- LLVMValueRef out[CHANNELS]) {
- LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
- LLVMValueRef values[MAX_VECTOR_COUNT];
- for (int i = 0; i < fVectorCount; ++i) {
- values[i] = value;
- }
- out[0] = LLVMConstVector(values, fVectorCount);
- return true;
- }
- bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
- LLVMValueRef out[CHANNELS]) {
- LLVMValueRef all[CHANNELS];
- if (!this->compileVectorExpression(builder, *s.fBase, all)) {
- return false;
- }
- for (size_t i = 0; i < s.fComponents.size(); ++i) {
- out[i] = all[s.fComponents[i]];
- }
- return true;
- }
- bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
- LLVMValueRef out[CHANNELS]) {
- if (&v.fVariable == fColorParam) {
- for (int i = 0; i < CHANNELS; ++i) {
- out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
- }
- return true;
- }
- return false;
- }
- bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
- LLVMValueRef out[CHANNELS]) {
- switch (expr.fKind) {
- case Expression::kBinary_Kind:
- return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
- case Expression::kConstructor_Kind:
- return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
- case Expression::kFloatLiteral_Kind:
- return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
- case Expression::kSwizzle_Kind:
- return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
- case Expression::kVariableReference_Kind:
- return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
- out);
- default:
- return false;
- }
- }
- bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
- switch (stmt.fKind) {
- case Statement::kBlock_Kind:
- for (const auto& s : ((const Block&) stmt).fStatements) {
- if (!this->compileVectorStatement(builder, *s)) {
- return false;
- }
- }
- return true;
- case Statement::kExpression_Kind:
- LLVMValueRef result;
- return this->compileVectorExpression(builder,
- *((const ExpressionStatement&) stmt).fExpression,
- &result);
- default:
- return false;
- }
- }
- bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
- LLVMValueRef oldFunction = fCurrentFunction;
- fCurrentFunction = newFunc;
- std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
- LLVMGetParams(fCurrentFunction, params.get());
- LLVMValueRef programParam = params.get()[1];
- LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
- LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
- LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
- fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
- this->setBlock(builder, fAllocaBlock);
- fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
- LLVMBuildStore(builder, params.get()[4], fChannels[0]);
- fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
- LLVMBuildStore(builder, params.get()[5], fChannels[1]);
- fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
- LLVMBuildStore(builder, params.get()[6], fChannels[2]);
- fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
- LLVMBuildStore(builder, params.get()[7], fChannels[3]);
- LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
- this->setBlock(builder, start);
- bool success = this->compileVectorStatement(builder, *f.fBody);
- if (success) {
- // increment program pointer, call next
- LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
- LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
- LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
- "cast next->func");
- LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
- LLVMBuildAdd(builder,
- LLVMBuildPtrToInt(builder,
- programParam,
- fInt64Type,
- "cast 1"),
- LLVMConstInt(fInt64Type, PTR_SIZE,
- false),
- "add"),
- LLVMPointerType(fInt8PtrType, 0), "cast 2");
- LLVMValueRef args[STAGE_PARAM_COUNT] = {
- params.get()[0],
- nextInc,
- params.get()[2],
- params.get()[3],
- LLVMBuildLoad(builder, fChannels[0], "rVec"),
- LLVMBuildLoad(builder, fChannels[1], "gVec"),
- LLVMBuildLoad(builder, fChannels[2], "bVec"),
- LLVMBuildLoad(builder, fChannels[3], "aVec"),
- params.get()[8],
- params.get()[9],
- params.get()[10],
- params.get()[11]
- };
- LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
- LLVMBuildRetVoid(builder);
- // finish
- LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
- LLVMBuildBr(builder, start);
- LLVMDisposeBuilder(builder);
- if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
- ABORT("verify failed\n");
- }
- } else {
- LLVMDeleteBasicBlock(fAllocaBlock);
- LLVMDeleteBasicBlock(start);
- }
- fAllocaBlock = oldAllocaBlock;
- fCurrentBlock = oldCurrentBlock;
- fCurrentFunction = oldFunction;
- return success;
- }
- LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
- LLVMTypeRef returnType = fVoidType;
- LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
- fSizeTType, fFloat32VectorType, fFloat32VectorType,
- fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
- fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
- LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
- LLVMValueRef result = LLVMAddFunction(fModule,
- (String(f.fDeclaration.fName) + "$stage").c_str(),
- stageFuncType);
- fColorParam = f.fDeclaration.fParameters[2];
- if (!this->compileStageFunctionVector(f, result)) {
- // vectorization failed, fall back to looping over the pixels
- this->compileStageFunctionLoop(f, result);
- }
- return result;
- }
- bool JIT::hasStageSignature(const FunctionDeclaration& f) {
- return f.fReturnType == *fProgram->fContext->fVoid_Type &&
- f.fParameters.size() == 3 &&
- f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
- f.fParameters[0]->fModifiers.fFlags == 0 &&
- f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
- f.fParameters[1]->fModifiers.fFlags == 0 &&
- f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
- f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
- }
- LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
- if (this->hasStageSignature(f.fDeclaration)) {
- this->compileStageFunction(f);
- // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
- // was to produce an SkJumper stage just because the signature matched or that the function
- // is not otherwise called. May need a better way to handle this.
- }
- LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
- std::vector<LLVMTypeRef> parameterTypes;
- for (const auto& p : f.fDeclaration.fParameters) {
- LLVMTypeRef type = this->getType(p->fType);
- if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
- type = LLVMPointerType(type, 0);
- }
- parameterTypes.push_back(type);
- }
- fCurrentFunction = LLVMAddFunction(fModule,
- String(f.fDeclaration.fName).c_str(),
- LLVMFunctionType(returnType, parameterTypes.data(),
- parameterTypes.size(), false));
- fFunctions[&f.fDeclaration] = fCurrentFunction;
- std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
- LLVMGetParams(fCurrentFunction, params.get());
- for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
- fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
- }
- LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
- fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
- LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
- fCurrentBlock = start;
- LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
- this->compileStatement(builder, *f.fBody);
- if (!ends_with_branch(*f.fBody)) {
- if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
- LLVMBuildRetVoid(builder);
- } else {
- LLVMBuildUnreachable(builder);
- }
- }
- LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
- LLVMBuildBr(builder, start);
- LLVMDisposeBuilder(builder);
- if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
- ABORT("verify failed\n");
- }
- return fCurrentFunction;
- }
- void JIT::createModule() {
- fPromotedParameters.clear();
- fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
- this->loadBuiltinFunctions();
- LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
- fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
- LLVMFunctionType(fInt1Type, fold2Params, 1, false));
- fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
- LLVMFunctionType(fInt1Type, fold2Params, 1, false));
- LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
- fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
- LLVMFunctionType(fInt1Type, fold3Params, 1, false));
- fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
- LLVMFunctionType(fInt1Type, fold3Params, 1, false));
- LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
- fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
- LLVMFunctionType(fInt1Type, fold4Params, 1, false));
- fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
- LLVMFunctionType(fInt1Type, fold4Params, 1, false));
- // LLVM doesn't do void*, have to declare it as int8*
- LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
- fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
- appendParams,
- 3,
- false));
- LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
- fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
- LLVMFunctionType(fVoidType, appendCallbackParams, 2,
- false));
- LLVMTypeRef debugParams[3] = { fFloat32Type };
- fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
- debugParams,
- 1,
- false));
- for (const auto& e : *fProgram) {
- if (e.fKind == ProgramElement::kFunction_Kind) {
- this->compileFunction((FunctionDefinition&) e);
- }
- }
- }
- std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
- fCompiler.optimize(*program);
- fProgram = std::move(program);
- this->createModule();
- this->optimize();
- return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
- }
- void JIT::optimize() {
- LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
- LLVMPassManagerBuilderSetOptLevel(pmb, 3);
- LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
- LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
- LLVMPassManagerRef modulePM = LLVMCreatePassManager();
- LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
- LLVMInitializeFunctionPassManager(functionPM);
- LLVMValueRef func = LLVMGetFirstFunction(fModule);
- for (;;) {
- if (!func) {
- break;
- }
- LLVMRunFunctionPassManager(functionPM, func);
- func = LLVMGetNextFunction(func);
- }
- LLVMRunPassManager(modulePM, fModule);
- LLVMDisposePassManager(functionPM);
- LLVMDisposePassManager(modulePM);
- LLVMPassManagerBuilderDispose(pmb);
- std::string error_string;
- if (LLVMLoadLibraryPermanently(nullptr)) {
- ABORT("LLVMLoadLibraryPermanently failed");
- }
- char* defaultTriple = LLVMGetDefaultTargetTriple();
- char* error;
- LLVMTargetRef target;
- if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
- ABORT("LLVMGetTargetFromTriple failed");
- }
- if (!LLVMTargetHasJIT(target)) {
- ABORT("!LLVMTargetHasJIT");
- }
- LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
- defaultTriple,
- fCPU,
- nullptr,
- LLVMCodeGenLevelDefault,
- LLVMRelocDefault,
- LLVMCodeModelJITDefault);
- LLVMDisposeMessage(defaultTriple);
- LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
- LLVMSetModuleDataLayout(fModule, dataLayout);
- LLVMDisposeTargetData(dataLayout);
- fJITStack = LLVMOrcCreateInstance(targetMachine);
- fSharedModule = LLVMOrcMakeSharedModule(fModule);
- LLVMOrcModuleHandle orcModule;
- LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
- (LLVMOrcSymbolResolverFn) resolveSymbol, this);
- LLVMDisposeTargetMachine(targetMachine);
- }
- void* JIT::Module::getSymbol(const char* name) {
- LLVMOrcTargetAddress result;
- if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
- ABORT("GetSymbolAddress error");
- }
- if (!result) {
- ABORT("symbol not found");
- }
- return (void*) result;
- }
- void* JIT::Module::getJumperStage(const char* name) {
- return this->getSymbol((String(name) + "$stage").c_str());
- }
- } // namespace
- #endif // SK_LLVM_AVAILABLE
- #endif // SKSL_STANDALONE
|