Add support for SPIR-V Struct Types. Current support is limited to
supporting only Offset decorations

PiperOrigin-RevId: 256216704
diff --git a/g3doc/Dialects/SPIR-V.md b/g3doc/Dialects/SPIR-V.md
index 6101d3b..e149540 100644
--- a/g3doc/Dialects/SPIR-V.md
+++ b/g3doc/Dialects/SPIR-V.md
@@ -152,6 +152,24 @@
 !spv.rtarray<vector<4 x f32>>
 ```
 
+### Struct type
+
+This corresponds to SPIR-V [struct type][StructType]. Its syntax is
+
+``` {.ebnf}
+struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]` )?
+                (`, ` spirv-type ( ` [` integer-literal `] ` )? )* `>`
+```
+
+For Example,
+
+``` {.mlir}
+!spv.struct<f32>
+!spv.struct<f32 [0]>
+!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
+!spv.struct<f32 [0], i32 [4]>
+```
+
 ## Serialization
 
 The serialization library provides two entry points, `mlir::spirv::serialize()`
@@ -168,7 +186,8 @@
 
 [SPIR-V]: https://www.khronos.org/registry/spir-v/
 [ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray
+[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
 [PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer
 [RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray
-[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
+[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure
 [SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index adaa7d9..48c7cb3 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -145,10 +145,10 @@
   unsigned getKind() const;
 
   /// Return the LLVMContext in which this type was uniqued.
-  MLIRContext *getContext();
+  MLIRContext *getContext() const;
 
   /// Get the dialect this type is registered to.
-  Dialect &getDialect();
+  Dialect &getDialect() const;
 
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
diff --git a/include/mlir/SPIRV/SPIRVDialect.h b/include/mlir/SPIRV/SPIRVDialect.h
index 4272a72..66345ee 100644
--- a/include/mlir/SPIRV/SPIRVDialect.h
+++ b/include/mlir/SPIRV/SPIRVDialect.h
@@ -38,22 +38,6 @@
 
   /// Prints a type registered to this dialect.
   void printType(Type type, llvm::raw_ostream &os) const override;
-
-private:
-  /// Parses `spec` as a type and verifies it can be used in SPIR-V types.
-  Type parseAndVerifyType(StringRef spec, Location loc) const;
-
-  /// Parses `spec` as a SPIR-V array type.
-  Type parseArrayType(StringRef spec, Location loc) const;
-
-  /// Parses `spec` as a SPIR-V pointer type.
-  Type parsePointerType(StringRef spec, Location loc) const;
-
-  /// Parses `spec` as a SPIR-V run-time array type.
-  Type parseRuntimeArrayType(StringRef spec, Location loc) const;
-
-  /// Parses `spec` as a SPIR-V image type
-  Type parseImageType(StringRef spec, Location loc) const;
 };
 
 } // end namespace spirv
diff --git a/include/mlir/SPIRV/SPIRVOps.td b/include/mlir/SPIRV/SPIRVOps.td
index 3ce4f64..7951bf8 100644
--- a/include/mlir/SPIRV/SPIRVOps.td
+++ b/include/mlir/SPIRV/SPIRVOps.td
@@ -83,7 +83,7 @@
     ### Custom assembly form
 
     ``` {.ebnf}
-    memory-access ::= `"None"` | `"Volatile"` | `"Aligned"` integer-literal
+    memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal
                     | `"NonTemporal"`
 
     load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use
@@ -118,6 +118,8 @@
       return "alignment";
     }
   }];
+
+  let opcode = 61;
 }
 
 def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
@@ -157,7 +159,7 @@
 
     ``` {.ebnf}
     store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, `
-                  (memory-access)? : spirv-element-type
+                  (`[` memory-access `]`)? `:` spirv-element-type
     ```
 
     For example:
@@ -185,6 +187,8 @@
       return "alignment";
     }
   }];
+
+  let opcode = 62;
 }
 
 def SPV_VariableOp : SPV_Op<"Variable"> {
diff --git a/include/mlir/SPIRV/SPIRVTypes.h b/include/mlir/SPIRV/SPIRVTypes.h
index 6ade59b..fbb0ce0 100644
--- a/include/mlir/SPIRV/SPIRVTypes.h
+++ b/include/mlir/SPIRV/SPIRVTypes.h
@@ -37,14 +37,16 @@
 struct ImageTypeStorage;
 struct PointerTypeStorage;
 struct RuntimeArrayTypeStorage;
+struct StructTypeStorage;
 } // namespace detail
 
 namespace TypeKind {
 enum Kind {
   Array = Type::FIRST_SPIRV_TYPE,
-  ImageType,
+  Image,
   Pointer,
   RuntimeArray,
+  Struct,
 };
 }
 
@@ -58,9 +60,9 @@
 
   static ArrayType get(Type elementType, int64_t elementCount);
 
-  Type getElementType();
+  Type getElementType() const;
 
-  int64_t getElementCount();
+  int64_t getElementCount() const;
 };
 
 // SPIR-V pointer type
@@ -73,9 +75,10 @@
 
   static PointerType get(Type pointeeType, StorageClass storageClass);
 
-  Type getPointeeType();
+  Type getPointeeType() const;
 
-  StorageClass getStorageClass();
+  StorageClass getStorageClass() const;
+  StringRef getStorageClassStr() const;
 };
 
 // SPIR-V run-time array type
@@ -89,16 +92,17 @@
 
   static RuntimeArrayType get(Type elementType);
 
-  Type getElementType();
+  Type getElementType() const;
 };
 
 // SPIR-V image type
+// TODO(ravishankarm) : Move this in alphabetical order
 class ImageType
     : public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
 public:
   using Base::Base;
 
-  static bool kindof(unsigned kind) { return kind == TypeKind::ImageType; }
+  static bool kindof(unsigned kind) { return kind == TypeKind::Image; }
 
   static ImageType
   get(Type elementType, Dim dim,
@@ -118,16 +122,45 @@
       get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                      ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
 
-  Type getElementType();
-  Dim getDim();
-  ImageDepthInfo getDepthInfo();
-  ImageArrayedInfo getArrayedInfo();
-  ImageSamplingInfo getSamplingInfo();
-  ImageSamplerUseInfo getSamplerUseInfo();
-  ImageFormat getImageFormat();
+  Type getElementType() const;
+  Dim getDim() const;
+  ImageDepthInfo getDepthInfo() const;
+  ImageArrayedInfo getArrayedInfo() const;
+  ImageSamplingInfo getSamplingInfo() const;
+  ImageSamplerUseInfo getSamplerUseInfo() const;
+  ImageFormat getImageFormat() const;
   // TODO(ravishankarm): Add support for Access qualifier
 };
 
+// SPIR-V struct type
+class StructType
+    : public Type::TypeBase<StructType, Type, detail::StructTypeStorage> {
+
+public:
+  using Base::Base;
+
+  // Layout information used for members in a struct in SPIR-V
+  //
+  // TODO(ravishankarm) : For now this only supports the offset type, so uses
+  // uint64_t value to represent the offset, with
+  // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
+  // something that can hold all the information needed for different member
+  // types
+  using LayoutInfo = uint64_t;
+
+  static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
+
+  static StructType get(ArrayRef<Type> memberTypes);
+
+  static StructType get(ArrayRef<Type> memberTypes,
+                        ArrayRef<LayoutInfo> layoutInfo);
+
+  size_t getNumMembers() const;
+  Type getMemberType(size_t) const;
+  bool hasLayout() const;
+  uint64_t getOffset(size_t) const;
+};
+
 } // end namespace spirv
 } // end namespace mlir
 
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 78bfc47..cd75176 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -27,9 +27,9 @@
 unsigned Type::getKind() const { return impl->getKind(); }
 
 /// Get the dialect this type is registered to.
-Dialect &Type::getDialect() { return impl->getDialect(); }
+Dialect &Type::getDialect() const { return impl->getDialect(); }
 
-MLIRContext *Type::getContext() { return getDialect().getContext(); }
+MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
 unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
 void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
diff --git a/lib/SPIRV/SPIRVDialect.cpp b/lib/SPIRV/SPIRVDialect.cpp
index 816d673..67dd549 100644
--- a/lib/SPIRV/SPIRVDialect.cpp
+++ b/lib/SPIRV/SPIRVDialect.cpp
@@ -1,26 +1,16 @@
 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
 //
-// Copyright 2019 The MLIR Authors.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
 //
 // This file defines the SPIR-V dialect in MLIR.
 //
 //===----------------------------------------------------------------------===//
 
 #include "mlir/SPIRV/SPIRVDialect.h"
-
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Parser.h"
@@ -32,8 +22,6 @@
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 
-#include <type_traits>
-
 using namespace mlir;
 using namespace mlir::spirv;
 
@@ -43,7 +31,7 @@
 
 SPIRVDialect::SPIRVDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
-  addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType>();
+  addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
 
   addOperations<
 #define GET_OP_LIST
@@ -77,8 +65,9 @@
   return true;
 }
 
-static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc,
-                                   StringRef spec) {
+static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
+                               Location loc) {
+  spec = spec.trim();
   auto *context = dialect.getContext();
   auto type = mlir::parseType(spec.trim(), context);
   if (!type) {
@@ -116,17 +105,14 @@
   return type;
 }
 
-Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const {
-  return parseAndVerifyTypeImpl(*this, loc, spec);
-}
-
 // element-type ::= integer-type
 //                | floating-point-type
 //                | vector-type
 //                | spirv-type
 //
 // array-type ::= `!spv.array<` integer-literal `x` element-type `>`
-Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const {
+static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
+                           Location loc) {
   if (!spec.consume_front("array<") || !spec.consume_back(">")) {
     emitError(loc, "spv.array delimiter <...> mismatch");
     return Type();
@@ -145,20 +131,24 @@
     return Type();
   }
 
-  Type elementType = parseAndVerifyType(spec, loc);
+  Type elementType = parseAndVerifyType(dialect, spec, loc);
   if (!elementType)
     return Type();
 
   return ArrayType::get(elementType, count);
 }
 
+// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
+// methods in alphabetical order
+//
 // storage-class ::= `UniformConstant`
 //                 | `Uniform`
 //                 | `Workgroup`
 //                 | <and other storage classes...>
 //
 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
-Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const {
+static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec,
+                             Location loc) {
   if (!spec.consume_front("ptr<") || !spec.consume_back(">")) {
     emitError(loc, "spv.ptr delimiter <...> mismatch");
     return Type();
@@ -186,7 +176,7 @@
     return Type();
   }
 
-  auto pointeeType = parseAndVerifyType(ptSpec, loc);
+  auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc);
   if (!pointeeType)
     return Type();
 
@@ -194,7 +184,8 @@
 }
 
 // runtime-array-type ::= `!spv.rtarray<` element-type `>`
-Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
+static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec,
+                                  Location loc) {
   if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) {
     emitError(loc, "spv.rtarray delimiter <...> mismatch");
     return Type();
@@ -205,7 +196,7 @@
     return Type();
   }
 
-  Type elementType = parseAndVerifyType(spec, loc);
+  Type elementType = parseAndVerifyType(dialect, spec, loc);
   if (!elementType)
     return Type();
 
@@ -215,8 +206,8 @@
 // Specialize this function to parse each of the parameters that define an
 // ImageType
 template <typename ValTy>
-Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
-                               StringRef spec) {
+static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+                                      StringRef spec) {
   emitError(loc, "unexpected parameter while parsing '") << spec << "'";
   return llvm::None;
 }
@@ -225,15 +216,20 @@
 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
                                     StringRef spec) {
   // TODO(ravishankarm): Further verify that the element type can be sampled
-  return parseAndVerifyTypeImpl(dialect, loc, spec);
+  auto ty = parseAndVerifyType(dialect, spec, loc);
+  if (!ty) {
+    return llvm::None;
+  }
+  return ty;
 }
 
 template <>
 Optional<Dim> parseAndVerify<Dim>(SPIRVDialect const &dialect, Location loc,
                                   StringRef spec) {
   auto dim = symbolizeDim(spec);
-  if (!dim)
+  if (!dim) {
     emitError(loc, "unknown Dim in Image type: '") << spec << "'";
+  }
   return dim;
 }
 
@@ -242,8 +238,9 @@
 parseAndVerify<ImageDepthInfo>(SPIRVDialect const &dialect, Location loc,
                                StringRef spec) {
   auto depth = symbolizeImageDepthInfo(spec);
-  if (!depth)
+  if (!depth) {
     emitError(loc, "unknown ImageDepthInfo in Image type: '") << spec << "'";
+  }
   return depth;
 }
 
@@ -252,8 +249,9 @@
 parseAndVerify<ImageArrayedInfo>(SPIRVDialect const &dialect, Location loc,
                                  StringRef spec) {
   auto arrayedInfo = symbolizeImageArrayedInfo(spec);
-  if (!arrayedInfo)
+  if (!arrayedInfo) {
     emitError(loc, "unknown ImageArrayedInfo in Image type: '") << spec << "'";
+  }
   return arrayedInfo;
 }
 
@@ -262,8 +260,9 @@
 parseAndVerify<ImageSamplingInfo>(SPIRVDialect const &dialect, Location loc,
                                   StringRef spec) {
   auto samplingInfo = symbolizeImageSamplingInfo(spec);
-  if (!samplingInfo)
+  if (!samplingInfo) {
     emitError(loc, "unknown ImageSamplingInfo in Image type: '") << spec << "'";
+  }
   return samplingInfo;
 }
 
@@ -272,9 +271,10 @@
 parseAndVerify<ImageSamplerUseInfo>(SPIRVDialect const &dialect, Location loc,
                                     StringRef spec) {
   auto samplerUseInfo = symbolizeImageSamplerUseInfo(spec);
-  if (!samplerUseInfo)
+  if (!samplerUseInfo) {
     emitError(loc, "unknown ImageSamplerUseInfo in Image type: '")
         << spec << "'";
+  }
   return samplerUseInfo;
 }
 
@@ -283,11 +283,41 @@
                                                   Location loc,
                                                   StringRef spec) {
   auto format = symbolizeImageFormat(spec);
-  if (!format)
+  if (!format) {
     emitError(loc, "unknown ImageFormat in Image type: '") << spec << "'";
+  }
   return format;
 }
 
+template <>
+Optional<spirv::StructType::LayoutInfo>
+parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
+  uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
+  if (!spec.consume_front("[")) {
+    emitError(loc, "expected '[' while parsing layout specification in '")
+        << spec << "'";
+    return llvm::None;
+  }
+  if (spec.consumeInteger(10, offsetVal)) {
+    emitError(
+        loc,
+        "expected unsigned integer to specify offset of member in struct: '")
+        << spec << "'";
+    return llvm::None;
+  }
+  spec = spec.trim();
+  if (!spec.consume_front("]")) {
+    emitError(loc, "missing ']' in decorations spec: '") << spec << "'";
+    return llvm::None;
+  }
+  if (spec != "") {
+    emitError(loc, "unexpected extra tokens in layout information: '")
+        << spec << "'";
+    return llvm::None;
+  }
+  return spirv::StructType::LayoutInfo{offsetVal};
+}
+
 // Functor object to parse a comma separated list of specs. The function
 // parseAndVerify does the actual parsing and verification of individual
 // elements. This is a functor since parsing the last element of the list
@@ -350,7 +380,8 @@
 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
 //                              arrayed-info `,` sampling-info `,`
 //                              sampler-use-info `,` format `>`
-Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const {
+static Type parseImageType(SPIRVDialect const &dialect, StringRef spec,
+                           Location loc) {
   if (!spec.consume_front("image<") || !spec.consume_back(">")) {
     emitError(loc, "spv.image delimiter <...> mismatch");
     return Type();
@@ -359,7 +390,7 @@
   auto value =
       parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                               ImageSamplingInfo, ImageSamplerUseInfo,
-                              ImageFormat>{}(*this, loc, spec);
+                              ImageFormat>{}(dialect, loc, spec);
   if (!value) {
     return Type();
   }
@@ -367,15 +398,151 @@
   return ImageType::get(value.getValue());
 }
 
+// Method to parse one member of a struct (including Layout information)
+static ParseResult
+parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc,
+                   SmallVectorImpl<Type> &memberTypes,
+                   SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
+  // Check for a '[' <layoutInfo> ']'
+  auto lastLSquare = spec.rfind('[');
+  auto typeSpec = spec.substr(0, lastLSquare);
+  auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("")
+                                                    : spec.substr(lastLSquare));
+  auto type = parseAndVerify<Type>(dialect, loc, typeSpec);
+  if (!type) {
+    return failure();
+  }
+  memberTypes.push_back(type.getValue());
+  if (layoutSpec.empty()) {
+    return success();
+  }
+  if (layoutInfo.size() != memberTypes.size() - 1) {
+    emitError(loc, "layout specification must be given for all members");
+    return failure();
+  }
+  auto layout =
+      parseAndVerify<StructType::LayoutInfo>(dialect, loc, layoutSpec);
+  if (!layout) {
+    return failure();
+  }
+  layoutInfo.push_back(layout.getValue());
+  return success();
+}
+
+// Helper method to record the position of the corresponding '>' for every '<'
+// encountered when parsing the string left to right. The relative position of
+// '>' w.r.t to the '<' is recorded.
+static bool
+computeMatchingRAngles(Location loc, StringRef const &spec,
+                       SmallVectorImpl<size_t> &matchingRAngleOffset) {
+  SmallVector<size_t, 4> openBrackets;
+  for (size_t i = 0, e = spec.size(); i != e; ++i) {
+    if (spec[i] == '<') {
+      openBrackets.push_back(i);
+    } else if (spec[i] == '>') {
+      if (openBrackets.empty()) {
+        emitError(loc, "unbalanced '<' in '") << spec << "'";
+        return false;
+      }
+      matchingRAngleOffset.push_back(i - openBrackets.pop_back_val());
+    }
+  }
+  return true;
+}
+
+static ParseResult
+parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc,
+                  ArrayRef<size_t> matchingRAngleOffset,
+                  SmallVectorImpl<Type> &memberTypes,
+                  SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
+  // Check if the occurrence of ',' or '<' is before. If former, split using
+  // ','. If latter, split using matching '>' to get the entire type
+  // description
+  auto firstComma = spec.find(',');
+  auto firstLAngle = spec.find('<');
+  if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) {
+    return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo);
+  }
+  if (firstLAngle == StringRef::npos || firstComma < firstLAngle) {
+    // Parse the type before the ','
+    if (parseStructElement(dialect, spec.substr(0, firstComma), loc,
+                           memberTypes, layoutInfo)) {
+      return failure();
+    }
+    return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc,
+                             matchingRAngleOffset, memberTypes, layoutInfo);
+  }
+  auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle;
+  // Find the next ',' or '>'
+  auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size());
+  if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes,
+                         layoutInfo)) {
+    return failure();
+  }
+  auto rest = spec.substr(endLoc + 1).ltrim();
+  if (rest.empty()) {
+    return success();
+  }
+  if (rest.front() == ',') {
+    return parseStructHelper(
+        dialect, rest.drop_front().trim(), loc,
+        ArrayRef<size_t>(std::next(matchingRAngleOffset.begin()),
+                         matchingRAngleOffset.end()),
+        memberTypes, layoutInfo);
+  }
+  emitError(loc, "unexpected string : '") << rest << "'";
+  return failure();
+}
+
+// struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)?
+//                 (`, ` spirv-type ( ` [` integer-literal `] ` )? )*
+static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
+                            Location loc) {
+  if (!spec.consume_front("struct<") || !spec.consume_back(">")) {
+    emitError(loc, "spv.struct delimiter <...> mismatch");
+    return Type();
+  }
+
+  if (spec.trim().empty()) {
+    emitError(loc, "expected SPIR-V type");
+    return Type();
+  }
+
+  SmallVector<Type, 4> memberTypes;
+  SmallVector<StructType::LayoutInfo, 4> layoutInfo;
+  SmallVector<size_t, 4> matchingRAngleOffset;
+  if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) ||
+      parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes,
+                        layoutInfo)) {
+    return Type();
+  }
+  if (layoutInfo.empty()) {
+    return StructType::get(memberTypes);
+  }
+  if (memberTypes.size() != layoutInfo.size()) {
+    emitError(loc, "layout specification must be given for all members");
+    return Type();
+  }
+  return StructType::get(memberTypes, layoutInfo);
+}
+
+// spirv-type ::= array-type
+//              | element-type
+//              | image-type
+//              | pointer-type
+//              | runtime-array-type
+//              | struct-type
 Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
   if (spec.startswith("array"))
-    return parseArrayType(spec, loc);
+    return parseArrayType(*this, spec, loc);
   if (spec.startswith("image"))
-    return parseImageType(spec, loc);
+    return parseImageType(*this, spec, loc);
   if (spec.startswith("ptr"))
-    return parsePointerType(spec, loc);
+    return parsePointerType(*this, spec, loc);
   if (spec.startswith("rtarray"))
-    return parseRuntimeArrayType(spec, loc);
+    return parseRuntimeArrayType(*this, spec, loc);
+  if (spec.startswith("struct"))
+    return parseStructType(*this, spec, loc);
 
   emitError(loc, "unknown SPIR-V type: ") << spec;
   return Type();
@@ -408,6 +575,19 @@
      << stringifyImageFormat(type.getImageFormat()) << ">";
 }
 
+static void print(StructType type, llvm::raw_ostream &os) {
+  os << "struct<";
+  std::string sep = "";
+  for (size_t i = 0, e = type.getNumMembers(); i != e; ++i) {
+    os << sep << type.getMemberType(i);
+    if (type.hasLayout()) {
+      os << " [" << type.getOffset(i) << "]";
+    }
+    sep = ", ";
+  }
+  os << ">";
+}
+
 void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
   switch (type.getKind()) {
   case TypeKind::Array:
@@ -419,9 +599,12 @@
   case TypeKind::RuntimeArray:
     print(type.cast<RuntimeArrayType>(), os);
     return;
-  case TypeKind::ImageType:
+  case TypeKind::Image:
     print(type.cast<ImageType>(), os);
     return;
+  case TypeKind::Struct:
+    print(type.cast<StructType>(), os);
+    return;
   default:
     llvm_unreachable("unhandled SPIR-V type");
   }
diff --git a/lib/SPIRV/SPIRVTypes.cpp b/lib/SPIRV/SPIRVTypes.cpp
index 23acd65..f62f3e1 100644
--- a/lib/SPIRV/SPIRVTypes.cpp
+++ b/lib/SPIRV/SPIRVTypes.cpp
@@ -56,9 +56,9 @@
                    elementCount);
 }
 
-Type ArrayType::getElementType() { return getImpl()->elementType; }
+Type ArrayType::getElementType() const { return getImpl()->elementType; }
 
-int64_t ArrayType::getElementCount() { return getImpl()->elementCount; }
+int64_t ArrayType::getElementCount() const { return getImpl()->elementCount; }
 
 //===----------------------------------------------------------------------===//
 // ImageType
@@ -216,28 +216,32 @@
 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
                           ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
                    value) {
-  return Base::get(std::get<0>(value).getContext(), TypeKind::ImageType, value);
+  return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
 }
 
-Type ImageType::getElementType() { return getImpl()->elementType; }
+Type ImageType::getElementType() const { return getImpl()->elementType; }
 
-Dim ImageType::getDim() { return getImpl()->getDim(); }
+Dim ImageType::getDim() const { return getImpl()->getDim(); }
 
-ImageDepthInfo ImageType::getDepthInfo() { return getImpl()->getDepthInfo(); }
+ImageDepthInfo ImageType::getDepthInfo() const {
+  return getImpl()->getDepthInfo();
+}
 
-ImageArrayedInfo ImageType::getArrayedInfo() {
+ImageArrayedInfo ImageType::getArrayedInfo() const {
   return getImpl()->getArrayedInfo();
 }
 
-ImageSamplingInfo ImageType::getSamplingInfo() {
+ImageSamplingInfo ImageType::getSamplingInfo() const {
   return getImpl()->getSamplingInfo();
 }
 
-ImageSamplerUseInfo ImageType::getSamplerUseInfo() {
+ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
   return getImpl()->getSamplerUseInfo();
 }
 
-ImageFormat ImageType::getImageFormat() { return getImpl()->getImageFormat(); }
+ImageFormat ImageType::getImageFormat() const {
+  return getImpl()->getImageFormat();
+}
 
 //===----------------------------------------------------------------------===//
 // PointerType
@@ -274,12 +278,16 @@
                    storageClass);
 }
 
-Type PointerType::getPointeeType() { return getImpl()->pointeeType; }
+Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
 
-StorageClass PointerType::getStorageClass() {
+StorageClass PointerType::getStorageClass() const {
   return getImpl()->getStorageClass();
 }
 
+StringRef PointerType::getStorageClassStr() const {
+  return stringifyStorageClass(getStorageClass());
+}
+
 //===----------------------------------------------------------------------===//
 // RuntimeArrayType
 //===----------------------------------------------------------------------===//
@@ -305,4 +313,88 @@
                    elementType);
 }
 
-Type RuntimeArrayType::getElementType() { return getImpl()->elementType; }
+Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
+
+//===----------------------------------------------------------------------===//
+// StructType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::StructTypeStorage : public TypeStorage {
+  StructTypeStorage(unsigned numMembers, Type const *memberTypes,
+                    StructType::LayoutInfo const *layoutInfo)
+      : TypeStorage(numMembers), memberTypes(memberTypes),
+        layoutInfo(layoutInfo) {}
+
+  using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>>;
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(getMemberTypes(), getLayoutInfo());
+  }
+
+  static StructTypeStorage *construct(TypeStorageAllocator &allocator,
+                                      const KeyTy &key) {
+    ArrayRef<Type> keyTypes = key.first;
+
+    // Copy the member type and layout information into the bump pointer
+    auto typesList = allocator.copyInto(keyTypes).data();
+
+    const StructType::LayoutInfo *layoutInfoList = nullptr;
+    if (!key.second.empty()) {
+      ArrayRef<StructType::LayoutInfo> keyLayoutInfo = key.second;
+      assert(keyLayoutInfo.size() == keyTypes.size() &&
+             "size of layout information must be same as the size of number of "
+             "elements");
+      layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
+    }
+
+    return new (allocator.allocate<StructTypeStorage>())
+        StructTypeStorage(keyTypes.size(), typesList, layoutInfoList);
+  }
+
+  ArrayRef<Type> getMemberTypes() const {
+    return ArrayRef<Type>(memberTypes, getSubclassData());
+  }
+
+  ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
+    if (layoutInfo) {
+      return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
+    }
+    return ArrayRef<StructType::LayoutInfo>(nullptr, size_t(0));
+  }
+
+  Type const *memberTypes;
+  StructType::LayoutInfo const *layoutInfo;
+};
+
+StructType StructType::get(ArrayRef<Type> memberTypes) {
+  assert(!memberTypes.empty() && "Struct needs at least one member type");
+  ArrayRef<StructType::LayoutInfo> noLayout(nullptr, size_t(0));
+  return Base::get(memberTypes[0].getContext(), TypeKind::Struct, memberTypes,
+                   noLayout);
+}
+
+StructType StructType::get(ArrayRef<Type> memberTypes,
+                           ArrayRef<StructType::LayoutInfo> layoutInfo) {
+  assert(!memberTypes.empty() && "Struct needs at least one member type");
+  return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
+                   memberTypes, layoutInfo);
+}
+
+size_t StructType::getNumMembers() const {
+  return getImpl()->getSubclassData();
+}
+
+Type StructType::getMemberType(size_t i) const {
+  assert(
+      getNumMembers() > i &&
+      "element index is more than number of members of the SPIR-V StructType");
+  return getImpl()->memberTypes[i];
+}
+
+bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
+
+uint64_t StructType::getOffset(size_t i) const {
+  assert(
+      getNumMembers() > i &&
+      "element index is more than number of members of the SPIR-V StructType");
+  return getImpl()->layoutInfo[i];
+}
diff --git a/test/SPIRV/types.mlir b/test/SPIRV/types.mlir
index 857871a..ffd0fb3 100644
--- a/test/SPIRV/types.mlir
+++ b/test/SPIRV/types.mlir
@@ -200,3 +200,51 @@
 // expected-error @+1 {{expected more parameters for image type 'SamplerUnknown Unknown'}}
 func @image_parameters_nocomma_5(!spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown Unknown>) -> ()
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// StructType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @struct_type(!spv.struct<f32>)
+func @struct_type(!spv.struct<f32>) -> ()
+
+// CHECK: func @struct_type2(!spv.struct<f32 [0]>)
+func @struct_type2(!spv.struct<f32 [0]>) -> ()
+
+// CHECK: func @struct_type_simple(!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>)
+func @struct_type_simple(!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>) -> ()
+
+// CHECK: func @struct_type_with_offset(!spv.struct<f32 [0], i32 [4]>)
+func @struct_type_with_offset(!spv.struct<f32 [0], i32 [4]>) -> ()
+
+// CHECK: func @nested_struct(!spv.struct<f32, !spv.struct<f32, i32>>)
+func @nested_struct(!spv.struct<f32, !spv.struct<f32, i32>>)
+
+// CHECK: func @nested_struct_with_offset(!spv.struct<f32 [0], !spv.struct<f32 [0], i32 [4]> [4]>)
+func @nested_struct_with_offset(!spv.struct<f32 [0], !spv.struct<f32 [0], i32 [4]> [4]>)
+
+// -----
+
+// expected-error @+1 {{layout specification must be given for all members}}
+func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
+
+// -----
+
+// expected-error @+1 {{layout specification must be given for all members}}
+func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{cannot parse type: f32 i32}}
+func @struct_type_missing_comma1(!spv.struct<f32 i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{unexpected extra tokens in layout information: ' i32'}}
+func @struct_type_missing_comma2(!spv.struct<f32 [0] i32>) -> ()
+
+// -----
+
+//  expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
+func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()