Add support for SPIR-V Struct Types. Current support is limited to
supporting only Offset decorations
PiperOrigin-RevId: 256216704
diff --git a/g3doc/Dialects/SPIR-V.md b/g3doc/Dialects/SPIR-V.md
index 6101d3b..e149540 100644
--- a/g3doc/Dialects/SPIR-V.md
+++ b/g3doc/Dialects/SPIR-V.md
@@ -152,6 +152,24 @@
!spv.rtarray<vector<4 x f32>>
```
+### Struct type
+
+This corresponds to SPIR-V [struct type][StructType]. Its syntax is
+
+``` {.ebnf}
+struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]` )?
+ (`, ` spirv-type ( ` [` integer-literal `] ` )? )* `>`
+```
+
+For Example,
+
+``` {.mlir}
+!spv.struct<f32>
+!spv.struct<f32 [0]>
+!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
+!spv.struct<f32 [0], i32 [4]>
+```
+
## Serialization
The serialization library provides two entry points, `mlir::spirv::serialize()`
@@ -168,7 +186,8 @@
[SPIR-V]: https://www.khronos.org/registry/spir-v/
[ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray
+[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
[PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer
[RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray
-[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
+[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure
[SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index adaa7d9..48c7cb3 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -145,10 +145,10 @@
unsigned getKind() const;
/// Return the LLVMContext in which this type was uniqued.
- MLIRContext *getContext();
+ MLIRContext *getContext() const;
/// Get the dialect this type is registered to.
- Dialect &getDialect();
+ Dialect &getDialect() const;
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
diff --git a/include/mlir/SPIRV/SPIRVDialect.h b/include/mlir/SPIRV/SPIRVDialect.h
index 4272a72..66345ee 100644
--- a/include/mlir/SPIRV/SPIRVDialect.h
+++ b/include/mlir/SPIRV/SPIRVDialect.h
@@ -38,22 +38,6 @@
/// Prints a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
-
-private:
- /// Parses `spec` as a type and verifies it can be used in SPIR-V types.
- Type parseAndVerifyType(StringRef spec, Location loc) const;
-
- /// Parses `spec` as a SPIR-V array type.
- Type parseArrayType(StringRef spec, Location loc) const;
-
- /// Parses `spec` as a SPIR-V pointer type.
- Type parsePointerType(StringRef spec, Location loc) const;
-
- /// Parses `spec` as a SPIR-V run-time array type.
- Type parseRuntimeArrayType(StringRef spec, Location loc) const;
-
- /// Parses `spec` as a SPIR-V image type
- Type parseImageType(StringRef spec, Location loc) const;
};
} // end namespace spirv
diff --git a/include/mlir/SPIRV/SPIRVOps.td b/include/mlir/SPIRV/SPIRVOps.td
index 3ce4f64..7951bf8 100644
--- a/include/mlir/SPIRV/SPIRVOps.td
+++ b/include/mlir/SPIRV/SPIRVOps.td
@@ -83,7 +83,7 @@
### Custom assembly form
``` {.ebnf}
- memory-access ::= `"None"` | `"Volatile"` | `"Aligned"` integer-literal
+ memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal
| `"NonTemporal"`
load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use
@@ -118,6 +118,8 @@
return "alignment";
}
}];
+
+ let opcode = 61;
}
def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
@@ -157,7 +159,7 @@
``` {.ebnf}
store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, `
- (memory-access)? : spirv-element-type
+ (`[` memory-access `]`)? `:` spirv-element-type
```
For example:
@@ -185,6 +187,8 @@
return "alignment";
}
}];
+
+ let opcode = 62;
}
def SPV_VariableOp : SPV_Op<"Variable"> {
diff --git a/include/mlir/SPIRV/SPIRVTypes.h b/include/mlir/SPIRV/SPIRVTypes.h
index 6ade59b..fbb0ce0 100644
--- a/include/mlir/SPIRV/SPIRVTypes.h
+++ b/include/mlir/SPIRV/SPIRVTypes.h
@@ -37,14 +37,16 @@
struct ImageTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
+struct StructTypeStorage;
} // namespace detail
namespace TypeKind {
enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
- ImageType,
+ Image,
Pointer,
RuntimeArray,
+ Struct,
};
}
@@ -58,9 +60,9 @@
static ArrayType get(Type elementType, int64_t elementCount);
- Type getElementType();
+ Type getElementType() const;
- int64_t getElementCount();
+ int64_t getElementCount() const;
};
// SPIR-V pointer type
@@ -73,9 +75,10 @@
static PointerType get(Type pointeeType, StorageClass storageClass);
- Type getPointeeType();
+ Type getPointeeType() const;
- StorageClass getStorageClass();
+ StorageClass getStorageClass() const;
+ StringRef getStorageClassStr() const;
};
// SPIR-V run-time array type
@@ -89,16 +92,17 @@
static RuntimeArrayType get(Type elementType);
- Type getElementType();
+ Type getElementType() const;
};
// SPIR-V image type
+// TODO(ravishankarm) : Move this in alphabetical order
class ImageType
: public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
public:
using Base::Base;
- static bool kindof(unsigned kind) { return kind == TypeKind::ImageType; }
+ static bool kindof(unsigned kind) { return kind == TypeKind::Image; }
static ImageType
get(Type elementType, Dim dim,
@@ -118,16 +122,45 @@
get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
- Type getElementType();
- Dim getDim();
- ImageDepthInfo getDepthInfo();
- ImageArrayedInfo getArrayedInfo();
- ImageSamplingInfo getSamplingInfo();
- ImageSamplerUseInfo getSamplerUseInfo();
- ImageFormat getImageFormat();
+ Type getElementType() const;
+ Dim getDim() const;
+ ImageDepthInfo getDepthInfo() const;
+ ImageArrayedInfo getArrayedInfo() const;
+ ImageSamplingInfo getSamplingInfo() const;
+ ImageSamplerUseInfo getSamplerUseInfo() const;
+ ImageFormat getImageFormat() const;
// TODO(ravishankarm): Add support for Access qualifier
};
+// SPIR-V struct type
+class StructType
+ : public Type::TypeBase<StructType, Type, detail::StructTypeStorage> {
+
+public:
+ using Base::Base;
+
+ // Layout information used for members in a struct in SPIR-V
+ //
+ // TODO(ravishankarm) : For now this only supports the offset type, so uses
+ // uint64_t value to represent the offset, with
+ // std::numeric_limit<uint64_t>::max indicating no offset. Change this to
+ // something that can hold all the information needed for different member
+ // types
+ using LayoutInfo = uint64_t;
+
+ static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
+
+ static StructType get(ArrayRef<Type> memberTypes);
+
+ static StructType get(ArrayRef<Type> memberTypes,
+ ArrayRef<LayoutInfo> layoutInfo);
+
+ size_t getNumMembers() const;
+ Type getMemberType(size_t) const;
+ bool hasLayout() const;
+ uint64_t getOffset(size_t) const;
+};
+
} // end namespace spirv
} // end namespace mlir
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 78bfc47..cd75176 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -27,9 +27,9 @@
unsigned Type::getKind() const { return impl->getKind(); }
/// Get the dialect this type is registered to.
-Dialect &Type::getDialect() { return impl->getDialect(); }
+Dialect &Type::getDialect() const { return impl->getDialect(); }
-MLIRContext *Type::getContext() { return getDialect().getContext(); }
+MLIRContext *Type::getContext() const { return getDialect().getContext(); }
unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }
diff --git a/lib/SPIRV/SPIRVDialect.cpp b/lib/SPIRV/SPIRVDialect.cpp
index 816d673..67dd549 100644
--- a/lib/SPIRV/SPIRVDialect.cpp
+++ b/lib/SPIRV/SPIRVDialect.cpp
@@ -1,26 +1,16 @@
//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
//
-// Copyright 2019 The MLIR Authors.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
+//===----------------------------------------------------------------------===//
//
// This file defines the SPIR-V dialect in MLIR.
//
//===----------------------------------------------------------------------===//
#include "mlir/SPIRV/SPIRVDialect.h"
-
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
@@ -32,8 +22,6 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/raw_ostream.h"
-#include <type_traits>
-
using namespace mlir;
using namespace mlir::spirv;
@@ -43,7 +31,7 @@
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
- addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType>();
+ addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
addOperations<
#define GET_OP_LIST
@@ -77,8 +65,9 @@
return true;
}
-static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc,
- StringRef spec) {
+static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
+ Location loc) {
+ spec = spec.trim();
auto *context = dialect.getContext();
auto type = mlir::parseType(spec.trim(), context);
if (!type) {
@@ -116,17 +105,14 @@
return type;
}
-Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const {
- return parseAndVerifyTypeImpl(*this, loc, spec);
-}
-
// element-type ::= integer-type
// | floating-point-type
// | vector-type
// | spirv-type
//
// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
-Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const {
+static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
+ Location loc) {
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
emitError(loc, "spv.array delimiter <...> mismatch");
return Type();
@@ -145,20 +131,24 @@
return Type();
}
- Type elementType = parseAndVerifyType(spec, loc);
+ Type elementType = parseAndVerifyType(dialect, spec, loc);
if (!elementType)
return Type();
return ArrayType::get(elementType, count);
}
+// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
+// methods in alphabetical order
+//
// storage-class ::= `UniformConstant`
// | `Uniform`
// | `Workgroup`
// | <and other storage classes...>
//
// pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
-Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const {
+static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec,
+ Location loc) {
if (!spec.consume_front("ptr<") || !spec.consume_back(">")) {
emitError(loc, "spv.ptr delimiter <...> mismatch");
return Type();
@@ -186,7 +176,7 @@
return Type();
}
- auto pointeeType = parseAndVerifyType(ptSpec, loc);
+ auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc);
if (!pointeeType)
return Type();
@@ -194,7 +184,8 @@
}
// runtime-array-type ::= `!spv.rtarray<` element-type `>`
-Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
+static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec,
+ Location loc) {
if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) {
emitError(loc, "spv.rtarray delimiter <...> mismatch");
return Type();
@@ -205,7 +196,7 @@
return Type();
}
- Type elementType = parseAndVerifyType(spec, loc);
+ Type elementType = parseAndVerifyType(dialect, spec, loc);
if (!elementType)
return Type();
@@ -215,8 +206,8 @@
// Specialize this function to parse each of the parameters that define an
// ImageType
template <typename ValTy>
-Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
- StringRef spec) {
+static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
+ StringRef spec) {
emitError(loc, "unexpected parameter while parsing '") << spec << "'";
return llvm::None;
}
@@ -225,15 +216,20 @@
Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
// TODO(ravishankarm): Further verify that the element type can be sampled
- return parseAndVerifyTypeImpl(dialect, loc, spec);
+ auto ty = parseAndVerifyType(dialect, spec, loc);
+ if (!ty) {
+ return llvm::None;
+ }
+ return ty;
}
template <>
Optional<Dim> parseAndVerify<Dim>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto dim = symbolizeDim(spec);
- if (!dim)
+ if (!dim) {
emitError(loc, "unknown Dim in Image type: '") << spec << "'";
+ }
return dim;
}
@@ -242,8 +238,9 @@
parseAndVerify<ImageDepthInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto depth = symbolizeImageDepthInfo(spec);
- if (!depth)
+ if (!depth) {
emitError(loc, "unknown ImageDepthInfo in Image type: '") << spec << "'";
+ }
return depth;
}
@@ -252,8 +249,9 @@
parseAndVerify<ImageArrayedInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto arrayedInfo = symbolizeImageArrayedInfo(spec);
- if (!arrayedInfo)
+ if (!arrayedInfo) {
emitError(loc, "unknown ImageArrayedInfo in Image type: '") << spec << "'";
+ }
return arrayedInfo;
}
@@ -262,8 +260,9 @@
parseAndVerify<ImageSamplingInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto samplingInfo = symbolizeImageSamplingInfo(spec);
- if (!samplingInfo)
+ if (!samplingInfo) {
emitError(loc, "unknown ImageSamplingInfo in Image type: '") << spec << "'";
+ }
return samplingInfo;
}
@@ -272,9 +271,10 @@
parseAndVerify<ImageSamplerUseInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto samplerUseInfo = symbolizeImageSamplerUseInfo(spec);
- if (!samplerUseInfo)
+ if (!samplerUseInfo) {
emitError(loc, "unknown ImageSamplerUseInfo in Image type: '")
<< spec << "'";
+ }
return samplerUseInfo;
}
@@ -283,11 +283,41 @@
Location loc,
StringRef spec) {
auto format = symbolizeImageFormat(spec);
- if (!format)
+ if (!format) {
emitError(loc, "unknown ImageFormat in Image type: '") << spec << "'";
+ }
return format;
}
+template <>
+Optional<spirv::StructType::LayoutInfo>
+parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
+ uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
+ if (!spec.consume_front("[")) {
+ emitError(loc, "expected '[' while parsing layout specification in '")
+ << spec << "'";
+ return llvm::None;
+ }
+ if (spec.consumeInteger(10, offsetVal)) {
+ emitError(
+ loc,
+ "expected unsigned integer to specify offset of member in struct: '")
+ << spec << "'";
+ return llvm::None;
+ }
+ spec = spec.trim();
+ if (!spec.consume_front("]")) {
+ emitError(loc, "missing ']' in decorations spec: '") << spec << "'";
+ return llvm::None;
+ }
+ if (spec != "") {
+ emitError(loc, "unexpected extra tokens in layout information: '")
+ << spec << "'";
+ return llvm::None;
+ }
+ return spirv::StructType::LayoutInfo{offsetVal};
+}
+
// Functor object to parse a comma separated list of specs. The function
// parseAndVerify does the actual parsing and verification of individual
// elements. This is a functor since parsing the last element of the list
@@ -350,7 +380,8 @@
// image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
// arrayed-info `,` sampling-info `,`
// sampler-use-info `,` format `>`
-Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const {
+static Type parseImageType(SPIRVDialect const &dialect, StringRef spec,
+ Location loc) {
if (!spec.consume_front("image<") || !spec.consume_back(">")) {
emitError(loc, "spv.image delimiter <...> mismatch");
return Type();
@@ -359,7 +390,7 @@
auto value =
parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo,
- ImageFormat>{}(*this, loc, spec);
+ ImageFormat>{}(dialect, loc, spec);
if (!value) {
return Type();
}
@@ -367,15 +398,151 @@
return ImageType::get(value.getValue());
}
+// Method to parse one member of a struct (including Layout information)
+static ParseResult
+parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc,
+ SmallVectorImpl<Type> &memberTypes,
+ SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
+ // Check for a '[' <layoutInfo> ']'
+ auto lastLSquare = spec.rfind('[');
+ auto typeSpec = spec.substr(0, lastLSquare);
+ auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("")
+ : spec.substr(lastLSquare));
+ auto type = parseAndVerify<Type>(dialect, loc, typeSpec);
+ if (!type) {
+ return failure();
+ }
+ memberTypes.push_back(type.getValue());
+ if (layoutSpec.empty()) {
+ return success();
+ }
+ if (layoutInfo.size() != memberTypes.size() - 1) {
+ emitError(loc, "layout specification must be given for all members");
+ return failure();
+ }
+ auto layout =
+ parseAndVerify<StructType::LayoutInfo>(dialect, loc, layoutSpec);
+ if (!layout) {
+ return failure();
+ }
+ layoutInfo.push_back(layout.getValue());
+ return success();
+}
+
+// Helper method to record the position of the corresponding '>' for every '<'
+// encountered when parsing the string left to right. The relative position of
+// '>' w.r.t to the '<' is recorded.
+static bool
+computeMatchingRAngles(Location loc, StringRef const &spec,
+ SmallVectorImpl<size_t> &matchingRAngleOffset) {
+ SmallVector<size_t, 4> openBrackets;
+ for (size_t i = 0, e = spec.size(); i != e; ++i) {
+ if (spec[i] == '<') {
+ openBrackets.push_back(i);
+ } else if (spec[i] == '>') {
+ if (openBrackets.empty()) {
+ emitError(loc, "unbalanced '<' in '") << spec << "'";
+ return false;
+ }
+ matchingRAngleOffset.push_back(i - openBrackets.pop_back_val());
+ }
+ }
+ return true;
+}
+
+static ParseResult
+parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc,
+ ArrayRef<size_t> matchingRAngleOffset,
+ SmallVectorImpl<Type> &memberTypes,
+ SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
+ // Check if the occurrence of ',' or '<' is before. If former, split using
+ // ','. If latter, split using matching '>' to get the entire type
+ // description
+ auto firstComma = spec.find(',');
+ auto firstLAngle = spec.find('<');
+ if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) {
+ return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo);
+ }
+ if (firstLAngle == StringRef::npos || firstComma < firstLAngle) {
+ // Parse the type before the ','
+ if (parseStructElement(dialect, spec.substr(0, firstComma), loc,
+ memberTypes, layoutInfo)) {
+ return failure();
+ }
+ return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc,
+ matchingRAngleOffset, memberTypes, layoutInfo);
+ }
+ auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle;
+ // Find the next ',' or '>'
+ auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size());
+ if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes,
+ layoutInfo)) {
+ return failure();
+ }
+ auto rest = spec.substr(endLoc + 1).ltrim();
+ if (rest.empty()) {
+ return success();
+ }
+ if (rest.front() == ',') {
+ return parseStructHelper(
+ dialect, rest.drop_front().trim(), loc,
+ ArrayRef<size_t>(std::next(matchingRAngleOffset.begin()),
+ matchingRAngleOffset.end()),
+ memberTypes, layoutInfo);
+ }
+ emitError(loc, "unexpected string : '") << rest << "'";
+ return failure();
+}
+
+// struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)?
+// (`, ` spirv-type ( ` [` integer-literal `] ` )? )*
+static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
+ Location loc) {
+ if (!spec.consume_front("struct<") || !spec.consume_back(">")) {
+ emitError(loc, "spv.struct delimiter <...> mismatch");
+ return Type();
+ }
+
+ if (spec.trim().empty()) {
+ emitError(loc, "expected SPIR-V type");
+ return Type();
+ }
+
+ SmallVector<Type, 4> memberTypes;
+ SmallVector<StructType::LayoutInfo, 4> layoutInfo;
+ SmallVector<size_t, 4> matchingRAngleOffset;
+ if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) ||
+ parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes,
+ layoutInfo)) {
+ return Type();
+ }
+ if (layoutInfo.empty()) {
+ return StructType::get(memberTypes);
+ }
+ if (memberTypes.size() != layoutInfo.size()) {
+ emitError(loc, "layout specification must be given for all members");
+ return Type();
+ }
+ return StructType::get(memberTypes, layoutInfo);
+}
+
+// spirv-type ::= array-type
+// | element-type
+// | image-type
+// | pointer-type
+// | runtime-array-type
+// | struct-type
Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
if (spec.startswith("array"))
- return parseArrayType(spec, loc);
+ return parseArrayType(*this, spec, loc);
if (spec.startswith("image"))
- return parseImageType(spec, loc);
+ return parseImageType(*this, spec, loc);
if (spec.startswith("ptr"))
- return parsePointerType(spec, loc);
+ return parsePointerType(*this, spec, loc);
if (spec.startswith("rtarray"))
- return parseRuntimeArrayType(spec, loc);
+ return parseRuntimeArrayType(*this, spec, loc);
+ if (spec.startswith("struct"))
+ return parseStructType(*this, spec, loc);
emitError(loc, "unknown SPIR-V type: ") << spec;
return Type();
@@ -408,6 +575,19 @@
<< stringifyImageFormat(type.getImageFormat()) << ">";
}
+static void print(StructType type, llvm::raw_ostream &os) {
+ os << "struct<";
+ std::string sep = "";
+ for (size_t i = 0, e = type.getNumMembers(); i != e; ++i) {
+ os << sep << type.getMemberType(i);
+ if (type.hasLayout()) {
+ os << " [" << type.getOffset(i) << "]";
+ }
+ sep = ", ";
+ }
+ os << ">";
+}
+
void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
switch (type.getKind()) {
case TypeKind::Array:
@@ -419,9 +599,12 @@
case TypeKind::RuntimeArray:
print(type.cast<RuntimeArrayType>(), os);
return;
- case TypeKind::ImageType:
+ case TypeKind::Image:
print(type.cast<ImageType>(), os);
return;
+ case TypeKind::Struct:
+ print(type.cast<StructType>(), os);
+ return;
default:
llvm_unreachable("unhandled SPIR-V type");
}
diff --git a/lib/SPIRV/SPIRVTypes.cpp b/lib/SPIRV/SPIRVTypes.cpp
index 23acd65..f62f3e1 100644
--- a/lib/SPIRV/SPIRVTypes.cpp
+++ b/lib/SPIRV/SPIRVTypes.cpp
@@ -56,9 +56,9 @@
elementCount);
}
-Type ArrayType::getElementType() { return getImpl()->elementType; }
+Type ArrayType::getElementType() const { return getImpl()->elementType; }
-int64_t ArrayType::getElementCount() { return getImpl()->elementCount; }
+int64_t ArrayType::getElementCount() const { return getImpl()->elementCount; }
//===----------------------------------------------------------------------===//
// ImageType
@@ -216,28 +216,32 @@
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
value) {
- return Base::get(std::get<0>(value).getContext(), TypeKind::ImageType, value);
+ return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
}
-Type ImageType::getElementType() { return getImpl()->elementType; }
+Type ImageType::getElementType() const { return getImpl()->elementType; }
-Dim ImageType::getDim() { return getImpl()->getDim(); }
+Dim ImageType::getDim() const { return getImpl()->getDim(); }
-ImageDepthInfo ImageType::getDepthInfo() { return getImpl()->getDepthInfo(); }
+ImageDepthInfo ImageType::getDepthInfo() const {
+ return getImpl()->getDepthInfo();
+}
-ImageArrayedInfo ImageType::getArrayedInfo() {
+ImageArrayedInfo ImageType::getArrayedInfo() const {
return getImpl()->getArrayedInfo();
}
-ImageSamplingInfo ImageType::getSamplingInfo() {
+ImageSamplingInfo ImageType::getSamplingInfo() const {
return getImpl()->getSamplingInfo();
}
-ImageSamplerUseInfo ImageType::getSamplerUseInfo() {
+ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
return getImpl()->getSamplerUseInfo();
}
-ImageFormat ImageType::getImageFormat() { return getImpl()->getImageFormat(); }
+ImageFormat ImageType::getImageFormat() const {
+ return getImpl()->getImageFormat();
+}
//===----------------------------------------------------------------------===//
// PointerType
@@ -274,12 +278,16 @@
storageClass);
}
-Type PointerType::getPointeeType() { return getImpl()->pointeeType; }
+Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
-StorageClass PointerType::getStorageClass() {
+StorageClass PointerType::getStorageClass() const {
return getImpl()->getStorageClass();
}
+StringRef PointerType::getStorageClassStr() const {
+ return stringifyStorageClass(getStorageClass());
+}
+
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
@@ -305,4 +313,88 @@
elementType);
}
-Type RuntimeArrayType::getElementType() { return getImpl()->elementType; }
+Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
+
+//===----------------------------------------------------------------------===//
+// StructType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::StructTypeStorage : public TypeStorage {
+ StructTypeStorage(unsigned numMembers, Type const *memberTypes,
+ StructType::LayoutInfo const *layoutInfo)
+ : TypeStorage(numMembers), memberTypes(memberTypes),
+ layoutInfo(layoutInfo) {}
+
+ using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>>;
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(getMemberTypes(), getLayoutInfo());
+ }
+
+ static StructTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ ArrayRef<Type> keyTypes = key.first;
+
+ // Copy the member type and layout information into the bump pointer
+ auto typesList = allocator.copyInto(keyTypes).data();
+
+ const StructType::LayoutInfo *layoutInfoList = nullptr;
+ if (!key.second.empty()) {
+ ArrayRef<StructType::LayoutInfo> keyLayoutInfo = key.second;
+ assert(keyLayoutInfo.size() == keyTypes.size() &&
+ "size of layout information must be same as the size of number of "
+ "elements");
+ layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>())
+ StructTypeStorage(keyTypes.size(), typesList, layoutInfoList);
+ }
+
+ ArrayRef<Type> getMemberTypes() const {
+ return ArrayRef<Type>(memberTypes, getSubclassData());
+ }
+
+ ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
+ if (layoutInfo) {
+ return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
+ }
+ return ArrayRef<StructType::LayoutInfo>(nullptr, size_t(0));
+ }
+
+ Type const *memberTypes;
+ StructType::LayoutInfo const *layoutInfo;
+};
+
+StructType StructType::get(ArrayRef<Type> memberTypes) {
+ assert(!memberTypes.empty() && "Struct needs at least one member type");
+ ArrayRef<StructType::LayoutInfo> noLayout(nullptr, size_t(0));
+ return Base::get(memberTypes[0].getContext(), TypeKind::Struct, memberTypes,
+ noLayout);
+}
+
+StructType StructType::get(ArrayRef<Type> memberTypes,
+ ArrayRef<StructType::LayoutInfo> layoutInfo) {
+ assert(!memberTypes.empty() && "Struct needs at least one member type");
+ return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
+ memberTypes, layoutInfo);
+}
+
+size_t StructType::getNumMembers() const {
+ return getImpl()->getSubclassData();
+}
+
+Type StructType::getMemberType(size_t i) const {
+ assert(
+ getNumMembers() > i &&
+ "element index is more than number of members of the SPIR-V StructType");
+ return getImpl()->memberTypes[i];
+}
+
+bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
+
+uint64_t StructType::getOffset(size_t i) const {
+ assert(
+ getNumMembers() > i &&
+ "element index is more than number of members of the SPIR-V StructType");
+ return getImpl()->layoutInfo[i];
+}
diff --git a/test/SPIRV/types.mlir b/test/SPIRV/types.mlir
index 857871a..ffd0fb3 100644
--- a/test/SPIRV/types.mlir
+++ b/test/SPIRV/types.mlir
@@ -200,3 +200,51 @@
// expected-error @+1 {{expected more parameters for image type 'SamplerUnknown Unknown'}}
func @image_parameters_nocomma_5(!spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown Unknown>) -> ()
+// -----
+
+//===----------------------------------------------------------------------===//
+// StructType
+//===----------------------------------------------------------------------===//
+
+// CHECK: func @struct_type(!spv.struct<f32>)
+func @struct_type(!spv.struct<f32>) -> ()
+
+// CHECK: func @struct_type2(!spv.struct<f32 [0]>)
+func @struct_type2(!spv.struct<f32 [0]>) -> ()
+
+// CHECK: func @struct_type_simple(!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>)
+func @struct_type_simple(!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>) -> ()
+
+// CHECK: func @struct_type_with_offset(!spv.struct<f32 [0], i32 [4]>)
+func @struct_type_with_offset(!spv.struct<f32 [0], i32 [4]>) -> ()
+
+// CHECK: func @nested_struct(!spv.struct<f32, !spv.struct<f32, i32>>)
+func @nested_struct(!spv.struct<f32, !spv.struct<f32, i32>>)
+
+// CHECK: func @nested_struct_with_offset(!spv.struct<f32 [0], !spv.struct<f32 [0], i32 [4]> [4]>)
+func @nested_struct_with_offset(!spv.struct<f32 [0], !spv.struct<f32 [0], i32 [4]> [4]>)
+
+// -----
+
+// expected-error @+1 {{layout specification must be given for all members}}
+func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
+
+// -----
+
+// expected-error @+1 {{layout specification must be given for all members}}
+func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{cannot parse type: f32 i32}}
+func @struct_type_missing_comma1(!spv.struct<f32 i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{unexpected extra tokens in layout information: ' i32'}}
+func @struct_type_missing_comma2(!spv.struct<f32 [0] i32>) -> ()
+
+// -----
+
+// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
+func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()