[spirv] Use I32EnumAttr for enum attributes
This saves us the excessive string conversions and comparisons in
verification and transformation and scopes them only to parsing
and printing, which are meant for I/O so string conversions should
be fine.
In order to do this, changed the custom assembly format of
spv.module regarding addressing model and memory model.
PiperOrigin-RevId: 256149856
diff --git a/include/mlir/SPIRV/SPIRVBase.td b/include/mlir/SPIRV/SPIRVBase.td
index 52f8eef..9d9cafc 100644
--- a/include/mlir/SPIRV/SPIRVBase.td
+++ b/include/mlir/SPIRV/SPIRVBase.td
@@ -102,80 +102,82 @@
// Begin enum section. Generated from SPIR-V spec; DO NOT MODIFY!
-def SPV_AM_Logical : EnumAttrCase<"Logical", 0>;
-def SPV_AM_Physical32 : EnumAttrCase<"Physical32", 1>;
-def SPV_AM_Physical64 : EnumAttrCase<"Physical64", 2>;
-def SPV_AM_PhysicalStorageBuffer64EXT : EnumAttrCase<"PhysicalStorageBuffer64EXT", 5348>;
+def SPV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
+def SPV_AM_Physical32 : I32EnumAttrCase<"Physical32", 1>;
+def SPV_AM_Physical64 : I32EnumAttrCase<"Physical64", 2>;
+def SPV_AM_PhysicalStorageBuffer64EXT : I32EnumAttrCase<"PhysicalStorageBuffer64EXT", 5348>;
def SPV_AddressingModelAttr :
- EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
+ I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
SPV_AM_PhysicalStorageBuffer64EXT
]> {
+ let returnType = "::mlir::spirv::AddressingModel";
+ let convertFromStorage = "static_cast<::mlir::spirv::AddressingModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_D_1D : EnumAttrCase<"1D", 0>;
-def SPV_D_2D : EnumAttrCase<"2D", 1>;
-def SPV_D_3D : EnumAttrCase<"3D", 2>;
-def SPV_D_Cube : EnumAttrCase<"Cube", 3>;
-def SPV_D_Rect : EnumAttrCase<"Rect", 4>;
-def SPV_D_Buffer : EnumAttrCase<"Buffer", 5>;
-def SPV_D_SubpassData : EnumAttrCase<"SubpassData", 6>;
+def SPV_D_1D : I32EnumAttrCase<"1D", 0>;
+def SPV_D_2D : I32EnumAttrCase<"2D", 1>;
+def SPV_D_3D : I32EnumAttrCase<"3D", 2>;
+def SPV_D_Cube : I32EnumAttrCase<"Cube", 3>;
+def SPV_D_Rect : I32EnumAttrCase<"Rect", 4>;
+def SPV_D_Buffer : I32EnumAttrCase<"Buffer", 5>;
+def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6>;
def SPV_DimAttr :
- EnumAttr<"Dim", "valid SPIR-V Dim", [
+ I32EnumAttr<"Dim", "valid SPIR-V Dim", [
SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
SPV_D_SubpassData
]> {
+ let returnType = "::mlir::spirv::Dim";
+ let convertFromStorage = "static_cast<::mlir::spirv::Dim>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_IF_Unknown : EnumAttrCase<"Unknown", 0>;
-def SPV_IF_Rgba32f : EnumAttrCase<"Rgba32f", 1>;
-def SPV_IF_Rgba16f : EnumAttrCase<"Rgba16f", 2>;
-def SPV_IF_R32f : EnumAttrCase<"R32f", 3>;
-def SPV_IF_Rgba8 : EnumAttrCase<"Rgba8", 4>;
-def SPV_IF_Rgba8Snorm : EnumAttrCase<"Rgba8Snorm", 5>;
-def SPV_IF_Rg32f : EnumAttrCase<"Rg32f", 6>;
-def SPV_IF_Rg16f : EnumAttrCase<"Rg16f", 7>;
-def SPV_IF_R11fG11fB10f : EnumAttrCase<"R11fG11fB10f", 8>;
-def SPV_IF_R16f : EnumAttrCase<"R16f", 9>;
-def SPV_IF_Rgba16 : EnumAttrCase<"Rgba16", 10>;
-def SPV_IF_Rgb10A2 : EnumAttrCase<"Rgb10A2", 11>;
-def SPV_IF_Rg16 : EnumAttrCase<"Rg16", 12>;
-def SPV_IF_Rg8 : EnumAttrCase<"Rg8", 13>;
-def SPV_IF_R16 : EnumAttrCase<"R16", 14>;
-def SPV_IF_R8 : EnumAttrCase<"R8", 15>;
-def SPV_IF_Rgba16Snorm : EnumAttrCase<"Rgba16Snorm", 16>;
-def SPV_IF_Rg16Snorm : EnumAttrCase<"Rg16Snorm", 17>;
-def SPV_IF_Rg8Snorm : EnumAttrCase<"Rg8Snorm", 18>;
-def SPV_IF_R16Snorm : EnumAttrCase<"R16Snorm", 19>;
-def SPV_IF_R8Snorm : EnumAttrCase<"R8Snorm", 20>;
-def SPV_IF_Rgba32i : EnumAttrCase<"Rgba32i", 21>;
-def SPV_IF_Rgba16i : EnumAttrCase<"Rgba16i", 22>;
-def SPV_IF_Rgba8i : EnumAttrCase<"Rgba8i", 23>;
-def SPV_IF_R32i : EnumAttrCase<"R32i", 24>;
-def SPV_IF_Rg32i : EnumAttrCase<"Rg32i", 25>;
-def SPV_IF_Rg16i : EnumAttrCase<"Rg16i", 26>;
-def SPV_IF_Rg8i : EnumAttrCase<"Rg8i", 27>;
-def SPV_IF_R16i : EnumAttrCase<"R16i", 28>;
-def SPV_IF_R8i : EnumAttrCase<"R8i", 29>;
-def SPV_IF_Rgba32ui : EnumAttrCase<"Rgba32ui", 30>;
-def SPV_IF_Rgba16ui : EnumAttrCase<"Rgba16ui", 31>;
-def SPV_IF_Rgba8ui : EnumAttrCase<"Rgba8ui", 32>;
-def SPV_IF_R32ui : EnumAttrCase<"R32ui", 33>;
-def SPV_IF_Rgb10a2ui : EnumAttrCase<"Rgb10a2ui", 34>;
-def SPV_IF_Rg32ui : EnumAttrCase<"Rg32ui", 35>;
-def SPV_IF_Rg16ui : EnumAttrCase<"Rg16ui", 36>;
-def SPV_IF_Rg8ui : EnumAttrCase<"Rg8ui", 37>;
-def SPV_IF_R16ui : EnumAttrCase<"R16ui", 38>;
-def SPV_IF_R8ui : EnumAttrCase<"R8ui", 39>;
+def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>;
+def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1>;
+def SPV_IF_Rgba16f : I32EnumAttrCase<"Rgba16f", 2>;
+def SPV_IF_R32f : I32EnumAttrCase<"R32f", 3>;
+def SPV_IF_Rgba8 : I32EnumAttrCase<"Rgba8", 4>;
+def SPV_IF_Rgba8Snorm : I32EnumAttrCase<"Rgba8Snorm", 5>;
+def SPV_IF_Rg32f : I32EnumAttrCase<"Rg32f", 6>;
+def SPV_IF_Rg16f : I32EnumAttrCase<"Rg16f", 7>;
+def SPV_IF_R11fG11fB10f : I32EnumAttrCase<"R11fG11fB10f", 8>;
+def SPV_IF_R16f : I32EnumAttrCase<"R16f", 9>;
+def SPV_IF_Rgba16 : I32EnumAttrCase<"Rgba16", 10>;
+def SPV_IF_Rgb10A2 : I32EnumAttrCase<"Rgb10A2", 11>;
+def SPV_IF_Rg16 : I32EnumAttrCase<"Rg16", 12>;
+def SPV_IF_Rg8 : I32EnumAttrCase<"Rg8", 13>;
+def SPV_IF_R16 : I32EnumAttrCase<"R16", 14>;
+def SPV_IF_R8 : I32EnumAttrCase<"R8", 15>;
+def SPV_IF_Rgba16Snorm : I32EnumAttrCase<"Rgba16Snorm", 16>;
+def SPV_IF_Rg16Snorm : I32EnumAttrCase<"Rg16Snorm", 17>;
+def SPV_IF_Rg8Snorm : I32EnumAttrCase<"Rg8Snorm", 18>;
+def SPV_IF_R16Snorm : I32EnumAttrCase<"R16Snorm", 19>;
+def SPV_IF_R8Snorm : I32EnumAttrCase<"R8Snorm", 20>;
+def SPV_IF_Rgba32i : I32EnumAttrCase<"Rgba32i", 21>;
+def SPV_IF_Rgba16i : I32EnumAttrCase<"Rgba16i", 22>;
+def SPV_IF_Rgba8i : I32EnumAttrCase<"Rgba8i", 23>;
+def SPV_IF_R32i : I32EnumAttrCase<"R32i", 24>;
+def SPV_IF_Rg32i : I32EnumAttrCase<"Rg32i", 25>;
+def SPV_IF_Rg16i : I32EnumAttrCase<"Rg16i", 26>;
+def SPV_IF_Rg8i : I32EnumAttrCase<"Rg8i", 27>;
+def SPV_IF_R16i : I32EnumAttrCase<"R16i", 28>;
+def SPV_IF_R8i : I32EnumAttrCase<"R8i", 29>;
+def SPV_IF_Rgba32ui : I32EnumAttrCase<"Rgba32ui", 30>;
+def SPV_IF_Rgba16ui : I32EnumAttrCase<"Rgba16ui", 31>;
+def SPV_IF_Rgba8ui : I32EnumAttrCase<"Rgba8ui", 32>;
+def SPV_IF_R32ui : I32EnumAttrCase<"R32ui", 33>;
+def SPV_IF_Rgb10a2ui : I32EnumAttrCase<"Rgb10a2ui", 34>;
+def SPV_IF_Rg32ui : I32EnumAttrCase<"Rg32ui", 35>;
+def SPV_IF_Rg16ui : I32EnumAttrCase<"Rg16ui", 36>;
+def SPV_IF_Rg8ui : I32EnumAttrCase<"Rg8ui", 37>;
+def SPV_IF_R16ui : I32EnumAttrCase<"R16ui", 38>;
+def SPV_IF_R8ui : I32EnumAttrCase<"R8ui", 39>;
def SPV_ImageFormatAttr :
- EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
+ I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8,
SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f,
SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8,
@@ -186,64 +188,67 @@
SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui,
SPV_IF_R8ui
]> {
+ let returnType = "::mlir::spirv::ImageFormat";
+ let convertFromStorage = "static_cast<::mlir::spirv::ImageFormat>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_MA_None : EnumAttrCase<"None", 0x0000>;
-def SPV_MA_Volatile : EnumAttrCase<"Volatile", 0x0001>;
-def SPV_MA_Aligned : EnumAttrCase<"Aligned", 0x0002>;
-def SPV_MA_Nontemporal : EnumAttrCase<"Nontemporal", 0x0004>;
-def SPV_MA_MakePointerAvailableKHR : EnumAttrCase<"MakePointerAvailableKHR", 0x0008>;
-def SPV_MA_MakePointerVisibleKHR : EnumAttrCase<"MakePointerVisibleKHR", 0x0010>;
-def SPV_MA_NonPrivatePointerKHR : EnumAttrCase<"NonPrivatePointerKHR", 0x0020>;
+def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>;
+def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>;
+def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>;
+def SPV_MA_Nontemporal : I32EnumAttrCase<"Nontemporal", 0x0004>;
+def SPV_MA_MakePointerAvailableKHR : I32EnumAttrCase<"MakePointerAvailableKHR", 0x0008>;
+def SPV_MA_MakePointerVisibleKHR : I32EnumAttrCase<"MakePointerVisibleKHR", 0x0010>;
+def SPV_MA_NonPrivatePointerKHR : I32EnumAttrCase<"NonPrivatePointerKHR", 0x0020>;
def SPV_MemoryAccessAttr :
- EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
+ I32EnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
SPV_MA_MakePointerAvailableKHR, SPV_MA_MakePointerVisibleKHR,
SPV_MA_NonPrivatePointerKHR
]> {
+ let returnType = "::mlir::spirv::MemoryAccess";
+ let convertFromStorage = "static_cast<::mlir::spirv::MemoryAccess>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_MM_Simple : EnumAttrCase<"Simple", 0>;
-def SPV_MM_GLSL450 : EnumAttrCase<"GLSL450", 1>;
-def SPV_MM_OpenCL : EnumAttrCase<"OpenCL", 2>;
-def SPV_MM_VulkanKHR : EnumAttrCase<"VulkanKHR", 3>;
+def SPV_MM_Simple : I32EnumAttrCase<"Simple", 0>;
+def SPV_MM_GLSL450 : I32EnumAttrCase<"GLSL450", 1>;
+def SPV_MM_OpenCL : I32EnumAttrCase<"OpenCL", 2>;
+def SPV_MM_VulkanKHR : I32EnumAttrCase<"VulkanKHR", 3>;
def SPV_MemoryModelAttr :
- EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
+ I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_VulkanKHR
]> {
+ let returnType = "::mlir::spirv::MemoryModel";
+ let convertFromStorage = "static_cast<::mlir::spirv::MemoryModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_SC_UniformConstant : EnumAttrCase<"UniformConstant", 0>;
-def SPV_SC_Input : EnumAttrCase<"Input", 1>;
-def SPV_SC_Uniform : EnumAttrCase<"Uniform", 2>;
-def SPV_SC_Output : EnumAttrCase<"Output", 3>;
-def SPV_SC_Workgroup : EnumAttrCase<"Workgroup", 4>;
-def SPV_SC_CrossWorkgroup : EnumAttrCase<"CrossWorkgroup", 5>;
-def SPV_SC_Private : EnumAttrCase<"Private", 6>;
-def SPV_SC_Function : EnumAttrCase<"Function", 7>;
-def SPV_SC_Generic : EnumAttrCase<"Generic", 8>;
-def SPV_SC_PushConstant : EnumAttrCase<"PushConstant", 9>;
-def SPV_SC_AtomicCounter : EnumAttrCase<"AtomicCounter", 10>;
-def SPV_SC_Image : EnumAttrCase<"Image", 11>;
-def SPV_SC_StorageBuffer : EnumAttrCase<"StorageBuffer", 12>;
-def SPV_SC_CallableDataNV : EnumAttrCase<"CallableDataNV", 5328>;
-def SPV_SC_IncomingCallableDataNV : EnumAttrCase<"IncomingCallableDataNV", 5329>;
-def SPV_SC_RayPayloadNV : EnumAttrCase<"RayPayloadNV", 5338>;
-def SPV_SC_HitAttributeNV : EnumAttrCase<"HitAttributeNV", 5339>;
-def SPV_SC_IncomingRayPayloadNV : EnumAttrCase<"IncomingRayPayloadNV", 5342>;
-def SPV_SC_ShaderRecordBufferNV : EnumAttrCase<"ShaderRecordBufferNV", 5343>;
-def SPV_SC_PhysicalStorageBufferEXT : EnumAttrCase<"PhysicalStorageBufferEXT", 5349>;
+def SPV_SC_UniformConstant : I32EnumAttrCase<"UniformConstant", 0>;
+def SPV_SC_Input : I32EnumAttrCase<"Input", 1>;
+def SPV_SC_Uniform : I32EnumAttrCase<"Uniform", 2>;
+def SPV_SC_Output : I32EnumAttrCase<"Output", 3>;
+def SPV_SC_Workgroup : I32EnumAttrCase<"Workgroup", 4>;
+def SPV_SC_CrossWorkgroup : I32EnumAttrCase<"CrossWorkgroup", 5>;
+def SPV_SC_Private : I32EnumAttrCase<"Private", 6>;
+def SPV_SC_Function : I32EnumAttrCase<"Function", 7>;
+def SPV_SC_Generic : I32EnumAttrCase<"Generic", 8>;
+def SPV_SC_PushConstant : I32EnumAttrCase<"PushConstant", 9>;
+def SPV_SC_AtomicCounter : I32EnumAttrCase<"AtomicCounter", 10>;
+def SPV_SC_Image : I32EnumAttrCase<"Image", 11>;
+def SPV_SC_StorageBuffer : I32EnumAttrCase<"StorageBuffer", 12>;
+def SPV_SC_CallableDataNV : I32EnumAttrCase<"CallableDataNV", 5328>;
+def SPV_SC_IncomingCallableDataNV : I32EnumAttrCase<"IncomingCallableDataNV", 5329>;
+def SPV_SC_RayPayloadNV : I32EnumAttrCase<"RayPayloadNV", 5338>;
+def SPV_SC_HitAttributeNV : I32EnumAttrCase<"HitAttributeNV", 5339>;
+def SPV_SC_IncomingRayPayloadNV : I32EnumAttrCase<"IncomingRayPayloadNV", 5342>;
+def SPV_SC_ShaderRecordBufferNV : I32EnumAttrCase<"ShaderRecordBufferNV", 5343>;
+def SPV_SC_PhysicalStorageBufferEXT : I32EnumAttrCase<"PhysicalStorageBufferEXT", 5349>;
def SPV_StorageClassAttr :
- EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
+ I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output,
SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function,
SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image,
@@ -251,54 +256,51 @@
SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV,
SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBufferEXT
]> {
+ let returnType = "::mlir::spirv::StorageClass";
+ let convertFromStorage = "static_cast<::mlir::spirv::StorageClass>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
// End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
// Enums added manually that are not part of SPIRV spec
-def SPV_IDI_NoDepth : EnumAttrCase<"NoDepth", 0>;
-def SPV_IDI_IsDepth : EnumAttrCase<"IsDepth", 1>;
-def SPV_IDI_DepthUnknown : EnumAttrCase<"DepthUnknown", 2>;
+def SPV_IDI_NoDepth : I32EnumAttrCase<"NoDepth", 0>;
+def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>;
+def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>;
def SPV_DepthAttr :
- EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",[
- SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]> {
+ I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",
+ [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]> {
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_IAI_NonArrayed : EnumAttrCase<"NonArrayed", 0>;
-def SPV_IAI_Arrayed : EnumAttrCase<"Arrayed", 1>;
+def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>;
+def SPV_IAI_Arrayed : I32EnumAttrCase<"Arrayed", 1>;
def SPV_ArrayedAttr :
- EnumAttr<"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification", [
- SPV_IAI_NonArrayed, SPV_IAI_Arrayed]> {
+ I32EnumAttr<"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
+ [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]> {
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_ISI_SingleSampled : EnumAttrCase<"SingleSampled", 0>;
-def SPV_ISI_MultiSampled : EnumAttrCase<"MultiSampled", 1>;
+def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
+def SPV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>;
def SPV_SamplingAttr:
- EnumAttr<"ImageSamplingInfo", "valid SPIR-V Image Sampling specification", [
- SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]> {
+ I32EnumAttr<"ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
+ [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]> {
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
-def SPV_ISUI_SamplerUnknown : EnumAttrCase<"SamplerUnknown", 0>;
-def SPV_ISUI_NeedSampler : EnumAttrCase<"NeedSampler", 1>;
-def SPV_ISUI_NoSampler : EnumAttrCase<"NoSampler", 2>;
+def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>;
+def SPV_ISUI_NeedSampler : I32EnumAttrCase<"NeedSampler", 1>;
+def SPV_ISUI_NoSampler : I32EnumAttrCase<"NoSampler", 2>;
def SPV_SamplerUseAttr:
- EnumAttr<"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification", [
- SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]> {
+ I32EnumAttr<"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
+ [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]> {
let cppNamespace = "::mlir::spirv";
- let underlyingType = "uint32_t";
}
//===----------------------------------------------------------------------===//
diff --git a/include/mlir/SPIRV/SPIRVOps.h b/include/mlir/SPIRV/SPIRVOps.h
index 00f8548..e0345fb 100644
--- a/include/mlir/SPIRV/SPIRVOps.h
+++ b/include/mlir/SPIRV/SPIRVOps.h
@@ -23,6 +23,7 @@
#define MLIR_SPIRV_SPIRVOPS_H_
#include "mlir/IR/OpDefinition.h"
+#include "mlir/SPIRV/SPIRVTypes.h"
namespace mlir {
namespace spirv {
diff --git a/include/mlir/SPIRV/SPIRVStructureOps.td b/include/mlir/SPIRV/SPIRVStructureOps.td
index fed417f..41b8238 100644
--- a/include/mlir/SPIRV/SPIRVStructureOps.td
+++ b/include/mlir/SPIRV/SPIRVStructureOps.td
@@ -48,6 +48,31 @@
This op has only one region, which only contains one block. The block
must be terminated via the `spv._module_end` op.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ addressing-model ::= `"Logical"` | `"Physical32"` | `"Physical64"`
+ memory-model ::= `"Simple"` | `"GLSL450"` | `"OpenCL"` | `"VulkanKHR"`
+ spv-module-op ::= `spv.module` addressing-model memory-model
+ region
+ (`attributes` attribute-dict)?
+ ```
+
+ For example:
+
+ ```
+ spv.module "Logical" "VulkanKHR" { }
+
+ spv.module "Logical" "VulkanKHR" {
+ func @do_nothing() -> () {
+ spv.Return
+ }
+ } attributes {
+ capability = ["Shader"],
+ extension = ["SPV_KHR_16bit_storage"]
+ }
+ ```
}];
let arguments = (ins
diff --git a/include/mlir/SPIRV/SPIRVTypes.h b/include/mlir/SPIRV/SPIRVTypes.h
index 80370e8..6ade59b 100644
--- a/include/mlir/SPIRV/SPIRVTypes.h
+++ b/include/mlir/SPIRV/SPIRVTypes.h
@@ -76,7 +76,6 @@
Type getPointeeType();
StorageClass getStorageClass();
- StringRef getStorageClassStr();
};
// SPIR-V run-time array type
diff --git a/lib/SPIRV/SPIRVOps.cpp b/lib/SPIRV/SPIRVOps.cpp
index 1fd27c4..dbc08fe 100644
--- a/lib/SPIRV/SPIRVOps.cpp
+++ b/lib/SPIRV/SPIRVOps.cpp
@@ -28,8 +28,11 @@
using namespace mlir;
+// TODO(antiagainst): generate these strings using ODS.
+static constexpr const char kAddressingModelAttrName[] = "addressing_model";
static constexpr const char kBindingAttrName[] = "binding";
static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
+static constexpr const char kMemoryModelAttrName[] = "memory_model";
static constexpr const char kStorageClassAttrName[] = "storage_class";
static constexpr const char kValueAttrName[] = "value";
@@ -37,6 +40,15 @@
// Common utility functions
//===----------------------------------------------------------------------===//
+template <typename Dst, typename Src>
+inline Dst bitwiseCast(Src source) noexcept {
+ Dst dest;
+ static_assert(sizeof(source) == sizeof(dest),
+ "bitwiseCast requires same source and destination bitwidth");
+ std::memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
static ParseResult parseStorageClassAttribute(spirv::StorageClass &storageClass,
OpAsmParser *parser,
OperationState *state) {
@@ -54,7 +66,7 @@
auto storageClassOptional = spirv::symbolizeStorageClass(
storageClassAttr.cast<StringAttr>().getValue());
if (!storageClassOptional) {
- return parser->emitError(loc, "invalid storage class specifier :")
+ return parser->emitError(loc, "invalid storage class specifier: ")
<< storageClassAttr;
}
storageClass = storageClassOptional.getValue();
@@ -70,13 +82,13 @@
return success();
}
+ StringRef memAccessAttrName = LoadStoreOpTy::getMemoryAccessAttrName();
Attribute memAccessAttr;
+ SmallVector<NamedAttribute, 1> attrs;
auto loc = parser->getCurrentLocation();
- if (parser->parseAttribute(memAccessAttr,
- LoadStoreOpTy::getMemoryAccessAttrName(),
- state->attributes)) {
+
+ if (parser->parseAttribute(memAccessAttr, memAccessAttrName, attrs))
return failure();
- }
// Check that this is a memory attribute
if (!memAccessAttr.isa<StringAttr>()) {
return parser->emitError(loc, "expected a string memory access specifier");
@@ -84,9 +96,12 @@
auto memAccessOptional =
spirv::symbolizeMemoryAccess(memAccessAttr.cast<StringAttr>().getValue());
if (!memAccessOptional) {
- return parser->emitError(loc, "invalid memory access specifier :")
+ return parser->emitError(loc, "invalid memory access specifier: ")
<< memAccessAttr;
}
+ state->addAttribute(memAccessAttrName,
+ parser->getBuilder().getI32IntegerAttr(
+ bitwiseCast<int32_t>(*memAccessOptional)));
if (auto memAccess =
memAccessOptional.getValue() == spirv::MemoryAccess::Aligned) {
@@ -115,9 +130,9 @@
printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter *printer,
SmallVectorImpl<StringRef> &elidedAttrs) {
// Print optional memory access attribute.
- if (auto memaccess = loadStoreOp.memory_access()) {
+ if (auto memAccess = loadStoreOp.memory_access()) {
elidedAttrs.push_back(LoadStoreOpTy::getMemoryAccessAttrName());
- *printer << " [\"" << memaccess << "\"";
+ *printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
// Print integer alignment attribute.
if (auto alignment = loadStoreOp.alignment()) {
@@ -134,10 +149,10 @@
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = loadStoreOp.getOperation();
- auto memaccessAttr = op->getAttr(LoadStoreOpTy::getMemoryAccessAttrName());
- if (!memaccessAttr) {
- // Alignment attribute shouldnt be present if memory access attribute is not
- // present.
+ auto memAccessAttr = op->getAttr(LoadStoreOpTy::getMemoryAccessAttrName());
+ if (!memAccessAttr) {
+ // Alignment attribute shouldn't be present if memory access attribute is
+ // not present.
if (op->getAttr(LoadStoreOpTy::getAlignmentAttrName())) {
return loadStoreOp.emitOpError(
"invalid alignment specification without aligned memory access "
@@ -146,10 +161,15 @@
return success();
}
- if (auto memaccess =
- spirv::symbolizeMemoryAccess(
- memaccessAttr.template cast<StringAttr>().getValue()) ==
- spirv::MemoryAccess::Aligned) {
+ auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
+ auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
+
+ if (!memAccess) {
+ return loadStoreOp.emitOpError("invalid memory access specifier: ")
+ << memAccessVal;
+ }
+
+ if (*memAccess == spirv::MemoryAccess::Aligned) {
if (!op->getAttr(LoadStoreOpTy::getAlignmentAttrName())) {
return loadStoreOp.emitOpError("missing alignment value");
}
@@ -284,10 +304,9 @@
static void print(spirv::LoadOp loadOp, OpAsmPrinter *printer) {
auto *op = loadOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs;
- *printer
- << spirv::LoadOp::getOperationName() << " \""
- << loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClassStr()
- << "\" ";
+ StringRef sc = stringifyStorageClass(
+ loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
+ *printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" ";
// Print the pointer operand.
printer->printOperand(loadOp.ptr());
@@ -321,12 +340,56 @@
}
static ParseResult parseModuleOp(OpAsmParser *parser, OperationState *state) {
+ Builder builder = parser->getBuilder();
Region *body = state->addRegion();
- if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
- parser->parseKeyword("attributes") ||
- parser->parseOptionalAttributeDict(state->attributes))
+ Attribute addressingModel, memoryModel;
+ SmallVector<NamedAttribute, 2> attrs;
+
+ // Parse addressing model
+ auto loc = parser->getCurrentLocation();
+ if (parser->parseAttribute(addressingModel, kAddressingModelAttrName, attrs))
return failure();
+ if (!addressingModel.isa<StringAttr>()) {
+ return parser->emitError(loc,
+ "requires string for addressing model but found '")
+ << addressingModel << "'";
+ }
+ auto addrModel = spirv::symbolizeAddressingModel(
+ addressingModel.cast<StringAttr>().getValue());
+ if (!addrModel) {
+ return parser->emitError(loc, "unknown addressing model: ")
+ << addressingModel;
+ }
+ state->addAttribute(
+ kAddressingModelAttrName,
+ builder.getI32IntegerAttr(bitwiseCast<int32_t>(*addrModel)));
+
+ // Parse memory model
+ loc = parser->getCurrentLocation();
+ if (parser->parseAttribute(memoryModel, kMemoryModelAttrName, attrs))
+ return failure();
+ if (!memoryModel.isa<StringAttr>()) {
+ return parser->emitError(loc,
+ "requires string for memory model but found '")
+ << memoryModel << "'";
+ }
+ auto memModel =
+ spirv::symbolizeMemoryModel(memoryModel.cast<StringAttr>().getValue());
+ if (!memModel) {
+ return parser->emitError(loc, "unknown memory model: ") << memoryModel;
+ }
+ state->addAttribute(
+ kMemoryModelAttrName,
+ builder.getI32IntegerAttr(bitwiseCast<int32_t>(*memModel)));
+
+ if (parser->parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+
+ if (succeeded(parser->parseOptionalKeyword("attributes"))) {
+ if (parser->parseOptionalAttributeDict(state->attributes))
+ return failure();
+ }
ensureModuleEnd(body, parser->getBuilder(), state->location);
@@ -335,11 +398,33 @@
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter *printer) {
auto *op = moduleOp.getOperation();
- *printer << spirv::ModuleOp::getOperationName();
+
+ // Only print out addressing model and memory model in a nicer way if both
+ // presents. Otherwise, print them in the general form. This helps debugging
+ // ill-formed ModuleOp.
+ SmallVector<StringRef, 2> elidedAttrs;
+ if (op->getAttr(kAddressingModelAttrName) &&
+ op->getAttr(kMemoryModelAttrName)) {
+ *printer << spirv::ModuleOp::getOperationName() << " \""
+ << spirv::stringifyAddressingModel(moduleOp.addressing_model())
+ << "\" \"" << spirv::stringifyMemoryModel(moduleOp.memory_model())
+ << '"';
+ elidedAttrs.assign({kAddressingModelAttrName, kMemoryModelAttrName});
+ }
+
printer->printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
- *printer << " attributes";
- printer->printOptionalAttrDict(op->getAttrs());
+
+ bool printAttrDict = elidedAttrs.size() != 2 ||
+ llvm::any_of(op->getAttrs(), [](NamedAttribute attr) {
+ return attr.first != kAddressingModelAttrName &&
+ attr.first != kMemoryModelAttrName;
+ });
+
+ if (printAttrDict) {
+ *printer << " attributes";
+ printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+ }
}
static LogicalResult verify(spirv::ModuleOp moduleOp) {
@@ -419,12 +504,9 @@
static void print(spirv::StoreOp storeOp, OpAsmPrinter *printer) {
auto *op = storeOp.getOperation();
SmallVector<StringRef, 4> elidedAttrs;
- *printer << spirv::StoreOp::getOperationName() << " \""
- << storeOp.ptr()
- ->getType()
- .cast<spirv::PointerType>()
- .getStorageClassStr()
- << "\" ";
+ StringRef sc = stringifyStorageClass(
+ storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
+ *printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" ";
// Print the pointer operand
printer->printOperand(storeOp.ptr());
*printer << ", ";
@@ -501,9 +583,8 @@
state->addOperands(init);
}
- // TODO(antiagainst): The enum attribute should be integer backed so we don't
- // have these excessive string conversions.
- auto attr = parser->getBuilder().getStringAttr(ptrType.getStorageClassStr());
+ auto attr = parser->getBuilder().getI32IntegerAttr(
+ bitwiseCast<int32_t>(ptrType.getStorageClass()));
state->addAttribute(kStorageClassAttrName, attr);
return success();
@@ -538,11 +619,11 @@
// SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
// object. It cannot be Generic. It must be the same as the Storage Class
// operand of the Result Type."
- if (varOp.storage_class() == "Generic")
+ if (varOp.storage_class() == spirv::StorageClass::Generic)
return varOp.emitOpError("storage class cannot be 'Generic'");
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
- if (varOp.storage_class() != pointerType.getStorageClassStr())
+ if (varOp.storage_class() != pointerType.getStorageClass())
return varOp.emitOpError(
"storage class must match result pointer's storage class");
diff --git a/lib/SPIRV/SPIRVTypes.cpp b/lib/SPIRV/SPIRVTypes.cpp
index 6248a16..23acd65 100644
--- a/lib/SPIRV/SPIRVTypes.cpp
+++ b/lib/SPIRV/SPIRVTypes.cpp
@@ -280,10 +280,6 @@
return getImpl()->getStorageClass();
}
-StringRef PointerType::getStorageClassStr() {
- return stringifyStorageClass(getStorageClass());
-}
-
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
diff --git a/lib/SPIRV/Serialization/Deserializer.cpp b/lib/SPIRV/Serialization/Deserializer.cpp
index e8ad6fd..e7eaa99 100644
--- a/lib/SPIRV/Serialization/Deserializer.cpp
+++ b/lib/SPIRV/Serialization/Deserializer.cpp
@@ -31,6 +31,15 @@
using namespace mlir;
+template <typename Dst, typename Src>
+inline Dst bitwiseCast(Src source) noexcept {
+ Dst dest;
+ static_assert(sizeof(source) == sizeof(dest),
+ "bitwiseCast requires same source and destination bitwidth");
+ std::memcpy(&dest, &source, sizeof(dest));
+ return dest;
+}
+
namespace {
/// A SPIR-V module serializer.
///
@@ -152,22 +161,11 @@
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
- // TODO(antiagainst): use IntegerAttr-backed enum attributes to avoid the
- // excessive string conversions here.
-
- auto am = spirv::symbolizeAddressingModel(operands.front());
- if (!am)
- return emitError(unknownLoc, "unknown addressing model for OpMemoryModel");
-
- auto mm = spirv::symbolizeMemoryModel(operands.back());
- if (!mm)
- return emitError(unknownLoc, "unknown memory model for OpMemoryModel");
-
module->setAttr(
"addressing_model",
- opBuilder.getStringAttr(spirv::stringifyAddressingModel(*am)));
- module->setAttr("memory_model",
- opBuilder.getStringAttr(spirv::stringifyMemoryModel(*mm)));
+ opBuilder.getI32IntegerAttr(bitwiseCast<int32_t>(operands.front())));
+ module->setAttr("memory_model", opBuilder.getI32IntegerAttr(
+ bitwiseCast<int32_t>(operands.back())));
return success();
}
diff --git a/lib/SPIRV/Serialization/Serializer.cpp b/lib/SPIRV/Serialization/Serializer.cpp
index 93bd464..74d671f 100644
--- a/lib/SPIRV/Serialization/Serializer.cpp
+++ b/lib/SPIRV/Serialization/Serializer.cpp
@@ -142,12 +142,8 @@
}
void Serializer::processMemoryModel() {
- // TODO(antiagainst): use IntegerAttr-backed enum attributes to avoid the
- // excessive string conversions here.
- auto mm = static_cast<uint32_t>(*spirv::symbolizeMemoryModel(
- module.getAttrOfType<StringAttr>("memory_model").getValue()));
- auto am = static_cast<uint32_t>(*spirv::symbolizeAddressingModel(
- module.getAttrOfType<StringAttr>("addressing_model").getValue()));
+ uint32_t mm = module.getAttrOfType<IntegerAttr>("memory_model").getInt();
+ uint32_t am = module.getAttrOfType<IntegerAttr>("addressing_model").getInt();
constexpr uint32_t kNumWords = 3;
diff --git a/test/SPIRV/Serialization/minimal-module.mlir b/test/SPIRV/Serialization/minimal-module.mlir
index 4593c76..c74ba86 100644
--- a/test/SPIRV/Serialization/minimal-module.mlir
+++ b/test/SPIRV/Serialization/minimal-module.mlir
@@ -1,13 +1,9 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
-// CHECK: spv.module {
-// CHECK-NEXT: } attributes {addressing_model = "Logical", major_version = 1 : i32, memory_model = "VulkanKHR", minor_version = 0 : i32}
+// CHECK: spv.module "Logical" "VulkanKHR" {
+// CHECK-NEXT: } attributes {major_version = 1 : i32, minor_version = 0 : i32}
func @spirv_module() -> () {
- spv.module {
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
- }
+ spv.module "Logical" "VulkanKHR" { }
return
}
diff --git a/test/SPIRV/ops.mlir b/test/SPIRV/ops.mlir
index 407ce4f..a3a55da 100644
--- a/test/SPIRV/ops.mlir
+++ b/test/SPIRV/ops.mlir
@@ -144,6 +144,15 @@
// -----
+func @load_unknown_memory_access() -> () {
+ %0 = spv.Variable : !spv.ptr<f32, Function>
+ // expected-error @+1 {{invalid memory access specifier: "Something"}}
+ %1 = spv.Load "Function" %0 ["Something"] : f32
+ return
+}
+
+// -----
+
func @aligned_load_incorrect_attributes() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{expected ']'}}
@@ -165,14 +174,11 @@
// -----
func @return_mismatch_func_signature() -> () {
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
func @work() -> (i32) {
// expected-error @+1 {{cannot be used in functions returning value}}
spv.Return
}
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -340,7 +346,7 @@
func @storage_class_mismatch() -> () {
%0 = spv.constant 5.0 : f32
// expected-error @+1 {{storage class must match result pointer's storage class}}
- %1 = "spv.Variable"(%0) {storage_class = "Uniform"} : (f32) -> !spv.ptr<f32, Function>
+ %1 = "spv.Variable"(%0) {storage_class = 2: i32} : (f32) -> !spv.ptr<f32, Function>
return
}
diff --git a/test/SPIRV/structure-ops.mlir b/test/SPIRV/structure-ops.mlir
index bbe3be2..fd3c7ab 100644
--- a/test/SPIRV/structure-ops.mlir
+++ b/test/SPIRV/structure-ops.mlir
@@ -61,22 +61,17 @@
// CHECK-LABEL: func @module_without_cap_ext
func @module_without_cap_ext() -> () {
- // CHECK: spv.module
- spv.module { } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
- }
+ // CHECK: spv.module "Logical" "VulkanKHR"
+ spv.module "Logical" "VulkanKHR" { }
return
}
// CHECK-LABEL: func @module_with_cap_ext
func @module_with_cap_ext() -> () {
- // CHECK: spv.module
- spv.module { } attributes {
+ // CHECK: attributes {capability = ["Shader"], extension = ["SPV_KHR_16bit_storage"]}
+ spv.module "Logical" "VulkanKHR" { } attributes {
capability = ["Shader"],
- extension = ["SPV_KHR_16bit_storage"],
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
+ extension = ["SPV_KHR_16bit_storage"]
}
return
}
@@ -84,11 +79,8 @@
// CHECK-LABEL: func @module_with_explict_module_end
func @module_with_explict_module_end() -> () {
// CHECK: spv.module
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
spv._module_end
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -96,13 +88,10 @@
// CHECK-LABEL: func @module_with_func
func @module_with_func() -> () {
// CHECK: spv.module
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
spv.Return
}
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -110,24 +99,32 @@
// -----
func @missing_addressing_model() -> () {
- // expected-error@+1 {{requires attribute 'addressing_model'}}
- spv.module { } attributes {}
+ // expected-error@+1 {{requires string for addressing model}}
+ spv.module { }
return
}
// -----
func @wrong_addressing_model() -> () {
- // expected-error@+1 {{attribute 'addressing_model' failed to satisfy constraint}}
- spv.module { } attributes {addressing_model = "Physical", memory_model = "VulkanHKR"}
+ // expected-error@+1 {{unknown addressing model: "Physical"}}
+ spv.module "Physical" { }
return
}
// -----
func @missing_memory_model() -> () {
- // expected-error@+1 {{requires attribute 'memory_model'}}
- spv.module { } attributes {addressing_model = "Logical"}
+ // expected-error@+1 {{requires string for memory model}}
+ spv.module "Logical" { }
+ return
+}
+
+// -----
+
+func @wrong_memory_model() -> () {
+ // expected-error@+1 {{unknown memory model: "Bla"}}
+ spv.module "Logical" "Bla" { }
return
}
@@ -135,14 +132,11 @@
func @module_with_multiple_blocks() -> () {
// expected-error @+1 {{failed to verify constraint: region with 1 blocks}}
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
^first:
spv.Return
^second:
spv.Return
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -150,12 +144,9 @@
// -----
func @use_non_spv_op_inside_module() -> () {
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{'spv.module' can only contain func and spv.* ops}}
"dialect.op"() : () -> ()
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -163,14 +154,11 @@
// -----
func @use_non_spv_op_inside_func() -> () {
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
func @do_nothing() -> () {
// expected-error @+1 {{functions in 'spv.module' can only contain spv.* ops}}
"dialect.op"() : () -> ()
}
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -178,12 +166,9 @@
// -----
func @use_extern_func() -> () {
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
// expected-error @+1 {{'spv.module' cannot contain external functions}}
func @extern() -> ()
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
@@ -191,7 +176,7 @@
// -----
func @module_with_nested_func() -> () {
- spv.module {
+ spv.module "Logical" "VulkanKHR" {
func @outer_func() -> () {
// expected-error @+1 {{'spv.module' cannot contain nested functions}}
func @inner_func() -> () {
@@ -199,9 +184,6 @@
}
spv.Return
}
- } attributes {
- addressing_model = "Logical",
- memory_model = "VulkanKHR"
}
return
}
diff --git a/utils/spirv/gen_spirv_dialect.py b/utils/spirv/gen_spirv_dialect.py
index f76d163..88f640c 100755
--- a/utils/spirv/gen_spirv_dialect.py
+++ b/utils/spirv/gen_spirv_dialect.py
@@ -79,11 +79,11 @@
def gen_operand_kind_enum_attr(operand_kind):
- """Generates the TableGen EnumAttr definition for the given operand kind.
+ """Generates the TableGen I32EnumAttr definition for the given operand kind.
Returns:
- The operand kind's name
- - A string containing the TableGen EnumAttr definition
+ - A string containing the TableGen I32EnumAttr definition
"""
if 'enumerants' not in operand_kind:
return '', ''
@@ -96,7 +96,7 @@
# Generate the definition for each enum case
fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\
- 'EnumAttrCase<"{symbol}", {value}>;'
+ 'I32EnumAttrCase<"{symbol}", {value}>;'
case_defs = [
fmt_str.format(
acronym=kind_acronym,
@@ -120,9 +120,11 @@
# Generate the enum attribute definition
enum_attr = 'def SPV_{name}Attr :\n '\
- 'EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\
- ' let cppNamespace = "::mlir::spirv";\n'\
- ' let underlyingType = "uint32_t";\n}}'.format(
+ 'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\
+ ' let returnType = "::mlir::spirv::{name}";\n'\
+ ' let convertFromStorage = '\
+ '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
+ ' let cppNamespace = "::mlir::spirv";\n}}'.format(
name=kind_name, cases=case_names)
return kind_name, case_defs + '\n\n' + enum_attr