Make SPIR-V lowering infrastructure follow Vulkan SPIR-V validation.

The lowering infrastructure needs to be enhanced to lower into a
spv.Module that is consistent with the SPIR-V spec. The following
changes are needed
1) The Vulkan/SPIR-V validation rules dictates entry functions to have
signature of void(void). This requires changes to the function
signature conversion infrastructure within the dialect conversion
framework. When an argument is dropped from the original function
signature, a function can be specified that when invoked will return
the value to use as a replacement for the argument from the original
function.
2) Some changes to the type converter to make the converted type
consistent with the Vulkan/SPIR-V validation rules,
   a) Add support for converting dynamically shaped tensors to
   spv.rtarray type.
   b) Make the global variable of type !spv.ptr<!spv.struct<...>>
3) Generate the entry point operation for the kernel functions and
automatically compute all the interface variables needed

PiperOrigin-RevId: 273784229
diff --git a/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h b/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
index e31c58c..e92ad03 100644
--- a/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
+++ b/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
@@ -49,16 +49,8 @@
   explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
       : basicTypeConverter(basicTypeConverter) {}
 
-  /// Convert types to SPIR-V types using the basic type converter.
-  Type convertType(Type t) override {
-    return basicTypeConverter->convertType(t);
-  }
-
-  /// Method to convert argument of a function. The `type` is converted to
-  /// spv.ptr<type, Uniform>.
-  // TODO(ravishankarm) : Support other storage classes.
-  LogicalResult convertSignatureArg(unsigned inputNo, Type type,
-                                    SignatureConversion &result) override;
+  /// Converts types to SPIR-V types using the basic type converter.
+  Type convertType(Type t) override;
 
   /// Gets the basic type converter.
   SPIRVBasicTypeConverter *getBasicTypeConverter() const {
@@ -163,17 +155,20 @@
 };
 
 /// Legalizes a function as a non-entry function.
-LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
-                            SPIRVTypeConverter *typeConverter,
+LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
                             ConversionPatternRewriter &rewriter,
                             FuncOp &newFuncOp);
 
 /// Legalizes a function as an entry function.
-LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+LogicalResult lowerAsEntryFunction(FuncOp funcOp,
                                    SPIRVTypeConverter *typeConverter,
                                    ConversionPatternRewriter &rewriter,
                                    FuncOp &newFuncOp);
 
+/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
+/// spv.ExecutionMode ops.
+LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
+
 /// Appends to a pattern list additional patterns for translating StandardOps to
 /// SPIR-V ops.
 void populateStandardToSPIRVPatterns(MLIRContext *context,
diff --git a/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 2689572..4031385 100644
--- a/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -86,18 +86,15 @@
   auto funcOp = cast<FuncOp>(op);
   FuncOp newFuncOp;
   if (!gpu::GPUDialect::isKernel(funcOp)) {
-    return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
-                                   newFuncOp))
+    return succeeded(lowerFunction(funcOp, &typeConverter, rewriter, newFuncOp))
                ? matchSuccess()
                : matchFailure();
   }
 
-  if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
-                                  newFuncOp))) {
+  if (failed(
+          lowerAsEntryFunction(funcOp, &typeConverter, rewriter, newFuncOp))) {
     return matchFailure();
   }
-  newFuncOp.getOperation()->removeAttr(Identifier::get(
-      gpu::GPUDialect::getKernelFuncAttrName(), op->getContext()));
   return matchSuccess();
 }
 
@@ -164,6 +161,24 @@
                                  &typeConverter))) {
     return signalPassFailure();
   }
+
+  // After the SPIR-V modules have been generated, some finalization is needed
+  // for the entry functions. For example, adding spv.EntryPoint op,
+  // spv.ExecutionMode op, etc.
+  for (auto *spvModule : spirvModules) {
+    for (auto op :
+         cast<spirv::ModuleOp>(spvModule).getBlock().getOps<FuncOp>()) {
+      if (gpu::GPUDialect::isKernel(op)) {
+        OpBuilder builder(op.getContext());
+        builder.setInsertionPointAfter(op);
+        if (failed(finalizeEntryFunction(op, builder))) {
+          return signalPassFailure();
+        }
+        op.getOperation()->removeAttr(Identifier::get(
+            gpu::GPUDialect::getKernelFuncAttrName(), op.getContext()));
+      }
+    }
+  }
 }
 
 OpPassBase<ModuleOp> *createGPUToSPIRVPass() { return new GPUToSPIRVPass(); }
diff --git a/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 12a58c0..b104b53 100644
--- a/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
+#include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
 
@@ -30,7 +31,7 @@
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-Type SPIRVBasicTypeConverter::convertType(Type t) {
+static Type basicTypeConversion(Type t) {
   // Check if the type is SPIR-V supported. If so return the type.
   if (spirv::SPIRVDialect::isValidType(t)) {
     return t;
@@ -42,75 +43,107 @@
   }
 
   if (auto memRefType = t.dyn_cast<MemRefType>()) {
+    auto elementType = memRefType.getElementType();
     if (memRefType.hasStaticShape()) {
-      // Convert MemrefType to a multi-dimensional spv.array if size is known.
-      auto elementType = memRefType.getElementType();
+      // Convert to a multi-dimensional spv.array if size is known.
       for (auto size : reverse(memRefType.getShape())) {
         elementType = spirv::ArrayType::get(elementType, size);
       }
-      // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
-      // to support other Storage Classes.
       return spirv::PointerType::get(elementType,
                                      spirv::StorageClass::StorageBuffer);
+    } else {
+      // Vulkan SPIR-V validation rules require runtime array type to be the
+      // last member of a struct.
+      return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType),
+                                     spirv::StorageClass::StorageBuffer);
     }
   }
   return Type();
 }
 
+Type SPIRVBasicTypeConverter::convertType(Type t) {
+  return basicTypeConversion(t);
+}
+
 //===----------------------------------------------------------------------===//
 // Entry Function signature Conversion
 //===----------------------------------------------------------------------===//
 
-LogicalResult
-SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
-                                        SignatureConversion &result) {
-  // Try to convert the given input type.
-  auto convertedType = basicTypeConverter->convertType(type);
-  // TODO(ravishankarm) : Vulkan spec requires these to be a
-  // spirv::StructType. This is not a SPIR-V requirement, so just making this a
-  // pointer type for now.
-  if (!convertedType)
-    return failure();
-  // For arguments to entry functions, convert the type into a pointer type if
-  // it is already not one, unless the original type was an index type.
-  // TODO(ravishankarm): For arguments that are of index type, keep the
-  // arguments as the scalar converted type, i.e. i32. These are still not
-  // handled effectively. These are potentially best handled as specialization
-  // constants.
-  if (!convertedType.isa<spirv::PointerType>() && !type.isa<IndexType>()) {
-    // TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
-    // to support other Storage classes.
-    convertedType = spirv::PointerType::get(convertedType,
-                                            spirv::StorageClass::StorageBuffer);
+/// Generates the type of variable given the type of object.
+static Type getGlobalVarTypeForEntryFnArg(Type t) {
+  auto convertedType = basicTypeConversion(t);
+  if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) {
+    if (!ptrType.getPointeeType().isa<spirv::StructType>()) {
+      return spirv::PointerType::get(
+          spirv::StructType::get(ptrType.getPointeeType()),
+          ptrType.getStorageClass());
+    }
+  } else {
+    return spirv::PointerType::get(spirv::StructType::get(convertedType),
+                                   spirv::StorageClass::StorageBuffer);
   }
-
-  // Add the new inputs.
-  result.addInputs(inputNo, convertedType);
-  return success();
+  return convertedType;
 }
 
-static LogicalResult lowerFunctionImpl(
-    FuncOp funcOp, ArrayRef<Value *> operands,
-    ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
-    TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
-  auto fnType = funcOp.getType();
+Type SPIRVTypeConverter::convertType(Type t) {
+  return getGlobalVarTypeForEntryFnArg(t);
+}
 
-  if (fnType.getNumResults()) {
-    return funcOp.emitError("SPIR-V dialect only supports functions with no "
-                            "return values right now");
+/// Computes the replacement value for an argument of an entry function. It
+/// allocates a global variable for this argument and adds statements in the
+/// entry block to get a replacement value within function scope.
+static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
+                                                  size_t origArgNum,
+                                                  Value *origArg) {
+  // Create a global variable for this argument.
+  auto insertionOp = rewriter.getInsertionBlock()->getParent();
+  auto module = insertionOp->getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return nullptr;
   }
-
-  for (auto &argType : enumerate(fnType.getInputs())) {
-    // Get the type of the argument
-    if (failed(typeConverter->convertSignatureArg(
-            argType.index(), argType.value(), signatureConverter))) {
-      return funcOp.emitError("unable to convert argument type ")
-             << argType.value() << " to SPIR-V type";
-    }
+  auto funcOp = insertionOp->getParentOfType<FuncOp>();
+  spirv::GlobalVariableOp var;
+  {
+    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
+    rewriter.setInsertionPointToStart(&module.getBlock());
+    std::string varName =
+        funcOp.getName().str() + "_arg_" + std::to_string(origArgNum);
+    var = rewriter.create<spirv::GlobalVariableOp>(
+        funcOp.getLoc(),
+        rewriter.getTypeAttr(getGlobalVarTypeForEntryFnArg(origArg->getType())),
+        rewriter.getStringAttr(varName), nullptr);
+    var.setAttr(
+        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
+        rewriter.getI32IntegerAttr(0));
+    var.setAttr(
+        spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
+        rewriter.getI32IntegerAttr(origArgNum));
   }
+  // Insert the addressOf and load instructions, to get back the converted value
+  // type.
+  auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
+  auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(),
+                                                 rewriter.getIntegerType(32),
+                                                 rewriter.getI32IntegerAttr(0));
+  auto accessChain = rewriter.create<spirv::AccessChainOp>(
+      funcOp.getLoc(), addressOf.pointer(), zero.constant());
+  // If the original argument is a tensor/memref type, the value is not
+  // loaded. Instead the pointer value is returned to allow its use in access
+  // chain ops.
+  auto origArgType = origArg->getType();
+  if (origArgType.isa<MemRefType>()) {
+    return accessChain;
+  }
+  return rewriter.create<spirv::LoadOp>(
+      funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr,
+      /*alignment=*/nullptr);
+}
 
+static FuncOp applySignatureConversion(
+    FuncOp funcOp, ConversionPatternRewriter &rewriter,
+    TypeConverter::SignatureConversion &signatureConverter) {
   // Create a new function with an updated signature.
-  newFuncOp = rewriter.cloneWithoutRegions(funcOp);
+  auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                               newFuncOp.end());
   newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
@@ -119,72 +152,113 @@
   // Tell the rewriter to convert the region signature.
   rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
   rewriter.replaceOp(funcOp.getOperation(), llvm::None);
+  return newFuncOp;
+}
+
+/// Gets the global variables that need to be specified as interface variable
+/// with an spv.EntryPointOp. Traverses the body of a entry function to do so.
+LogicalResult getInterfaceVariables(FuncOp funcOp,
+                                    SmallVectorImpl<Attribute> &interfaceVars) {
+  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
+  if (!module) {
+    return failure();
+  }
+  llvm::SetVector<Operation *> interfaceVarSet;
+  for (auto &block : funcOp) {
+    // TODO(ravishankarm) : This should in reality traverse the entry function
+    // call graph and collect all the interfaces. For now, just traverse the
+    // instructions in this function.
+    auto callOps = block.getOps<CallOp>();
+    if (std::distance(callOps.begin(), callOps.end())) {
+      return funcOp.emitError("Collecting interface variables through function "
+                              "calls unimplemented");
+    }
+    for (auto op : block.getOps<spirv::AddressOfOp>()) {
+      auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable());
+      if (var.type().cast<spirv::PointerType>().getStorageClass() ==
+          spirv::StorageClass::StorageBuffer) {
+        continue;
+      }
+      interfaceVarSet.insert(var.getOperation());
+    }
+  }
+  for (auto &var : interfaceVarSet) {
+    interfaceVars.push_back(SymbolRefAttr::get(
+        cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
+  }
   return success();
 }
 
 namespace mlir {
-LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
-                            SPIRVTypeConverter *typeConverter,
+LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
                             ConversionPatternRewriter &rewriter,
                             FuncOp &newFuncOp) {
   auto fnType = funcOp.getType();
+  if (fnType.getNumResults()) {
+    return funcOp.emitError("SPIR-V lowering only supports functions with no "
+                            "return values right now");
+  }
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  return lowerFunctionImpl(funcOp, operands, rewriter,
-                           typeConverter->getBasicTypeConverter(),
-                           signatureConverter, newFuncOp);
+  auto basicTypeConverter = typeConverter->getBasicTypeConverter();
+  for (auto origArgType : enumerate(fnType.getInputs())) {
+    auto convertedType = basicTypeConverter->convertType(origArgType.value());
+    if (!convertedType) {
+      return funcOp.emitError("unable to convert argument of type '")
+             << convertedType << "'";
+    }
+    signatureConverter.addInputs(origArgType.index(), convertedType);
+  }
+  newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
+  return success();
 }
 
-LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+LogicalResult lowerAsEntryFunction(FuncOp funcOp,
                                    SPIRVTypeConverter *typeConverter,
                                    ConversionPatternRewriter &rewriter,
                                    FuncOp &newFuncOp) {
   auto fnType = funcOp.getType();
+  if (fnType.getNumResults()) {
+    return funcOp.emitError("SPIR-V lowering only supports functions with no "
+                            "return values right now");
+  }
+  // For entry functions need to make the signature void(void). Compute the
+  // replacement value for all arguments and replace all uses.
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
-                               signatureConverter, newFuncOp))) {
+  {
+    OpBuilder::InsertionGuard moduleInsertionGuard(rewriter);
+    rewriter.setInsertionPointToStart(&funcOp.front());
+    for (auto origArg : enumerate(funcOp.getArguments())) {
+      auto replacement = createAndLoadGlobalVarForEntryFnArg(
+          rewriter, origArg.index(), origArg.value());
+      rewriter.replaceUsesOfBlockArgument(origArg.value(), replacement);
+    }
+  }
+  newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter);
+  return success();
+}
+
+LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
+  // Add the spv.EntryPointOp after collecting all the interface variables
+  // needed.
+  SmallVector<Attribute, 1> interfaceVars;
+  if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) {
     return failure();
   }
-  // Create spv.globalVariable ops for each of the arguments. These need to be
-  // bound by the runtime. For now use descriptor_set 0, and arg number as the
-  // binding number.
-  auto module = funcOp.getParentOfType<spirv::ModuleOp>();
-  if (!module) {
-    return funcOp.emitError("expected op to be within a spv.module");
-  }
-  auto ip = rewriter.saveInsertionPoint();
-  rewriter.setInsertionPointToStart(&module.getBlock());
-  SmallVector<Attribute, 4> interface;
-  for (auto &convertedArgType :
-       llvm::enumerate(signatureConverter.getConvertedTypes())) {
-    // TODO(ravishankarm) : The arguments to the converted function are either
-    // spirv::PointerType or i32 type, the latter due to conversion of index
-    // type to i32. Eventually entry function should be of signature
-    // void(void). Arguments converted to spirv::PointerType, will be made
-    // variables and those converted to i32 will be made specialization
-    // constants. Latter is not implemented.
-    if (!convertedArgType.value().isa<spirv::PointerType>()) {
-      continue;
-    }
-    std::string varName = funcOp.getName().str() + "_arg_" +
-                          std::to_string(convertedArgType.index());
-    auto variableOp = rewriter.create<spirv::GlobalVariableOp>(
-        funcOp.getLoc(), rewriter.getTypeAttr(convertedArgType.value()),
-        rewriter.getStringAttr(varName), nullptr);
-    variableOp.setAttr("descriptor_set", rewriter.getI32IntegerAttr(0));
-    variableOp.setAttr("binding",
-                       rewriter.getI32IntegerAttr(convertedArgType.index()));
-    interface.push_back(rewriter.getSymbolRefAttr(variableOp.sym_name()));
-  }
-  // Create an entry point instruction for this function.
-  // TODO(ravishankarm) : Add execution mode for the entry function
-  rewriter.setInsertionPoint(&(module.getBlock().back()));
-  rewriter.create<spirv::EntryPointOp>(
-      funcOp.getLoc(),
-      rewriter.getI32IntegerAttr(
-          static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
-      rewriter.getSymbolRefAttr(newFuncOp.getName()),
-      rewriter.getArrayAttr(interface));
-  rewriter.restoreInsertionPoint(ip);
+  builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(),
+                                      spirv::ExecutionModel::GLCompute,
+                                      newFuncOp, interfaceVars);
+  // Specify the spv.ExecutionModeOp.
+
+  /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a
+  /// LocalSize execution mode or an object decorated with the WorkgroupSize
+  /// decoration must be specified." Better approach is to use the
+  /// WorkgroupSize GlobalVariable with initializer being a specialization
+  /// constant. But current support for specialization constant does not allow
+  /// for this. So for now use the execution mode. Hard-wiring this to {1, 1,
+  /// 1} for now. To be fixed ASAP.
+  builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp,
+                                         spirv::ExecutionMode::LocalSize,
+                                         ArrayRef<int32_t>{1, 1, 1});
   return success();
 }
 } // namespace mlir
diff --git a/test/Conversion/GPUToSPIRV/load_store.mlir b/test/Conversion/GPUToSPIRV/load_store.mlir
index d362ce1..fc3f12d 100644
--- a/test/Conversion/GPUToSPIRV/load_store.mlir
+++ b/test/Conversion/GPUToSPIRV/load_store.mlir
@@ -16,13 +16,52 @@
   }
 
   // CHECK-LABEL: spv.module "Logical" "GLSL450"
-  // CHECK: spv.globalVariable {{@.*}} bind(0, 0) : [[TYPE1:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
-  // CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 1) : [[TYPE2:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
-  // CHECK-NEXT: spv.globalVariable {{@.*}} bind(0, 2) : [[TYPE3:!spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer>]]
-  // CHECK: func @load_store_kernel([[ARG0:%.*]]: [[TYPE1]], [[ARG1:%.*]]: [[TYPE2]], [[ARG2:%.*]]: [[TYPE3]], [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: i32, [[ARG6:%.*]]: i32)
   module @kernels attributes {gpu.kernel_module} {
+    // CHECK-DAG: spv.globalVariable [[WORKGROUPSIZEVAR:@.*]] built_in("WorkgroupSize") : !spv.ptr<vector<3xi32>, Input>
+    // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr<vector<3xi32>, Input>
+    // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
+    // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
+    // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
+    // CHECK-DAG: spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
+    // CHECK-DAG: spv.globalVariable [[VAR2:@.*]] bind(0, 2) : !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
+    // CHECK-DAG: spv.globalVariable [[VAR3:@.*]] bind(0, 3) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
+    // CHECK-DAG: spv.globalVariable [[VAR4:@.*]] bind(0, 4) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
+    // CHECK-DAG: spv.globalVariable [[VAR5:@.*]] bind(0, 5) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
+    // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32>, StorageBuffer>
+    // CHECK: func [[FN:@.*]]()
     func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
       attributes  {gpu.kernel} {
+      // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+      // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG0:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+      // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+      // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
+      // CHECK: [[ADDRESSARG2:%.*]] = spv._address_of [[VAR2]]
+      // CHECK: [[CONST2:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG2:%.*]] = spv.AccessChain [[ADDRESSARG2]]{{\[}}[[CONST2]]
+      // CHECK: [[ADDRESSARG3:%.*]] = spv._address_of [[VAR3]]
+      // CHECK: [[CONST3:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG3PTR:%.*]] = spv.AccessChain [[ADDRESSARG3]]{{\[}}[[CONST3]]
+      // CHECK: [[ARG3:%.*]] = spv.Load "StorageBuffer" [[ARG3PTR]]
+      // CHECK: [[ADDRESSARG4:%.*]] = spv._address_of [[VAR4]]
+      // CHECK: [[CONST4:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG4PTR:%.*]] = spv.AccessChain [[ADDRESSARG4]]{{\[}}[[CONST4]]
+      // CHECK: [[ARG4:%.*]] = spv.Load "StorageBuffer" [[ARG4PTR]]
+      // CHECK: [[ADDRESSARG5:%.*]] = spv._address_of [[VAR5]]
+      // CHECK: [[CONST5:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG5PTR:%.*]] = spv.AccessChain [[ADDRESSARG5]]{{\[}}[[CONST5]]
+      // CHECK: [[ARG5:%.*]] = spv.Load "StorageBuffer" [[ARG5PTR]]
+      // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
+      // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
+      // CHECK: [[ARG6:%.*]] = spv.Load "StorageBuffer" [[ARG6PTR]]
+      // CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
+      // CHECK: [[WORKGROUPID:%.*]] = spv.Load "Input" [[ADDRESSWORKGROUPID]]
+      // CHECK: [[WORKGROUPIDX:%.*]] = spv.CompositeExtract [[WORKGROUPID]]{{\[}}0 : i32{{\]}}
+      // CHECK: [[ADDRESSLOCALINVOCATIONID:%.*]] = spv._address_of [[LOCALINVOCATIONIDVAR]]
+      // CHECK: [[LOCALINVOCATIONID:%.*]] = spv.Load "Input" [[ADDRESSLOCALINVOCATIONID]]
+      // CHECK: [[LOCALINVOCATIONIDX:%.*]] = spv.CompositeExtract [[LOCALINVOCATIONID]]{{\[}}0 : i32{{\]}}
       %0 = "gpu.block_id"() {dimension = "x"} : () -> index
       %1 = "gpu.block_id"() {dimension = "y"} : () -> index
       %2 = "gpu.block_id"() {dimension = "z"} : () -> index
@@ -35,9 +74,9 @@
       %9 = "gpu.block_dim"() {dimension = "x"} : () -> index
       %10 = "gpu.block_dim"() {dimension = "y"} : () -> index
       %11 = "gpu.block_dim"() {dimension = "z"} : () -> index
-      // CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], {{%.*}}
+      // CHECK: [[INDEX1:%.*]] = spv.IAdd [[ARG3]], [[WORKGROUPIDX]]
       %12 = addi %arg3, %0 : index
-      // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], {{%.*}}
+      // CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
       %13 = addi %arg4, %3 : index
       // CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
       // CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
diff --git a/test/Conversion/GPUToSPIRV/simple.mlir b/test/Conversion/GPUToSPIRV/simple.mlir
index 73c72cb..e1642ea 100644
--- a/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/test/Conversion/GPUToSPIRV/simple.mlir
@@ -2,15 +2,23 @@
 
 module attributes {gpu.container_module} {
 
-  // CHECK:       spv.module "Logical" "GLSL450" {
-  // CHECK-NEXT:    spv.globalVariable [[VAR1:@.*]] bind(0, 0) : !spv.ptr<f32, StorageBuffer>
-  // CHECK-NEXT:    spv.globalVariable [[VAR2:@.*]] bind(0, 1) : !spv.ptr<!spv.array<12 x f32>, StorageBuffer>
-  // CHECK-NEXT:    func @kernel_1
-  // CHECK-NEXT:      spv.Return
-  // CHECK:       spv.EntryPoint "GLCompute" @kernel_1, [[VAR1]], [[VAR2]]
   module @kernels attributes {gpu.kernel_module} {
+    // CHECK:       spv.module "Logical" "GLSL450" {
+    // CHECK-DAG:    spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32>, StorageBuffer>
+    // CHECK-DAG:    spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
+    // CHECK:    func [[FN:@.*]]()
     func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
         attributes { gpu.kernel } {
+      // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
+      // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG0PTR:%.*]] = spv.AccessChain [[ADDRESSARG0]]{{\[}}[[CONST0]]
+      // CHECK: [[ARG0:%.*]] = spv.Load "StorageBuffer" [[ARG0PTR]]
+      // CHECK: [[ADDRESSARG1:%.*]] = spv._address_of [[VAR1]]
+      // CHECK: [[CONST1:%.*]] = spv.constant 0 : i32
+      // CHECK: [[ARG1:%.*]] = spv.AccessChain [[ADDRESSARG1]]{{\[}}[[CONST1]]
+      // CHECK-NEXT: spv.Return
+      // CHECK: spv.EntryPoint "GLCompute" [[FN]]
+      // CHECK: spv.ExecutionMode [[FN]] "LocalSize"
       return
     }
   }
@@ -23,5 +31,4 @@
         : (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
     return
   }
-
 }