blob: a23e97e930081c59a045a2377ceda1ded743f601 [file] [log] [blame]
/*
* Copyright 2016, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RSSPIRVWriter.h"
#include "SPIRVModule.h"
#include "bcinfo/MetadataExtractor.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/Triple.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/SPIRV.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/IPO.h"
#include "GlobalMergePass.h"
#include "InlinePreparationPass.h"
#include "LinkerModule.h"
#include "ReflectionPass.h"
#include "RemoveNonkernelsPass.h"
#include <fstream>
#include <sstream>
#define DEBUG_TYPE "rs2spirv-writer"
using namespace llvm;
using namespace SPIRV;
namespace llvm {
FunctionPass *createPromoteMemoryToRegisterPass();
}
namespace rs2spirv {
static cl::opt<std::string> WrapperOutputFile("wo",
cl::desc("Wrapper output file"),
cl::value_desc("filename.spt"));
static bool FixMain(LinkerModule &LM, MainFunBlock &MB, StringRef KernelName);
static bool InlineFunctionCalls(LinkerModule &LM, MainFunBlock &MB);
static bool FuseTypesAndConstants(LinkerModule &LM);
static bool TranslateInBoundsPtrAccessToAccess(SPIRVLine &L);
static bool FixVectorShuffles(MainFunBlock &MB);
static void FixModuleStorageClass(LinkerModule &M);
static void HandleTargetTriple(Module &M) {
Triple TT(M.getTargetTriple());
auto Arch = TT.getArch();
StringRef NewTriple;
switch (Arch) {
default:
llvm_unreachable("Unrecognized architecture");
break;
case Triple::arm:
NewTriple = "spir-unknown-unknown";
break;
case Triple::aarch64:
NewTriple = "spir64-unknown-unknown";
break;
case Triple::spir:
case Triple::spir64:
DEBUG(dbgs() << "!!! Already a spir triple !!!\n");
}
DEBUG(dbgs() << "New triple:\t" << NewTriple << "\n");
M.setTargetTriple(NewTriple);
}
void addPassesForRS2SPIRV(llvm::legacy::PassManager &PassMgr,
bcinfo::MetadataExtractor &Extractor) {
PassMgr.add(createInlinePreparationPass(Extractor));
PassMgr.add(createAlwaysInlinerPass());
PassMgr.add(createRemoveNonkernelsPass(Extractor));
// Delete unreachable globals.
PassMgr.add(createGlobalDCEPass());
// Remove dead debug info.
PassMgr.add(createStripDeadDebugInfoPass());
// Remove dead func decls.
PassMgr.add(createStripDeadPrototypesPass());
PassMgr.add(createGlobalMergePass());
PassMgr.add(createPromoteMemoryToRegisterPass());
PassMgr.add(createTransOCLMD());
// TODO: investigate removal of OCLTypeToSPIRV pass.
PassMgr.add(createOCLTypeToSPIRV());
PassMgr.add(createSPIRVRegularizeLLVM());
PassMgr.add(createSPIRVLowerConstExpr());
PassMgr.add(createSPIRVLowerBool());
}
bool WriteSPIRV(Module *M, llvm::raw_ostream &OS, std::string &ErrMsg) {
std::unique_ptr<SPIRVModule> BM(SPIRVModule::createSPIRVModule());
HandleTargetTriple(*M);
bcinfo::MetadataExtractor ME(M);
if (!ME.extract()) {
errs() << "Could not extract metadata\n";
return false;
}
DEBUG(dbgs() << "Metadata extracted\n");
llvm::legacy::PassManager PassMgr;
addPassesForRS2SPIRV(PassMgr, ME);
std::ofstream WrapperF;
if (!WrapperOutputFile.empty()) {
WrapperF.open(WrapperOutputFile, std::ios::trunc);
if (!WrapperF.good()) {
errs() << "Could not create/open file:\t" << WrapperOutputFile << "\n";
return false;
}
DEBUG(dbgs() << "Wrapper output:\t" << WrapperOutputFile << "\n");
PassMgr.add(createReflectionPass(WrapperF, ME));
}
PassMgr.add(createLLVMToSPIRV(BM.get()));
PassMgr.run(*M);
DEBUG(M->dump());
if (BM->getError(ErrMsg) != SPIRVEC_Success)
return false;
OS << *BM;
return true;
}
bool Link(llvm::StringRef KernelFilename, llvm::StringRef WrapperFilename,
llvm::StringRef OutputFilename) {
DEBUG(dbgs() << "Linking...\n");
std::ifstream WrapperF(WrapperFilename);
if (!WrapperF.good()) {
errs() << "Cannot open file: " << WrapperFilename << "\n";
}
std::ifstream KernelF(KernelFilename);
if (!KernelF.good()) {
errs() << "Cannot open file: " << KernelFilename << "\n";
}
LinkerModule WrapperM(WrapperF);
LinkerModule KernelM(KernelF);
WrapperF.close();
KernelF.close();
DEBUG(dbgs() << "WrapperF:\n");
DEBUG(WrapperM.dump());
DEBUG(dbgs() << "\n~~~~~~~~~~~~~~~~~~~~~~\n\nKernelF:\n");
DEBUG(KernelM.dump());
DEBUG(dbgs() << "\n======================\n\n");
const char Prefix[] = "%rs_linker_";
for (auto *LPtr : KernelM.lines()) {
assert(LPtr);
auto &L = *LPtr;
size_t Pos = 0;
while ((Pos = L.str().find("%", Pos)) != std::string::npos) {
L.str().replace(Pos, 1, Prefix);
Pos += strlen(Prefix);
}
}
FixModuleStorageClass(KernelM);
DEBUG(KernelM.dump());
auto WBlocks = WrapperM.blocks();
auto WIt = WBlocks.begin();
const auto WEnd = WBlocks.end();
auto KBlocks = KernelM.blocks();
auto KIt = KBlocks.begin();
const auto KEnd = KBlocks.end();
LinkerModule OutM;
if (WIt == WEnd || KIt == KEnd)
return false;
const auto *HeaderB = dyn_cast<HeaderBlock>(WIt->get());
if (!HeaderB || !isa<HeaderBlock>(KIt->get()))
return false;
SmallVector<std::string, 2> KernelNames;
const bool KernelsFound = HeaderB->getRSKernelNames(KernelNames);
if (!KernelsFound) {
errs() << "RS kernel names not found in wrapper\n";
return false;
}
// KernelM module's HeaderBlock is skipped - it has OpenCL-specific code that
// is replaced here with compute shader code.
OutM.addBlock<HeaderBlock>(*HeaderB);
if (++WIt == WEnd || ++KIt == KEnd)
return false;
const auto *DecorBW = dyn_cast<DecorBlock>(WIt->get());
if (!DecorBW || !isa<DecorBlock>(KIt->get()))
return false;
// KernelM module's DecorBlock is skipped, because it contains OpenCL-specific
// code that is not needed (eg. linkage type information).
OutM.addBlock<DecorBlock>(*DecorBW);
if (++WIt == WEnd || ++KIt == KEnd)
return false;
const auto *TypeAndConstBW = dyn_cast<TypeAndConstBlock>(WIt->get());
auto *TypeAndConstBK = dyn_cast<TypeAndConstBlock>(KIt->get());
if (!TypeAndConstBW || !TypeAndConstBK)
return false;
OutM.addBlock<TypeAndConstBlock>(*TypeAndConstBW);
OutM.addBlock<TypeAndConstBlock>(*TypeAndConstBK);
if (++WIt == WEnd || ++KIt == KEnd)
return false;
const auto *VarBW = dyn_cast<VarBlock>(WIt->get());
auto *VarBK = dyn_cast<VarBlock>(KIt->get());
if (!VarBW)
return false;
OutM.addBlock<VarBlock>(*VarBW);
if (VarBK)
OutM.addBlock<VarBlock>(*VarBK);
else
--KIt;
SmallVector<MainFunBlock *, 2> MainBs;
while (++WIt != WEnd) {
auto *FunB = dyn_cast<FunctionBlock>(WIt->get());
if (!FunB)
return false;
if (auto *MB = dyn_cast<MainFunBlock>(WIt->get())) {
MainBs.push_back(&OutM.addBlock<MainFunBlock>(*MB));
} else {
OutM.addBlock<FunctionBlock>(*FunB);
}
}
if (!MainBs.size()) {
errs() << "Wrapper module has no main function\n";
return false;
}
while (++KIt != KEnd) {
// TODO: Check if FunDecl is a known runtime function.
if (isa<FunDeclBlock>(KIt->get()))
continue;
auto *FunB = dyn_cast<FunctionBlock>(KIt->get());
if (!FunB)
return false;
// TODO: Detect also indirect recurion.
if (FunB->isDirectlyRecursive()) {
errs() << "Function: " << FunB->getFunctionName().str()
<< " is recursive\n";
return false;
}
OutM.addBlock<FunctionBlock>(*FunB);
}
OutM.fixBlockOrder();
auto KernelName = KernelNames.begin();
const auto KE = KernelNames.end();
auto MainB = MainBs.begin();
const auto ME = MainBs.end();
for (; KernelName != KE && MainB != ME; ++KernelName, ++MainB) {
// Remove the leading "%" character in kernel names
const std::string KernelNameStr = Prefix + KernelName->substr(1);
DEBUG(dbgs() << "Kernel name: " << KernelNameStr << '\n');
if (!FixMain(OutM, **MainB, KernelNameStr))
return false;
if (!FixVectorShuffles(**MainB))
return false;
}
if (KernelName != KE || MainB != ME) {
errs() << "Inconsistent kernel metadata and definitions\n";
return false;
}
OutM.removeUnusedFunctions();
DEBUG(dbgs() << ">>>>>>>>>>>> Output module after prelink:\n\n");
DEBUG(OutM.dump());
if (!FuseTypesAndConstants(OutM)) {
errs() << "Type fusion failed\n";
return false;
}
DEBUG(dbgs() << ">>>>>>>>>>>> Output module after value fusion:\n\n");
DEBUG(OutM.dump());
if (!OutM.saveToFile(OutputFilename)) {
errs() << "Could not save to file: " << OutputFilename << "\n";
return false;
}
return true;
}
bool FixMain(LinkerModule &LM, MainFunBlock &MainB, StringRef KernelName) {
MainB.replaceAllIds("%RS_SPIRV_DUMMY_", KernelName);
while (MainB.hasFunctionCalls())
if (!InlineFunctionCalls(LM, MainB)) {
errs() << "Could not inline function calls in main\n";
return false;
}
for (auto &L : MainB.lines()) {
if (!L.contains("OpInBoundsPtrAccessChain"))
continue;
if (!TranslateInBoundsPtrAccessToAccess(L))
return false;
}
return true;
}
struct FunctionCallInfo {
StringRef RetValName;
StringRef RetTy;
StringRef FName;
SmallVector<StringRef, 4> ArgNames;
};
static FunctionCallInfo GetFunctionCallInfo(const SPIRVLine &L) {
assert(L.contains("OpFunctionCall"));
const Optional<StringRef> Ret = L.getLHSIdentifier();
assert(Ret);
SmallVector<StringRef, 6> Ids;
L.getRHSIdentifiers(Ids);
assert(Ids.size() >= 2 && "No return type and function name");
const StringRef RetTy = Ids[0];
const StringRef FName = Ids[1];
SmallVector<StringRef, 4> Args(Ids.begin() + 2, Ids.end());
return {*Ret, RetTy, FName, std::move(Args)};
}
bool InlineFunctionCalls(LinkerModule &LM, MainFunBlock &MB) {
DEBUG(dbgs() << "InlineFunctionCalls\n");
MainFunBlock NewMB;
auto MLines = MB.lines();
auto MIt = MLines.begin();
const auto MEnd = MLines.end();
using iter_ty = decltype(MIt);
auto SkipToFunctionCall = [&MEnd, &NewMB](iter_ty &It) {
while (++It != MEnd && !It->contains("OpFunctionCall"))
NewMB.addLine(*It);
return It != MEnd;
};
NewMB.addLine(*MIt);
std::vector<std::pair<std::string, std::string>> NameMapping;
while (SkipToFunctionCall(MIt)) {
assert(MIt->contains("OpFunctionCall"));
const auto FInfo = GetFunctionCallInfo(*MIt);
DEBUG(dbgs() << "Found function call:\t" << MIt->str() << '\n');
SmallVector<Block *, 1> Callee;
LM.getBlocksIf(Callee, [&FInfo](Block &B) {
auto *FB = dyn_cast<FunctionBlock>(&B);
if (!FB)
return false;
return FB->getFunctionName() == FInfo.FName;
});
if (Callee.size() != 1) {
errs() << "Callee not found\n";
return false;
}
auto *FB = cast<FunctionBlock>(Callee.front());
if (FB->getArity() != FInfo.ArgNames.size()) {
errs() << "Arity mismatch (caller: " << FInfo.ArgNames.size()
<< ", callee: " << FB->getArity() << ")\n";
return false;
}
Optional<StringRef> RetValName = FB->getRetValName();
if (!RetValName && !FB->isReturnTypeVoid()) {
errs() << "Return value not found for a function with non-void "
"return type.\n";
return false;
}
SmallVector<StringRef, 4> Params;
FB->getArgNames(Params);
if (Params.size() != FInfo.ArgNames.size()) {
errs() << "Params size mismatch\n";
return false;
}
for (size_t i = 0, e = FInfo.ArgNames.size(); i < e; ++i) {
DEBUG(dbgs() << "New param mapping: " << Params[i] << " -> "
<< FInfo.ArgNames[i] << "\n");
NameMapping.emplace_back(Params[i].str(), FInfo.ArgNames[i].str());
}
if (RetValName) {
DEBUG(dbgs() << "New ret-val mapping: " << FInfo.RetValName << " -> "
<< *RetValName << "\n");
NameMapping.emplace_back(FInfo.RetValName.str(), RetValName->str());
}
const auto Body = FB->body();
for (const auto &L : Body)
NewMB.addLine(L);
}
while (MIt != MEnd) {
NewMB.addLine(*MIt);
++MIt;
}
std::reverse(NameMapping.begin(), NameMapping.end());
for (const auto &P : NameMapping) {
DEBUG(dbgs() << "Replace " << P.first << " with " << P.second << "\n");
NewMB.replaceAllIds(P.first, P.second);
}
MB = NewMB;
return true;
}
bool FuseTypesAndConstants(LinkerModule &LM) {
StringMap<std::string> TypesAndConstDefs;
StringMap<std::string> NameReps;
for (auto *LPtr : LM.lines()) {
assert(LPtr);
auto &L = *LPtr;
if (!L.contains("="))
continue;
SmallVector<StringRef, 4> IdsRefs;
L.getRHSIdentifiers(IdsRefs);
SmallVector<std::string, 4> Ids;
Ids.reserve(IdsRefs.size());
for (const auto &I : IdsRefs)
Ids.push_back(I.str());
for (auto &I : Ids)
if (NameReps.count(I) != 0) {
const bool Res = L.replaceId(I, NameReps[I]);
(void)Res;
assert(Res);
}
if (L.contains("OpType") || L.contains("OpConstant")) {
const auto LHS = L.getLHSIdentifier();
const auto RHS = L.getRHS();
assert(LHS);
assert(RHS);
if (!RHS->startswith("OpTypeStruct") &&
!RHS->startswith("OpTypeRuntimeArray") &&
TypesAndConstDefs.count(*RHS) != 0) {
NameReps.insert(
std::make_pair(LHS->str(), TypesAndConstDefs[RHS->str()]));
DEBUG(dbgs() << "New mapping: [" << LHS->str() << ", "
<< TypesAndConstDefs[RHS->str()] << "]\n");
L.markAsEmpty();
} else {
TypesAndConstDefs.insert(std::make_pair(RHS->str(), LHS->str()));
DEBUG(dbgs() << "New val:\t" << RHS->str() << " : " << LHS->str()
<< '\n');
}
};
}
LM.removeNonCode();
return true;
}
bool TranslateInBoundsPtrAccessToAccess(SPIRVLine &L) {
assert(L.contains(" OpInBoundsPtrAccessChain "));
SmallVector<StringRef, 4> Ids;
L.getRHSIdentifiers(Ids);
if (Ids.size() < 4) {
errs() << "OpInBoundsPtrAccessChain has not enough parameters:\n\t"
<< L.str();
return false;
}
std::istringstream SS(L.str());
std::string LHS, Eq, Op;
SS >> LHS >> Eq >> Op;
if (LHS.empty() || Eq != "=" || Op != "OpInBoundsPtrAccessChain") {
errs() << "Could not decompose OpInBoundsPtrAccessChain:\n\t" << L.str();
return false;
}
constexpr size_t ElementArgPosition = 2;
std::ostringstream NewLine;
NewLine << LHS << " " << Eq << " OpAccessChain ";
for (size_t i = 0, e = Ids.size(); i != e; ++i)
if (i != ElementArgPosition)
NewLine << Ids[i].str() << " ";
L.str() = NewLine.str();
L.trim();
return true;
}
// Replaces UndefValues in VectorShuffles with zeros, which is always
// safe, as the result for components marked as Undef is unused.
// Ex. 1) OpVectorShuffle %v4uchar %a %b 0 1 2 4294967295 -->
// OpVectorShuffle %v4uchar %a %b 0 1 2 0.
//
// Ex. 2) OpVectorShuffle %v4uchar %a %b 0 4294967295 3 4294967295 -->
// OpVectorShuffle %v4uchar %a %b 0 0 3 0.
//
// Fix needed for the current Vulkan driver, which crashed during
// backend compilation when case is not handled.
bool FixVectorShuffles(MainFunBlock &MB) {
const StringRef UndefStr = " 4294967295 ";
for (auto &L : MB.lines()) {
if (!L.contains("OpVectorShuffle"))
continue;
L.str().push_back(' ');
while (L.contains(UndefStr))
L.replaceStr(UndefStr, " 0 ");
L.trim();
}
return true;
}
// This function changes all Function StorageClass use into Uniform.
// It's needed, because llvm-spirv converter emits wrong StorageClass
// for globals.
// The transfromation, however, breaks legitimate uses of Function StorageClass
// inside functions.
//
// Ex. 1. %ptr_Function_uint = OpTypePointer Function %uint
// --> %ptr_Uniform_uint = OpTypePointer Uniform %uint
//
// Ex. 2. %gep = OpAccessChain %ptr_Function_uchar %G %uint_zero
// --> %gep = OpAccessChain %ptr_Uniform_uchar %G %uint_zero
//
// TODO: Consider a better way of fixing this.
void FixModuleStorageClass(LinkerModule &M) {
for (auto *LPtr : M.lines()) {
assert(LPtr);
auto &L = *LPtr;
while (L.contains(" Function"))
L.replaceStr(" Function", " Uniform");
while (L.contains("_Function_"))
L.replaceStr("_Function_", "_Uniform_");
}
}
} // namespace rs2spirv