blob: e0c512e62d2b5f005cf9c3aaa51e5f6f3d539221 [file] [log] [blame]
//===- SPIRVModule.cpp - Class to represent SPIR-V module --------*- C++ -*-===//
//
// The LLVM/SPIRV Translator
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
// Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal with the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimers.
// Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimers in the documentation
// and/or other materials provided with the distribution.
// Neither the names of Advanced Micro Devices, Inc., nor the names of its
// contributors may be used to endorse or promote products derived from this
// Software without specific prior written permission.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
// THE SOFTWARE.
//
//===----------------------------------------------------------------------===//
/// \file
///
/// This file implements Module class for SPIR-V.
///
//===----------------------------------------------------------------------===//
#include "SPIRVModule.h"
#include "SPIRVDebug.h"
#include "SPIRVEntry.h"
#include "SPIRVType.h"
#include "SPIRVValue.h"
#include "SPIRVExtInst.h"
#include "SPIRVFunction.h"
#include "SPIRVInstruction.h"
#include "SPIRVStream.h"
#include <set>
#include <unordered_map>
#include <unordered_set>
namespace SPIRV{
SPIRVModule::SPIRVModule():AutoAddCapability(true), ValidateCapability(false)
{}
SPIRVModule::~SPIRVModule()
{}
class SPIRVModuleImpl : public SPIRVModule {
public:
SPIRVModuleImpl():SPIRVModule(), NextId(1),
SPIRVVersion(SPIRV_1_0),
GeneratorId(SPIRVGEN_KhronosLLVMSPIRVTranslator),
GeneratorVer(0),
InstSchema(SPIRVISCH_Default),
SrcLang(SourceLanguageOpenCL_C),
SrcLangVer(102000),
MemoryModel(MemoryModelOpenCL){
AddrModel = sizeof(size_t) == 32 ? AddressingModelPhysical32
: AddressingModelPhysical64;
};
virtual ~SPIRVModuleImpl();
// Object query functions
bool exist(SPIRVId) const;
bool exist(SPIRVId, SPIRVEntry **) const;
SPIRVId getId(SPIRVId Id = SPIRVID_INVALID, unsigned Increment = 1);
virtual SPIRVEntry *getEntry(SPIRVId Id) const;
bool hasDebugInfo() const { return !LineVec.empty();}
// Error handling functions
SPIRVErrorLog &getErrorLog() { return ErrLog;}
SPIRVErrorCode getError(std::string &ErrMsg) { return ErrLog.getError(ErrMsg);}
// Module query functions
SPIRVAddressingModelKind getAddressingModel() { return AddrModel;}
SPIRVExtInstSetKind getBuiltinSet(SPIRVId SetId) const;
const SPIRVCapMap &getCapability() const { return CapMap; }
bool hasCapability(SPIRVCapabilityKind Cap) const {
return CapMap.find(Cap) != CapMap.end();
}
std::set<std::string> &getExtension() { return SPIRVExt;}
SPIRVFunction *getFunction(unsigned I) const { return FuncVec[I];}
SPIRVVariable *getVariable(unsigned I) const { return VariableVec[I];}
virtual SPIRVValue *getValue(SPIRVId TheId) const;
virtual std::vector<SPIRVValue *> getValues(const std::vector<SPIRVId>&)const;
virtual std::vector<SPIRVId> getIds(const std::vector<SPIRVEntry *>&)const;
virtual std::vector<SPIRVId> getIds(const std::vector<SPIRVValue *>&)const;
virtual SPIRVType *getValueType(SPIRVId TheId)const;
virtual std::vector<SPIRVType *> getValueTypes(const std::vector<SPIRVId>&)
const;
SPIRVMemoryModelKind getMemoryModel() const { return MemoryModel;}
virtual SPIRVConstant* getLiteralAsConstant(unsigned Literal);
unsigned getNumEntryPoints(SPIRVExecutionModelKind EM) const {
auto Loc = EntryPointVec.find(EM);
if (Loc == EntryPointVec.end())
return 0;
return Loc->second.size();
}
SPIRVFunction *getEntryPoint(SPIRVExecutionModelKind EM, unsigned I) const {
auto Loc = EntryPointVec.find(EM);
if (Loc == EntryPointVec.end())
return nullptr;
assert(I < Loc->second.size());
return get<SPIRVFunction>(Loc->second[I]);
}
unsigned getNumFunctions() const { return FuncVec.size();}
unsigned getNumVariables() const { return VariableVec.size();}
SourceLanguage getSourceLanguage(SPIRVWord * Ver = nullptr) const {
if (Ver)
*Ver = SrcLangVer;
return SrcLang;
}
std::set<std::string> &getSourceExtension() { return SrcExtension;}
bool isEntryPoint(SPIRVExecutionModelKind, SPIRVId EP) const;
unsigned short getGeneratorId() const { return GeneratorId; }
unsigned short getGeneratorVer() const { return GeneratorVer; }
SPIRVWord getSPIRVVersion() const { return SPIRVVersion; }
// Module changing functions
bool importBuiltinSet(const std::string &, SPIRVId *);
bool importBuiltinSetWithId(const std::string &, SPIRVId);
void optimizeDecorates();
void setAddressingModel(SPIRVAddressingModelKind AM) { AddrModel = AM;}
void setAlignment(SPIRVValue *, SPIRVWord);
void setMemoryModel(SPIRVMemoryModelKind MM) {
MemoryModel = MM;
if (MemoryModel == spv::MemoryModelOpenCL)
addCapability(CapabilityKernel);
}
void setName(SPIRVEntry *E, const std::string &Name);
void setSourceLanguage(SourceLanguage Lang, SPIRVWord Ver) {
SrcLang = Lang;
SrcLangVer = Ver;
}
void setGeneratorId(unsigned short Id) { GeneratorId = Id; }
void setGeneratorVer(unsigned short Ver) { GeneratorVer = Ver; }
void resolveUnknownStructFields();
void setSPIRVVersion(SPIRVWord Ver) override { SPIRVVersion = Ver; }
// Object creation functions
template<class T> void addTo(std::vector<T *> &V, SPIRVEntry *E);
virtual SPIRVEntry *addEntry(SPIRVEntry *E);
virtual SPIRVBasicBlock *addBasicBlock(SPIRVFunction *, SPIRVId);
virtual SPIRVString *getString(const std::string &Str);
virtual SPIRVMemberName *addMemberName(SPIRVTypeStruct *ST,
SPIRVWord MemberNumber, const std::string &Name);
virtual void addUnknownStructField(SPIRVTypeStruct *Struct, unsigned I,
SPIRVId ID);
virtual SPIRVLine *addLine(SPIRVEntry *E, SPIRVString *FileName, SPIRVWord Line,
SPIRVWord Column);
virtual void addCapability(SPIRVCapabilityKind);
virtual void addCapabilityInternal(SPIRVCapabilityKind);
virtual const SPIRVDecorateGeneric *addDecorate(const SPIRVDecorateGeneric *);
virtual SPIRVDecorationGroup *addDecorationGroup();
virtual SPIRVDecorationGroup *addDecorationGroup(SPIRVDecorationGroup *Group);
virtual SPIRVGroupDecorate *addGroupDecorate(SPIRVDecorationGroup *Group,
const std::vector<SPIRVEntry *> &Targets);
virtual SPIRVGroupDecorateGeneric *addGroupDecorateGeneric(
SPIRVGroupDecorateGeneric *GDec);
virtual SPIRVGroupMemberDecorate *addGroupMemberDecorate(
SPIRVDecorationGroup *Group, const std::vector<SPIRVEntry *> &Targets);
virtual void addEntryPoint(SPIRVExecutionModelKind ExecModel,
SPIRVId EntryPoint);
virtual SPIRVForward *addForward(SPIRVType *Ty);
virtual SPIRVForward *addForward(SPIRVId, SPIRVType *Ty);
virtual SPIRVFunction *addFunction(SPIRVFunction *);
virtual SPIRVFunction *addFunction(SPIRVTypeFunction *, SPIRVId);
virtual SPIRVEntry *replaceForward(SPIRVForward *, SPIRVEntry *);
// Type creation functions
template<class T> T * addType(T *Ty);
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *);
virtual SPIRVTypeBool *addBoolType();
virtual SPIRVTypeFloat *addFloatType(unsigned BitWidth);
virtual SPIRVTypeFunction *addFunctionType(SPIRVType *,
const std::vector<SPIRVType *> &);
virtual SPIRVTypeInt *addIntegerType(unsigned BitWidth);
virtual SPIRVTypeOpaque *addOpaqueType(const std::string &);
virtual SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, SPIRVType *);
virtual SPIRVTypeImage *addImageType(SPIRVType *,
const SPIRVTypeImageDescriptor &);
virtual SPIRVTypeImage *addImageType(SPIRVType *,
const SPIRVTypeImageDescriptor &, SPIRVAccessQualifierKind);
virtual SPIRVTypeSampler *addSamplerType();
virtual SPIRVTypePipeStorage *addPipeStorageType();
virtual SPIRVTypeSampledImage *addSampledImageType(SPIRVTypeImage *T);
virtual SPIRVTypeStruct *openStructType(unsigned, const std::string &);
virtual void closeStructType(SPIRVTypeStruct *T, bool);
virtual SPIRVTypeVector *addVectorType(SPIRVType *, SPIRVWord);
virtual SPIRVType *addOpaqueGenericType(Op);
virtual SPIRVTypeDeviceEvent *addDeviceEventType();
virtual SPIRVTypeQueue *addQueueType();
virtual SPIRVTypePipe *addPipeType();
virtual SPIRVTypeVoid *addVoidType();
virtual void createForwardPointers();
// Constant creation functions
virtual SPIRVInstruction *addBranchInst(SPIRVLabel *, SPIRVBasicBlock *);
virtual SPIRVInstruction *addBranchConditionalInst(SPIRVValue *, SPIRVLabel *,
SPIRVLabel *, SPIRVBasicBlock *);
virtual SPIRVValue *addCompositeConstant(SPIRVType *,
const std::vector<SPIRVValue*>&);
virtual SPIRVValue *addConstant(SPIRVValue *);
virtual SPIRVValue *addConstant(SPIRVType *, uint64_t);
virtual SPIRVValue *addDoubleConstant(SPIRVTypeFloat *, double);
virtual SPIRVValue *addFloatConstant(SPIRVTypeFloat *, float);
virtual SPIRVValue *addIntegerConstant(SPIRVTypeInt *, uint64_t);
virtual SPIRVValue *addNullConstant(SPIRVType *);
virtual SPIRVValue *addUndef(SPIRVType *TheType);
virtual SPIRVValue *addSamplerConstant(SPIRVType *TheType, SPIRVWord AddrMode,
SPIRVWord ParametricMode, SPIRVWord FilterMode);
virtual SPIRVValue* addPipeStorageConstant(SPIRVType* TheType,
SPIRVWord PacketSize, SPIRVWord PacketAlign, SPIRVWord Capacity);
// Instruction creation functions
virtual SPIRVInstruction *addPtrAccessChainInst(SPIRVType *, SPIRVValue *,
std::vector<SPIRVValue *>, SPIRVBasicBlock *, bool);
virtual SPIRVInstruction *addAsyncGroupCopy(SPIRVValue *Scope,
SPIRVValue *Dest, SPIRVValue *Src, SPIRVValue *NumElems, SPIRVValue *Stride,
SPIRVValue *Event, SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addExtInst(SPIRVType *,
SPIRVWord, SPIRVWord, const std::vector<SPIRVWord> &,
SPIRVBasicBlock *);
virtual SPIRVInstruction *addExtInst(SPIRVType *,
SPIRVWord, SPIRVWord, const std::vector<SPIRVValue *> &,
SPIRVBasicBlock *);
virtual SPIRVInstruction *addBinaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVValue *, SPIRVBasicBlock *);
virtual SPIRVInstruction *addCallInst(SPIRVFunction*,
const std::vector<SPIRVWord> &, SPIRVBasicBlock *);
virtual SPIRVInstruction *addCmpInst(Op, SPIRVType *, SPIRVValue *,
SPIRVValue *, SPIRVBasicBlock *);
virtual SPIRVInstruction *addLoadInst(SPIRVValue *,
const std::vector<SPIRVWord>&, SPIRVBasicBlock *);
virtual SPIRVInstruction *addPhiInst(SPIRVType *, std::vector<SPIRVValue *>,
SPIRVBasicBlock *);
virtual SPIRVInstruction *addCompositeExtractInst(SPIRVType *, SPIRVValue *,
const std::vector<SPIRVWord>&, SPIRVBasicBlock *);
virtual SPIRVInstruction *addCompositeInsertInst(SPIRVValue *Object,
SPIRVValue *Composite, const std::vector<SPIRVWord>& Indices,
SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addCopyObjectInst(SPIRVType *TheType,
SPIRVValue *Operand, SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addCopyMemoryInst(SPIRVValue *, SPIRVValue *,
const std::vector<SPIRVWord>&, SPIRVBasicBlock *);
virtual SPIRVInstruction *addCopyMemorySizedInst(SPIRVValue *, SPIRVValue *,
SPIRVValue *, const std::vector<SPIRVWord>&, SPIRVBasicBlock *);
virtual SPIRVInstruction *addControlBarrierInst(
SPIRVValue *ExecKind, SPIRVValue *MemKind,
SPIRVValue *MemSema, SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addGroupInst(Op OpCode, SPIRVType *Type,
Scope Scope, const std::vector<SPIRVValue *> &Ops,
SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addInstruction(SPIRVInstruction *Inst,
SPIRVBasicBlock *BB);
virtual SPIRVInstTemplateBase *addInstTemplate(Op OC,
SPIRVBasicBlock* BB, SPIRVType *Ty);
virtual SPIRVInstTemplateBase *addInstTemplate(Op OC,
const std::vector<SPIRVWord>& Ops, SPIRVBasicBlock* BB, SPIRVType *Ty);
virtual SPIRVInstruction *addMemoryBarrierInst(
Scope ScopeKind, SPIRVWord MemFlag, SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addReturnInst(SPIRVBasicBlock *);
virtual SPIRVInstruction *addReturnValueInst(SPIRVValue *, SPIRVBasicBlock *);
virtual SPIRVInstruction *addSelectInst(SPIRVValue *, SPIRVValue *, SPIRVValue *,
SPIRVBasicBlock *);
virtual SPIRVInstruction *addStoreInst(SPIRVValue *, SPIRVValue *,
const std::vector<SPIRVWord>&, SPIRVBasicBlock *);
virtual SPIRVInstruction *addSwitchInst(SPIRVValue *, SPIRVBasicBlock *,
const std::vector<std::pair<SPIRVWord, SPIRVBasicBlock *>>&,
SPIRVBasicBlock *);
virtual SPIRVInstruction *addUnaryInst(Op, SPIRVType *, SPIRVValue *,
SPIRVBasicBlock *);
virtual SPIRVInstruction *addVariable(SPIRVType *, bool, SPIRVLinkageTypeKind,
SPIRVValue *, const std::string &, SPIRVStorageClassKind, SPIRVBasicBlock *);
virtual SPIRVValue *addVectorShuffleInst(SPIRVType *Type, SPIRVValue *Vec1,
SPIRVValue *Vec2, const std::vector<SPIRVWord> &Components,
SPIRVBasicBlock *BB);
virtual SPIRVInstruction *addVectorExtractDynamicInst(SPIRVValue *,
SPIRVValue *, SPIRVBasicBlock *);
virtual SPIRVInstruction *addVectorInsertDynamicInst(SPIRVValue *,
SPIRVValue *, SPIRVValue*, SPIRVBasicBlock *);
// I/O functions
friend spv_ostream & operator<<(spv_ostream &O, SPIRVModule& M);
friend std::istream & operator>>(std::istream &I, SPIRVModule& M);
private:
SPIRVErrorLog ErrLog;
SPIRVId NextId;
SPIRVWord SPIRVVersion;
unsigned short GeneratorId;
unsigned short GeneratorVer;
SPIRVInstructionSchemaKind InstSchema;
SourceLanguage SrcLang;
SPIRVWord SrcLangVer;
std::set<std::string> SrcExtension;
std::set<std::string> SPIRVExt;
SPIRVAddressingModelKind AddrModel;
SPIRVMemoryModelKind MemoryModel;
typedef std::map<SPIRVId, SPIRVEntry *> SPIRVIdToEntryMap;
typedef std::vector<SPIRVEntry *> SPIRVEntryVector;
typedef std::set<SPIRVId> SPIRVIdSet;
typedef std::vector<SPIRVId> SPIRVIdVec;
typedef std::vector<SPIRVFunction *> SPIRVFunctionVector;
typedef std::vector<SPIRVTypeForwardPointer *> SPIRVForwardPointerVec;
typedef std::vector<SPIRVType *> SPIRVTypeVec;
typedef std::vector<SPIRVValue *> SPIRVConstantVector;
typedef std::vector<SPIRVVariable *> SPIRVVariableVec;
typedef std::vector<SPIRVString *> SPIRVStringVec;
typedef std::vector<SPIRVMemberName *> SPIRVMemberNameVec;
typedef std::vector<SPIRVLine *> SPIRVLineVec;
typedef std::vector<SPIRVDecorationGroup *> SPIRVDecGroupVec;
typedef std::vector<SPIRVGroupDecorateGeneric *> SPIRVGroupDecVec;
typedef std::map<SPIRVId, SPIRVExtInstSetKind> SPIRVIdToBuiltinSetMap;
typedef std::map<SPIRVExecutionModelKind, SPIRVIdSet> SPIRVExecModelIdSetMap;
typedef std::map<SPIRVExecutionModelKind, SPIRVIdVec> SPIRVExecModelIdVecMap;
typedef std::unordered_map<std::string, SPIRVString*> SPIRVStringMap;
typedef std::map<SPIRVTypeStruct *, std::vector<std::pair<unsigned, SPIRVId>>>
SPIRVUnknownStructFieldMap;
SPIRVForwardPointerVec ForwardPointerVec;
SPIRVTypeVec TypeVec;
SPIRVIdToEntryMap IdEntryMap;
SPIRVFunctionVector FuncVec;
SPIRVConstantVector ConstVec;
SPIRVVariableVec VariableVec;
SPIRVEntryVector EntryNoId; // Entries without id
SPIRVIdToBuiltinSetMap IdBuiltinMap;
SPIRVIdSet NamedId;
SPIRVStringVec StringVec;
SPIRVMemberNameVec MemberNameVec;
SPIRVLineVec LineVec;
SPIRVDecorateSet DecorateSet;
SPIRVDecGroupVec DecGroupVec;
SPIRVGroupDecVec GroupDecVec;
SPIRVExecModelIdSetMap EntryPointSet;
SPIRVExecModelIdVecMap EntryPointVec;
SPIRVStringMap StrMap;
SPIRVCapMap CapMap;
SPIRVUnknownStructFieldMap UnknownStructFieldMap;
std::map<unsigned, SPIRVTypeInt*> IntTypeMap;
std::map<unsigned, SPIRVConstant*> LiteralMap;
void layoutEntry(SPIRVEntry* Entry);
};
SPIRVModuleImpl::~SPIRVModuleImpl() {
//ToDo: Fix bug causing crash
//for (auto I:IdEntryMap)
// delete I.second;
// ToDo: Fix bug causing crash
//for (auto I:EntryNoId) {
// bildbgs() << "[delete] " << *I;
// delete I;
//}
for (auto C : CapMap)
delete C.second;
}
SPIRVLine*
SPIRVModuleImpl::addLine(SPIRVEntry* E, SPIRVString* FileName,
SPIRVWord Line, SPIRVWord Column) {
auto L = add(new SPIRVLine(E, FileName->getId(), Line, Column));
E->setLine(L);
return L;
}
// Creates decoration group and group decorates from decorates shared by
// multiple targets.
void
SPIRVModuleImpl::optimizeDecorates() {
SPIRVDBG(spvdbgs() << "[optimizeDecorates] begin\n");
for (auto I = DecorateSet.begin(), E = DecorateSet.end(); I != E;) {
auto D = *I;
SPIRVDBG(spvdbgs() << " check " << *D << '\n');
if (D->getOpCode() == OpMemberDecorate) {
++I;
continue;
}
auto ER = DecorateSet.equal_range(D);
SPIRVDBG(spvdbgs() << " equal range " << **ER.first
<< " to ";
if (ER.second != DecorateSet.end())
spvdbgs() << **ER.second;
else
spvdbgs() << "end";
spvdbgs() << '\n');
if (std::distance(ER.first, ER.second) < 2) {
I = ER.second;
SPIRVDBG(spvdbgs() << " skip equal range \n");
continue;
}
SPIRVDBG(spvdbgs() << " add deco group. erase equal range\n");
auto G = new SPIRVDecorationGroup(this, getId());
std::vector<SPIRVId> Targets;
Targets.push_back(D->getTargetId());
const_cast<SPIRVDecorateGeneric*>(D)->setTargetId(G->getId());
G->getDecorations().insert(D);
for (I = ER.first; I != ER.second; ++I) {
auto E = *I;
if (*E == *D)
continue;
Targets.push_back(E->getTargetId());
}
// WordCount is only 16 bits. We can only have 65535 - FixedWC targtets per
// group.
// For now, just skip using a group if the number of targets to too big
if (Targets.size() < 65530) {
DecorateSet.erase(ER.first, ER.second);
auto GD = new SPIRVGroupDecorate(G, Targets);
DecGroupVec.push_back(G);
GroupDecVec.push_back(GD);
}
}
}
SPIRVValue*
SPIRVModuleImpl::addSamplerConstant(SPIRVType* TheType,
SPIRVWord AddrMode, SPIRVWord ParametricMode, SPIRVWord FilterMode) {
return addConstant(new SPIRVConstantSampler(this, TheType, getId(), AddrMode,
ParametricMode, FilterMode));
}
SPIRVValue*
SPIRVModuleImpl::addPipeStorageConstant(SPIRVType* TheType,
SPIRVWord PacketSize, SPIRVWord PacketAlign, SPIRVWord Capacity) {
return addConstant(new SPIRVConstantPipeStorage(this, TheType, getId(),
PacketSize, PacketAlign, Capacity));
}
void
SPIRVModuleImpl::addCapability(SPIRVCapabilityKind Cap) {
addCapabilities(SPIRV::getCapability(Cap));
SPIRVDBG(spvdbgs() << "addCapability: " << Cap << '\n');
if (hasCapability(Cap))
return;
CapMap.insert(std::make_pair(Cap, new SPIRVCapability(this, Cap)));
}
void
SPIRVModuleImpl::addCapabilityInternal(SPIRVCapabilityKind Cap) {
if (AutoAddCapability) {
if (hasCapability(Cap))
return;
CapMap.insert(std::make_pair(Cap, new SPIRVCapability(this, Cap)));
}
}
SPIRVConstant*
SPIRVModuleImpl::getLiteralAsConstant(unsigned Literal) {
auto Loc = LiteralMap.find(Literal);
if (Loc != LiteralMap.end())
return Loc->second;
auto Ty = addIntegerType(32);
auto V = new SPIRVConstant(this, Ty, getId(), static_cast<uint64_t>(Literal));
LiteralMap[Literal] = V;
addConstant(V);
return V;
}
void
SPIRVModuleImpl::layoutEntry(SPIRVEntry* E) {
auto OC = E->getOpCode();
switch (OC) {
case OpString:
addTo(StringVec, E);
break;
case OpMemberName:
addTo(MemberNameVec, E);
break;
case OpLine:
addTo(LineVec, E);
break;
case OpVariable: {
auto BV = static_cast<SPIRVVariable*>(E);
if (!BV->getParent())
addTo(VariableVec, E);
}
break;
default:
if (isTypeOpCode(OC))
TypeVec.push_back(static_cast<SPIRVType*>(E));
else if (isConstantOpCode(OC))
ConstVec.push_back(static_cast<SPIRVConstant*>(E));
break;
}
}
// Add an entry to the id to entry map.
// Assert if the id is mapped to a different entry.
// Certain entries need to be add to specific collectors to maintain
// logic layout of SPIRV.
SPIRVEntry *
SPIRVModuleImpl::addEntry(SPIRVEntry *Entry) {
assert(Entry && "Invalid entry");
if (Entry->hasId()) {
SPIRVId Id = Entry->getId();
assert(Entry->getId() != SPIRVID_INVALID && "Invalid id");
SPIRVEntry *Mapped = nullptr;
if (exist(Id, &Mapped)) {
if (Mapped->getOpCode() == OpForward) {
replaceForward(static_cast<SPIRVForward *>(Mapped), Entry);
} else {
assert(Mapped == Entry && "Id used twice");
}
} else
IdEntryMap[Id] = Entry;
} else {
EntryNoId.push_back(Entry);
}
Entry->setModule(this);
layoutEntry(Entry);
if (AutoAddCapability) {
for (auto &I:Entry->getRequiredCapability()) {
addCapability(I);
}
}
if (ValidateCapability) {
for (auto &I:Entry->getRequiredCapability()) {
(void) I;
assert(CapMap.count(I));
}
}
return Entry;
}
bool
SPIRVModuleImpl::exist(SPIRVId Id) const {
return exist(Id, nullptr);
}
bool
SPIRVModuleImpl::exist(SPIRVId Id, SPIRVEntry **Entry) const {
assert (Id != SPIRVID_INVALID && "Invalid Id");
SPIRVIdToEntryMap::const_iterator Loc = IdEntryMap.find(Id);
if (Loc == IdEntryMap.end())
return false;
if (Entry)
*Entry = Loc->second;
return true;
}
// If Id is invalid, returns the next available id.
// Otherwise returns the given id and adjust the next available id by increment.
SPIRVId
SPIRVModuleImpl::getId(SPIRVId Id, unsigned increment) {
if (!isValidId(Id))
Id = NextId;
else
NextId = std::max(Id, NextId);
NextId += increment;
return Id;
}
SPIRVEntry *
SPIRVModuleImpl::getEntry(SPIRVId Id) const {
assert (Id != SPIRVID_INVALID && "Invalid Id");
SPIRVIdToEntryMap::const_iterator Loc = IdEntryMap.find(Id);
assert (Loc != IdEntryMap.end() && "Id is not in map");
return Loc->second;
}
SPIRVExtInstSetKind
SPIRVModuleImpl::getBuiltinSet(SPIRVId SetId) const {
auto Loc = IdBuiltinMap.find(SetId);
assert(Loc != IdBuiltinMap.end() && "Invalid builtin set id");
return Loc->second;
}
bool
SPIRVModuleImpl::isEntryPoint(SPIRVExecutionModelKind ExecModel, SPIRVId EP)
const {
assert(isValid(ExecModel) && "Invalid execution model");
assert(EP != SPIRVID_INVALID && "Invalid function id");
auto Loc = EntryPointSet.find(ExecModel);
if (Loc == EntryPointSet.end())
return false;
return Loc->second.count(EP);
}
// Module change functions
bool
SPIRVModuleImpl::importBuiltinSet(const std::string& BuiltinSetName,
SPIRVId *BuiltinSetId) {
SPIRVId TmpBuiltinSetId = getId();
if (!importBuiltinSetWithId(BuiltinSetName, TmpBuiltinSetId))
return false;
if (BuiltinSetId)
*BuiltinSetId = TmpBuiltinSetId;
return true;
}
bool
SPIRVModuleImpl::importBuiltinSetWithId(const std::string& BuiltinSetName,
SPIRVId BuiltinSetId) {
SPIRVExtInstSetKind BuiltinSet = SPIRVEIS_Count;
SPIRVCKRT(SPIRVBuiltinSetNameMap::rfind(BuiltinSetName, &BuiltinSet),
InvalidBuiltinSetName, "Actual is " + BuiltinSetName);
IdBuiltinMap[BuiltinSetId] = BuiltinSet;
return true;
}
void
SPIRVModuleImpl::setAlignment(SPIRVValue *V, SPIRVWord A) {
V->setAlignment(A);
}
void
SPIRVModuleImpl::setName(SPIRVEntry *E, const std::string &Name) {
E->setName(Name);
if (!E->hasId())
return;
if (!Name.empty())
NamedId.insert(E->getId());
else
NamedId.erase(E->getId());
}
void SPIRVModuleImpl::resolveUnknownStructFields() {
for (auto &KV : UnknownStructFieldMap) {
auto *Struct = KV.first;
for (auto &Indices : KV.second) {
unsigned I = Indices.first;
SPIRVId ID = Indices.second;
auto Ty = static_cast<SPIRVType *>(getEntry(ID));
Struct->setMemberType(I, Ty);
}
}
}
// Type creation functions
template<class T>
T *
SPIRVModuleImpl::addType(T *Ty) {
add(Ty);
if (!Ty->getName().empty())
setName(Ty, Ty->getName());
return Ty;
}
SPIRVTypeVoid *
SPIRVModuleImpl::addVoidType() {
return addType(new SPIRVTypeVoid(this, getId()));
}
SPIRVTypeArray *
SPIRVModuleImpl::addArrayType(SPIRVType *ElementType, SPIRVConstant *Length) {
return addType(new SPIRVTypeArray(this, getId(), ElementType, Length));
}
SPIRVTypeBool *
SPIRVModuleImpl::addBoolType() {
return addType(new SPIRVTypeBool(this, getId()));
}
SPIRVTypeInt *
SPIRVModuleImpl::addIntegerType(unsigned BitWidth) {
auto Loc = IntTypeMap.find(BitWidth);
if (Loc != IntTypeMap.end())
return Loc->second;
auto Ty = new SPIRVTypeInt(this, getId(), BitWidth, false);
IntTypeMap[BitWidth] = Ty;
return addType(Ty);
}
SPIRVTypeFloat *
SPIRVModuleImpl::addFloatType(unsigned BitWidth) {
SPIRVTypeFloat *T = addType(new SPIRVTypeFloat(this, getId(), BitWidth));
return T;
}
SPIRVTypePointer *
SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
SPIRVType *ElementType) {
return addType(new SPIRVTypePointer(this, getId(), StorageClass,
ElementType));
}
SPIRVTypeFunction *
SPIRVModuleImpl::addFunctionType(SPIRVType *ReturnType,
const std::vector<SPIRVType *>& ParameterTypes) {
return addType(new SPIRVTypeFunction(this, getId(), ReturnType,
ParameterTypes));
}
SPIRVTypeOpaque*
SPIRVModuleImpl::addOpaqueType(const std::string& Name) {
return addType(new SPIRVTypeOpaque(this, getId(), Name));
}
SPIRVTypeStruct *SPIRVModuleImpl::openStructType(unsigned NumMembers,
const std::string &Name) {
auto T = new SPIRVTypeStruct(this, getId(), NumMembers, Name);
return T;
}
void SPIRVModuleImpl::closeStructType(SPIRVTypeStruct *T, bool Packed) {
addType(T);
T->setPacked(Packed);
}
SPIRVTypeVector*
SPIRVModuleImpl::addVectorType(SPIRVType* CompType, SPIRVWord CompCount) {
return addType(new SPIRVTypeVector(this, getId(), CompType, CompCount));
}
SPIRVType *
SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
}
SPIRVTypeDeviceEvent *
SPIRVModuleImpl::addDeviceEventType() {
return addType(new SPIRVTypeDeviceEvent(this, getId()));
}
SPIRVTypeQueue *
SPIRVModuleImpl::addQueueType() {
return addType(new SPIRVTypeQueue(this, getId()));
}
SPIRVTypePipe*
SPIRVModuleImpl::addPipeType() {
return addType(new SPIRVTypePipe(this, getId()));
}
SPIRVTypeImage *
SPIRVModuleImpl::addImageType(SPIRVType *SampledType,
const SPIRVTypeImageDescriptor &Desc) {
return addType(new SPIRVTypeImage(this, getId(),
SampledType ? SampledType->getId() : 0, Desc));
}
SPIRVTypeImage *
SPIRVModuleImpl::addImageType(SPIRVType *SampledType,
const SPIRVTypeImageDescriptor &Desc, SPIRVAccessQualifierKind Acc) {
return addType(new SPIRVTypeImage(this, getId(),
SampledType ? SampledType->getId() : 0, Desc, Acc));
}
SPIRVTypeSampler *
SPIRVModuleImpl::addSamplerType() {
return addType(new SPIRVTypeSampler(this, getId()));
}
SPIRVTypePipeStorage*
SPIRVModuleImpl::addPipeStorageType() {
return addType(new SPIRVTypePipeStorage(this, getId()));
}
SPIRVTypeSampledImage *
SPIRVModuleImpl::addSampledImageType(SPIRVTypeImage *T) {
return addType(new SPIRVTypeSampledImage(this, getId(), T));
}
void SPIRVModuleImpl::createForwardPointers() {
std::unordered_set<SPIRVId> Seen;
for (auto *T : TypeVec) {
if (T->hasId())
Seen.insert(T->getId());
if (!T->isTypeStruct())
continue;
auto ST = static_cast<SPIRVTypeStruct *>(T);
for (unsigned i = 0; i < ST->getStructMemberCount(); ++i) {
auto MemberTy = ST->getStructMemberType(i);
if (!MemberTy->isTypePointer()) continue;
auto Ptr = static_cast<SPIRVTypePointer *>(MemberTy);
if (Seen.find(Ptr->getId()) == Seen.end()) {
ForwardPointerVec.push_back(new SPIRVTypeForwardPointer(
this, Ptr, Ptr->getPointerStorageClass()));
}
}
}
}
SPIRVFunction *
SPIRVModuleImpl::addFunction(SPIRVFunction *Func) {
FuncVec.push_back(add(Func));
return Func;
}
SPIRVFunction *
SPIRVModuleImpl::addFunction(SPIRVTypeFunction *FuncType, SPIRVId Id) {
return addFunction(new SPIRVFunction(this, FuncType,
getId(Id, FuncType->getNumParameters() + 1)));
}
SPIRVBasicBlock *
SPIRVModuleImpl::addBasicBlock(SPIRVFunction *Func, SPIRVId Id) {
return Func->addBasicBlock(new SPIRVBasicBlock(getId(Id), Func));
}
const SPIRVDecorateGeneric *
SPIRVModuleImpl::addDecorate(const SPIRVDecorateGeneric *Dec) {
SPIRVId Id = Dec->getTargetId();
SPIRVEntry *Target = nullptr;
bool Found = exist(Id, &Target);
(void) Found;
assert (Found && "Decorate target does not exist");
if (!Dec->getOwner())
DecorateSet.insert(Dec);
addCapabilities(Dec->getRequiredCapability());
return Dec;
}
void
SPIRVModuleImpl::addEntryPoint(SPIRVExecutionModelKind ExecModel,
SPIRVId EntryPoint){
assert(isValid(ExecModel) && "Invalid execution model");
assert(EntryPoint != SPIRVID_INVALID && "Invalid entry point");
EntryPointSet[ExecModel].insert(EntryPoint);
EntryPointVec[ExecModel].push_back(EntryPoint);
addCapabilities(SPIRV::getCapability(ExecModel));
}
SPIRVForward *
SPIRVModuleImpl::addForward(SPIRVType *Ty) {
return add(new SPIRVForward(this, Ty, getId()));
}
SPIRVForward *
SPIRVModuleImpl::addForward(SPIRVId Id, SPIRVType *Ty) {
return add(new SPIRVForward(this, Ty, Id));
}
SPIRVEntry *
SPIRVModuleImpl::replaceForward(SPIRVForward *Forward, SPIRVEntry *Entry) {
SPIRVId Id = Entry->getId();
SPIRVId ForwardId = Forward->getId();
if (ForwardId == Id)
IdEntryMap[Id] = Entry;
else {
auto Loc = IdEntryMap.find(Id);
assert(Loc != IdEntryMap.end());
IdEntryMap.erase(Loc);
Entry->setId(ForwardId);
IdEntryMap[ForwardId] = Entry;
}
// Annotations include name, decorations, execution modes
Entry->takeAnnotations(Forward);
delete Forward;
return Entry;
}
SPIRVValue *
SPIRVModuleImpl::addConstant(SPIRVValue *C) {
return add(C);
}
SPIRVValue *
SPIRVModuleImpl::addConstant(SPIRVType *Ty, uint64_t V) {
if (Ty->isTypeBool()) {
if (V)
return addConstant(new SPIRVConstantTrue(this, Ty, getId()));
else
return addConstant(new SPIRVConstantFalse(this, Ty, getId()));
}
if (Ty->isTypeInt())
return addIntegerConstant(static_cast<SPIRVTypeInt*>(Ty), V);
return addConstant(new SPIRVConstant(this, Ty, getId(), V));
}
SPIRVValue *
SPIRVModuleImpl::addIntegerConstant(SPIRVTypeInt *Ty, uint64_t V) {
if (Ty->getBitWidth() == 32) {
unsigned I32 = V;
assert(I32 == V && "Integer value truncated");
return getLiteralAsConstant(I32);
}
return addConstant(new SPIRVConstant(this, Ty, getId(), V));
}
SPIRVValue *
SPIRVModuleImpl::addFloatConstant(SPIRVTypeFloat *Ty, float V) {
return addConstant(new SPIRVConstant(this, Ty, getId(), V));
}
SPIRVValue *
SPIRVModuleImpl::addDoubleConstant(SPIRVTypeFloat *Ty, double V) {
return addConstant(new SPIRVConstant(this, Ty, getId(), V));
}
SPIRVValue *
SPIRVModuleImpl::addNullConstant(SPIRVType *Ty) {
return addConstant(new SPIRVConstantNull(this, Ty, getId()));
}
SPIRVValue *
SPIRVModuleImpl::addCompositeConstant(SPIRVType *Ty,
const std::vector<SPIRVValue*>& Elements) {
return addConstant(new SPIRVConstantComposite(this, Ty, getId(), Elements));
}
SPIRVValue *
SPIRVModuleImpl::addUndef(SPIRVType *TheType) {
return addConstant(new SPIRVUndef(this, TheType, getId()));
}
// Instruction creation functions
SPIRVInstruction *
SPIRVModuleImpl::addStoreInst(SPIRVValue *Target, SPIRVValue *Source,
const std::vector<SPIRVWord> &TheMemoryAccess, SPIRVBasicBlock *BB) {
return BB->addInstruction(new SPIRVStore(Target->getId(),
Source->getId(), TheMemoryAccess, BB));
}
SPIRVInstruction *
SPIRVModuleImpl::addSwitchInst(SPIRVValue *Select, SPIRVBasicBlock *Default,
const std::vector<std::pair<SPIRVWord, SPIRVBasicBlock *>>& Pairs,
SPIRVBasicBlock *BB) {
return BB->addInstruction(new SPIRVSwitch(Select, Default, Pairs, BB));
}
SPIRVInstruction *
SPIRVModuleImpl::addGroupInst(Op OpCode, SPIRVType *Type,
Scope Scope, const std::vector<SPIRVValue *> &Ops,
SPIRVBasicBlock *BB) {
assert(!Type || !Type->isTypeVoid());
auto WordOps = getIds(Ops);
WordOps.insert(WordOps.begin(), Scope);
return addInstTemplate(OpCode, WordOps, BB, Type);
}
SPIRVInstruction *
SPIRVModuleImpl::addInstruction(SPIRVInstruction *Inst, SPIRVBasicBlock *BB) {
if (BB)
return BB->addInstruction(Inst);
if (Inst->getOpCode() != OpSpecConstantOp)
Inst = createSpecConstantOpInst(Inst);
return static_cast<SPIRVInstruction *>(addConstant(Inst));
}
SPIRVInstruction *
SPIRVModuleImpl::addLoadInst(SPIRVValue *Source,
const std::vector<SPIRVWord> &TheMemoryAccess, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVLoad(getId(), Source->getId(),
TheMemoryAccess, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addPhiInst(SPIRVType *Type,
std::vector<SPIRVValue *> IncomingPairs, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVPhi(Type, getId(), IncomingPairs, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addExtInst(SPIRVType *TheType, SPIRVWord BuiltinSet,
SPIRVWord EntryPoint, const std::vector<SPIRVWord> &Args,
SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVExtInst(TheType, getId(),
BuiltinSet, EntryPoint, Args, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addExtInst(SPIRVType *TheType, SPIRVWord BuiltinSet,
SPIRVWord EntryPoint, const std::vector<SPIRVValue *> &Args,
SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVExtInst(TheType, getId(),
BuiltinSet, EntryPoint, Args, BB), BB);
}
SPIRVInstruction*
SPIRVModuleImpl::addCallInst(SPIRVFunction* TheFunction,
const std::vector<SPIRVWord> &TheArguments, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVFunctionCall(getId(), TheFunction,
TheArguments, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addBinaryInst(Op TheOpCode, SPIRVType *Type,
SPIRVValue *Op1, SPIRVValue *Op2, SPIRVBasicBlock *BB){
return addInstruction(SPIRVInstTemplateBase::create(TheOpCode, Type, getId(),
getVec(Op1->getId(), Op2->getId()), BB, this), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addReturnInst(SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVReturn(BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addReturnValueInst(SPIRVValue *ReturnValue, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVReturnValue(ReturnValue, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addUnaryInst(Op TheOpCode, SPIRVType *TheType,
SPIRVValue *Op, SPIRVBasicBlock *BB) {
return addInstruction(SPIRVInstTemplateBase::create(TheOpCode,
TheType, getId(), getVec(Op->getId()), BB, this), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addVectorExtractDynamicInst(SPIRVValue *TheVector,
SPIRVValue *Index, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVVectorExtractDynamic(getId(), TheVector,
Index, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addVectorInsertDynamicInst(SPIRVValue *TheVector,
SPIRVValue *TheComponent, SPIRVValue*Index, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVVectorInsertDynamic(getId(), TheVector,
TheComponent, Index, BB), BB);
}
SPIRVValue *
SPIRVModuleImpl::addVectorShuffleInst(SPIRVType * Type, SPIRVValue *Vec1,
SPIRVValue *Vec2, const std::vector<SPIRVWord> &Components,
SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVVectorShuffle(getId(), Type, Vec1, Vec2,
Components, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addBranchInst(SPIRVLabel *TargetLabel, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVBranch(TargetLabel, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addBranchConditionalInst(SPIRVValue *Condition,
SPIRVLabel *TrueLabel, SPIRVLabel *FalseLabel, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVBranchConditional(Condition, TrueLabel,
FalseLabel, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addCmpInst(Op TheOpCode, SPIRVType *TheType,
SPIRVValue *Op1, SPIRVValue *Op2, SPIRVBasicBlock *BB) {
return addInstruction(SPIRVInstTemplateBase::create(TheOpCode,
TheType, getId(), getVec(Op1->getId(), Op2->getId()), BB, this), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addControlBarrierInst(SPIRVValue *ExecKind,
SPIRVValue *MemKind, SPIRVValue *MemSema, SPIRVBasicBlock *BB) {
return addInstruction(
new SPIRVControlBarrier(ExecKind, MemKind, MemSema, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addMemoryBarrierInst(Scope ScopeKind,
SPIRVWord MemFlag, SPIRVBasicBlock *BB) {
return addInstruction(SPIRVInstTemplateBase::create(OpMemoryBarrier,
nullptr, SPIRVID_INVALID,
getVec(static_cast<SPIRVWord>(ScopeKind), MemFlag), BB, this), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addSelectInst(SPIRVValue *Condition, SPIRVValue *Op1,
SPIRVValue *Op2, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVSelect(getId(), Condition->getId(),
Op1->getId(), Op2->getId(), BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addPtrAccessChainInst(SPIRVType *Type, SPIRVValue *Base,
std::vector<SPIRVValue *> Indices, SPIRVBasicBlock *BB, bool IsInBounds){
return addInstruction(SPIRVInstTemplateBase::create(
IsInBounds?OpInBoundsPtrAccessChain:OpPtrAccessChain,
Type, getId(), getVec(Base->getId(), Base->getIds(Indices)),
BB, this), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addAsyncGroupCopy(SPIRVValue *Scope,
SPIRVValue *Dest, SPIRVValue *Src, SPIRVValue *NumElems, SPIRVValue *Stride,
SPIRVValue *Event, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVGroupAsyncCopy(Scope, getId(), Dest, Src,
NumElems, Stride, Event, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addCompositeExtractInst(SPIRVType *Type, SPIRVValue *TheVector,
const std::vector<SPIRVWord>& Indices, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVCompositeExtract(Type, getId(), TheVector,
Indices, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addCompositeInsertInst(SPIRVValue *Object,
SPIRVValue *Composite, const std::vector<SPIRVWord>& Indices,
SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVCompositeInsert(getId(), Object, Composite,
Indices, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addCopyObjectInst(SPIRVType *TheType, SPIRVValue *Operand,
SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVCopyObject(TheType, getId(), Operand, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addCopyMemoryInst(SPIRVValue *TheTarget, SPIRVValue *TheSource,
const std::vector<SPIRVWord> &TheMemoryAccess, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVCopyMemory(TheTarget, TheSource,
TheMemoryAccess, BB), BB);
}
SPIRVInstruction *
SPIRVModuleImpl::addCopyMemorySizedInst(SPIRVValue *TheTarget,
SPIRVValue *TheSource, SPIRVValue *TheSize,
const std::vector<SPIRVWord> &TheMemoryAccess, SPIRVBasicBlock *BB) {
return addInstruction(new SPIRVCopyMemorySized(TheTarget, TheSource, TheSize,
TheMemoryAccess, BB), BB);
}
SPIRVInstruction*
SPIRVModuleImpl::addVariable(SPIRVType *Type, bool IsConstant,
SPIRVLinkageTypeKind LinkageType, SPIRVValue *Initializer,
const std::string &Name, SPIRVStorageClassKind StorageClass,
SPIRVBasicBlock *BB) {
SPIRVVariable *Variable = new SPIRVVariable(Type, getId(), Initializer,
Name, StorageClass, BB, this);
if (BB)
return addInstruction(Variable, BB);
add(Variable);
if (LinkageType != LinkageTypeInternal)
Variable->setLinkageType(LinkageType);
Variable->setIsConstant(IsConstant);
return Variable;
}
template<class T>
spv_ostream &
operator<< (spv_ostream &O, const std::vector<T *>& V) {
for (auto &I: V)
O << *I;
return O;
}
template<class T, class B>
spv_ostream &
operator<< (spv_ostream &O, const std::multiset<T *, B>& V) {
for (auto &I: V)
O << *I;
return O;
}
// To satisfy SPIR-V spec requirement:
// "All operands must be declared before being used",
// we do DFS based topological sort
// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
class TopologicalSort {
enum DFSState : char {
Unvisited,
Discovered,
Visited
};
typedef std::vector<SPIRVType *> SPIRVTypeVec;
typedef std::vector<SPIRVValue *> SPIRVConstantVector;
typedef std::vector<SPIRVVariable *> SPIRVVariableVec;
typedef std::vector<SPIRVTypeForwardPointer *> SPIRVForwardPointerVec;
typedef std::function<bool(SPIRVEntry*, SPIRVEntry*)> IdComp;
typedef std::map<SPIRVEntry*, DFSState, IdComp> EntryStateMapTy;
SPIRVTypeVec TypeIntVec;
SPIRVConstantVector ConstIntVec;
SPIRVTypeVec TypeVec;
SPIRVConstantVector ConstVec;
SPIRVVariableVec VariableVec;
const SPIRVForwardPointerVec& ForwardPointerVec;
EntryStateMapTy EntryStateMap;
friend spv_ostream & operator<<(spv_ostream &O, const TopologicalSort &S);
// This method implements recursive depth-first search among all Entries in
// EntryStateMap. Traversing entries and adding them to corresponding container
// after visiting all dependent entries(post-order traversal) guarantees that
// the entry's operands will appear in the container before the entry itslef.
void visit(SPIRVEntry* E) {
DFSState& State = EntryStateMap[E];
assert(State != Discovered && "Cyclic dependency detected");
if (State == Visited)
return;
State = Discovered;
for (SPIRVEntry *Op : E->getNonLiteralOperands()) {
auto Comp = [&Op](SPIRVTypeForwardPointer *FwdPtr) {
return FwdPtr->getPointer() == Op;
};
// Skip forward referenced pointers
if (Op->getOpCode() == OpTypePointer &&
find_if(ForwardPointerVec.begin(), ForwardPointerVec.end(), Comp) !=
ForwardPointerVec.end())
continue;
visit(Op);
}
State = Visited;
Op OC = E->getOpCode();
if (OC == OpTypeInt)
TypeIntVec.push_back(static_cast<SPIRVType*>(E));
else if (isConstantOpCode(OC)) {
SPIRVConstant *C = static_cast<SPIRVConstant*>(E);
if (C->getType()->isTypeInt())
ConstIntVec.push_back(C);
else
ConstVec.push_back(C);
} else if (isTypeOpCode(OC))
TypeVec.push_back(static_cast<SPIRVType*>(E));
else if (E->isVariable())
VariableVec.push_back(static_cast<SPIRVVariable*>(E));
}
public:
TopologicalSort(const SPIRVTypeVec &_TypeVec,
const SPIRVConstantVector &_ConstVec,
const SPIRVVariableVec &_VariableVec,
const SPIRVForwardPointerVec &_ForwardPointerVec) :
ForwardPointerVec(_ForwardPointerVec),
EntryStateMap([](SPIRVEntry* a, SPIRVEntry* b) -> bool {
return a->getId() < b->getId();
})
{
// Collect entries for sorting
for (auto *T : _TypeVec)
EntryStateMap[T] = DFSState::Unvisited;
for (auto *C : _ConstVec)
EntryStateMap[C] = DFSState::Unvisited;
for (auto *V : _VariableVec)
EntryStateMap[V] = DFSState::Unvisited;
// Run topoligical sort
for (auto ES : EntryStateMap)
visit(ES.first);
}
};
spv_ostream &
operator<< (spv_ostream &O, const TopologicalSort &S) {
O << S.TypeIntVec
<< S.ConstIntVec
<< S.TypeVec
<< S.ConstVec
<< S.VariableVec;
return O;
}
spv_ostream &
operator<< (spv_ostream &O, SPIRVModule &M) {
SPIRVModuleImpl &MI = *static_cast<SPIRVModuleImpl*>(&M);
SPIRVEncoder Encoder(O);
Encoder << MagicNumber
<< MI.SPIRVVersion
<< (((SPIRVWord)MI.GeneratorId << 16) | MI.GeneratorVer)
<< MI.NextId /* Bound for Id */
<< MI.InstSchema;
O << SPIRVNL();
for (auto &I:MI.CapMap)
O << *I.second;
for (auto &I:M.getExtension()) {
assert(!I.empty() && "Invalid extension");
O << SPIRVExtension(&M, I);
}
for (auto &I:MI.IdBuiltinMap)
O << SPIRVExtInstImport(&M, I.first, SPIRVBuiltinSetNameMap::map(I.second));
O << SPIRVMemoryModel(&M);
for (auto &I:MI.EntryPointVec)
for (auto &II:I.second)
O << SPIRVEntryPoint(&M, I.first, II,
M.get<SPIRVFunction>(II)->getName());
for (auto &I:MI.EntryPointVec)
for (auto &II:I.second)
MI.get<SPIRVFunction>(II)->encodeExecutionModes(O);
O << MI.StringVec;
for (auto &I:M.getSourceExtension()) {
assert(!I.empty() && "Invalid source extension");
O << SPIRVSourceExtension(&M, I);
}
O << SPIRVSource(&M);
for (auto &I:MI.NamedId) {
// Don't output name for entry point since it is redundant
bool IsEntryPoint = false;
for (auto &EPS:MI.EntryPointSet)
if (EPS.second.count(I)) {
IsEntryPoint = true;
break;
}
if (!IsEntryPoint)
M.getEntry(I)->encodeName(O);
}
O << MI.MemberNameVec
<< MI.LineVec
<< MI.DecGroupVec
<< MI.DecorateSet
<< MI.GroupDecVec
<< MI.ForwardPointerVec
<< TopologicalSort(MI.TypeVec, MI.ConstVec, MI.VariableVec,
MI.ForwardPointerVec)
<< SPIRVNL()
<< MI.FuncVec;
return O;
}
template<class T>
void SPIRVModuleImpl::addTo(std::vector<T*>& V, SPIRVEntry* E) {
V.push_back(static_cast<T *>(E));
}
// The first decoration group includes all the previously defined decorates.
// The second decoration group includes all the decorates defined between the
// first and second decoration group. So long so forth.
SPIRVDecorationGroup*
SPIRVModuleImpl::addDecorationGroup() {
return addDecorationGroup(new SPIRVDecorationGroup(this, getId()));
}
SPIRVDecorationGroup*
SPIRVModuleImpl::addDecorationGroup(SPIRVDecorationGroup* Group) {
add(Group);
Group->takeDecorates(DecorateSet);
DecGroupVec.push_back(Group);
SPIRVDBG(spvdbgs() << "[addDecorationGroup] {" << *Group << "}\n";
spvdbgs() << " Remaining DecorateSet: {" << DecorateSet << "}\n");
assert(DecorateSet.empty());
return Group;
}
SPIRVGroupDecorateGeneric*
SPIRVModuleImpl::addGroupDecorateGeneric(SPIRVGroupDecorateGeneric *GDec) {
add(GDec);
GDec->decorateTargets();
GroupDecVec.push_back(GDec);
return GDec;
}
SPIRVGroupDecorate*
SPIRVModuleImpl::addGroupDecorate(
SPIRVDecorationGroup* Group, const std::vector<SPIRVEntry*>& Targets) {
auto GD = new SPIRVGroupDecorate(Group, getIds(Targets));
addGroupDecorateGeneric(GD);
return GD;
}
SPIRVGroupMemberDecorate*
SPIRVModuleImpl::addGroupMemberDecorate(
SPIRVDecorationGroup* Group, const std::vector<SPIRVEntry*>& Targets) {
auto GMD = new SPIRVGroupMemberDecorate(Group, getIds(Targets));
addGroupDecorateGeneric(GMD);
return GMD;
}
SPIRVString*
SPIRVModuleImpl::getString(const std::string& Str) {
auto Loc = StrMap.find(Str);
if (Loc != StrMap.end())
return Loc->second;
auto S = add(new SPIRVString(this, getId(), Str));
StrMap[Str] = S;
return S;
}
SPIRVMemberName*
SPIRVModuleImpl::addMemberName(SPIRVTypeStruct* ST,
SPIRVWord MemberNumber, const std::string& Name) {
return add(new SPIRVMemberName(ST, MemberNumber, Name));
}
void SPIRVModuleImpl::addUnknownStructField(SPIRVTypeStruct *Struct, unsigned I,
SPIRVId ID) {
UnknownStructFieldMap[Struct].push_back(std::make_pair(I, ID));
}
std::istream &
operator>> (std::istream &I, SPIRVModule &M) {
SPIRVDecoder Decoder(I, M);
SPIRVModuleImpl &MI = *static_cast<SPIRVModuleImpl*>(&M);
// Disable automatic capability filling.
MI.setAutoAddCapability(false);
SPIRVWord Magic;
Decoder >> Magic;
assert(Magic == MagicNumber && "Invalid magic number");
Decoder >> MI.SPIRVVersion;
assert(MI.SPIRVVersion <= SPV_VERSION && "Unsupported SPIRV version number");
SPIRVWord Generator = 0;
Decoder >> Generator;
MI.GeneratorId = Generator >> 16;
MI.GeneratorVer = Generator & 0xFFFF;
// Bound for Id
Decoder >> MI.NextId;
Decoder >> MI.InstSchema;
assert(MI.InstSchema == SPIRVISCH_Default && "Unsupported instruction schema");
while(Decoder.getWordCountAndOpCode())
Decoder.getEntry();
MI.optimizeDecorates();
MI.resolveUnknownStructFields();
MI.createForwardPointers();
return I;
}
SPIRVModule *
SPIRVModule::createSPIRVModule() {
return new SPIRVModuleImpl;
}
SPIRVValue *
SPIRVModuleImpl::getValue(SPIRVId TheId)const {
return get<SPIRVValue>(TheId);
}
SPIRVType *
SPIRVModuleImpl::getValueType(SPIRVId TheId)const {
return get<SPIRVValue>(TheId)->getType();
}
std::vector<SPIRVValue *>
SPIRVModuleImpl::getValues(const std::vector<SPIRVId>& IdVec)const {
std::vector<SPIRVValue *> ValueVec;
for (auto i:IdVec)
ValueVec.push_back(getValue(i));
return ValueVec;
}
std::vector<SPIRVType *>
SPIRVModuleImpl::getValueTypes(const std::vector<SPIRVId>& IdVec)const {
std::vector<SPIRVType *> TypeVec;
for (auto i:IdVec)
TypeVec.push_back(getValue(i)->getType());
return TypeVec;
}
std::vector<SPIRVId>
SPIRVModuleImpl::getIds(const std::vector<SPIRVEntry *> &ValueVec)const {
std::vector<SPIRVId> IdVec;
for (auto i:ValueVec)
IdVec.push_back(i->getId());
return IdVec;
}
std::vector<SPIRVId>
SPIRVModuleImpl::getIds(const std::vector<SPIRVValue *> &ValueVec)const {
std::vector<SPIRVId> IdVec;
for (auto i:ValueVec)
IdVec.push_back(i->getId());
return IdVec;
}
SPIRVInstTemplateBase*
SPIRVModuleImpl::addInstTemplate(Op OC,
SPIRVBasicBlock* BB, SPIRVType *Ty) {
assert (!Ty || !Ty->isTypeVoid());
SPIRVId Id = Ty ? getId() : SPIRVID_INVALID;
auto Ins = SPIRVInstTemplateBase::create(OC, Ty, Id, BB, this);
BB->addInstruction(Ins);
return Ins;
}
SPIRVInstTemplateBase*
SPIRVModuleImpl::addInstTemplate(Op OC,
const std::vector<SPIRVWord>& Ops, SPIRVBasicBlock* BB, SPIRVType *Ty) {
assert (!Ty || !Ty->isTypeVoid());
SPIRVId Id = Ty ? getId() : SPIRVID_INVALID;
auto Ins = SPIRVInstTemplateBase::create(OC, Ty, Id, Ops, BB, this);
BB->addInstruction(Ins);
return Ins;
}
SPIRVDbgInfo::SPIRVDbgInfo(SPIRVModule *TM)
:M(TM){
}
std::string
SPIRVDbgInfo::getEntryPointFileStr(SPIRVExecutionModelKind EM, unsigned I) {
if (M->getNumEntryPoints(EM) == 0)
return "";
return getFunctionFileStr(M->getEntryPoint(EM, I));
}
std::string
SPIRVDbgInfo::getFunctionFileStr(SPIRVFunction *F) {
if (F->hasLine())
return F->getLine()->getFileNameStr();
return "";
}
unsigned
SPIRVDbgInfo::getFunctionLineNo(SPIRVFunction *F) {
if (F->hasLine())
return F->getLine()->getLine();
return 0;
}
bool IsSPIRVBinary(const std::string &Img) {
if (Img.size() < sizeof(unsigned))
return false;
auto Magic = reinterpret_cast<const unsigned*>(Img.data());
return *Magic == MagicNumber;
}
#ifdef _SPIRV_SUPPORT_TEXT_FMT
bool ConvertSPIRV(std::istream &IS, spv_ostream &OS,
std::string &ErrMsg, bool FromText, bool ToText) {
auto SaveOpt = SPIRVUseTextFormat;
SPIRVUseTextFormat = FromText;
SPIRVModuleImpl M;
IS >> M;
if (M.getError(ErrMsg) != SPIRVEC_Success) {
SPIRVUseTextFormat = SaveOpt;
return false;
}
SPIRVUseTextFormat = ToText;
OS << M;
if (M.getError(ErrMsg) != SPIRVEC_Success) {
SPIRVUseTextFormat = SaveOpt;
return false;
}
SPIRVUseTextFormat = SaveOpt;
return true;
}
bool IsSPIRVText(const std::string &Img) {
std::istringstream SS(Img);
unsigned Magic = 0;
SS >> Magic;
if (SS.bad())
return false;
return Magic == MagicNumber;
}
bool ConvertSPIRV(std::string &Input, std::string &Out,
std::string &ErrMsg, bool ToText) {
auto FromText = IsSPIRVText(Input);
if (ToText == FromText) {
Out = Input;
return true;
}
std::istringstream IS(Input);
#ifdef _SPIRV_LLVM_API
llvm::raw_string_ostream OS(Out);
#else
std::ostringstream OS;
#endif
if (!ConvertSPIRV(IS, OS, ErrMsg, FromText, ToText))
return false;
Out = OS.str();
return true;
}
#endif // _SPIRV_SUPPORT_TEXT_FMT
}