blob: 185a391105a1f63ea06b3b3727a99b251f05e471 [file] [log] [blame]
//===-- op_base.td - Base op definition file ---------------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the base operation definition file.
//
//===----------------------------------------------------------------------===//
#ifdef OP_BASE
#else
#define OP_BASE
//===----------------------------------------------------------------------===//
// Predicates.
//===----------------------------------------------------------------------===//
// A logical predicate.
class Pred;
// Logical predicate wrapping a C expression.
class CPred<code pred> : Pred {
code predCall = "(" # pred # ")";
}
// Kinds of combined logical predicates. These must closesly match the
// predicates implemented by the C++ backend (tblgen::PredCombinerKind).
class PredCombinerKind;
def PredCombinerAnd : PredCombinerKind;
def PredCombinerOr : PredCombinerKind;
def PredCombinerNot : PredCombinerKind;
def PredCombinerSubstLeaves : PredCombinerKind;
// A predicate that combines other predicates as defined by PredCombinerKind.
// Instantiated below.
class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred {
PredCombinerKind kind = k;
list<Pred> children = c;
}
// A predicate that holds if all of its children hold. Always holds for zero
// children.
class AllOf<list<Pred> children> : CombinedPred<PredCombinerAnd, children>;
// A predicate that holds if any of its children hold. Never holds for zero
// children.
class AnyOf<list<Pred> children> : CombinedPred<PredCombinerOr, children>;
// A predicate that holds if its child does not.
class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>;
// A predicate that substitutes "pat" with "repl" in predicate calls of the
// leaves of the predicate tree (i.e., not CombinedPredicates). This is plain
// string substitution without regular expressions or captures, new predicates
// with more complex logical can be introduced should the need arise.
class SubstLeaves<string pat, string repl, Pred child>
: CombinedPred<PredCombinerSubstLeaves, [child]> {
string pattern = pat;
string replacement = repl;
}
//===----------------------------------------------------------------------===//
// Type predicates. ({0} is replaced by an instance of mlir::Type)
//===----------------------------------------------------------------------===//
// Whether a type is a VectorType.
def IsVectorTypePred : CPred<"{0}.isa<VectorType>()">;
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"{0}.isa<TensorType>()">;
// For a TensorType, verify that it is a statically shaped tensor.
def IsStaticShapeTensorTypePred :
CPred<"{0}.cast<TensorType>().hasStaticShape()">;
//===----------------------------------------------------------------------===//
// Type constraints and types.
//===----------------------------------------------------------------------===//
// A constraint on types. This can be used to check the validity of
// instruction arguments.
class TypeConstraint<Pred condition, string descr> {
// The predicates that this type satisfies.
// Format: {0} will be expanded to the type.
Pred predicate = condition;
// User-readable description used, e.g., for error reporting. If empty,
// a generic message will be used instead.
string description = descr;
}
// A type, carries type constraints.
class Type<Pred condition, string descr = "">
: TypeConstraint<condition, descr>;
// A variadic type. It expands to zero or more of the base type.
// This class is used for supporting variadic operands/results. An op can
// declare no more than one variadic operand/result, and that operand/result
// must be the last one in the operand/result list.
class Variadic<Type type, string descr = "">
// TODO: support variadic type conditions
: Type<CPred<"true">, descr> {
Type baseType = type;
}
// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid
// diamond "inheritance".
// TODO(zinenko): we may extend this to a more general 'Buildable' trait,
// making some Types and some Attrs buildable.
class BuildableType<code builder> {
// The builder call to invoke (if specified) to construct the BuildableType.
// Format: this will be affixed to the builder.
code builderCall = builder;
}
// Integer types.
class IntegerBase<CPred pred, string descr> : Type<pred, descr>;
// Any integer type irrespective of its width.
def Integer : IntegerBase<CPred<"{0}.isa<IntegerType>()">, "integer">;
// Index type.
def Index : IntegerBase<CPred<"{0}.isa<IndexType>()">, "index">;
// Integer type of a specific width.
class I<int width>
: IntegerBase<CPred<"{0}.isInteger(" # width # ")">,
width # "-bit integer">,
BuildableType<"getIntegerType(" # width # ")"> {
int bitwidth = width;
}
def I1 : I<1>;
def I32 : I<32>;
// Floating point types.
class FloatBase<CPred pred, string descr> : Type<pred, descr>;
// Any float type irrespective of its width.
def Float : FloatBase<CPred<"{0}.isa<FloatType>()">, "floating-point">;
// Float type of a specific width.
class F<int width>
: FloatBase<CPred<"{0}.isF" # width # "()">,
width # "-bit float">,
BuildableType<"getF" # width # "Type()"> {
int bitwidth = width;
}
def F32 : F<32>;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<AllOf<[containerPred,
SubstLeaves<"{0}", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # " of " # etype.description # " values"> {
// The type of elements in the container.
Type elementType = etype;
// Call to retrieve.
code getElementTypeCall = elementTypeCall;
}
// Vector types.
class TypedVector<Type t> : ContainerType<t, IsVectorTypePred,
"{0}.cast<VectorType>().getElementType()", "vector">;
class Vector<Type t, list<int> dims> : ContainerType<t, AllOf<[
IsVectorTypePred,
// Match dims. Construct an ArrayRef with the elements of `dims` by folding
// over the list.
CPred<"{0}.cast<VectorType>().getShape() == ArrayRef{{" #
!foldl("", dims, sum, element, sum #
!if(!empty(sum), "", ",") # !cast<string>(element)) # "}">]>,
"{0}.cast<VectorType>().getElementType()",
"vector"> {
list<int> dimensions = dims;
}
// Tensor type.
// This represents a generic tensor without constraints on elemental type,
// rank, size. As there is no constraint on elemental type, derive from Type
// directly instead of ContainerType.
def Tensor : Type<IsTensorTypePred, "tensor">;
// A tensor with static shape but no other constraints. Note: as
// Tensor is a def this doesn't derive from it, but reuses the predicate
// that must hold for it to be a tensor.
def StaticShapeTensor
: Type<AllOf<[Tensor.predicate, IsStaticShapeTensorTypePred]>,
"statically shaped tensor">;
// For typed tensors.
class TypedTensor<Type t>
: ContainerType<t, Tensor.predicate,
"{0}.cast<TensorType>().getElementType()",
"tensor">;
def F32Tensor : TypedTensor<F32>;
// Type constraint for integer-like types: integers, indices, vectors of
// integers, tensors of integers.
def IntegerLike : TypeConstraint<AnyOf<[Integer.predicate, Index.predicate,
TypedVector<Integer>.predicate, TypedTensor<Integer>.predicate]>,
"integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<AnyOf<[Float.predicate,
TypedVector<Float>.predicate, TypedTensor<Float>.predicate]>,
"floating-point-like">;
//===----------------------------------------------------------------------===//
// Attributes
//===----------------------------------------------------------------------===//
// A constraint on attributes. This can be used to check the validity of
// instruction attributes.
class AttrConstraint<Pred condition, string descr> {
// The predicates that this attribute satisfies.
// Format: {0} will be expanded to the attribute.
Pred predicate = condition;
// User-readable description used, e.g., for error reporting.
// If empty, a generic message will be used instead.
string description = descr;
}
// Base class for all attributes.
class Attr<Pred condition, string descr = ""> :
AttrConstraint<condition, descr> {
code storageType = ?; // The backing mlir::Attribute type
code returnType = ?; // The underlying C++ value type
// Define converter method to convert from the storage type to the return
// type. For example, an enum can be stored as an int but returned as an
// enum class.
//
// Format: {0} will be expanded to the attribute. So
// '{0}.getValue().convertToFloat()' for 'FloatAttr val' will expand to
// 'getAttrOfType<FloatAttr>("val").getValue().convertToFloat()'.
code convertFromStorage = "{0}.getValue()";
// The call expression that builds an attribute from a constant value.
//
// Format: {0} will be expanded to an instance of mlir::Builder, {1} will be
// expanded to the constant value of the attribute. For example,
// '{0}.getStringAttr("{1}")' for 'StringAttr:"foo"' will expand to
// 'builder.getStringAttr("foo")'.
code constBuilderCall = ?;
// Default value for attribute.
// Requires a constBuilderCall defined.
string defaultValue = ?;
}
// A generic attribute that must be constructed around a specific type.
// Backed by a C++ class "attrName".
class TypeBasedAttr<BuildableType t, string attrName, string descr> :
Attr<CPred<"true">, descr> {
let constBuilderCall =
"{0}.get" # attrName # "({0}." # t.builderCall # ", {1})";
let storageType = attrName;
}
// An attribute backed by a string type.
class StringBasedAttr<string descr> : Attr<CPred<"true">, descr> {
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
let storageType = [{ StringAttr }];
}
// Base class for instantiating float attributes of fixed width.
class FloatAttrBase<BuildableType t, string descr> :
TypeBasedAttr<t, "FloatAttr", descr>;
// Base class for instantiating integer attributes of fixed width.
class IntegerAttrBase<BuildableType t, string descr> :
TypeBasedAttr<t, "IntegerAttr", descr>;
def BoolAttr : Attr<CPred<"true">, "bool"> {
let storageType = [{ BoolAttr }];
let returnType = [{ bool }];
let constBuilderCall = [{ {0}.getBoolAttr({1}) }];
}
def ArrayAttr : Attr<CPred<"true">, "array"> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayAttr }];
code convertFromStorage = "{0}";
}
def ElementsAttr : Attr<CPred<"true">, "constant vector/tensor"> {
let storageType = [{ ElementsAttr }];
let returnType = [{ ElementsAttr }];
let convertFromStorage = "{0}";
}
def F32Attr : FloatAttrBase<F32, "32-bit float"> {
let returnType = [{ float }];
let convertFromStorage = [{ {0}.getValue().convertToFloat() }];
}
def I32Attr : IntegerAttrBase<I32, "32-bit integer"> {
let storageType = [{ IntegerAttr }];
let returnType = [{ int }];
let convertFromStorage = [{ {0}.getValue().getSExtValue() }];
}
def StrAttr : StringBasedAttr<"string"> {
let storageType = [{ StringAttr }];
let returnType = [{ StringRef }];
let constBuilderCall = [{ {0}.getStringAttr("{1}") }];
}
// DerivedAttr are attributes whose value is computed from properties
// of the operation. They do not require additional storage and are
// materialized as needed.
class DerivedAttr<code ret, code b> : Attr<CPred<"true">, "derived"> {
let returnType = ret;
code body = b;
}
// Derived attribute that returns a mlir::Type.
class DerivedTypeAttr<code body> : DerivedAttr<"Type", body>;
// Represents a constant attribute of specific Attr type. A constant
// attribute can be specified only of attributes that have a constant
// builder call defined. The constant value is specified as a string.
//
// If used as a constraint, it generates a matcher on a constant attribute by
// using the constant value builder of the attribute and the value.
class ConstantAttr<Attr attribute, string val> : AttrConstraint<
CPred<"{0} == " #
!subst("{0}", "mlir::Builder(ctx)", !subst("{1}", val,
!cast<string>(attribute.constBuilderCall)))>,
"constant attribute " # val> {
Attr attr = attribute;
string value = val;
}
class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>;
//===----------------------------------------------------------------------===//
// Op Traits
//===----------------------------------------------------------------------===//
// OpTrait represents a trait regarding an op. It corresponds to the MLIR C++
// OpTrait mechanism. The purpose to wrap around C++ symbol string with this
// class is to make traits specified for ops in TableGen less alien and more
// integrated.
class OpTrait<string prop> {
string trait = prop;
}
// op supports operand broadcast behavior
def Broadcastable : OpTrait<"BroadcastableTwoOperandsOneResult">;
// X op Y == Y op X
def Commutative : OpTrait<"IsCommutative">;
// op has no side effect
def NoSideEffect : OpTrait<"HasNoSideEffect">;
// op has the same operand and result type
def SameValueType : OpTrait<"SameOperandsAndResultType">;
// op is a terminator
def Terminator : OpTrait<"IsTerminator">;
//===----------------------------------------------------------------------===//
// Ops
//===----------------------------------------------------------------------===//
// Marker used to identify the argument list for an op.
def ins;
// Marker used to identify the result list for an op.
def outs;
// Base class for all ops.
class Op<string mnemonic, list<OpTrait> props = []> {
// The mnemonic of the op.
string opName = mnemonic;
// One-line human-readable description of what the op does.
string summary = "";
// Additional, longer human-readable description of what the op does.
string description = "";
// Dag containting the arguments of the op. Default to 0 arguments. Operands
// to the op need to precede attributes to ops in the argument specification.
dag arguments = (ins);
// The list of results of the op. Default to 0 results.
dag results = (outs);
// Attribute getters can be added to the op by adding an Attr member
// with the name and type of the attribute. E.g., adding int attribute
// with name "value" and type "i32":
// I32Attr value;
// Define the hooks used for building, parsing, printing, verification.
// Custom builder.
// If a derived class/def does not override this, then two default builders
// are generated, with the following signatures:
//
// static void build(Builder* builder, OperationState* result,
// Type resultType0, Type resultType1, ...,
// Value arg0, Value arg1, ...,
// Attribute <attr0-name>, Attribute <attr1-name>, ...);
//
// * where the attributes follow the same declaration order as in the op.
//
// static void build(Builder* builder, OperationState* result,
// ArrayRef<Type> resultTypes,
// ArrayRef<Value> args,
// ArrayRef<NamedAttribute> attributes);
code builder = ?;
// Custom parser.
code parser = ?;
// Custom printer.
code printer = ?;
// Custom verifier.
code verifier = ?;
// Whether this op has associated canonicalization patterns.
// TODO(b/120163349): figure out a better way to write canonicalization
// patterns in TableGen rules directly instead of using this marker
// and C++ implementations.
bit hasCanonicalizer = 0b0;
// Whether this op has a constant folder.
bit hasConstantFolder = 0b0;
// Whether this op has a folder.
bit hasFolder = 0b0;
// Op traits.
list<string> traits = !foldl(
// Collect all OpTrait's internal strings into a list
[]<string>, props, prev, cur, !listconcat(prev, [cur.trait]));
}
// The arguments of an op.
class Arguments<dag args> {
dag arguments = args;
}
// The results of an op.
class Results<dag rets> {
dag results = rets;
}
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
// Base class for op+ -> op+ rewrite patterns. These allow declaratively
// specifying rewrite patterns.
// TODO(jpienaar): Add the constraint list along with the Pattern.
class Pattern<dag patternToMatch, list<dag> resultOps> {
dag PatternToMatch = patternToMatch;
list<dag> ResultOps = resultOps;
}
// Form of a pattern which produces a single result.
class Pat<dag pattern, dag result> : Pattern<pattern, [result]>;
// Attribute matcher. This is the base class to specify a predicate
// that has to match. Used on the input attributes of a rewrite rule.
class mAttr<Pred pred> : AttrConstraint<pred, "">;
// Combine a list of attribute matchers into an attribute matcher that holds if
// any of the original matchers does.
class mAttrAnyOf<list<AttrConstraint> attrs> :
mAttr<AnyOf<!foldl([]<Pred>, attrs, prev, attr,
!listconcat(prev, [attr.predicate]))>>;
// Attribute transforms. This is the base class to specify a
// transformation of a matched attribute. Used on the output of a rewrite
// rule.
class tAttr<code transform> {
// Code to transform the attribute.
// Format: {0} represents the attribute.
code attrTransform = transform;
}
// Native code op creation method. This allows performing an arbitrary op
// creation/replacement by invoking a C++ function with the operands and
// attributes. The function specified needs to have the signature:
//
// void f(OperationInst *op, ArrayRef<Value *> operands,
// ArrayRef<Attribute> attrs, PatternRewriter &rewriter);
//
// The operands and attributes are passed to this function in the order of
// the DAG specified. It is the responsibility of this function to replace the
// matched op(s) using the rewriter. This is intended for the long tail op
// creation and replacement.
class cOp<string f> {
// Function to invoke with the given arguments to construct a new op. The
// operands will be passed to the function first followed by the attributes
// (as in the function signature above and required by Op arguments).
string function = f;
}
// Marker used to indicate that no new result op are generated by applying the
// rewrite pattern, so to replace the matched DAG with an existing SSA value.
def replaceWithValue;
#endif // OP_BASE