Make static shape tensor container match other container types
Suffix with Of and take a list of types. Also give a better description that includes the element type information.
--
PiperOrigin-RevId: 249341159
diff --git a/include/mlir/IR/OpBase.td b/include/mlir/IR/OpBase.td
index e49e7dc..15477f5 100644
--- a/include/mlir/IR/OpBase.td
+++ b/include/mlir/IR/OpBase.td
@@ -370,12 +370,12 @@
def AnyTensor : TensorOf<[AnyType]>;
-// TODO(b/130807343) Fix description to contain element information.
-class StaticShapeTensor<Type t>
- : Type<And<[ TensorOf<[t]>.predicate, HasStaticShapePred ]>,
- "statically shaped tensor">;
+// TODO(b/130064155) Have an easy way to add another constraint to a type.
+class StaticShapeTensorOf<list<Type> allowedTypes>
+ : Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
+ "statically shaped " # TensorOf<allowedTypes>.description>;
-def AnyStaticShapeTensor : StaticShapeTensor<AnyType>;
+def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
def I1Tensor : TensorOf<[I1]>;
def I8Tensor : TensorOf<[I8]>;