Add spirv::GlobalVariableOp that allows module level definition of variables

FuncOps in MLIR use explicit capture. So global variables defined in
module scope need to have a symbol name and this should be used to
refer to the variable within the function. This deviates from SPIR-V
spec, which assigns an SSA value to variables at all scopes that can
be used to refer to the variable, which requires SPIR-V functions to
allow implicit capture. To handle this add a new op,
spirv::GlobalVariableOp that can be used to define module scope
variables.
Since instructions need an SSA value, an new spirv::AddressOfOp is
added to convert a symbol reference to an SSA value for use with other
instructions.
This also means the spirv::EntryPointOp instruction needs to change to
allow initializers to be specified using symbol reference instead of
SSA value
The current spirv::VariableOp which returns an SSA value (as defined
by SPIR-V spec) can still be used to define function-scope variables.
PiperOrigin-RevId: 263951109
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index ba95a76..de496a7 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -146,67 +146,6 @@
 
 // -----
 
-def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
-  let summary = [{
-    Declare an entry point, its execution model, and its interface.
-  }];
-
-  let description = [{
-    Execution Model is the execution model for the entry point and its
-    static call tree. See Execution Model.
-
-    Entry Point must be the Result <id> of an OpFunction instruction.
-
-    Name is a name string for the entry point. A module cannot have two
-    OpEntryPoint instructions with the same Execution Model and the same
-    Name string.
-
-    Interface is a list of <id> of global OpVariable instructions. These
-    declare the set of global variables from a module that form the
-    interface of this entry point. The set of Interface <id> must be equal
-    to or a superset of the global OpVariable Result <id> referenced by the
-    entry point’s static call tree, within the interface’s storage classes.
-    Before version 1.4, the interface’s storage classes are limited to the
-    Input and Output storage classes. Starting with version 1.4, the
-    interface’s storage classes are all storage classes used in declaring
-    all global variables referenced by the entry point’s call tree.
-
-    Interface <id> are forward references. Before version 1.4, duplication
-    of these <id> is tolerated. Starting with version 1.4, an <id> must not
-    appear more than once.
-
-    ### Custom assembly form
-
-    ``` {.ebnf}
-    execution-model ::= "Vertex" | "TesellationControl" |
-                        <and other SPIR-V execution models...>
-
-    entry-point-op ::= ssa-id ` = spv.EntryPoint ` execution-model fn-name
-                       (ssa-use ( `, ` ssa-use)* ` : `
-                        pointer-type ( `, ` pointer-type)* )?
-    ```
-
-    For example:
-
-    ```
-    spv.EntryPoint "GLCompute" @foo
-    spv.EntryPoint "Kernel" @foo, %1, %2 : !spv.ptr<f32, Input>, !spv.ptr<f32, Output>
-
-    ```
-  }];
-
-  let arguments = (ins
-    SPV_ExecutionModelAttr:$execution_model,
-    SymbolRefAttr:$fn,
-    Variadic<SPV_AnyPtr>:$interface
-  );
-
-  let results = (outs);
-  let autogenSerialization = 0;
-}
-
-// -----
-
 def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [ModuleOnly]> {
   let summary = "Declare an execution mode for an entry point.";
 
diff --git a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index b44d8ef..d475639 100644
--- a/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/third_party/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -30,6 +30,160 @@
 include "mlir/SPIRV/SPIRVBase.td"
 #endif // SPIRV_BASE
 
+def SPV_AddressOfOp : SPV_Op<"_address_of", [NoSideEffect]> {
+  let summary = "Get the address of a global variable.";
+
+  let description = [{
+    Variables in module scope are defined using symbol names. This
+    instruction generates an SSA value that can be used to refer to
+    the symbol within function scope for use in instructions that
+    expect an SSA value. This operation has no equivalent SPIR-V
+    instruction. Since variables in module scope in SPIR-V dialect are
+    of pointer type, this instruction returns a pointer type as well,
+    and the type is same as the variable referenced.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    address-of-op ::= ssa-id `=` `spv.addressOf` `@`string-literal : pointer-type
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.addressOf @var1 : !spv.ptr<f32, Input>
+    ```
+  }];
+
+  let arguments = (ins
+    SymbolRefAttr:$variable
+  );
+
+  let results = (outs
+    SPV_AnyPtr:$pointer
+  );
+
+  let hasOpcode = 0;
+}
+
+def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
+  let summary = [{
+    Declare an entry point, its execution model, and its interface.
+  }];
+
+  let description = [{
+    Execution Model is the execution model for the entry point and its
+    static call tree. See Execution Model.
+
+    Entry Point must be the Result <id> of an OpFunction instruction.
+
+    Name is a name string for the entry point. A module cannot have two
+    OpEntryPoint instructions with the same Execution Model and the same
+    Name string.
+
+    Interface is a list of symbol references to spv.globalVariable
+    operations. These declare the set of global variables from a
+    module that form the interface of this entry point. The set of
+    Interface symbols must be equal to or a superset of the
+    spv.globalVariables referenced by the entry point’s static call
+    tree, within the interface’s storage classes.  Before version 1.4,
+    the interface’s storage classes are limited to the Input and
+    Output storage classes. Starting with version 1.4, the interface’s
+    storage classes are all storage classes used in declaring all
+    global variables referenced by the entry point’s call tree.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    execution-model ::= "Vertex" | "TesellationControl" |
+                        <and other SPIR-V execution models...>
+
+    entry-point-op ::= ssa-id `=` `spv.EntryPoint` execution-model
+                       symbol-reference (`, ` symbol-reference)*
+    ```
+
+    For example:
+
+    ```
+    spv.EntryPoint "GLCompute" @foo
+    spv.EntryPoint "Kernel" @foo, @var1, @var2
+
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_ExecutionModelAttr:$execution_model,
+    SymbolRefAttr:$fn,
+    OptionalAttr<SymbolRefArrayAttr>:$interface
+  );
+
+  let results = (outs);
+  let autogenSerialization = 0;
+}
+
+
+def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [ModuleOnly]> {
+  let summary = [{
+    Allocate an object in memory at module scope. The object is
+    referenced using a symbol name.
+  }];
+
+  let description = [{
+    The variable type must be an OpTypePointer. Its type operand is the type of
+    object in memory.
+
+    Storage Class is the Storage Class of the memory holding the object. It
+    cannot be Generic. It must be the same as the Storage Class operand of
+    the variable types. Only those storage classes that are valid at module
+    scope (like Input, Output, StorageBuffer, etc.) are valid.
+
+    Initializer is optional.  If Initializer is present, it will be
+    the initial value of the variable’s memory content. Initializer
+    must be an symbol defined from a constant instruction or other
+    spv.globalVariable operation in module scope. Initializer must
+    have the same type as the type of the defined symbol.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    variable-op ::= `spv.globalVariable` spirv-type string-literal
+                    (`initializer(` symbol-reference `)`)?
+                    (`bind(` integer-literal, integer-literal `)`)?
+                    (`built_in(` string-literal `)`)?
+                    attribute-dict?
+    ```
+
+    where `initializer` specifies initializer and `bind` specifies the
+    descriptor set and binding number. `built_in` specifies SPIR-V
+    BuiltIn decoration associated with the op.
+
+    For example:
+
+    ```
+    spv.Variable !spv.ptr<f32, Input> @var0
+    spv.Variable !spv.ptr<f32, Output> @var2 initializer(@var0)
+    spv.Variable !spv.ptr<f32, Uniform> @var bind(1, 2)
+    spv.Variable !spv.ptr<vector<3xi32>> @var3 built_in("GlobalInvocationID")
+    ```
+  }];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    OptionalAttr<SymbolRefAttr>:$initializer
+  );
+
+  let results = (outs);
+
+  let hasOpcode = 0;
+
+  let extraClassDeclaration = [{
+    ::mlir::spirv::StorageClass storageClass() {
+      return this->type().cast<::mlir::spirv::PointerType>().getStorageClass();
+    }
+  }];
+}
+
 def SPV_ModuleOp : SPV_Op<"module",
                           [SingleBlockImplicitTerminator<"ModuleEndOp">,
                            NativeOpTrait<"SymbolTable">]> {
diff --git a/third_party/mlir/include/mlir/IR/OpBase.td b/third_party/mlir/include/mlir/IR/OpBase.td
index 519222c..3183a76 100644
--- a/third_party/mlir/include/mlir/IR/OpBase.td
+++ b/third_party/mlir/include/mlir/IR/OpBase.td
@@ -872,6 +872,11 @@
   let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
 }
 
+def SymbolRefArrayAttr :
+  TypedArrayAttrBase<SymbolRefAttr, "symbol ref array attribute"> {
+  let constBuilderCall = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // Derive attribute kinds
 
diff --git a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 53a40df..035de4f 100644
--- a/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/third_party/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -136,26 +136,26 @@
                                signatureConverter, newFuncOp))) {
     return failure();
   }
-  // Create spv.Variable ops for each of the arguments. These need to be bound
-  // by the runtime. For now use descriptor_set 0, and arg number as the binding
-  // number.
+  // Create spv.globalVariable ops for each of the arguments. These need to be
+  // bound by the runtime. For now use descriptor_set 0, and arg number as the
+  // binding number.
   auto module = funcOp.getParentOfType<spirv::ModuleOp>();
   if (!module) {
     return funcOp.emitError("expected op to be within a spv.module");
   }
   OpBuilder builder(module.getOperation()->getRegion(0));
-  SmallVector<Value *, 4> interface;
+  SmallVector<Attribute, 4> interface;
   for (auto &convertedArgType :
        llvm::enumerate(signatureConverter.getConvertedTypes())) {
-    auto variableOp = builder.create<spirv::VariableOp>(
-        funcOp.getLoc(), convertedArgType.value(),
-        builder.getI32IntegerAttr(
-            static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
-        llvm::None);
+    std::string varName = funcOp.getName().str() + "_arg_" +
+                          std::to_string(convertedArgType.index());
+    auto variableOp = builder.create<spirv::GlobalVariableOp>(
+        funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
+        builder.getStringAttr(varName), nullptr);
     variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
     variableOp.setAttr("binding",
                        builder.getI32IntegerAttr(convertedArgType.index()));
-    interface.push_back(variableOp.getResult());
+    interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
   }
   // Create an entry point instruction for this function.
   // TODO(ravishankarm) : Add execution mode for the entry function
@@ -164,7 +164,8 @@
       funcOp.getLoc(),
       builder.getI32IntegerAttr(
           static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
-      builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+      builder.getSymbolRefAttr(newFuncOp.getName()),
+      builder.getArrayAttr(interface));
   return success();
 }
 } // namespace mlir
diff --git a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 4bea441..9947c02 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -32,11 +32,15 @@
 
 // TODO(antiagainst): generate these strings using ODS.
 static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kFnNameAttrName[] = "fn";
 static constexpr const char kIndicesAttrName[] = "indices";
+static constexpr const char kInitializerAttrName[] = "initializer";
+static constexpr const char kInterfaceAttrName[] = "interface";
 static constexpr const char kIsSpecConstName[] = "is_spec_const";
+static constexpr const char kTypeAttrName[] = "type";
 static constexpr const char kValueAttrName[] = "value";
 static constexpr const char kValuesAttrName[] = "values";
-static constexpr const char kFnNameAttrName[] = "fn";
+static constexpr const char kVariableAttrName[] = "variable";
 
 //===----------------------------------------------------------------------===//
 // Common utility functions
@@ -239,6 +243,71 @@
   printer->printOptionalAttrDict(op->getAttrs());
 }
 
+static ParseResult parseVariableDecorations(OpAsmParser *parser,
+                                            OperationState *state) {
+  auto builtInName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
+  if (succeeded(parser->parseOptionalKeyword("bind"))) {
+    Attribute set, binding;
+    // Parse optional descriptor binding
+    auto descriptorSetName = convertToSnakeCase(
+        stringifyDecoration(spirv::Decoration::DescriptorSet));
+    auto bindingName =
+        convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+    Type i32Type = parser->getBuilder().getIntegerType(32);
+    if (parser->parseLParen() ||
+        parser->parseAttribute(set, i32Type, descriptorSetName,
+                               state->attributes) ||
+        parser->parseComma() ||
+        parser->parseAttribute(binding, i32Type, bindingName,
+                               state->attributes) ||
+        parser->parseRParen()) {
+      return failure();
+    }
+  } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
+    StringAttr builtIn;
+    if (parser->parseLParen() ||
+        parser->parseAttribute(builtIn, Type(), builtInName,
+                               state->attributes) ||
+        parser->parseRParen()) {
+      return failure();
+    }
+  }
+
+  // Parse other attributes
+  if (parser->parseOptionalAttributeDict(state->attributes))
+    return failure();
+
+  return success();
+}
+
+static void printVariableDecorations(Operation *op, OpAsmPrinter *printer,
+                                     SmallVectorImpl<StringRef> &elidedAttrs) {
+  // Print optional descriptor binding
+  auto descriptorSetName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
+  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
+  if (descriptorSet && binding) {
+    elidedAttrs.push_back(descriptorSetName);
+    elidedAttrs.push_back(bindingName);
+    *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+             << ")";
+  }
+
+  // Print BuiltIn attribute if present
+  auto builtInName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
+  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
+    *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
+    elidedAttrs.push_back(builtInName);
+  }
+
+  printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+}
+
 //===----------------------------------------------------------------------===//
 // spv.AccessChainOp
 //===----------------------------------------------------------------------===//
@@ -363,6 +432,53 @@
 }
 
 //===----------------------------------------------------------------------===//
+// spv._address_of
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAddressOfOp(OpAsmParser *parser,
+                                    OperationState *state) {
+  SymbolRefAttr varRefAttr;
+  Type type;
+  if (parser->parseAttribute(varRefAttr, Type(), kVariableAttrName,
+                             state->attributes) ||
+      parser->parseColonType(type)) {
+    return failure();
+  }
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType) {
+    return parser->emitError(parser->getCurrentLocation(),
+                             "expected spv.ptr type");
+  }
+  state->addTypes(ptrType);
+  return success();
+}
+
+static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter *printer) {
+  SmallVector<StringRef, 4> elidedAttrs;
+  *printer << spirv::AddressOfOp::getOperationName();
+
+  // Print symbol name.
+  *printer << " @" << addressOfOp.variable();
+
+  // Print the type.
+  *printer << " : " << addressOfOp.pointer();
+}
+
+static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
+  auto moduleOp = addressOfOp.getParentOfType<spirv::ModuleOp>();
+  auto varOp =
+      moduleOp.lookupSymbol<spirv::GlobalVariableOp>(addressOfOp.variable());
+  if (!varOp) {
+    return addressOfOp.emitError("expected spv.globalVariable symbol");
+  }
+  if (addressOfOp.pointer()->getType() != varOp.type()) {
+    return addressOfOp.emitError(
+        "mismatch in result type and type of global variable referenced");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.CompositeExtractOp
 //===----------------------------------------------------------------------===//
 
@@ -541,18 +657,28 @@
   SmallVector<OpAsmParser::OperandType, 0> identifiers;
   SmallVector<Type, 0> idTypes;
 
-  Attribute fn;
-  auto loc = parser->getCurrentLocation();
-
+  SymbolRefAttr fn;
   if (parseEnumAttribute(execModel, parser, state) ||
-      parser->parseAttribute(fn, kFnNameAttrName, state->attributes) ||
-      parser->parseTrailingOperandList(identifiers) ||
-      parser->parseOptionalColonTypeList(idTypes) ||
-      parser->resolveOperands(identifiers, idTypes, loc, state->operands)) {
+      parser->parseAttribute(fn, Type(), kFnNameAttrName, state->attributes)) {
     return failure();
   }
-  if (!fn.isa<SymbolRefAttr>()) {
-    return parser->emitError(loc, "expected symbol reference attribute");
+
+  if (!parser->parseOptionalComma()) {
+    // Parse the interface variables
+    SmallVector<Attribute, 4> interfaceVars;
+    do {
+      // The name of the interface variable attribute isnt important
+      auto attrName = "var_symbol";
+      SymbolRefAttr var;
+      SmallVector<NamedAttribute, 1> attrs;
+      if (parser->parseAttribute(var, Type(), attrName, attrs)) {
+        return failure();
+      }
+      interfaceVars.push_back(var);
+    } while (!parser->parseOptionalComma());
+    state->attributes.push_back(
+        {parser->getBuilder().getIdentifier(kInterfaceAttrName),
+         parser->getBuilder().getArrayAttr(interfaceVars)});
   }
   return success();
 }
@@ -561,27 +687,16 @@
   *printer << spirv::EntryPointOp::getOperationName() << " \""
            << stringifyExecutionModel(entryPointOp.execution_model()) << "\" @"
            << entryPointOp.fn();
-  if (!entryPointOp.getNumOperands()) {
-    return;
+  if (auto interface = entryPointOp.interface()) {
+    *printer << ", ";
+    mlir::interleaveComma(interface.getValue().getValue(), printer->getStream(),
+                          [&](Attribute a) { printer->printAttribute(a); });
   }
-  *printer << ", ";
-  mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
-                        [&](Value *a) { printer->printOperand(a); });
-  *printer << " : ";
-  mlir::interleaveComma(entryPointOp.getOperands(), printer->getStream(),
-                        [&](const Value *a) { *printer << a->getType(); });
 }
 
 static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
-  // Verify that all the interface ops are created from VariableOp
-  for (auto interface : entryPointOp.interface()) {
-    if (!llvm::isa_and_nonnull<spirv::VariableOp>(interface->getDefiningOp())) {
-      return entryPointOp.emitOpError("interface operands to entry point must "
-                                      "be generated from a variable op");
-    }
-    // TODO:  Before version 1.4 the variables can only have storage_class of
-    // Input or Output. That needs to be verified.
-  }
+  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
+  // verification.
   return success();
 }
 
@@ -628,6 +743,95 @@
 }
 
 //===----------------------------------------------------------------------===//
+// spv.globalVariable
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGlobalVariableOp(OpAsmParser *parser,
+                                         OperationState *state) {
+  // Parse variable type.
+  TypeAttr typeAttr;
+  auto loc = parser->getCurrentLocation();
+  if (parser->parseAttribute(typeAttr, Type(), kTypeAttrName,
+                             state->attributes)) {
+    return failure();
+  }
+  auto ptrType = typeAttr.getValue().dyn_cast<spirv::PointerType>();
+  if (!ptrType) {
+    return parser->emitError(loc, "expected spv.ptr type");
+  }
+
+  // Parse variable name.
+  StringAttr nameAttr;
+  if (parser->parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                              state->attributes)) {
+    return failure();
+  }
+
+  // Parse optional initializer
+  if (succeeded(parser->parseOptionalKeyword(kInitializerAttrName))) {
+    SymbolRefAttr initSymbol;
+    if (parser->parseLParen() ||
+        parser->parseAttribute(initSymbol, Type(), kInitializerAttrName,
+                               state->attributes) ||
+        parser->parseRParen())
+      return failure();
+  }
+
+  if (parseVariableDecorations(parser, state)) {
+    return failure();
+  }
+
+  return success();
+}
+
+static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter *printer) {
+  auto *op = varOp.getOperation();
+  SmallVector<StringRef, 4> elidedAttrs{
+      spirv::attributeName<spirv::StorageClass>()};
+  *printer << spirv::GlobalVariableOp::getOperationName();
+
+  // Print variable type.
+  *printer << " " << varOp.type();
+  elidedAttrs.push_back(kTypeAttrName);
+
+  // Print variable name.
+  *printer << " @" << varOp.sym_name();
+  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+
+  // Print optional initializer
+  if (auto initializer = varOp.initializer()) {
+    *printer << " " << kInitializerAttrName << "(@" << initializer.getValue()
+             << ")";
+    elidedAttrs.push_back(kInitializerAttrName);
+  }
+  printVariableDecorations(op, printer, elidedAttrs);
+}
+
+static LogicalResult verify(spirv::GlobalVariableOp varOp) {
+  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
+  // object. It cannot be Generic. It must be the same as the Storage Class
+  // operand of the Result Type."
+  if (varOp.storageClass() == spirv::StorageClass::Generic)
+    return varOp.emitOpError("storage class cannot be 'Generic'");
+
+  if (auto initializer =
+          varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+    // Get the module
+    auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
+    // TODO: Currently only variable initialization with other variables is
+    // supported. They could be constants as well, but this needs module-level
+    // constants to have symbol name as well.
+    if (!moduleOp.lookupSymbol<spirv::GlobalVariableOp>(
+            initializer.getValue())) {
+      return varOp.emitOpError(
+          "initializer must be result of a spv.globalVariable op");
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.LoadOp
 //===----------------------------------------------------------------------===//
 
@@ -773,13 +977,33 @@
   for (auto &op : body) {
     if (op.getDialect() == dialect) {
       // For EntryPoint op, check that the function and execution model is not
-      // duplicated in EntryPointOps
+      // duplicated in EntryPointOps. Also verify that the interface specified
+      // comes from globalVariables here to make this check cheaper.
       if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) {
         auto funcOp = table.lookup<FuncOp>(entryPointOp.fn());
         if (!funcOp) {
           return entryPointOp.emitError("function '")
                  << entryPointOp.fn() << "' not found in 'spv.module'";
         }
+        if (auto interface = entryPointOp.interface()) {
+          for (auto varRef : interface.getValue().getValue()) {
+            auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+            if (!varSymRef) {
+              return entryPointOp.emitError(
+                         "expected symbol reference for interface "
+                         "specification instead of '")
+                     << varRef;
+            }
+            auto variableOp =
+                table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
+            if (!variableOp) {
+              return entryPointOp.emitError("expected spv.globalVariable "
+                                            "symbol reference instead of'")
+                     << varSymRef << "'";
+            }
+          }
+        }
+
         auto key = std::pair<FuncOp, spirv::ExecutionModel>(
             funcOp, entryPointOp.execution_model());
         auto entryPtIt = entryPoints.find(key);
@@ -898,42 +1122,9 @@
       return failure();
   }
 
-  auto builtInName =
-      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
-  if (succeeded(parser->parseOptionalKeyword("bind"))) {
-    Attribute set, binding;
-    // Parse optional descriptor binding
-    auto descriptorSetName = convertToSnakeCase(
-        stringifyDecoration(spirv::Decoration::DescriptorSet));
-    auto bindingName =
-        convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
-    Type i32Type = parser->getBuilder().getIntegerType(32);
-    if (parser->parseLParen() ||
-        parser->parseAttribute(set, i32Type, descriptorSetName,
-                               state->attributes) ||
-        parser->parseComma() ||
-        parser->parseAttribute(binding, i32Type, bindingName,
-                               state->attributes) ||
-        parser->parseRParen()) {
-      return failure();
-    }
-  } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) {
-    Attribute builtIn;
-    if (parser->parseLParen() ||
-        parser->parseAttribute(builtIn, Type(), builtInName,
-                               state->attributes) ||
-        parser->parseRParen()) {
-      return failure();
-    }
-    if (!builtIn.isa<StringAttr>()) {
-      return parser->emitError(parser->getCurrentLocation(),
-                               "expected string value for built_in attribute");
-    }
-  }
-
-  // Parse other attributes
-  if (parser->parseOptionalAttributeDict(state->attributes))
+  if (parseVariableDecorations(parser, state)) {
     return failure();
+  }
 
   // Parse result pointer type
   Type type;
@@ -976,29 +1167,8 @@
     *printer << ")";
   }
 
-  // Print optional descriptor binding
-  auto descriptorSetName =
-      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
-  auto bindingName =
-      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
-  auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
-  auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
-  if (descriptorSet && binding) {
-    elidedAttrs.push_back(descriptorSetName);
-    elidedAttrs.push_back(bindingName);
-    *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
-             << ")";
-  }
+  printVariableDecorations(op, printer, elidedAttrs);
 
-  // Print BuiltIn attribute if present
-  auto builtInName =
-      convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn));
-  if (auto builtin = varOp.getAttrOfType<StringAttr>(builtInName)) {
-    *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
-    elidedAttrs.push_back(builtInName);
-  }
-
-  printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
   *printer << " : " << varOp.getType();
 }
 
@@ -1006,8 +1176,11 @@
   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
   // object. It cannot be Generic. It must be the same as the Storage Class
   // operand of the Result Type."
-  if (varOp.storage_class() == spirv::StorageClass::Generic)
-    return varOp.emitOpError("storage class cannot be 'Generic'");
+  if (varOp.storage_class() != spirv::StorageClass::Function) {
+    return varOp.emitOpError(
+        "can only be used to model function-level variables. Use "
+        "spv.globalVariable for module-level variables.");
+  }
 
   auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
   if (varOp.storage_class() != pointerType.getStorageClass())
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 1aad717..a3d71ed 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -90,9 +90,20 @@
   /// them to their handler method accordingly.
   LogicalResult processFunction(ArrayRef<uint32_t> operands);
 
+  /// Process the OpVariable instructions at current `offset` into `binary`. It
+  /// is expected that this method is used for variables that are to be defined
+  /// at module scope and will be deserialized into a spv.globalVariable
+  /// instruction.
+  LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
+
   /// Get the FuncOp associated with a result <id> of OpFunction.
   FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
 
+  /// Get the global variable associated with a result <id> of OpVariable
+  spirv::GlobalVariableOp getVariable(uint32_t id) {
+    return globalVariableMap.lookup(id);
+  }
+
   //===--------------------------------------------------------------------===//
   // Type
   //===--------------------------------------------------------------------===//
@@ -138,7 +149,15 @@
   //===--------------------------------------------------------------------===//
 
   /// Get the Value associated with a result <id>.
-  Value *getValue(uint32_t id) { return valueMap.lookup(id); }
+  Value *getValue(uint32_t id) {
+    if (auto varOp = getVariable(id)) {
+      auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
+          unknownLoc, varOp.type(),
+          opBuilder.getSymbolRefAttr(varOp.getOperation()));
+      return addressOfOp.pointer();
+    }
+    return valueMap.lookup(id);
+  }
 
   /// Slices the first instruction out of `binary` and returns its opcode and
   /// operands via `opcode` and `operands` respectively. Returns failure if
@@ -198,6 +217,9 @@
   // Result <id> to function mapping.
   DenseMap<uint32_t, FuncOp> funcMap;
 
+  // Result <id> to variable mapping;
+  DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+
   // Result <id> to value mapping.
   DenseMap<uint32_t, Value *> valueMap;
 
@@ -452,6 +474,76 @@
   return success();
 }
 
+LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
+  unsigned wordIndex = 0;
+  if (operands.size() < 3) {
+    return emitError(
+        unknownLoc,
+        "OpVariable needs at least 3 operands, type, <id> and storage class");
+  }
+
+  // Result Type.
+  auto type = getType(operands[wordIndex]);
+  if (!type) {
+    return emitError(unknownLoc, "unknown result type <id> : ")
+           << operands[wordIndex];
+  }
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType) {
+    return emitError(unknownLoc,
+                     "expected a result type <id> to be a spv.ptr, found : ")
+           << type;
+  }
+  wordIndex++;
+
+  // Result <id>.
+  auto variableID = operands[wordIndex];
+  auto variableName = nameMap.lookup(variableID).str();
+  if (variableName.empty()) {
+    variableName = "spirv_var_" + std::to_string(variableID);
+  }
+  wordIndex++;
+
+  // Storage class.
+  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
+  if (ptrType.getStorageClass() != storageClass) {
+    return emitError(unknownLoc, "mismatch in storage class of pointer type ")
+           << type << " and that specified in OpVariable instruction  : "
+           << stringifyStorageClass(storageClass);
+  }
+  wordIndex++;
+
+  // Initializer.
+  SymbolRefAttr initializer = nullptr;
+  if (wordIndex < operands.size()) {
+    auto initializerOp = getVariable(operands[wordIndex]);
+    if (!initializerOp) {
+      return emitError(unknownLoc, "unknown <id> ")
+             << operands[wordIndex] << "used as initializer";
+    }
+    wordIndex++;
+    initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation());
+  }
+  if (wordIndex != operands.size()) {
+    return emitError(unknownLoc,
+                     "found more operands than expected when deserializing "
+                     "OpVariable instruction, only ")
+           << wordIndex << " of " << operands.size() << " processed";
+  }
+  auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
+      unknownLoc, opBuilder.getTypeAttr(type),
+      opBuilder.getStringAttr(variableName), initializer);
+
+  // Decorations.
+  if (decorations.count(variableID)) {
+    for (auto attr : decorations[variableID].getAttrs()) {
+      varOp.setAttr(attr.first, attr.second);
+    }
+  }
+  globalVariableMap[variableID] = varOp;
+  return success();
+}
+
 LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
   if (operands.size() < 2) {
     return emitError(unknownLoc, "OpName needs at least 2 operands");
@@ -887,6 +979,11 @@
       return success();
     }
     break;
+  case spirv::Opcode::OpVariable:
+    if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
+      return processGlobalVariable(operands);
+    }
+    break;
   case spirv::Opcode::OpName:
     return processName(operands);
   case spirv::Opcode::OpTypeVoid:
@@ -954,18 +1051,19 @@
                                  "and OpFunction with <id> ")
            << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
   }
-  SmallVector<Value *, 4> interface;
+  SmallVector<Attribute, 4> interface;
   while (wordIndex < words.size()) {
-    auto arg = getValue(words[wordIndex]);
+    auto arg = getVariable(words[wordIndex]);
     if (!arg) {
       return emitError(unknownLoc, "undefined result <id> ")
              << words[wordIndex] << " while decoding OpEntryPoint";
     }
-    interface.push_back(arg);
+    interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
     wordIndex++;
   }
-  opBuilder.create<spirv::EntryPointOp>(
-      unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), interface);
+  opBuilder.create<spirv::EntryPointOp>(unknownLoc, exec_model,
+                                        opBuilder.getSymbolRefAttr(fnName),
+                                        opBuilder.getArrayAttr(interface));
   return success();
 }
 
diff --git a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index d06363a..575d995 100644
--- a/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/third_party/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -125,9 +125,19 @@
     return funcIDMap.lookup(fnName);
   }
 
+  uint32_t findVariableID(StringRef varName) const {
+    return globalVarIDMap.lookup(varName);
+  }
+
+  /// Emit OpName for the given `resultID`.
+  LogicalResult processName(uint32_t resultID, StringRef name);
+
   /// Processes a SPIR-V function op.
   LogicalResult processFuncOp(FuncOp op);
 
+  /// Process a SPIR-V GlobalVariableOp
+  LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
+
   /// Process attributes that translate to decorations on the result <id>
   LogicalResult processDecoration(Location loc, uint32_t resultID,
                                   NamedAttribute attr);
@@ -215,6 +225,9 @@
 
   uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
 
+  /// Process spv.addressOf operations.
+  LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
+
   /// Main dispatch method for serializing an operation.
   LogicalResult processOperation(Operation *op);
 
@@ -265,6 +278,9 @@
   /// Map from FuncOps name to <id>s.
   llvm::StringMap<uint32_t> funcIDMap;
 
+  /// Map from GlobalVariableOps name to <id>s
+  llvm::StringMap<uint32_t> globalVarIDMap;
+
   /// Map from results of normal operations to their <id>s
   DenseMap<Value *, uint32_t> valueIDMap;
 };
@@ -372,6 +388,15 @@
   return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
 }
 
+LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
+  SmallVector<uint32_t, 4> nameOperands;
+  nameOperands.push_back(resultID);
+  if (failed(encodeStringLiteralInto(nameOperands, name))) {
+    return failure();
+  }
+  return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+}
+
 namespace {
 template <>
 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
@@ -416,10 +441,9 @@
   encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
 
   // Add function name.
-  SmallVector<uint32_t, 4> nameOperands;
-  nameOperands.push_back(funcID);
-  encodeStringLiteralInto(nameOperands, op.getName());
-  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
+  if (failed(processName(funcID, op.getName()))) {
+    return failure();
+  }
 
   // Declare the parameters.
   for (auto arg : op.getArguments()) {
@@ -450,6 +474,61 @@
   return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
 }
 
+LogicalResult
+Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
+  // Get TypeID.
+  uint32_t resultTypeID = 0;
+  SmallVector<StringRef, 4> elidedAttrs;
+  if (failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
+    return failure();
+  }
+  elidedAttrs.push_back("type");
+  SmallVector<uint32_t, 4> operands;
+  operands.push_back(resultTypeID);
+  auto resultID = getNextID();
+
+  // Encode the name.
+  auto varName = varOp.sym_name();
+  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+  if (failed(processName(resultID, varName))) {
+    return failure();
+  }
+  globalVarIDMap[varName] = resultID;
+  operands.push_back(resultID);
+
+  // Encode StorageClass.
+  operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
+
+  // Encode initialization.
+  if (auto initializer = varOp.initializer()) {
+    auto initializerID = findVariableID(initializer.getValue());
+    if (!initializerID) {
+      return emitError(varOp.getLoc(),
+                       "invalid usage of undefined variable as initializer");
+    }
+    operands.push_back(initializerID);
+    elidedAttrs.push_back("initializer");
+  }
+
+  if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
+                                   operands))) {
+    elidedAttrs.push_back("initializer");
+    return failure();
+  }
+
+  // Encode decorations.
+  for (auto attr : varOp.getAttrs()) {
+    if (llvm::any_of(elidedAttrs,
+                     [&](StringRef elided) { return attr.first.is(elided); })) {
+      continue;
+    }
+    if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
+      return failure();
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Type
 //===----------------------------------------------------------------------===//
@@ -912,6 +991,17 @@
 // Operation
 //===----------------------------------------------------------------------===//
 
+LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
+  auto varName = addressOfOp.variable();
+  auto variableID = findVariableID(varName);
+  if (!variableID) {
+    return addressOfOp.emitError("unknown result <id> for variable ")
+           << varName;
+  }
+  valueIDMap[addressOfOp.pointer()] = variableID;
+  return success();
+}
+
 LogicalResult Serializer::processOperation(Operation *op) {
   // First dispatch the methods that do not directly mirror an operation from
   // the SPIR-V spec
@@ -924,6 +1014,12 @@
   if (isa<spirv::ModuleEndOp>(op)) {
     return success();
   }
+  if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
+    return processGlobalVariableOp(varOp);
+  }
+  if (auto addressOfOp = dyn_cast<spirv::AddressOfOp>(op)) {
+    return processAddressOfOp(addressOfOp);
+  }
   return dispatchToAutogenSerialization(op);
 }
 
@@ -947,14 +1043,16 @@
   encodeStringLiteralInto(operands, op.fn());
 
   // Add the interface values.
-  for (auto val : op.interface()) {
-    auto id = findValueID(val);
-    if (!id) {
-      return op.emitError("referencing unintialized variable <id>. "
-                          "spv.EntryPoint is at the end of spv.module. All "
-                          "referenced variables should already be defined");
+  if (auto interface = op.interface()) {
+    for (auto var : interface.getValue()) {
+      auto id = findVariableID(var.cast<SymbolRefAttr>().getValue());
+      if (!id) {
+        return op.emitError("referencing undefined global variable."
+                            "spv.EntryPoint is at the end of spv.module. All "
+                            "referenced variables should already be defined");
+      }
+      operands.push_back(id);
     }
-    operands.push_back(id);
   }
   return encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint,
                                operands);