| /* |
| * Copyright 2016, The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "RSSPIRVWriter.h" |
| |
| #include "SPIRVModule.h" |
| #include "bcinfo/MetadataExtractor.h" |
| |
| #include "llvm/ADT/StringMap.h" |
| #include "llvm/ADT/Triple.h" |
| #include "llvm/IR/LegacyPassManager.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/SPIRV.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "llvm/Transforms/IPO.h" |
| |
| #include "GlobalMergePass.h" |
| #include "InlinePreparationPass.h" |
| #include "LinkerModule.h" |
| #include "ReflectionPass.h" |
| #include "RemoveNonkernelsPass.h" |
| |
| #include <fstream> |
| #include <sstream> |
| |
| #define DEBUG_TYPE "rs2spirv-writer" |
| |
| using namespace llvm; |
| using namespace SPIRV; |
| |
| namespace llvm { |
| FunctionPass *createPromoteMemoryToRegisterPass(); |
| } |
| |
| namespace rs2spirv { |
| |
| static cl::opt<std::string> WrapperOutputFile("wo", |
| cl::desc("Wrapper output file"), |
| cl::value_desc("filename.spt")); |
| |
| static bool FixMain(LinkerModule &LM, MainFunBlock &MB, StringRef KernelName); |
| static bool InlineFunctionCalls(LinkerModule &LM, MainFunBlock &MB); |
| static bool FuseTypesAndConstants(LinkerModule &LM); |
| static bool TranslateInBoundsPtrAccessToAccess(SPIRVLine &L); |
| static bool FixVectorShuffles(MainFunBlock &MB); |
| static void FixModuleStorageClass(LinkerModule &M); |
| |
| static void HandleTargetTriple(Module &M) { |
| Triple TT(M.getTargetTriple()); |
| auto Arch = TT.getArch(); |
| |
| StringRef NewTriple; |
| switch (Arch) { |
| default: |
| llvm_unreachable("Unrecognized architecture"); |
| break; |
| case Triple::arm: |
| NewTriple = "spir-unknown-unknown"; |
| break; |
| case Triple::aarch64: |
| NewTriple = "spir64-unknown-unknown"; |
| break; |
| case Triple::spir: |
| case Triple::spir64: |
| DEBUG(dbgs() << "!!! Already a spir triple !!!\n"); |
| } |
| |
| DEBUG(dbgs() << "New triple:\t" << NewTriple << "\n"); |
| M.setTargetTriple(NewTriple); |
| } |
| |
| void addPassesForRS2SPIRV(llvm::legacy::PassManager &PassMgr, |
| bcinfo::MetadataExtractor &Extractor) { |
| PassMgr.add(createInlinePreparationPass(Extractor)); |
| PassMgr.add(createAlwaysInlinerPass()); |
| PassMgr.add(createRemoveNonkernelsPass(Extractor)); |
| // Delete unreachable globals. |
| PassMgr.add(createGlobalDCEPass()); |
| // Remove dead debug info. |
| PassMgr.add(createStripDeadDebugInfoPass()); |
| // Remove dead func decls. |
| PassMgr.add(createStripDeadPrototypesPass()); |
| PassMgr.add(createGlobalMergePass()); |
| PassMgr.add(createPromoteMemoryToRegisterPass()); |
| PassMgr.add(createTransOCLMD()); |
| // TODO: investigate removal of OCLTypeToSPIRV pass. |
| PassMgr.add(createOCLTypeToSPIRV()); |
| PassMgr.add(createSPIRVRegularizeLLVM()); |
| PassMgr.add(createSPIRVLowerConstExpr()); |
| PassMgr.add(createSPIRVLowerBool()); |
| } |
| |
| bool WriteSPIRV(Module *M, llvm::raw_ostream &OS, std::string &ErrMsg) { |
| std::unique_ptr<SPIRVModule> BM(SPIRVModule::createSPIRVModule()); |
| |
| HandleTargetTriple(*M); |
| |
| bcinfo::MetadataExtractor ME(M); |
| if (!ME.extract()) { |
| errs() << "Could not extract metadata\n"; |
| return false; |
| } |
| DEBUG(dbgs() << "Metadata extracted\n"); |
| |
| llvm::legacy::PassManager PassMgr; |
| addPassesForRS2SPIRV(PassMgr, ME); |
| |
| std::ofstream WrapperF; |
| if (!WrapperOutputFile.empty()) { |
| WrapperF.open(WrapperOutputFile, std::ios::trunc); |
| if (!WrapperF.good()) { |
| errs() << "Could not create/open file:\t" << WrapperOutputFile << "\n"; |
| return false; |
| } |
| DEBUG(dbgs() << "Wrapper output:\t" << WrapperOutputFile << "\n"); |
| PassMgr.add(createReflectionPass(WrapperF, ME)); |
| } |
| |
| PassMgr.add(createLLVMToSPIRV(BM.get())); |
| PassMgr.run(*M); |
| DEBUG(M->dump()); |
| |
| if (BM->getError(ErrMsg) != SPIRVEC_Success) |
| return false; |
| |
| OS << *BM; |
| |
| return true; |
| } |
| |
| bool Link(llvm::StringRef KernelFilename, llvm::StringRef WrapperFilename, |
| llvm::StringRef OutputFilename) { |
| DEBUG(dbgs() << "Linking...\n"); |
| |
| std::ifstream WrapperF(WrapperFilename); |
| if (!WrapperF.good()) { |
| errs() << "Cannot open file: " << WrapperFilename << "\n"; |
| } |
| std::ifstream KernelF(KernelFilename); |
| if (!KernelF.good()) { |
| errs() << "Cannot open file: " << KernelFilename << "\n"; |
| } |
| |
| LinkerModule WrapperM(WrapperF); |
| LinkerModule KernelM(KernelF); |
| |
| WrapperF.close(); |
| KernelF.close(); |
| |
| DEBUG(dbgs() << "WrapperF:\n"); |
| DEBUG(WrapperM.dump()); |
| DEBUG(dbgs() << "\n~~~~~~~~~~~~~~~~~~~~~~\n\nKernelF:\n"); |
| DEBUG(KernelM.dump()); |
| DEBUG(dbgs() << "\n======================\n\n"); |
| |
| const char Prefix[] = "%rs_linker_"; |
| |
| for (auto *LPtr : KernelM.lines()) { |
| assert(LPtr); |
| auto &L = *LPtr; |
| size_t Pos = 0; |
| while ((Pos = L.str().find("%", Pos)) != std::string::npos) { |
| L.str().replace(Pos, 1, Prefix); |
| Pos += strlen(Prefix); |
| } |
| } |
| |
| FixModuleStorageClass(KernelM); |
| DEBUG(KernelM.dump()); |
| |
| auto WBlocks = WrapperM.blocks(); |
| auto WIt = WBlocks.begin(); |
| const auto WEnd = WBlocks.end(); |
| |
| auto KBlocks = KernelM.blocks(); |
| auto KIt = KBlocks.begin(); |
| const auto KEnd = KBlocks.end(); |
| |
| LinkerModule OutM; |
| |
| if (WIt == WEnd || KIt == KEnd) |
| return false; |
| |
| const auto *HeaderB = dyn_cast<HeaderBlock>(WIt->get()); |
| if (!HeaderB || !isa<HeaderBlock>(KIt->get())) |
| return false; |
| |
| SmallVector<std::string, 2> KernelNames; |
| const bool KernelsFound = HeaderB->getRSKernelNames(KernelNames); |
| |
| if (!KernelsFound) { |
| errs() << "RS kernel names not found in wrapper\n"; |
| return false; |
| } |
| |
| // KernelM module's HeaderBlock is skipped - it has OpenCL-specific code that |
| // is replaced here with compute shader code. |
| |
| OutM.addBlock<HeaderBlock>(*HeaderB); |
| |
| if (++WIt == WEnd || ++KIt == KEnd) |
| return false; |
| |
| const auto *DecorBW = dyn_cast<DecorBlock>(WIt->get()); |
| if (!DecorBW || !isa<DecorBlock>(KIt->get())) |
| return false; |
| |
| // KernelM module's DecorBlock is skipped, because it contains OpenCL-specific |
| // code that is not needed (eg. linkage type information). |
| |
| OutM.addBlock<DecorBlock>(*DecorBW); |
| |
| if (++WIt == WEnd || ++KIt == KEnd) |
| return false; |
| |
| const auto *TypeAndConstBW = dyn_cast<TypeAndConstBlock>(WIt->get()); |
| auto *TypeAndConstBK = dyn_cast<TypeAndConstBlock>(KIt->get()); |
| if (!TypeAndConstBW || !TypeAndConstBK) |
| return false; |
| |
| OutM.addBlock<TypeAndConstBlock>(*TypeAndConstBW); |
| OutM.addBlock<TypeAndConstBlock>(*TypeAndConstBK); |
| |
| if (++WIt == WEnd || ++KIt == KEnd) |
| return false; |
| |
| const auto *VarBW = dyn_cast<VarBlock>(WIt->get()); |
| auto *VarBK = dyn_cast<VarBlock>(KIt->get()); |
| if (!VarBW) |
| return false; |
| |
| OutM.addBlock<VarBlock>(*VarBW); |
| |
| if (VarBK) |
| OutM.addBlock<VarBlock>(*VarBK); |
| else |
| --KIt; |
| |
| SmallVector<MainFunBlock *, 2> MainBs; |
| |
| while (++WIt != WEnd) { |
| auto *FunB = dyn_cast<FunctionBlock>(WIt->get()); |
| if (!FunB) |
| return false; |
| |
| if (auto *MB = dyn_cast<MainFunBlock>(WIt->get())) { |
| MainBs.push_back(&OutM.addBlock<MainFunBlock>(*MB)); |
| } else { |
| OutM.addBlock<FunctionBlock>(*FunB); |
| } |
| } |
| |
| if (!MainBs.size()) { |
| errs() << "Wrapper module has no main function\n"; |
| return false; |
| } |
| |
| while (++KIt != KEnd) { |
| // TODO: Check if FunDecl is a known runtime function. |
| if (isa<FunDeclBlock>(KIt->get())) |
| continue; |
| |
| auto *FunB = dyn_cast<FunctionBlock>(KIt->get()); |
| if (!FunB) |
| return false; |
| |
| // TODO: Detect also indirect recurion. |
| if (FunB->isDirectlyRecursive()) { |
| errs() << "Function: " << FunB->getFunctionName().str() |
| << " is recursive\n"; |
| return false; |
| } |
| |
| OutM.addBlock<FunctionBlock>(*FunB); |
| } |
| |
| OutM.fixBlockOrder(); |
| |
| auto KernelName = KernelNames.begin(); |
| const auto KE = KernelNames.end(); |
| auto MainB = MainBs.begin(); |
| const auto ME = MainBs.end(); |
| |
| for (; KernelName != KE && MainB != ME; ++KernelName, ++MainB) { |
| // Remove the leading "%" character in kernel names |
| const std::string KernelNameStr = Prefix + KernelName->substr(1); |
| DEBUG(dbgs() << "Kernel name: " << KernelNameStr << '\n'); |
| if (!FixMain(OutM, **MainB, KernelNameStr)) |
| return false; |
| if (!FixVectorShuffles(**MainB)) |
| return false; |
| } |
| |
| if (KernelName != KE || MainB != ME) { |
| errs() << "Inconsistent kernel metadata and definitions\n"; |
| return false; |
| } |
| |
| OutM.removeUnusedFunctions(); |
| |
| DEBUG(dbgs() << ">>>>>>>>>>>> Output module after prelink:\n\n"); |
| DEBUG(OutM.dump()); |
| |
| if (!FuseTypesAndConstants(OutM)) { |
| errs() << "Type fusion failed\n"; |
| return false; |
| } |
| |
| DEBUG(dbgs() << ">>>>>>>>>>>> Output module after value fusion:\n\n"); |
| DEBUG(OutM.dump()); |
| |
| if (!OutM.saveToFile(OutputFilename)) { |
| errs() << "Could not save to file: " << OutputFilename << "\n"; |
| return false; |
| } |
| |
| return true; |
| } |
| |
| bool FixMain(LinkerModule &LM, MainFunBlock &MainB, StringRef KernelName) { |
| MainB.replaceAllIds("%RS_SPIRV_DUMMY_", KernelName); |
| |
| while (MainB.hasFunctionCalls()) |
| if (!InlineFunctionCalls(LM, MainB)) { |
| errs() << "Could not inline function calls in main\n"; |
| return false; |
| } |
| |
| for (auto &L : MainB.lines()) { |
| if (!L.contains("OpInBoundsPtrAccessChain")) |
| continue; |
| |
| if (!TranslateInBoundsPtrAccessToAccess(L)) |
| return false; |
| } |
| |
| return true; |
| } |
| |
| struct FunctionCallInfo { |
| StringRef RetValName; |
| StringRef RetTy; |
| StringRef FName; |
| SmallVector<StringRef, 4> ArgNames; |
| }; |
| |
| static FunctionCallInfo GetFunctionCallInfo(const SPIRVLine &L) { |
| assert(L.contains("OpFunctionCall")); |
| |
| const Optional<StringRef> Ret = L.getLHSIdentifier(); |
| assert(Ret); |
| |
| SmallVector<StringRef, 6> Ids; |
| L.getRHSIdentifiers(Ids); |
| assert(Ids.size() >= 2 && "No return type and function name"); |
| |
| const StringRef RetTy = Ids[0]; |
| const StringRef FName = Ids[1]; |
| SmallVector<StringRef, 4> Args(Ids.begin() + 2, Ids.end()); |
| |
| return {*Ret, RetTy, FName, std::move(Args)}; |
| } |
| |
| bool InlineFunctionCalls(LinkerModule &LM, MainFunBlock &MB) { |
| DEBUG(dbgs() << "InlineFunctionCalls\n"); |
| MainFunBlock NewMB; |
| |
| auto MLines = MB.lines(); |
| auto MIt = MLines.begin(); |
| const auto MEnd = MLines.end(); |
| using iter_ty = decltype(MIt); |
| |
| auto SkipToFunctionCall = [&MEnd, &NewMB](iter_ty &It) { |
| while (++It != MEnd && !It->contains("OpFunctionCall")) |
| NewMB.addLine(*It); |
| |
| return It != MEnd; |
| }; |
| |
| NewMB.addLine(*MIt); |
| |
| std::vector<std::pair<std::string, std::string>> NameMapping; |
| |
| while (SkipToFunctionCall(MIt)) { |
| assert(MIt->contains("OpFunctionCall")); |
| const auto FInfo = GetFunctionCallInfo(*MIt); |
| DEBUG(dbgs() << "Found function call:\t" << MIt->str() << '\n'); |
| |
| SmallVector<Block *, 1> Callee; |
| LM.getBlocksIf(Callee, [&FInfo](Block &B) { |
| auto *FB = dyn_cast<FunctionBlock>(&B); |
| if (!FB) |
| return false; |
| |
| return FB->getFunctionName() == FInfo.FName; |
| }); |
| |
| if (Callee.size() != 1) { |
| errs() << "Callee not found\n"; |
| return false; |
| } |
| |
| auto *FB = cast<FunctionBlock>(Callee.front()); |
| |
| if (FB->getArity() != FInfo.ArgNames.size()) { |
| errs() << "Arity mismatch (caller: " << FInfo.ArgNames.size() |
| << ", callee: " << FB->getArity() << ")\n"; |
| return false; |
| } |
| |
| Optional<StringRef> RetValName = FB->getRetValName(); |
| if (!RetValName && !FB->isReturnTypeVoid()) { |
| errs() << "Return value not found for a function with non-void " |
| "return type.\n"; |
| return false; |
| } |
| |
| SmallVector<StringRef, 4> Params; |
| FB->getArgNames(Params); |
| |
| if (Params.size() != FInfo.ArgNames.size()) { |
| errs() << "Params size mismatch\n"; |
| return false; |
| } |
| |
| for (size_t i = 0, e = FInfo.ArgNames.size(); i < e; ++i) { |
| DEBUG(dbgs() << "New param mapping: " << Params[i] << " -> " |
| << FInfo.ArgNames[i] << "\n"); |
| NameMapping.emplace_back(Params[i].str(), FInfo.ArgNames[i].str()); |
| } |
| |
| if (RetValName) { |
| DEBUG(dbgs() << "New ret-val mapping: " << FInfo.RetValName << " -> " |
| << *RetValName << "\n"); |
| NameMapping.emplace_back(FInfo.RetValName.str(), RetValName->str()); |
| } |
| |
| const auto Body = FB->body(); |
| for (const auto &L : Body) |
| NewMB.addLine(L); |
| } |
| |
| while (MIt != MEnd) { |
| NewMB.addLine(*MIt); |
| ++MIt; |
| } |
| |
| std::reverse(NameMapping.begin(), NameMapping.end()); |
| for (const auto &P : NameMapping) { |
| DEBUG(dbgs() << "Replace " << P.first << " with " << P.second << "\n"); |
| NewMB.replaceAllIds(P.first, P.second); |
| } |
| |
| MB = NewMB; |
| |
| return true; |
| } |
| |
| bool FuseTypesAndConstants(LinkerModule &LM) { |
| StringMap<std::string> TypesAndConstDefs; |
| StringMap<std::string> NameReps; |
| |
| for (auto *LPtr : LM.lines()) { |
| assert(LPtr); |
| auto &L = *LPtr; |
| if (!L.contains("=")) |
| continue; |
| |
| SmallVector<StringRef, 4> IdsRefs; |
| L.getRHSIdentifiers(IdsRefs); |
| |
| SmallVector<std::string, 4> Ids; |
| Ids.reserve(IdsRefs.size()); |
| for (const auto &I : IdsRefs) |
| Ids.push_back(I.str()); |
| |
| for (auto &I : Ids) |
| if (NameReps.count(I) != 0) { |
| const bool Res = L.replaceId(I, NameReps[I]); |
| (void)Res; |
| assert(Res); |
| } |
| |
| if (L.contains("OpType") || L.contains("OpConstant")) { |
| const auto LHS = L.getLHSIdentifier(); |
| const auto RHS = L.getRHS(); |
| assert(LHS); |
| assert(RHS); |
| |
| if (!RHS->startswith("OpTypeStruct") && |
| !RHS->startswith("OpTypeRuntimeArray") && |
| TypesAndConstDefs.count(*RHS) != 0) { |
| NameReps.insert( |
| std::make_pair(LHS->str(), TypesAndConstDefs[RHS->str()])); |
| DEBUG(dbgs() << "New mapping: [" << LHS->str() << ", " |
| << TypesAndConstDefs[RHS->str()] << "]\n"); |
| L.markAsEmpty(); |
| } else { |
| TypesAndConstDefs.insert(std::make_pair(RHS->str(), LHS->str())); |
| DEBUG(dbgs() << "New val:\t" << RHS->str() << " : " << LHS->str() |
| << '\n'); |
| } |
| }; |
| } |
| |
| LM.removeNonCode(); |
| |
| return true; |
| } |
| |
| bool TranslateInBoundsPtrAccessToAccess(SPIRVLine &L) { |
| assert(L.contains(" OpInBoundsPtrAccessChain ")); |
| |
| SmallVector<StringRef, 4> Ids; |
| L.getRHSIdentifiers(Ids); |
| |
| if (Ids.size() < 4) { |
| errs() << "OpInBoundsPtrAccessChain has not enough parameters:\n\t" |
| << L.str(); |
| return false; |
| } |
| |
| std::istringstream SS(L.str()); |
| std::string LHS, Eq, Op; |
| SS >> LHS >> Eq >> Op; |
| |
| if (LHS.empty() || Eq != "=" || Op != "OpInBoundsPtrAccessChain") { |
| errs() << "Could not decompose OpInBoundsPtrAccessChain:\n\t" << L.str(); |
| return false; |
| } |
| |
| constexpr size_t ElementArgPosition = 2; |
| |
| std::ostringstream NewLine; |
| NewLine << LHS << " " << Eq << " OpAccessChain "; |
| for (size_t i = 0, e = Ids.size(); i != e; ++i) |
| if (i != ElementArgPosition) |
| NewLine << Ids[i].str() << " "; |
| |
| L.str() = NewLine.str(); |
| L.trim(); |
| |
| return true; |
| } |
| |
| // Replaces UndefValues in VectorShuffles with zeros, which is always |
| // safe, as the result for components marked as Undef is unused. |
| // Ex. 1) OpVectorShuffle %v4uchar %a %b 0 1 2 4294967295 --> |
| // OpVectorShuffle %v4uchar %a %b 0 1 2 0. |
| // |
| // Ex. 2) OpVectorShuffle %v4uchar %a %b 0 4294967295 3 4294967295 --> |
| // OpVectorShuffle %v4uchar %a %b 0 0 3 0. |
| // |
| // Fix needed for the current Vulkan driver, which crashed during |
| // backend compilation when case is not handled. |
| bool FixVectorShuffles(MainFunBlock &MB) { |
| const StringRef UndefStr = " 4294967295 "; |
| |
| for (auto &L : MB.lines()) { |
| if (!L.contains("OpVectorShuffle")) |
| continue; |
| |
| L.str().push_back(' '); |
| while (L.contains(UndefStr)) |
| L.replaceStr(UndefStr, " 0 "); |
| |
| L.trim(); |
| } |
| |
| return true; |
| } |
| |
| // This function changes all Function StorageClass use into Uniform. |
| // It's needed, because llvm-spirv converter emits wrong StorageClass |
| // for globals. |
| // The transfromation, however, breaks legitimate uses of Function StorageClass |
| // inside functions. |
| // |
| // Ex. 1. %ptr_Function_uint = OpTypePointer Function %uint |
| // --> %ptr_Uniform_uint = OpTypePointer Uniform %uint |
| // |
| // Ex. 2. %gep = OpAccessChain %ptr_Function_uchar %G %uint_zero |
| // --> %gep = OpAccessChain %ptr_Uniform_uchar %G %uint_zero |
| // |
| // TODO: Consider a better way of fixing this. |
| void FixModuleStorageClass(LinkerModule &M) { |
| for (auto *LPtr : M.lines()) { |
| assert(LPtr); |
| auto &L = *LPtr; |
| |
| while (L.contains(" Function")) |
| L.replaceStr(" Function", " Uniform"); |
| |
| while (L.contains("_Function_")) |
| L.replaceStr("_Function_", "_Uniform_"); |
| } |
| } |
| |
| } // namespace rs2spirv |