Add spirv::GlobalVariableOp that allows module level definition of variables
FuncOps in MLIR use explicit capture. So global variables defined in
module scope need to have a symbol name and this should be used to
refer to the variable within the function. This deviates from SPIR-V
spec, which assigns an SSA value to variables at all scopes that can
be used to refer to the variable, which requires SPIR-V functions to
allow implicit capture. To handle this add a new op,
spirv::GlobalVariableOp that can be used to define module scope
variables.
Since instructions need an SSA value, an new spirv::AddressOfOp is
added to convert a symbol reference to an SSA value for use with other
instructions.
This also means the spirv::EntryPointOp instruction needs to change to
allow initializers to be specified using symbol reference instead of
SSA value
The current spirv::VariableOp which returns an SSA value (as defined
by SPIR-V spec) can still be used to define function-scope variables.
PiperOrigin-RevId: 263951109
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index ba95a76..de496a7 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -146,67 +146,6 @@
// -----
-def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
- let summary = [{
- Declare an entry point, its execution model, and its interface.
- }];
-
- let description = [{
- Execution Model is the execution model for the entry point and its
- static call tree. See Execution Model.
-
- Entry Point must be the Result <id> of an OpFunction instruction.
-
- Name is a name string for the entry point. A module cannot have two
- OpEntryPoint instructions with the same Execution Model and the same
- Name string.
-
- Interface is a list of <id> of global OpVariable instructions. These
- declare the set of global variables from a module that form the
- interface of this entry point. The set of Interface <id> must be equal
- to or a superset of the global OpVariable Result <id> referenced by the
- entry point’s static call tree, within the interface’s storage classes.
- Before version 1.4, the interface’s storage classes are limited to the
- Input and Output storage classes. Starting with version 1.4, the
- interface’s storage classes are all storage classes used in declaring
- all global variables referenced by the entry point’s call tree.
-
- Interface <id> are forward references. Before version 1.4, duplication
- of these <id> is tolerated. Starting with version 1.4, an <id> must not
- appear more than once.
-
- ### Custom assembly form
-
- ``` {.ebnf}
- execution-model ::= "Vertex" | "TesellationControl" |
- <and other SPIR-V execution models...>
-
- entry-point-op ::= ssa-id ` = spv.EntryPoint ` execution-model fn-name
- (ssa-use ( `, ` ssa-use)* ` : `
- pointer-type ( `, ` pointer-type)* )?
- ```
-
- For example:
-
- ```
- spv.EntryPoint "GLCompute" @foo
- spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
-
- ```
- }];
-
- let arguments = (ins
- SPV_ExecutionModelAttr:$execution_model,
- SymbolRefAttr:$fn,
- Variadic<SPV_AnyPtr>:$interface
- );
-
- let results = (outs);
- let autogenSerialization = 0;
-}
-
-// -----
-
def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
let summary = "Declare an execution mode for an entry point.";
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index b44d8ef..d475639 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -30,6 +30,160 @@
include "mlir/SPIRV/SPIRVBase.td"
#endif // SPIRV_BASE
+def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
+ let summary = "Get the address of a global variable.";
+
+ let description = [{
+ Variables in module scope are defined using symbol names. This
+ instruction generates an SSA value that can be used to refer to
+ the symbol within function scope for use in instructions that
+ expect an SSA value. This operation has no equivalent SPIR-V
+ instruction. Since variables in module scope in SPIR-V dialect are
+ of pointer type, this instruction returns a pointer type as well,
+ and the type is same as the variable referenced.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ address-of-op ::= ssa-id `=` `spv.addressOf` `@`string-literal : pointer-type
+ ```
+
+ For example:
+
+ ```
+ %0 = spv.addressOf @var1 : !spv.ptr<f32, Input>
+ ```
+ }];
+
+ let arguments = (ins
+ SymbolRefAttr:$variable
+ );
+
+ let results = (outs
+ SPV_AnyPtr:$pointer
+ );
+
+ let hasOpcode = 0;
+}
+
+def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
+ let summary = [{
+ Declare an entry point, its execution model, and its interface.
+ }];
+
+ let description = [{
+ Execution Model is the execution model for the entry point and its
+ static call tree. See Execution Model.
+
+ Entry Point must be the Result <id> of an OpFunction instruction.
+
+ Name is a name string for the entry point. A module cannot have two
+ OpEntryPoint instructions with the same Execution Model and the same
+ Name string.
+
+ Interface is a list of symbol references to spv.globalVariable
+ operations. These declare the set of global variables from a
+ module that form the interface of this entry point. The set of
+ Interface symbols must be equal to or a superset of the
+ spv.globalVariables referenced by the entry point’s static call
+ tree, within the interface’s storage classes. Before version 1.4,
+ the interface’s storage classes are limited to the Input and
+ Output storage classes. Starting with version 1.4, the interface’s
+ storage classes are all storage classes used in declaring all
+ global variables referenced by the entry point’s call tree.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ execution-model ::= "Vertex" | "TesellationControl" |
+ <and other SPIR-V execution models...>
+
+ entry-point-op ::= ssa-id `=` `spv.EntryPoint` execution-model
+ symbol-reference (`, ` symbol-reference)*
+ ```
+
+ For example:
+
+ ```
+ spv.EntryPoint "GLCompute" @foo
+ spv.EntryPoint "Kernel" @foo, @var1, @var2
+
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_ExecutionModelAttr:$execution_model,
+ SymbolRefAttr:$fn,
+ OptionalAttr<SymbolRefArrayAttr>:$interface
+ );
+
+ let results = (outs);
+ let autogenSerialization = 0;
+}
+
+
+def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> {
+ let summary = [{
+ Allocate an object in memory at module scope. The object is
+ referenced using a symbol name.
+ }];
+
+ let description = [{
+ The variable type must be an OpTypePointer. Its type operand is the type of
+ object in memory.
+
+ Storage Class is the Storage Class of the memory holding the object. It
+ cannot be Generic. It must be the same as the Storage Class operand of
+ the variable types. Only those storage classes that are valid at module
+ scope (like Input, Output, StorageBuffer, etc.) are valid.
+
+ Initializer is optional. If Initializer is present, it will be
+ the initial value of the variable’s memory content. Initializer
+ must be an symbol defined from a constant instruction or other
+ spv.globalVariable operation in module scope. Initializer must
+ have the same type as the type of the defined symbol.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ variable-op ::= `spv.globalVariable` spirv-type string-literal
+ (`initializer(` symbol-reference `)`)?
+ (`bind(` integer-literal, integer-literal `)`)?
+ (`built_in(` string-literal `)`)?
+ attribute-dict?
+ ```
+
+ where `initializer` specifies initializer and `bind` specifies the
+ descriptor set and binding number. `built_in` specifies SPIR-V
+ BuiltIn decoration associated with the op.
+
+ For example:
+
+ ```
+ spv.Variable !spv.ptr<f32, Input> @var0
+ spv.Variable !spv.ptr<f32, Output> @var2 initializer(@var0)
+ spv.Variable !spv.ptr<f32, Uniform> @var bind(1, 2)
+ spv.Variable !spv.ptr<vector<3xi32>> @var3 built_in("GlobalInvocationID")
+ ```
+ }];
+
+ let arguments = (ins
+ TypeAttr:$type,
+ StrAttr:$sym_name,
+ OptionalAttr<SymbolRefAttr>:$initializer
+ );
+
+ let results = (outs);
+
+ let hasOpcode = 0;
+
+ let extraClassDeclaration = [{
+ ::mlir::spirv::StorageClass storageClass() {
+ return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
+ }
+ }];
+}
+
def SPV_ModuleOp : SPV_Op<"module",
[SingleBlockImplicitTerminator<"ModuleEndOp">,
NativeOpTrait<"SymbolTable">]> {
diff --git a/third_party/mlir/include/mlir/IR/OpBase.td b/third_party/mlir/include/mlir/IR/OpBase.td
index 519222c..3183a76 100644
--- a/third_party/mlir/include/mlir/IR/OpBase.td
+++ b/third_party/mlir/include/mlir/IR/OpBase.td
@@ -872,6 +872,11 @@
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
}
+def SymbolRefArrayAttr :
+ TypedArrayAttrBase<SymbolRefAttr, "symbol ref array attribute"> {
+ let constBuilderCall = ?;
+}
+
//===----------------------------------------------------------------------===//
// Derive attribute kinds
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 53a40df..035de4f 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -136,26 +136,26 @@
signatureConverter, newFuncOp))) {
return failure();
}
- // Create spv.Variable ops for each of the arguments. These need to be bound
- // by the runtime. For now use descriptor_set 0, and arg number as the binding
- // number.
+ // Create spv.globalVariable ops for each of the arguments. These need to be
+ // bound by the runtime. For now use descriptor_set 0, and arg number as the
+ // binding number.
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return funcOp.emitError("expected op to be within a spv.module");
}
OpBuilder builder(module.getOperation()->getRegion(0));
- SmallVector<Value *, 4> interface;
+ SmallVector<Attribute, 4> interface;
for (auto &convertedArgType :
llvm::enumerate(signatureConverter.getConvertedTypes())) {
- auto variableOp = builder.create<spirv::VariableOp>(
- funcOp.getLoc(), convertedArgType.value(),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
- llvm::None);
+ std::string varName = funcOp.getName().str() + "_arg_" +
+ std::to_string(convertedArgType.index());
+ auto variableOp = builder.create<spirv::GlobalVariableOp>(
+ funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
+ builder.getStringAttr(varName), nullptr);
variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
variableOp.setAttr("binding",
builder.getI32IntegerAttr(convertedArgType.index()));
- interface.push_back(variableOp.getResult());
+ interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
@@ -164,7 +164,8 @@
funcOp.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
- builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+ builder.getSymbolRefAttr(newFuncOp.getName()),
+ builder.getArrayAttr(interface));
return success();
}
} // namespace mlir
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 4bea441..9947c02 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -32,11 +32,15 @@
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
+static constexpr const char kInitializerAttrName[] = "initializer";
+static constexpr const char kInterfaceAttrName[] = "interface";
static constexpr const char kIsSpecConstName[] = "is_spec_const";
+static constexpr const char kTypeAttrName[] = "type";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
-static constexpr const char kFnNameAttrName[] = "fn";
+static constexpr const char kVariableAttrName[] = "variable";
//===----------------------------------------------------------------------===//
// Common utility functions
@@ -239,6 +243,71 @@
printer->printOptionalAttrDict(op->getAttrs());
}
+static ParseResult parseVariableDecorations(OpAsmParser *parser,
+ OperationState *state) {
+ auto builtInName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
+ if (succeeded(parser->parseOptionalKeyword("bind"))) {
+ Attribute set, binding;
+ // Parse optional descriptor binding
+ auto descriptorSetName = convertToSnakeCase(
+ stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+ Type i32Type = parser->getBuilder().getIntegerType(32);
+ if (parser->parseLParen() ||
+ parser->parseAttribute(set, i32Type, descriptorSetName,
+ state->attributes) ||
+ parser->parseComma() ||
+ parser->parseAttribute(binding, i32Type, bindingName,
+ state->attributes) ||
+ parser->parseRParen()) {
+ return failure();
+ }
+ } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
+ StringAttr builtIn;
+ if (parser->parseLParen() ||
+ parser->parseAttribute(builtIn, Type(), builtInName,
+ state->attributes) ||
+ parser->parseRParen()) {
+ return failure();
+ }
+ }
+
+ // Parse other attributes
+ if (parser->parseOptionalAttributeDict(state->attributes))
+ return failure();
+
+ return success();
+}
+
+static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
+ SmallVectorImpl<StringRef> &elidedAttrs) {
+ // Print optional descriptor binding
+ auto descriptorSetName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+ auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
+ auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
+ if (descriptorSet && binding) {
+ elidedAttrs.push_back(descriptorSetName);
+ elidedAttrs.push_back(bindingName);
+ *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+ << ")";
+ }
+
+ // Print BuiltIn attribute if present
+ auto builtInName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
+ if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
+ *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
+ elidedAttrs.push_back(builtInName);
+ }
+
+ printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+}
+
//===----------------------------------------------------------------------===//
// spv.AccessChainOp
//===----------------------------------------------------------------------===//
@@ -363,6 +432,53 @@
}
//===----------------------------------------------------------------------===//
+// spv._address_of
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAddressOfOp(OpAsmParser *parser,
+ OperationState *state) {
+ SymbolRefAttr varRefAttr;
+ Type type;
+ if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName,
+ state->attributes) ||
+ parser->parseColonType(type)) {
+ return failure();
+ }
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return parser->emitError(parser->getCurrentLocation(),
+ "expected spv.ptr type");
+ }
+ state->addTypes(ptrType);
+ return success();
+}
+
+static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
+ SmallVector<StringRef, 4> elidedAttrs;
+ *printer << spirv::AddressOfOp::getOperationName();
+
+ // Print symbol name.
+ *printer << " @" << addressOfOp.variable();
+
+ // Print the type.
+ *printer << " : " << addressOfOp.pointer();
+}
+
+static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
+ auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
+ auto varOp =
+ moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
+ if (!varOp) {
+ return addressOfOp.emitError("expected spv.globalVariable symbol");
+ }
+ if (addressOfOp.pointer()->getType() != varOp.type()) {
+ return addressOfOp.emitError(
+ "mismatch in result type and type of global variable referenced");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
@@ -541,18 +657,28 @@
SmallVector<OpAsmParser::OperandType, 0> identifiers;
SmallVector<Type, 0> idTypes;
- Attribute fn;
- auto loc = parser->getCurrentLocation();
-
+ SymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
- parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
- parser->parseTrailingOperandList(identifiers) ||
- parser->parseOptionalColonTypeList(idTypes) ||
- parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
+ parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) {
return failure();
}
- if (!fn.isa<SymbolRefAttr>()) {
- return parser->emitError(loc, "expected symbol reference attribute");
+
+ if (!parser->parseOptionalComma()) {
+ // Parse the interface variables
+ SmallVector<Attribute, 4> interfaceVars;
+ do {
+ // The name of the interface variable attribute isnt important
+ auto attrName = "var_symbol";
+ SymbolRefAttr var;
+ SmallVector<NamedAttribute, 1> attrs;
+ if (parser->parseAttribute(var, Type(), attrName, attrs)) {
+ return failure();
+ }
+ interfaceVars.push_back(var);
+ } while (!parser->parseOptionalComma());
+ state->attributes.push_back(
+ {parser->getBuilder().getIdentifier(kInterfaceAttrName),
+ parser->getBuilder().getArrayAttr(interfaceVars)});
}
return success();
}
@@ -561,27 +687,16 @@
*printer << spirv::EntryPointOp::getOperationName() << " \""
<< stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
<< entryPointOp.fn();
- if (!entryPointOp.getNumOperands()) {
- return;
+ if (auto interface = entryPointOp.interface()) {
+ *printer << ", ";
+ mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(),
+ [&](Attribute a) { printer->printAttribute(a); });
}
- *printer << ", ";
- mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
- [&](Value *a) { printer->printOperand(a); });
- *printer << " : ";
- mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
- [&](const Value *a) { *printer << a->getType(); });
}
static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
- // Verify that all the interface ops are created from VariableOp
- for (auto interface : entryPointOp.interface()) {
- if (!llvm::isa_and_nonnull<spirv::VariableOp>(interface->getDefiningOp())) {
- return entryPointOp.emitOpError("interface operands to entry point must "
- "be generated from a variable op");
- }
- // TODO: Before version 1.4 the variables can only have storage_class of
- // Input or Output. That needs to be verified.
- }
+ // Checks for fn and interface symbol reference are done in spirv::ModuleOp
+ // verification.
return success();
}
@@ -628,6 +743,95 @@
}
//===----------------------------------------------------------------------===//
+// spv.globalVariable
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
+ OperationState *state) {
+ // Parse variable type.
+ TypeAttr typeAttr;
+ auto loc = parser->getCurrentLocation();
+ if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName,
+ state->attributes)) {
+ return failure();
+ }
+ auto ptrType = typeAttr.getValue().dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return parser->emitError(loc, "expected spv.ptr type");
+ }
+
+ // Parse variable name.
+ StringAttr nameAttr;
+ if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+ state->attributes)) {
+ return failure();
+ }
+
+ // Parse optional initializer
+ if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) {
+ SymbolRefAttr initSymbol;
+ if (parser->parseLParen() ||
+ parser->parseAttribute(initSymbol, Type(), kInitializerAttrName,
+ state->attributes) ||
+ parser->parseRParen())
+ return failure();
+ }
+
+ if (parseVariableDecorations(parser, state)) {
+ return failure();
+ }
+
+ return success();
+}
+
+static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
+ auto *op = varOp.getOperation();
+ SmallVector<StringRef, 4> elidedAttrs{
+ spirv::attributeName<spirv::StorageClass>()};
+ *printer << spirv::GlobalVariableOp::getOperationName();
+
+ // Print variable type.
+ *printer << " " << varOp.type();
+ elidedAttrs.push_back(kTypeAttrName);
+
+ // Print variable name.
+ *printer << " @" << varOp.sym_name();
+ elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+
+ // Print optional initializer
+ if (auto initializer = varOp.initializer()) {
+ *printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
+ << ")";
+ elidedAttrs.push_back(kInitializerAttrName);
+ }
+ printVariableDecorations(op, printer, elidedAttrs);
+}
+
+static LogicalResult verify(spirv::GlobalVariableOp varOp) {
+ // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
+ // object. It cannot be Generic. It must be the same as the Storage Class
+ // operand of the Result Type."
+ if (varOp.storageClass() == spirv::StorageClass::Generic)
+ return varOp.emitOpError("storage class cannot be 'Generic'");
+
+ if (auto initializer =
+ varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+ // Get the module
+ auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
+ // TODO: Currently only variable initialization with other variables is
+ // supported. They could be constants as well, but this needs module-level
+ // constants to have symbol name as well.
+ if (!moduleOp.lookupSymbol<spirv::GlobalVariableOp>(
+ initializer.getValue())) {
+ return varOp.emitOpError(
+ "initializer must be result of a spv.globalVariable op");
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//
@@ -773,13 +977,33 @@
for (auto &op : body) {
if (op.getDialect() == dialect) {
// For EntryPoint op, check that the function and execution model is not
- // duplicated in EntryPointOps
+ // duplicated in EntryPointOps. Also verify that the interface specified
+ // comes from globalVariables here to make this check cheaper.
if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) {
auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
if (!funcOp) {
return entryPointOp.emitError("function '")
<< entryPointOp.fn() << "' not found in 'spv.module'";
}
+ if (auto interface = entryPointOp.interface()) {
+ for (auto varRef : interface.getValue().getValue()) {
+ auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+ if (!varSymRef) {
+ return entryPointOp.emitError(
+ "expected symbol reference for interface "
+ "specification instead of '")
+ << varRef;
+ }
+ auto variableOp =
+ table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
+ if (!variableOp) {
+ return entryPointOp.emitError("expected spv.globalVariable "
+ "symbol reference instead of'")
+ << varSymRef << "'";
+ }
+ }
+ }
+
auto key = std::pair<FuncOp, spirv::ExecutionModel>(
funcOp, entryPointOp.execution_model());
auto entryPtIt = entryPoints.find(key);
@@ -898,42 +1122,9 @@
return failure();
}
- auto builtInName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
- if (succeeded(parser->parseOptionalKeyword("bind"))) {
- Attribute set, binding;
- // Parse optional descriptor binding
- auto descriptorSetName = convertToSnakeCase(
- stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
- Type i32Type = parser->getBuilder().getIntegerType(32);
- if (parser->parseLParen() ||
- parser->parseAttribute(set, i32Type, descriptorSetName,
- state->attributes) ||
- parser->parseComma() ||
- parser->parseAttribute(binding, i32Type, bindingName,
- state->attributes) ||
- parser->parseRParen()) {
- return failure();
- }
- } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
- Attribute builtIn;
- if (parser->parseLParen() ||
- parser->parseAttribute(builtIn, Type(), builtInName,
- state->attributes) ||
- parser->parseRParen()) {
- return failure();
- }
- if (!builtIn.isa<StringAttr>()) {
- return parser->emitError(parser->getCurrentLocation(),
- "expected string value for built_in attribute");
- }
- }
-
- // Parse other attributes
- if (parser->parseOptionalAttributeDict(state->attributes))
+ if (parseVariableDecorations(parser, state)) {
return failure();
+ }
// Parse result pointer type
Type type;
@@ -976,29 +1167,8 @@
*printer << ")";
}
- // Print optional descriptor binding
- auto descriptorSetName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
- auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
- auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
- if (descriptorSet && binding) {
- elidedAttrs.push_back(descriptorSetName);
- elidedAttrs.push_back(bindingName);
- *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
- << ")";
- }
+ printVariableDecorations(op, printer, elidedAttrs);
- // Print BuiltIn attribute if present
- auto builtInName =
- convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
- if (auto builtin = varOp.getAttrOfType<StringAttr>(builtInName)) {
- *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
- elidedAttrs.push_back(builtInName);
- }
-
- printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
*printer << " : " << varOp.getType();
}
@@ -1006,8 +1176,11 @@
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
- if (varOp.storage_class() == spirv::StorageClass::Generic)
- return varOp.emitOpError("storage class cannot be 'Generic'");
+ if (varOp.storage_class() != spirv::StorageClass::Function) {
+ return varOp.emitOpError(
+ "can only be used to model function-level variables. Use "
+ "spv.globalVariable for module-level variables.");
+ }
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
if (varOp.storage_class() != pointerType.getStorageClass())
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 1aad717..a3d71ed 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -90,9 +90,20 @@
/// them to their handler method accordingly.
LogicalResult processFunction(ArrayRef<uint32_t> operands);
+ /// Process the OpVariable instructions at current `offset` into `binary`. It
+ /// is expected that this method is used for variables that are to be defined
+ /// at module scope and will be deserialized into a spv.globalVariable
+ /// instruction.
+ LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
+
/// Get the FuncOp associated with a result <id> of OpFunction.
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
+ /// Get the global variable associated with a result <id> of OpVariable
+ spirv::GlobalVariableOp getVariable(uint32_t id) {
+ return globalVariableMap.lookup(id);
+ }
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
@@ -138,7 +149,15 @@
//===--------------------------------------------------------------------===//
/// Get the Value associated with a result <id>.
- Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+ Value *getValue(uint32_t id) {
+ if (auto varOp = getVariable(id)) {
+ auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
+ unknownLoc, varOp.type(),
+ opBuilder.getSymbolRefAttr(varOp.getOperation()));
+ return addressOfOp.pointer();
+ }
+ return valueMap.lookup(id);
+ }
/// Slices the first instruction out of `binary` and returns its opcode and
/// operands via `opcode` and `operands` respectively. Returns failure if
@@ -198,6 +217,9 @@
// Result <id> to function mapping.
DenseMap<uint32_t, FuncOp> funcMap;
+ // Result <id> to variable mapping;
+ DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+
// Result <id> to value mapping.
DenseMap<uint32_t, Value *> valueMap;
@@ -452,6 +474,76 @@
return success();
}
+LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
+ unsigned wordIndex = 0;
+ if (operands.size() < 3) {
+ return emitError(
+ unknownLoc,
+ "OpVariable needs at least 3 operands, type, <id> and storage class");
+ }
+
+ // Result Type.
+ auto type = getType(operands[wordIndex]);
+ if (!type) {
+ return emitError(unknownLoc, "unknown result type <id> : ")
+ << operands[wordIndex];
+ }
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType) {
+ return emitError(unknownLoc,
+ "expected a result type <id> to be a spv.ptr, found : ")
+ << type;
+ }
+ wordIndex++;
+
+ // Result <id>.
+ auto variableID = operands[wordIndex];
+ auto variableName = nameMap.lookup(variableID).str();
+ if (variableName.empty()) {
+ variableName = "spirv_var_" + std::to_string(variableID);
+ }
+ wordIndex++;
+
+ // Storage class.
+ auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
+ if (ptrType.getStorageClass() != storageClass) {
+ return emitError(unknownLoc, "mismatch in storage class of pointer type ")
+ << type << " and that specified in OpVariable instruction : "
+ << stringifyStorageClass(storageClass);
+ }
+ wordIndex++;
+
+ // Initializer.
+ SymbolRefAttr initializer = nullptr;
+ if (wordIndex < operands.size()) {
+ auto initializerOp = getVariable(operands[wordIndex]);
+ if (!initializerOp) {
+ return emitError(unknownLoc, "unknown <id> ")
+ << operands[wordIndex] << "used as initializer";
+ }
+ wordIndex++;
+ initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
+ }
+ if (wordIndex != operands.size()) {
+ return emitError(unknownLoc,
+ "found more operands than expected when deserializing "
+ "OpVariable instruction, only ")
+ << wordIndex << " of " << operands.size() << " processed";
+ }
+ auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
+ unknownLoc, opBuilder.getTypeAttr(type),
+ opBuilder.getStringAttr(variableName), initializer);
+
+ // Decorations.
+ if (decorations.count(variableID)) {
+ for (auto attr : decorations[variableID].getAttrs()) {
+ varOp.setAttr(attr.first, attr.second);
+ }
+ }
+ globalVariableMap[variableID] = varOp;
+ return success();
+}
+
LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc, "OpName needs at least 2 operands");
@@ -887,6 +979,11 @@
return success();
}
break;
+ case spirv::Opcode::OpVariable:
+ if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
+ return processGlobalVariable(operands);
+ }
+ break;
case spirv::Opcode::OpName:
return processName(operands);
case spirv::Opcode::OpTypeVoid:
@@ -954,18 +1051,19 @@
"and OpFunction with <id> ")
<< fnID << ": " << fnName << " vs. " << parsedFunc.getName();
}
- SmallVector<Value *, 4> interface;
+ SmallVector<Attribute, 4> interface;
while (wordIndex < words.size()) {
- auto arg = getValue(words[wordIndex]);
+ auto arg = getVariable(words[wordIndex]);
if (!arg) {
return emitError(unknownLoc, "undefined result <id> ")
<< words[wordIndex] << " while decoding OpEntryPoint";
}
- interface.push_back(arg);
+ interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
wordIndex++;
}
- opBuilder.create<spirv::EntryPointOp>(
- unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface);
+ opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+ opBuilder.getSymbolRefAttr(fnName),
+ opBuilder.getArrayAttr(interface));
return success();
}
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index d06363a..575d995 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -125,9 +125,19 @@
return funcIDMap.lookup(fnName);
}
+ uint32_t findVariableID(StringRef varName) const {
+ return globalVarIDMap.lookup(varName);
+ }
+
+ /// Emit OpName for the given `resultID`.
+ LogicalResult processName(uint32_t resultID, StringRef name);
+
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
+ /// Process a SPIR-V GlobalVariableOp
+ LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
+
/// Process attributes that translate to decorations on the result <id>
LogicalResult processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr);
@@ -215,6 +225,9 @@
uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
+ /// Process spv.addressOf operations.
+ LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
+
/// Main dispatch method for serializing an operation.
LogicalResult processOperation(Operation *op);
@@ -265,6 +278,9 @@
/// Map from FuncOps name to <id>s.
llvm::StringMap<uint32_t> funcIDMap;
+ /// Map from GlobalVariableOps name to <id>s
+ llvm::StringMap<uint32_t> globalVarIDMap;
+
/// Map from results of normal operations to their <id>s
DenseMap<Value *, uint32_t> valueIDMap;
};
@@ -372,6 +388,15 @@
return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
}
+LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
+ SmallVector<uint32_t, 4> nameOperands;
+ nameOperands.push_back(resultID);
+ if (failed(encodeStringLiteralInto(nameOperands, name))) {
+ return failure();
+ }
+ return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+}
+
namespace {
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
@@ -416,10 +441,9 @@
encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
// Add function name.
- SmallVector<uint32_t, 4> nameOperands;
- nameOperands.push_back(funcID);
- encodeStringLiteralInto(nameOperands, op.getName());
- encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+ if (failed(processName(funcID, op.getName()))) {
+ return failure();
+ }
// Declare the parameters.
for (auto arg : op.getArguments()) {
@@ -450,6 +474,61 @@
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
}
+LogicalResult
+Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
+ // Get TypeID.
+ uint32_t resultTypeID = 0;
+ SmallVector<StringRef, 4> elidedAttrs;
+ if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
+ return failure();
+ }
+ elidedAttrs.push_back("type");
+ SmallVector<uint32_t, 4> operands;
+ operands.push_back(resultTypeID);
+ auto resultID = getNextID();
+
+ // Encode the name.
+ auto varName = varOp.sym_name();
+ elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+ if (failed(processName(resultID, varName))) {
+ return failure();
+ }
+ globalVarIDMap[varName] = resultID;
+ operands.push_back(resultID);
+
+ // Encode StorageClass.
+ operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
+
+ // Encode initialization.
+ if (auto initializer = varOp.initializer()) {
+ auto initializerID = findVariableID(initializer.getValue());
+ if (!initializerID) {
+ return emitError(varOp.getLoc(),
+ "invalid usage of undefined variable as initializer");
+ }
+ operands.push_back(initializerID);
+ elidedAttrs.push_back("initializer");
+ }
+
+ if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
+ operands))) {
+ elidedAttrs.push_back("initializer");
+ return failure();
+ }
+
+ // Encode decorations.
+ for (auto attr : varOp.getAttrs()) {
+ if (llvm::any_of(elidedAttrs,
+ [&](StringRef elided) { return attr.first.is(elided); })) {
+ continue;
+ }
+ if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
+ return failure();
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Type
//===----------------------------------------------------------------------===//
@@ -912,6 +991,17 @@
// Operation
//===----------------------------------------------------------------------===//
+LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
+ auto varName = addressOfOp.variable();
+ auto variableID = findVariableID(varName);
+ if (!variableID) {
+ return addressOfOp.emitError("unknown result <id> for variable ")
+ << varName;
+ }
+ valueIDMap[addressOfOp.pointer()] = variableID;
+ return success();
+}
+
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
@@ -924,6 +1014,12 @@
if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
+ if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
+ return processGlobalVariableOp(varOp);
+ }
+ if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
+ return processAddressOfOp(addressOfOp);
+ }
return dispatchToAutogenSerialization(op);
}
@@ -947,14 +1043,16 @@
encodeStringLiteralInto(operands, op.fn());
// Add the interface values.
- for (auto val : op.interface()) {
- auto id = findValueID(val);
- if (!id) {
- return op.emitError("referencing unintialized variable <id>. "
- "spv.EntryPoint is at the end of spv.module. All "
- "referenced variables should already be defined");
+ if (auto interface = op.interface()) {
+ for (auto var : interface.getValue()) {
+ auto id = findVariableID(var.cast<SymbolRefAttr>().getValue());
+ if (!id) {
+ return op.emitError("referencing undefined global variable."
+ "spv.EntryPoint is at the end of spv.module. All "
+ "referenced variables should already be defined");
+ }
+ operands.push_back(id);
}
- operands.push_back(id);
}
return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
operands);