| //===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This is the base operation definition file. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifdef OP_BASE |
| #else |
| #define OP_BASE |
| |
| //===----------------------------------------------------------------------===// |
| // Common utilities for defining TableGen mechanisms |
| //===----------------------------------------------------------------------===// |
| |
| // Concatenates a list of strings with a separator (default ", ") |
| class StrJoin<list<string> strings, string sep = ", "> { |
| string result = |
| !if(!empty(strings), "", |
| !foldl(!head(strings), !tail(strings), prev, cur, prev # sep # cur)); |
| } |
| |
| // Concatenates a list of integers into a string with a separator (default ", ") |
| class StrJoinInt<list<int> integers, string sep = ", "> : |
| StrJoin<!foreach(i, integers, !cast<string>(i)), sep>; |
| |
| //===----------------------------------------------------------------------===// |
| // Predicate definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Base class for logical predicates. |
| // |
| // Predicates are used to compose constraints (see next section for details). |
| // There are two categories of predicates: |
| // |
| // 1. CPred: the primitive leaf predicate. |
| // 2. Compound predicate: a predicate composed from child predicates using |
| // predicate combiners ("conjunction", "disjunction", "negation" or |
| // "substitution"). |
| class Pred; |
| |
| // A logical predicate wrapping any C expression. |
| // |
| // This is the basis for composing more complex predicates. It is the "atom" |
| // predicate from the perspective of TableGen and the "interface" between |
| // TableGen and C++. What is inside is already C++ code, which will be treated |
| // as opaque strings with special placeholders to be substituted. |
| // |
| // ## Special placeholders |
| // |
| // Special placeholders can be used to refer to entities in the context where |
| // this predicate is used. They serve as "hooks" to the enclosing environment. |
| // The following special placeholders are supported in constraints for an op: |
| // |
| // * `$_builder` will be replaced by a mlir::Builder instance. |
| // * `$_op` will be replaced by the current operation. |
| // * `$_self` will be replaced with the entity this predicate is attached to. |
| // E.g., `BoolAttr` is an attribute constraint that wraps a |
| // `CPred<"$_self.isa<BoolAttr>()">` (see the following sections for details). |
| // Then for `F32:$attr`,`$_self` will be replaced by `$attr`. |
| // For type constraints, it's a little bit special since we want the |
| // constraints on each type definition reads naturally and we want to attach |
| // type constraints directly to an operand/result, $_self will be replaced |
| // by the operand/result's type. E.g., for `F32` in `F32:$operand`, its |
| // `$_self` will be expanded as `getOperand(...)->getType()`. |
| class CPred<code pred> : Pred { |
| code predExpr = "(" # pred # ")"; |
| } |
| |
| // Kinds of predicate combiners. These must closesly match the predicates |
| // implemented by the C++ backend (tblgen::PredCombinerKind). |
| class PredCombinerKind; |
| def PredCombinerAnd : PredCombinerKind; |
| def PredCombinerOr : PredCombinerKind; |
| def PredCombinerNot : PredCombinerKind; |
| def PredCombinerSubstLeaves : PredCombinerKind; |
| def PredCombinerConcat : PredCombinerKind; |
| |
| // A predicate that combines other predicates as defined by PredCombinerKind. |
| // Instantiated below. |
| class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred { |
| PredCombinerKind kind = k; |
| list<Pred> children = c; |
| } |
| |
| // Predicate combiners |
| |
| // A predicate that holds if all of its children hold. Always holds for zero |
| // children. |
| class And<list<Pred> children> : CombinedPred<PredCombinerAnd, children>; |
| |
| // A predicate that holds if any of its children hold. Never holds for zero |
| // children. |
| class Or<list<Pred> children> : CombinedPred<PredCombinerOr, children>; |
| |
| // A predicate that holds if its child does not. |
| class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>; |
| |
| // A predicate that substitutes "pat" with "repl" in predicate calls of the |
| // leaves of the predicate tree (i.e., not CombinedPred). |
| // |
| // This is plain string substitution without regular expressions or captures. |
| // New predicates with more complex logical can be introduced should the need |
| // arise. |
| class SubstLeaves<string pat, string repl, Pred child> |
| : CombinedPred<PredCombinerSubstLeaves, [child]> { |
| string pattern = pat; |
| string replacement = repl; |
| } |
| |
| // A predicate that prepends `pre` and appends `suf` to the final predicate |
| // string composed from `child`. This is plain string concatenation and there |
| // will be no substitution happening for `pre` and `suf`. |
| class Concat<string pre, Pred child, string suf> : |
| CombinedPred<PredCombinerConcat, [child]> { |
| string prefix = pre; |
| string suffix = suf; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Constraint definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Base class for named constraints. |
| // |
| // An op's operands/attributes/results can have various requirements, e.g., |
| // having certain types, having values inside a certain range, and so on. |
| // Besides, for a graph rewrite rule, the source pattern used to match against |
| // the existing graph has conditions, like the op's operand must be of a more |
| // constrained subtype, the attribute must have a certain value, and so on. |
| // |
| // These requirements and conditions are modeled using this class. Records of |
| // this class are used to generate verification code in op verifier, and |
| // matching code in pattern matcher. |
| // |
| // Constraints are predicates with descriptive names, to facilitate inspection, |
| // provide nice error messages, etc. |
| class Constraint<Pred pred, string desc = ""> { |
| // The predicates that this constraint requires. |
| Pred predicate = pred; |
| // User-readable description used in error reporting messages. If empty, a |
| // generic message will be used. |
| string description = desc; |
| } |
| |
| // Subclasses used to differentiate different constraint kinds. These are used |
| // as markers for the TableGen backend to handle different constraint kinds |
| // differently if needed. Constraints not deriving from the following subclasses |
| // are considered as uncategorized constraints. |
| |
| // Subclass for constraints on a type. |
| class TypeConstraint<Pred predicate, string description = ""> : |
| Constraint<predicate, description>; |
| |
| // Subclass for constraints on an attribute. |
| class AttrConstraint<Pred predicate, string description = ""> : |
| Constraint<predicate, description>; |
| |
| // How to use these constraint categories: |
| // |
| // * Use TypeConstraint to specify |
| // * Constraints on an op's operand/result definition |
| // * Further constraints to match an op's operand/result in source pattern |
| // |
| // * Use Attr (a subclass for AttrConstraint) for |
| // * Constraints on an op's attribute definition |
| // * Use AttrConstraint to specify |
| // * Further constraints to match an op's attribute in source pattern |
| // |
| // * Use uncategorized constraint to specify |
| // * Multi-entity constraints in rewrite rules |
| |
| //===----------------------------------------------------------------------===// |
| // Common predicates |
| //===----------------------------------------------------------------------===// |
| |
| // Whether a type is a VectorType. |
| def IsVectorTypePred : CPred<"$_self.isa<VectorType>()">; |
| |
| // Whether a type is a TensorType. |
| def IsTensorTypePred : CPred<"$_self.isa<TensorType>()">; |
| |
| // Whether a type is a MemRefType. |
| def IsMemRefTypePred : CPred<"$_self.isa<MemRefType>()">; |
| |
| // Whether a type is a ShapedType. |
| def IsShapedTypePred : CPred<"$_self.isa<ShapedType>()">; |
| |
| // For a ShapedType, verify that it has a static shape. |
| def HasStaticShapePred : CPred<"$_self.cast<ShapedType>().hasStaticShape()">; |
| |
| // Whether a type is a TupleType. |
| def IsTupleTypePred : CPred<"$_self.isa<TupleType>()">; |
| |
| //===----------------------------------------------------------------------===// |
| // Dialect definitions |
| //===----------------------------------------------------------------------===// |
| |
| class Dialect { |
| // The name of the dialect. |
| string name = ?; |
| |
| // Short summary of the dialect. |
| string summary = ?; |
| |
| // The description of the dialect. |
| string description = ?; |
| |
| // The C++ namespace that ops of this dialect should be placed into. |
| // |
| // By default, uses the name of the dialect as the only namespace. To avoid |
| // placing in any namespace, use "". To specify nested namespaces, use "::" |
| // as the delimiter, e.g., given "A::B", ops will be placed in |
| // `namespace A { namespace B { <ops> } }`. |
| // |
| // Note that this works in conjunction with dialect C++ code. Depending on how |
| // the generated files are included into the dialect, you may want to specify |
| // a full namespace path or a partial one. |
| string cppNamespace = name; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Type definitions |
| //===----------------------------------------------------------------------===// |
| |
| // A type, carries type constraints. |
| class Type<Pred condition, string descr = ""> : |
| TypeConstraint<condition, descr>; |
| |
| // Allows providing an alternative name and description to an existing type def. |
| class TypeAlias<Type t, string description = t.description> : |
| Type<t.predicate, description>; |
| |
| // A variadic type constraint. It expands to zero or more of the base type. This |
| // class is used for supporting variadic operands/results. An op can declare no |
| // more than one variadic operand/result, and that operand/result must be the |
| // last one in the operand/result list. |
| class Variadic<Type type, string descr = ""> |
| // TODO(b/132908002): support variadic type conditions |
| : TypeConstraint<CPred<"true">, descr> { |
| Type baseType = type; |
| } |
| |
| // A type that can be constructed using MLIR::Builder. |
| // Note that this does not "inherit" from Type because it would require |
| // duplicating Type subclasses for buildable and non-buildable cases to avoid |
| // diamond "inheritance". |
| // TODO(zinenko): we may extend this to a more general 'Buildable' trait, |
| // making some Types and some Attrs buildable. |
| class BuildableType<code builder> { |
| // The builder call to invoke (if specified) to construct the BuildableType. |
| // Format: this will be affixed to the builder. |
| code builderCall = builder; |
| } |
| |
| // Any type at all. |
| def AnyType : Type<CPred<"true">, "any type">; |
| |
| // None type |
| def NoneType : Type<CPred<"$_self.isa<NoneType>()">, "none type">; |
| |
| // Any type from the given list |
| class AnyTypeOf<list<Type> allowedTypes, string description = ""> : Type< |
| // Satisfy any of the allowed type's condition |
| Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>, |
| !if(!eq(description, ""), |
| StrJoin<!foreach(t, allowedTypes, t.description), " or ">.result, |
| description)>; |
| |
| // Integer types. |
| // Any integer type irrespective of its width. |
| def Integer : Type<CPred<"$_self.isa<IntegerType>()">, "integer">; |
| |
| // Index type. |
| def Index : Type<CPred<"$_self.isa<IndexType>()">, "index">; |
| |
| // Integer type of a specific width. |
| class I<int width> |
| : Type<CPred<"$_self.isInteger(" # width # ")">, |
| width # "-bit integer">, |
| BuildableType<"getIntegerType(" # width # ")"> { |
| int bitwidth = width; |
| } |
| |
| class IntOfWidths<list<int> widths> : |
| AnyTypeOf<!foreach(w, widths, I<w>), |
| StrJoinInt<widths, "/">.result # "-bit integer">; |
| |
| def I1 : I<1>; |
| def I8 : I<8>; |
| def I16 : I<16>; |
| def I32 : I<32>; |
| def I64 : I<64>; |
| |
| // Floating point types. |
| |
| // Any float type irrespective of its width. |
| def Float : Type<CPred<"$_self.isa<FloatType>()">, "floating-point">; |
| |
| // Float type of a specific width. |
| class F<int width> |
| : Type<CPred<"$_self.isF" # width # "()">, |
| width # "-bit float">, |
| BuildableType<"getF" # width # "Type()"> { |
| int bitwidth = width; |
| } |
| |
| class FloatOfWidths<list<int> widths> : |
| AnyTypeOf<!foreach(w, widths, F<w>), |
| StrJoinInt<widths, "/">.result # "-bit float">; |
| |
| def F16 : F<16>; |
| def F32 : F<32>; |
| def F64 : F<64>; |
| |
| def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">, |
| BuildableType<"getBF16Type()">; |
| |
| // Function Type |
| |
| // Any function type. |
| def FunctionType : Type<CPred<"$_self.isa<FunctionType>()">, "function type">; |
| |
| // A container type is a type that has another type embedded within it. |
| class ContainerType<Type etype, Pred containerPred, code elementTypeCall, |
| string descr> : |
| // First, check the container predicate. Then, substitute the extracted |
| // element into the element type checker. |
| Type<And<[containerPred, |
| SubstLeaves<"$_self", !cast<string>(elementTypeCall), |
| etype.predicate>]>, |
| descr # " of " # etype.description # " values"> { |
| // The type of elements in the container. |
| Type elementType = etype; |
| |
| // Call to retrieve. |
| code getElementTypeCall = elementTypeCall; |
| } |
| |
| class ShapedContainerType<list<Type> allowedTypes, Pred containerPred, string descr> : |
| ContainerType<AnyTypeOf<allowedTypes>, containerPred, |
| "$_self.cast<ShapedType>().getElementType()", descr>; |
| |
| // Vector types. |
| |
| class VectorOf<list<Type> allowedTypes> : |
| ShapedContainerType<allowedTypes, IsVectorTypePred, "vector">; |
| |
| def AnyVector : VectorOf<[AnyType]>; |
| |
| // Tensor types. |
| |
| // Any tensor type whose element type is from the given `allowedTypes` list |
| class TensorOf<list<Type> allowedTypes> : |
| ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor">; |
| |
| def AnyTensor : TensorOf<[AnyType]>; |
| |
| // TODO(b/130064155) Have an easy way to add another constraint to a type. |
| class StaticShapeTensorOf<list<Type> allowedTypes> |
| : Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>, |
| "statically shaped " # TensorOf<allowedTypes>.description>; |
| |
| def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; |
| |
| def I1Tensor : TensorOf<[I1]>; |
| def I8Tensor : TensorOf<[I8]>; |
| def I16Tensor : TensorOf<[I16]>; |
| def I32Tensor : TensorOf<[I32]>; |
| def I64Tensor : TensorOf<[I64]>; |
| |
| def BF16Tensor : TensorOf<[BF16]>; |
| def F16Tensor : TensorOf<[F16]>; |
| def F32Tensor : TensorOf<[F32]>; |
| def F64Tensor : TensorOf<[F64]>; |
| |
| // Memref type. |
| |
| // TODO(b/132735995) Use ShapedContainerType when MemRef subclasses ShapedType. |
| // Memrefs are blocks of data with fixed type and rank. |
| class MemRefOf<list<Type> allowedTypes> : |
| ContainerType<AnyTypeOf<allowedTypes>, IsMemRefTypePred, |
| "$_self.cast<MemRefType>().getElementType()", "memref">; |
| |
| def AnyMemRef : MemRefOf<[AnyType]>; |
| |
| // Memref declarations handle any memref, independent of rank, size, (static or |
| // dynamic), layout, or memory space. |
| def I1MemRef : MemRefOf<[I1]>; |
| def I8MemRef : MemRefOf<[I8]>; |
| def I16MemRef : MemRefOf<[I16]>; |
| def I32MemRef : MemRefOf<[I32]>; |
| def I64MemRef : MemRefOf<[I64]>; |
| |
| def BF16MemRef : MemRefOf<[BF16]>; |
| def F16MemRef : MemRefOf<[F16]>; |
| def F32MemRef : MemRefOf<[F32]>; |
| def F64MemRef : MemRefOf<[F64]>; |
| |
| // This represents a generic tuple without any constraints on element type. |
| def AnyTuple : Type<IsTupleTypePred, "tuple">; |
| |
| // TODO(b/132952417) Make this accept a list of types like the classes above. |
| // A Tuple that only holds elements of a certain type. This cannot inherit from |
| // ContainerType because tuples do not always have a single element type that |
| // could be retrieved with elementTypeCall. |
| class TupleOf<Type t> : |
| Type<And<[ |
| IsTupleTypePred, |
| Concat< |
| [{ |
| llvm::all_of( |
| $_self.cast<TupleType>().getTypes(), |
| [](Type t) { |
| return |
| }], |
| SubstLeaves<"$_self", "t", t.predicate>, |
| "; })"> |
| ]>, "tuple">; |
| |
| //===----------------------------------------------------------------------===// |
| // Common type constraints |
| //===----------------------------------------------------------------------===// |
| |
| // Type constraint for integer-like types: integers, indices, vectors of |
| // integers, tensors of integers. |
| def IntegerLike : TypeConstraint<Or<[Integer.predicate, Index.predicate, |
| VectorOf<[Integer]>.predicate, TensorOf<[Integer]>.predicate]>, |
| "integer-like">; |
| |
| // Type constraint for float-like types: floats, vectors or tensors thereof. |
| def FloatLike : TypeConstraint<Or<[Float.predicate, |
| VectorOf<[Float]>.predicate, TensorOf<[Float]>.predicate]>, |
| "floating-point-like">; |
| |
| |
| //===----------------------------------------------------------------------===// |
| // Attribute definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Base class for all attributes. |
| class Attr<Pred condition, string descr = ""> : |
| AttrConstraint<condition, descr> { |
| code storageType = ?; // The backing mlir::Attribute type |
| code returnType = ?; // The underlying C++ value type |
| |
| // The call expression to convert from the storage type to the return |
| // type. For example, an enum can be stored as an int but returned as an |
| // enum class. |
| // |
| // Format: $_self will be expanded to the attribute. |
| // |
| // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will |
| // expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`. |
| code convertFromStorage = "$_self.getValue()"; |
| |
| // The call expression to build an attribute from a constant value. |
| // |
| // Format: $0 will be expanded to the constant value of the attribute. |
| // |
| // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will |
| // expand to `builder.getStringAttr("foo")`. |
| string constBuilderCall = ?; |
| |
| // Default value for attribute. |
| // Requires a constBuilderCall defined. |
| string defaultValue = ?; |
| |
| // Whether the attribute is optional. Typically requires a custom |
| // convertFromStorage method to handle the case where the attribute is |
| // not present. |
| bit isOptional = 0; |
| } |
| |
| // Decorates an attribute to have an (unvalidated) default value if not present. |
| class DefaultValuedAttr<Attr attr, string val> : |
| Attr<attr.predicate, attr.description> { |
| // Construct this attribute with the input attribute and change only |
| // the default value. |
| // Note: this has to be kept up to date with Attr above. |
| let storageType = attr.storageType; |
| let returnType = attr.returnType; |
| let convertFromStorage = attr.convertFromStorage; |
| let constBuilderCall = attr.constBuilderCall; |
| let defaultValue = val; |
| |
| // Remember `attr`'s def name. |
| // TOOD(b/132458159): consider embedding Attr as a field. |
| string baseAttr = !cast<string>(attr); |
| } |
| |
| // Decorates an attribute as optional. The return type of the generated |
| // attribute accessor method will be Optional<>. |
| class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> { |
| // Rewrite the attribute to be optional. |
| // Note: this has to be kept up to date with Attr above. |
| let storageType = attr.storageType; |
| let returnType = "Optional<" # attr.returnType #">"; |
| let convertFromStorage = "$_self ? " # returnType # "(" # |
| attr.convertFromStorage # ") : (llvm::None)"; |
| let isOptional = 1; |
| |
| // Remember `attr`'s def name. |
| // TOOD(b/132458159): consider embedding Attr as a field. |
| string baseAttr = !cast<string>(attr); |
| } |
| |
| // A generic attribute that must be constructed around a specific type |
| // `attrValType`. Backed by MLIR attribute kind `attrKind`. |
| class TypedAttrBase<BuildableType attrValType, string attrKind, |
| Pred condition, string descr> : |
| Attr<condition, descr> { |
| let constBuilderCall = "$_builder.get" # attrKind # "($_builder." # |
| attrValType.builderCall # ", $0)"; |
| let storageType = attrKind; |
| } |
| |
| // Any attribute. |
| def AnyAttr : Attr<CPred<"true">, "any attribute"> { |
| let storageType = "Attribute"; |
| let returnType = "Attribute"; |
| let convertFromStorage = "$_self"; |
| let constBuilderCall = "$0"; |
| } |
| |
| def BoolAttr : Attr<CPred<"$_self.isa<BoolAttr>()">, "bool attribute"> { |
| let storageType = [{ BoolAttr }]; |
| let returnType = [{ bool }]; |
| let constBuilderCall = "$_builder.getBoolAttr($0)"; |
| } |
| |
| // Base class for integer attributes of fixed width. |
| class IntegerAttrBase<I attrValType, string descr> : |
| TypedAttrBase<attrValType, "IntegerAttr", |
| And<[CPred<"$_self.isa<IntegerAttr>()">, |
| CPred<"$_self.cast<IntegerAttr>().getType()." |
| "isInteger(" # attrValType.bitwidth # ")">]>, |
| descr> { |
| let returnType = [{ APInt }]; |
| } |
| |
| def APIntAttr : Attr<CPred<"$_self.isa<IntegerAttr>()">, |
| "arbitrary integer attribute"> { |
| let storageType = [{ IntegerAttr }]; |
| let returnType = [{ APInt }]; |
| } |
| |
| def I32Attr : IntegerAttrBase<I32, "32-bit integer attribute">; |
| def I64Attr : IntegerAttrBase<I64, "64-bit integer attribute">; |
| |
| // Base class for float attributes of fixed width. |
| class FloatAttrBase<F attrValType, string descr> : |
| TypedAttrBase<attrValType, "FloatAttr", |
| And<[CPred<"$_self.isa<FloatAttr>()">, |
| CPred<"$_self.cast<FloatAttr>().getType().isF" # |
| attrValType.bitwidth # "()">]>, |
| descr> { |
| let returnType = [{ APFloat }]; |
| } |
| |
| def F32Attr : FloatAttrBase<F32, "32-bit float attribute">; |
| def F64Attr : FloatAttrBase<F64, "64-bit float attribute">; |
| |
| // An attribute backed by a string type. |
| class StringBasedAttr<Pred condition, string descr> : Attr<condition, descr> { |
| let constBuilderCall = "$_builder.getStringAttr(\"$0\")"; |
| let storageType = [{ StringAttr }]; |
| let returnType = [{ StringRef }]; |
| } |
| |
| def StrAttr : StringBasedAttr<CPred<"$_self.isa<StringAttr>()">, |
| "string attribute">; |
| |
| // An enum attribute case. |
| class EnumAttrCase<string sym> : StringBasedAttr< |
| CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">, |
| "case " # sym> { |
| // The C++ enumerant symbol |
| string symbol = sym; |
| } |
| |
| // An enum attribute. Its value can only be one from the given list of `cases`. |
| // Enum attributes are emulated via mlir::StringAttr, plus extra verification |
| // on the string: only the symbols of the allowed cases are permitted as the |
| // string value. |
| class EnumAttr<string name, string description, list<EnumAttrCase> cases> : |
| StringBasedAttr<And<[StrAttr.predicate, |
| Or<!foreach(case, cases, case.predicate)>]>, |
| description> { |
| // The C++ enum class name |
| string className = name; |
| // List of all accepted cases |
| list<EnumAttrCase> enumerants = cases; |
| } |
| |
| class ElementsAttrBase<Pred condition, string description> : |
| Attr<condition, description> { |
| let storageType = [{ ElementsAttr }]; |
| let returnType = [{ ElementsAttr }]; |
| let convertFromStorage = "$_self"; |
| } |
| |
| def ElementsAttr: ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">, |
| "constant vector/tensor attribute">; |
| |
| // Base class for array attributes. |
| class ArrayAttrBase<Pred condition, string description> : |
| Attr<condition, description> { |
| let storageType = [{ ArrayAttr }]; |
| let returnType = [{ ArrayAttr }]; |
| let convertFromStorage = "$_self"; |
| } |
| |
| def ArrayAttr : ArrayAttrBase<CPred<"$_self.isa<ArrayAttr>()">, |
| "array attribute">; |
| |
| // Base class for array attributes whose elements are of the same kind. |
| // `element` specifies the element attribute kind stored in this array. |
| class TypedArrayAttrBase<Attr element, string description>: ArrayAttrBase< |
| And<[ |
| // Guranatee this is an ArrayAttr first |
| CPred<"$_self.isa<ArrayAttr>()">, |
| // Guarantee all elements satisfy the constraints from `element` |
| Concat<"llvm::all_of($_self.cast<ArrayAttr>(), " |
| "[](Attribute attr) { return ", |
| SubstLeaves<"$_self", "attr", element.predicate>, |
| "; })">]>, |
| description> { |
| let constBuilderCall = "$_builder.getArrayAttr($0)"; |
| } |
| |
| def I32ArrayAttr : TypedArrayAttrBase<I32Attr, |
| "32-bit integer array attribute"> { |
| let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; |
| } |
| def I64ArrayAttr : TypedArrayAttrBase<I64Attr, |
| "64-bit integer array attribute"> { |
| let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; |
| } |
| def F32ArrayAttr : TypedArrayAttrBase<F32Attr, "32-bit float array attribute"> { |
| let constBuilderCall = "$_builder.getF32ArrayAttr($0)"; |
| } |
| def F64ArrayAttr : TypedArrayAttrBase<F64Attr, "64-bit float array attribute"> { |
| let constBuilderCall = "$_builder.getF64ArrayAttr($0)"; |
| } |
| def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> { |
| let constBuilderCall = "$_builder.getStrArrayAttr($0)"; |
| } |
| |
| // Attributes containing functions. |
| def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">, |
| "function attribute"> { |
| let storageType = [{ FunctionAttr }]; |
| let returnType = [{ Function * }]; |
| let constBuilderCall = "$_builder.getFunctionAttr($0)"; |
| } |
| |
| // Base class for attributes containing types. Example: |
| // def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> |
| // defines a type attribute containing an integer type. |
| class TypeAttrBase<string retType, string description> : |
| Attr<And<[ |
| CPred<"$_self.isa<TypeAttr>()">, |
| CPred<"$_self.cast<TypeAttr>().getValue().isa<" # retType # ">()">]>, |
| description> { |
| let storageType = [{ TypeAttr }]; |
| let returnType = retType; |
| let convertFromStorage = "$_self.getValue().cast<" # retType # ">()"; |
| } |
| |
| def TypeAttr : TypeAttrBase<"Type", "any type attribute">; |
| |
| // DerivedAttr are attributes whose value is computed from properties |
| // of the operation. They do not require additional storage and are |
| // materialized as needed. |
| class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived attribute"> { |
| let returnType = ret; |
| code body = b; |
| } |
| |
| // Derived attribute that returns a mlir::Type. |
| class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>; |
| |
| // Represents a constant attribute of specific Attr type. A constant |
| // attribute can be specified only of attributes that have a constant |
| // builder call defined. The constant value is specified as a string. |
| // |
| // If used as a constraint, it generates a matcher on a constant attribute by |
| // using the constant value builder of the attribute and the value. |
| class ConstantAttr<Attr attribute, string val> : AttrConstraint< |
| CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>, |
| "constant attribute " # val> { |
| Attr attr = attribute; |
| string value = val; |
| } |
| |
| class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>; |
| |
| //===----------------------------------------------------------------------===// |
| // Common attribute constraints |
| //===----------------------------------------------------------------------===// |
| |
| // A general mechanism to further confine the given `attr` with all the |
| // `constraints`. This allows to compose complex constraints out of a series |
| // of more primitive ones. |
| class Confined<Attr attr, list<AttrConstraint> constraints> : Attr< |
| And<!listconcat([attr.predicate], |
| !foreach(pred, constraints, pred.predicate))>, |
| !foldl(/*init*/attr.description, /*list*/constraints, |
| prev, cur, prev # " " # cur.description)> { |
| let storageType = attr.storageType; |
| let returnType = attr.returnType; |
| let convertFromStorage = attr.convertFromStorage; |
| let constBuilderCall = attr.constBuilderCall; |
| let defaultValue = attr.defaultValue; |
| let isOptional = attr.isOptional; |
| } |
| |
| // An AttrConstraint that holds if all attr constraints specified in |
| // 'constraints' hold. |
| class AllAttrConstraintsOf<list<AttrConstraint> constraints> : AttrConstraint< |
| And<!listconcat([!head(constraints).predicate], |
| !foreach(pred, !tail(constraints), pred.predicate))>, |
| !foldl(/*init*/!head(constraints).description, /*list*/!tail(constraints), |
| prev, cur, prev # " and " # cur.description)> { |
| } |
| |
| class IntMinValue<int n> : AttrConstraint< |
| CPred<"$_self.cast<IntegerAttr>().getInt() >= " # n>, |
| "whose minimal value is " # n>; |
| |
| class ArrayMinCount<int n> : AttrConstraint< |
| CPred<"$_self.cast<ArrayAttr>().size() >= " # n>, |
| "with at least " # n # " elements">; |
| |
| class IntArrayNthElemEq<int index, int value> : AttrConstraint< |
| And<[ |
| CPred<"$_self.cast<ArrayAttr>().size() > " # index>, |
| CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]" |
| ".cast<IntegerAttr>().getInt() == " # value> |
| ]>, |
| "whose " # index # "-th element must be " # value>; |
| |
| class IntArrayNthElemMinValue<int index, int min> : AttrConstraint< |
| And<[ |
| CPred<"$_self.cast<ArrayAttr>().size() > " # index>, |
| CPred<"$_self.cast<ArrayAttr>().getValue()[" # index # "]" |
| ".cast<IntegerAttr>().getInt() >= " # min> |
| ]>, |
| "whose " # index # "-th element must be at least " # min>; |
| |
| def IsNullAttr : AttrConstraint< |
| CPred<"!$_self">, "empty attribute (for optional attributes)">; |
| |
| //===----------------------------------------------------------------------===// |
| // OpTrait definitions |
| //===----------------------------------------------------------------------===// |
| |
| // OpTrait represents a trait regarding an op. |
| class OpTrait; |
| |
| // NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The |
| // purpose to wrap around C++ symbol string with this class is to make |
| // traits specified for ops in TableGen less alien and more integrated. |
| class NativeOpTrait<string prop> : OpTrait { |
| string trait = prop; |
| } |
| |
| // GenInternalOpTrait is an op trait that does not have direct C++ mapping but |
| // affects op definition generator internals, like how op builders and |
| // operand/attribute/result getters are generated. |
| class GenInternalOpTrait<string prop> : OpTrait { |
| string trait = prop; |
| } |
| |
| // PredOpTrait is an op trait implemented by way of a predicate on the op. |
| class PredOpTrait<string descr, Pred pred> : OpTrait { |
| string description = descr; |
| Pred predicate = pred; |
| } |
| |
| // Op supports operand broadcast behavior. |
| def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">; |
| // X op Y == Y op X |
| def Commutative : NativeOpTrait<"IsCommutative">; |
| // Op results are float or vectors/tensors thereof. |
| def FloatLikeResults : NativeOpTrait<"ResultsAreFloatLike">; |
| // Op has no side effect. |
| def NoSideEffect : NativeOpTrait<"HasNoSideEffect">; |
| // Op has same operand and result shape. |
| def SameValueShape : NativeOpTrait<"SameOperandsAndResultShape">; |
| // Op has the same operand and result type. |
| def SameValueType : NativeOpTrait<"SameOperandsAndResultType">; |
| // Op has the same operand and result element type. |
| def SameOperandsAndResultElementType : |
| NativeOpTrait<"SameOperandsAndResultElementType">; |
| // Op is a terminator. |
| def Terminator : NativeOpTrait<"IsTerminator">; |
| |
| // Op result type is derived from the first attribute. If the attribute is an |
| // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the |
| // attribute content is used. |
| def FirstAttrDerivedResultType : |
| GenInternalOpTrait<"FirstAttrDerivedResultType">; |
| |
| // All variadic operands of the op have the same number of values. |
| // A variadic operand contains an array of values whose array size is only |
| // known at runtime. This trait requires all variadic operands of an op |
| // to have the same array size. |
| def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; |
| // All variadic results of the op have the same number of values. |
| // A variadic result contains an array of values whose array size is only |
| // known at runtime. This trait requires all variadic results of an op |
| // to have the same array size. |
| def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; |
| |
| //===----------------------------------------------------------------------===// |
| // Op definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Marker used to identify the argument list for an op. |
| def ins; |
| |
| // Marker used to identify the result list for an op. |
| def outs; |
| |
| // Class for defining a custom builder. |
| // |
| // TableGen generates several generic builders for each op by default (see |
| // comment in the `Op` class). If the default generated ones cannot cover |
| // some use case, custom builders can be defined using instances of this class. |
| // |
| // The signature of the builder is always |
| // |
| // ```c++ |
| // static void build(Builder *builder, OperationState *state, |
| // <other-parameters>...) { |
| // <body>... |
| // } |
| // ``` |
| // |
| // To define a custom builder, the parameter list (*including* the `Builder |
| // *builder, OperationState *state` part) and body should be passed in |
| // as separate template arguments to this class. This is because we generate |
| // op declaration and definition into separate files. If an empty string is |
| // passed in for `body`, then *only* the builder declaration will be |
| // generated; this provides a way to define complicated builders entirely |
| // in C++. |
| class OpBuilder<string p, code b = ""> { |
| string params = p; |
| code body = b; |
| } |
| |
| // Base class for all ops. |
| class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> { |
| // The dialect of the op. |
| Dialect opDialect = dialect; |
| |
| // The mnemonic of the op. |
| string opName = mnemonic; |
| |
| // One-line human-readable description of what the op does. |
| string summary = ""; |
| |
| // Additional, longer human-readable description of what the op does. |
| string description = ""; |
| |
| // Dag containting the arguments of the op. Default to 0 arguments. |
| dag arguments = (ins); |
| |
| // The list of results of the op. Default to 0 results. |
| dag results = (outs); |
| |
| // Attribute getters can be added to the op by adding an Attr member |
| // with the name and type of the attribute. E.g., adding int attribute |
| // with name "value" and type "i32": |
| // I32Attr value; |
| |
| // Define the hooks used for building, parsing, printing, verification. |
| |
| // Custom builder. |
| // In addtion to the custom builder provided here, two default builders |
| // are generated, with the following signatures: |
| // |
| // ```c++ |
| // static void build(Builder *, OperationState *tblgen_state, |
| // Type <result0-name>, Type <result1-name>, ..., |
| // Value <arg0-name>, Value <arg1-name>, ..., |
| // Attribute <attr0-name>, Attribute <attr1-name>, ...); |
| // ``` |
| // * where the attributes follow the same declaration order as in the op. |
| // |
| // ```c++ |
| // static void build(Builder *, OperationState *tblgen_state, |
| // ArrayRef<Type> resultTypes, |
| // ArrayRef<Value> operands, |
| // ArrayRef<NamedAttribute> attributes); |
| // ``` |
| list<OpBuilder> builders = ?; |
| |
| // Custom parser. |
| code parser = ?; |
| |
| // Custom printer. |
| code printer = ?; |
| |
| // Custom verifier. |
| code verifier = ?; |
| |
| // Whether this op has associated canonicalization patterns. |
| // TODO(b/120163349): figure out a better way to write canonicalization |
| // patterns in TableGen rules directly instead of using this marker |
| // and C++ implementations. |
| bit hasCanonicalizer = 0; |
| |
| // Whether this op has a folder. |
| bit hasFolder = 0; |
| |
| // Op traits. |
| list<OpTrait> traits = props; |
| |
| // Additional code that will be added to the public part of the generated |
| // C++ code of the op declaration. |
| code extraClassDeclaration = ?; |
| } |
| |
| // The arguments of an op. |
| class Arguments<dag args> { |
| dag arguments = args; |
| } |
| |
| // The results of an op. |
| class Results<dag rets> { |
| dag results = rets; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Common op type constraints |
| //===----------------------------------------------------------------------===// |
| |
| // Type Constraint operand `idx`'s Element type is `type`. |
| class TCopVTEtIs<int idx, Type type> : And<[ |
| CPred<"$_op.getNumOperands() > " # idx>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()", |
| IsShapedTypePred>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # idx # |
| ")->getType().cast<ShapedType>().getElementType()", |
| type.predicate>]>; |
| |
| // Predicate to verify that the i'th operand and the j'th operand have the same |
| // elemental type. |
| // Type Constraint operand `i`'s Element type is Same As operand `j`'s Element |
| // type. |
| class TCopVTEtIsSameAs<int i, int j> : And<[ |
| CPred<"$_op.getNumOperands() > std::max(" # i # "," # j # ")">, |
| SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()", |
| IsShapedTypePred>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", |
| IsShapedTypePred>, |
| // TODO: This could be made into C++ function instead. |
| CPred<"$_op.getOperand(" # i # ")->getType().cast<ShapedType>()." |
| "getElementType() == $_op.getOperand(" # j # ")->getType()." |
| "cast<ShapedType>().getElementType()">]>; |
| |
| // Predicate to verify that the i'th result and the j'th operand have the same |
| // elemental type. |
| // Type Constraint result`i`'s Element type is Same As Operand `j`'s Element |
| // type. |
| class TCresVTEtIsSameAsOp<int i, int j> : And<[ |
| CPred<"$_op.getNumResults() > " # i>, |
| CPred<"$_op.getNumOperands() > " # j>, |
| SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()", |
| IsShapedTypePred>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()", |
| IsShapedTypePred>, |
| // TODO: This could be made into C++ function instead. |
| CPred<"$_op.getResult(" # i # ")->getType().cast<ShapedType>()." |
| "getElementType() == $_op.getOperand(" # j # ")->getType()." |
| "cast<ShapedType>().getElementType()">]>; |
| |
| // Predicate to verify that all the operands at the given `indices` |
| // have the same element type. |
| // Type Constraint operands' Element type are all Same At the given `indices`. |
| // We query the operands' types into a list and check they are all the same. |
| // Precondition: |
| // 1) all operands involved are of vector or tensor type and |
| // 2) the indices are not out of range. |
| class TCopVTEtAreSameAt<list<int> indices> : |
| CPred<"llvm::is_splat(mlir::functional::map(" |
| "[this](unsigned i) { return this->getOperand(i)->getType()" |
| ".cast<ShapedType>().getElementType(); }, " |
| "llvm::ArrayRef<unsigned>({" # StrJoinInt<indices>.result # "})))">; |
| |
| //===----------------------------------------------------------------------===// |
| // Pattern definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Marker used to identify the delta value added to the default benefit value. |
| def addBenefit; |
| |
| // Base class for op+ -> op+ rewrite rules. These allow declaratively |
| // specifying rewrite rules. |
| // |
| // A rewrite rule contains two components: a source pattern and one or more |
| // result patterns. Each pattern is specified as a (recursive) DAG node (tree) |
| // in the form of `(node arg0, arg1, ...)`. |
| // |
| // The `node` are normally MLIR ops, but it can also be one of the directives |
| // listed later in this section. |
| // |
| // In the source pattern, `argN` can be used to specify matchers (e.g., using |
| // type/attribute type constraints, etc.) and bound to a name for later use. |
| // We can also bound names to op instances to reference them later in |
| // multi-entity constraints. |
| // |
| // In the result pattern, `argN` can be used to refer to a previously bound |
| // name, with potential transformations (e.g., using tAttr, etc.). `argN` can |
| // itself be nested DAG node. We can also bound names to ops to reference |
| // them later in other result patterns. |
| // |
| // For example, |
| // |
| // ``` |
| // def : Pattern<(OneResultOp1:$op1 $arg0, $arg1), |
| // [(OneResultOp2:$op2 $arg0, $arg1), |
| // (OneResultOp3 $op2 (OneResultOp4))], |
| // [(HasStaticShapePred $op1)]>; |
| // ``` |
| // |
| // `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to |
| // build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to |
| // check whether the result's shape is static. `$op2` is bound to |
| // `OneResultOp2` and used to build `OneResultOp3`. |
| class Pattern<dag source, list<dag> results, list<dag> preds = [], |
| dag benefitAdded = (addBenefit 0)> { |
| dag sourcePattern = source; |
| // Result patterns. Each result pattern is expected to replace one result |
| // of the root op in the source pattern. In the case of more result patterns |
| // than needed to replace the source op, only the last N results generated |
| // by the last N result pattern is used to replace a N-result source op. |
| // So that the beginning result patterns can be used to generate additional |
| // ops to aid building the results used for replacement. |
| list<dag> resultPatterns = results; |
| // Multi-entity constraints. Each constraint here involves multiple entities |
| // matched in source pattern and places further constraints on them as a |
| // whole. |
| list<dag> constraints = preds; |
| // The delta value added to the default benefit value. The default value is |
| // the number of ops in the source pattern. The rule with the highest final |
| // benefit value will be applied first if there are multiple rules matches. |
| // This delta value can be either positive or negative. |
| dag benefitDelta = benefitAdded; |
| } |
| |
| // Form of a pattern which produces a single result. |
| class Pat<dag pattern, dag result, list<dag> preds = [], |
| dag benefitAdded = (addBenefit 0)> : |
| Pattern<pattern, [result], preds, benefitAdded>; |
| |
| // Native code call wrapper. This allows invoking an arbitrary C++ expression |
| // to create an op operand/attribute or replace an op result. |
| // |
| // ## Placeholders |
| // |
| // If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`, |
| // the wrapped expression can take special placeholders listed below: |
| // |
| // * `$_builder` will be replaced by the current `mlir::PatternRewriter`. |
| // * `$_self` will be replaced with the entity this transformer is attached to. |
| // E.g., with the definition `def transform : tAttr<$_self...>`, `$_self` in |
| // `transform:$attr` will be replaced by the value for `$att`. |
| // |
| // If used as a DAG node, i.e., `(NativeCodeCall<"..."> <arg0>, ..., <argN>)`, |
| // then positional placeholders are also supported; placeholder `$N` in the |
| // wrapped C++ expression will be replaced by `<argN>`. |
| |
| class NativeCodeCall<string expr> { |
| string expression = expr; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Common directives |
| //===----------------------------------------------------------------------===// |
| |
| // Directive used in result pattern to indicate that no new op are generated, |
| // so to replace the matched DAG with an existing SSA value. |
| def replaceWithValue; |
| |
| // Directive used in result pattern to indicate that no replacement is generated |
| // for the current result. Predicates are generated to make sure the |
| // corresponding result in source pattern is unused. |
| // syntax: (verifyUnusedValue) |
| def verifyUnusedValue; |
| |
| #endif // OP_BASE |