blob: 020ffdc84be4b4d50787567ac24d60f8fa7c2650 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the base operation definition file for TensorFlow.
//
// This file includes the definition for the TensorFlow dialect, base TensorFlow
// op, and various commonly used TensorFlow traits, types, attributes, and
// builders.
#ifndef TF_OP_BASE
#define TF_OP_BASE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
//===----------------------------------------------------------------------===//
// TensorFlow dialect definitions
//===----------------------------------------------------------------------===//
def TF_Dialect : Dialect {
let name = "tf";
let description = [{
The TensorFlow dialect.
This dialect maps to TensorFlow operations.
Invariants:
* All values are of Tensor type (in particular, scalars are
represented using zero-dimensional tensors);
TODO: Make invariants more structured so that we can reference them in ops.
}];
let cppNamespace = "::mlir::TF";
}
//===----------------------------------------------------------------------===//
// TensorFlow traits
//===----------------------------------------------------------------------===//
// Specify this trait if the op requires all outputs to have the same type and
// the inputs either have the same type as result or a ref type corresponding to
// the result type.
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
"TF::OperandsSameAsResultsTypeOrRef">;
// Op has the same operand and result element types (or type itself, if scalar)
// after resolving reference types (i.e., after converting reference types to
// their corresponding TensorFlow or standard types).
def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait<
"TF::SameOperandsAndResultElementTypeResolveRef">;
// Op has the same operand and result types after resolving reference types
// (i.e., after converting reference types to their corresponding TensorFlow or
// standard types).
def TF_SameOperandsAndResultTypeResolveRef : NativeOpTrait<
"TF::SameOperandsAndResultTypeResolveRef">;
// Layout agnostic operations do not depend on the operands data layout (data
// format), as an example all element wise operations are layout agnostic.
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
// Trait to indicate operations that cannot be duplicated as they might carry
// certain state around within their implementations.
def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">;
// Trait to indicate an operation cannot be constant folded.
def TF_NoConstantFold : NativeOpTrait<"TF::NoConstantFold">;
// Coefficient wise binary operation with implicit broadcasting support, for
// example tf.Sub operation.
def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">;
// Coefficient wise unary operation, for example tf.Sqrt operation.
def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">;
// Variant of broadcastable trait that considers TF's subtype behavior.
class TF_OpIsBroadcastableToRes<int opId, int resId> : And<[
TCOpResIsShapedTypePred<opId, resId>,
CPred<"mlir::tf_type::BroadcastCompatible("
"$_op.getOperand(" # opId # ").getType(), "
"$_op.getResult(" # resId # ").getType())">]>;
class TF_AllTypesMatchPred<list<string> values> :
CPred<"tf_type::AreCastCompatible(llvm::makeArrayRef({" #
!interleave(values, ", ") # "}))">;
class TF_AllTypesMatch<list<string> names> :
PredOpTrait<
"all of {" # !interleave(names, ", ") #
"} have dynamically equal types ",
TF_AllTypesMatchPred<
!foreach(n, names, !subst("$_self", "$" # n, "$_self.getType()"))>>;
// This trait indicates that all returned resources are unique for a
// resource-allocating op (i.e. op with `MemAlloc` side effect).
//
// Note that if the trait is used where this invariant is not true, then this
// might lead to incorrect execution order, while if not used where it should
// be, it can only lead to reduced performance due to conservative ordering.
// Example op where the invariant is not true: `TF_VarHandleOp`.
def TF_UniqueResourceAllocation: OpTraitList<[
TF_ResourceHandleAllocatorInterface,
NativeOpTrait<"TF::UniqueResourceAllocation">
]>;
//===----------------------------------------------------------------------===//
// Rank/Shape helpers.
//===----------------------------------------------------------------------===//
class TF_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
class TF_ResultIsUnrankedPred<int n> :
CPred<"$_op.getResult(" # n # ").getType().isa<UnrankedTensorType>()">;
// Returns true if the n-th operand has unknown rank or has rank m.
class TF_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TF_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th result has unknown rank or has rank m.
class TF_ResultHasRank<int n, int m> :
PredOpTrait<"result " # n # " is " # m # "-D",
Or<[TF_ResultIsUnrankedPred<n>,
CPred<"$_op.getResult(" # n #
").getType().cast<ShapedType>().getRank() == " # m>]>>;
//===----------------------------------------------------------------------===//
// TensorFlow resources and side effects
//===----------------------------------------------------------------------===//
class TF_ResourceBase<string resourceKind> :
Resource<!strconcat("::mlir::TF::ResourceEffects::", resourceKind)> {
}
// Resource types
def TF_VariableResource : TF_ResourceBase<"Variable">;
def TF_StackResource : TF_ResourceBase<"Stack">;
def TF_TensorArrayResource : TF_ResourceBase<"TensorArray">;
def TF_SummaryResource : TF_ResourceBase<"Summary">;
def TF_LookupTableResource : TF_ResourceBase<"LookupTable">;
def TF_DatasetSeedGeneratorResource : TF_ResourceBase<"DatasetSeedGenerator">;
def TF_DatasetMemoryCacheResource : TF_ResourceBase<"DatasetMemoryCache">;
def TF_DatasetIteratorResource : TF_ResourceBase<"DatasetIterator">;
def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">;
def TF_SendRecvResource : TF_ResourceBase<"SendRecv">;
def TF_TPUCompileExecuteResource : TF_ResourceBase<"TPUCompileExecute">;
// Fake resource, see `TF_MustExecute` below.
def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">;
// Value-based side effects
//
// Value-based side effect traits are attached to op operands or results to
// signal what type of resource is accessed and in which way.
def TF_VariableRead : MemRead<TF_VariableResource>;
def TF_StackRead : MemRead<TF_StackResource>;
def TF_TensorArrayRead : MemRead<TF_TensorArrayResource>;
def TF_LookupTableRead : MemRead<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorRead : MemRead<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheRead : MemRead<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorRead : MemRead<TF_DatasetIteratorResource>;
def TF_VariableWrite : MemWrite<TF_VariableResource>;
def TF_StackWrite : MemWrite<TF_StackResource>;
def TF_TensorArrayWrite : MemWrite<TF_TensorArrayResource>;
def TF_SummaryWrite : MemWrite<TF_SummaryResource>;
def TF_LookupTableWrite : MemWrite<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorWrite : MemWrite<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheWrite : MemWrite<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorWrite : MemWrite<TF_DatasetIteratorResource>;
def TF_VariableAlloc : MemAlloc<TF_VariableResource>;
def TF_StackAlloc : MemAlloc<TF_StackResource>;
def TF_TensorArrayAlloc : MemAlloc<TF_TensorArrayResource>;
def TF_SummaryAlloc : MemAlloc<TF_SummaryResource>;
def TF_LookupTableAlloc : MemAlloc<TF_LookupTableResource>;
def TF_DatasetSeedGeneratorAlloc : MemAlloc<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheAlloc : MemAlloc<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorAlloc : MemAlloc<TF_DatasetIteratorResource>;
def TF_StackFree : MemFree<TF_StackResource>;
def TF_TensorArrayFree : MemFree<TF_TensorArrayResource>;
def TF_SummaryFree : MemFree<TF_SummaryResource>;
def TF_DatasetSeedGeneratorFree : MemFree<TF_DatasetSeedGeneratorResource>;
def TF_DatasetMemoryCacheFree : MemFree<TF_DatasetMemoryCacheResource>;
def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
// Op-based side effects
// Op-based side effect traits can be used to enforce certain execution order
// constraints, in particular for ops that don't use resource handles (those
// typically have value-based side effects). For a `read` effect, all instances
// of ops with the trait keep their order to all ops with unknown side effects
// (e.g. `stateful` ops). For a `write` effect, all instances of ops with the
// trait stay in order, and they also keep their order to all unknown side-
// effecting ops. Note that for `read` effects ops might be pruned if nothing
// depends on them.
def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>;
def TF_TPUEmbeddingWriteEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
def TF_TPUEmbeddingReadEffect : MemoryEffects<[MemRead<TF_TPUEmbeddingResource>]>;
def TF_SendRecvSideEffect : MemoryEffects<[MemWrite<TF_SendRecvResource>]>;
def TF_TPUCompileExecuteSideEffect : MemoryEffects<[MemWrite<TF_TPUCompileExecuteResource>]>;
// Trait for enforcing that a side-effecting op is executed, even if it would be
// considered dead by MLIR (see b/195782952).
// The trait is implemented as a write effect for a fake resource which is
// ignored by side effect analysis, so it does not affect execution order
// constraints and control dependencies at all (for example, multiple ops with
// this trait do not have to execute in order).
def TF_MustExecute : MemoryEffects<[MemWrite<TF_MustExecuteResource>]>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TF_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// TensorFlow attribute definitions
//===----------------------------------------------------------------------===//
class TF_TensorFlowAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::TF::" # name # "Attr>()">,
"TensorFlow " # description # " attribute">;
def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> {
let returnType = "llvm::Optional<llvm::ArrayRef<int64_t>>";
let convertFromStorage = "$_self.cast<mlir::TF::ShapeAttr>().getValue()";
// Create a ranked shape attr by default.
let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)";
}
def TF_ShapeAttrArray :
TypedArrayAttrBase<TF_ShapeAttr, "tensorflow shape attribute array">;
//===----------------------------------------------------------------------===//
// TensorFlow type definitions
//===----------------------------------------------------------------------===//
// Any tensor element type defined in the TensorFlow dialect
def TF_TFDialectType :
Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
// Class for any TensorFlow dialect specific type
class TF_TensorFlowType <string name, string description> :
Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
"TensorFlow " # description # " type">,
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
//===----------------------------------------------------------------------===//
// Reference types
// Float reference types
def TF_Float16Ref : TF_TensorFlowType<"HalfRef", "f16ref">;
def TF_Float32Ref : TF_TensorFlowType<"FloatRef", "f32ref">;
def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
// Complex reference types
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
def TF_Complex128Ref : TF_TensorFlowType<"Complex128Ref", "complex128ref">;
// Integer reference types
def TF_Int8Ref : TF_TensorFlowType<"Int8Ref", "i8ref">;
def TF_Int16Ref : TF_TensorFlowType<"Int16Ref", "i16ref">;
def TF_Int32Ref : TF_TensorFlowType<"Int32Ref", "i32ref">;
def TF_Int64Ref : TF_TensorFlowType<"Int64Ref", "i64ref">;
def TF_Uint8Ref : TF_TensorFlowType<"Uint8Ref", "ui8ref">;
def TF_Uint16Ref : TF_TensorFlowType<"Uint16Ref", "ui16ref">;
def TF_Uint32Ref : TF_TensorFlowType<"Uint32Ref", "ui32ref">;
def TF_Uint64Ref : TF_TensorFlowType<"Uint64Ref", "ui64ref">;
// Quantized reference types
def TF_Qint8Ref : TF_TensorFlowType<"Qint8Ref", "qint8ref">;
def TF_Qint16Ref : TF_TensorFlowType<"Qint16Ref", "qint16ref">;
def TF_Qint32Ref : TF_TensorFlowType<"Qint32Ref", "qint32ref">;
def TF_Quint8Ref : TF_TensorFlowType<"Quint8Ref", "quint8ref">;
def TF_Quint16Ref : TF_TensorFlowType<"Quint16Ref", "quint16ref">;
// Other reference types
def TF_BoolRef : TF_TensorFlowType<"BoolRef", "boolref">;
def TF_ResourceRef : TF_TensorFlowType<"ResourceRef", "resourceref">;
def TF_StrRef : TF_TensorFlowType<"StringRef", "stringref">;
def TF_VariantRef : TF_TensorFlowType<"VariantRef", "variantref">;
//===----------------------------------------------------------------------===//
// Integer types (including corresponding reference types)
def TF_Bool : AnyTypeOf<[I<1>, TF_BoolRef], "bool">;
def TF_Int8 : AnyTypeOf<[I8, TF_Int8Ref], "8-bit integer">;
def TF_Int16 : AnyTypeOf<[I16, TF_Int16Ref], "16-bit integer">;
def TF_Int32 : AnyTypeOf<[I32, TF_Int32Ref], "32-bit integer">;
def TF_Int64 : AnyTypeOf<[I64, TF_Int64Ref], "64-bit integer">;
def TF_I32OrI64 : AnyTypeOf<[I32, I64, TF_Int32Ref, TF_Int64Ref],
"32/64-bit signed integer">;
def TF_Uint8 : AnyTypeOf<[UI<8>, TF_Uint8Ref], "8-bit unsigned integer">;
def TF_Uint16 : AnyTypeOf<[UI<16>, TF_Uint16Ref], "16-bit unsigned integer">;
def TF_Uint32 : AnyTypeOf<[UI<32>, TF_Uint32Ref], "32-bit unsigned integer">;
def TF_Uint64 : AnyTypeOf<[UI<64>, TF_Uint64Ref], "64-bit unsigned integer">;
// Any unsigned integer type
def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64],
"unsigned integer">;
// Any signed integer type
def TF_SInt : AnyTypeOf<[TF_Int8, TF_Int16, TF_Int32, TF_Int64],
"signed integer">;
// Any integer type
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt], "integer">;
// Tensor types
def TF_BoolTensor : TensorOf<[TF_Bool]>;
def TF_IntTensor : TensorOf<[TF_Int]>;
def TF_Int8Tensor : TensorOf<[TF_Int8]>;
def TF_Int16Tensor : TensorOf<[TF_Int16]>;
def TF_Int32Tensor : TensorOf<[TF_Int32]>;
def TF_Int64Tensor : TensorOf<[TF_Int64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32OrI64]>;
def TF_Uint8Tensor : TensorOf<[TF_Uint8]>;
def TF_Uint16Tensor : TensorOf<[TF_Uint16]>;
def TF_Uint32Tensor : TensorOf<[TF_Uint32]>;
def TF_Uint64Tensor : TensorOf<[TF_Uint64]>;
//===----------------------------------------------------------------------===//
// Quantized types (including corresponding reference types)
def TF_Qint8 : AnyTypeOf<
[TF_TensorFlowType<"Qint8", "qint8">, TF_Qint8Ref],
"8-bit quantized integer">;
def TF_Qint16 : AnyTypeOf<
[TF_TensorFlowType<"Qint16", "qint16">, TF_Qint16Ref],
"16-bit quantized integer">;
def TF_Qint32 : AnyTypeOf<
[TF_TensorFlowType<"Qint32", "qint32">, TF_Qint32Ref],
"32-bit quantized integer">;
def TF_Quint8 : AnyTypeOf<
[TF_TensorFlowType<"Quint8", "quint8">, TF_Quint8Ref],
"8-bit quantized unsigned integer">;
def TF_Quint16 : AnyTypeOf<
[TF_TensorFlowType<"Quint16", "quint16">, TF_Quint16Ref],
"16-bit quantized unsigned integer">;
// Any quantized type
def TF_Quantized : AnyTypeOf<
[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, TF_Quint16], "quantized">;
//===----------------------------------------------------------------------===//
// Floating-point types (including corresponding reference types)
def TF_Float16 : AnyTypeOf<[F16, TF_Float16Ref], "16-bit float">;
def TF_Float32 : AnyTypeOf<[F32, TF_Float32Ref], "32-bit float">;
def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;
def TF_Float : AnyTypeOf<
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16],
"floating-point">;
// Tensor types
def TF_FloatTensor : TensorOf<[TF_Float]>;
def TF_F32OrF64Tensor : TensorOf<[TF_F32OrF64]>;
def TF_Float16Tensor : TensorOf<[TF_Float16]>;
def TF_Float32Tensor : TensorOf<[TF_Float32]>;
def TF_Float64Tensor : TensorOf<[TF_Float64]>;
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
//===----------------------------------------------------------------------===//
// Complex types (including corresponding reference types)
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
// with the associated cleanup.
def TF_Complex64 : AnyTypeOf<[Complex<F<32>>, TF_Complex64Ref],
"64-bit complex">;
def TF_Complex128 : AnyTypeOf<[Complex<F<64>>, TF_Complex128Ref],
"128-bit complex">;
def TF_Complex : AnyTypeOf<[TF_Complex64, TF_Complex128], "complex">;
// Tensor types
def TF_ComplexTensor : TensorOf<[TF_Complex]>;
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
//===----------------------------------------------------------------------===//
// String/variant/resource types (including corresponding reference types)
def TF_Str : AnyTypeOf<
[TF_TensorFlowType<"String", "str">, TF_StrRef], "string">;
def TF_StrTensor : TensorOf<[TF_Str]>;
def TF_Variant : AnyTypeOf<
[TF_TensorFlowType<"Variant", "var">, TF_VariantRef], "variant">;
def TF_VariantTensor : TensorOf<[TF_Variant]>;
def TF_Resource : AnyTypeOf<
[TF_TensorFlowType<"Resource", "res">, TF_ResourceRef], "resource">;
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
//===----------------------------------------------------------------------===//
// Multi-category type constraints
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32OrF64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[TF_Float, TF_I32OrI64]>;
def TF_IntOrFpTensor : TensorOf<[TF_Int, TF_Float]>;
def TF_SintOrFpTensor : TensorOf<[TF_SInt, TF_Float]>;
def TF_FpOrComplexTensor : TensorOf<[TF_Float, TF_Complex]>;
def TF_Number : AnyTypeOf<
[TF_Int, TF_Float, TF_Quantized, TF_Complex], "number">;
def TF_NumberTensor : TensorOf<[TF_Number]>;
def TF_NumberNotQuantizedTensor : TensorOf<
[TF_Float, TF_SInt, TF_Complex, TF_Uint8]>;
def TF_NumberNotQuantizedOrStr :
AnyTypeOf<[TF_Float, TF_SInt, TF_Complex, TF_Uint8, TF_Str]>;
def TF_NumberNotQuantizedOrStrTensor : TensorOf<[TF_NumberNotQuantizedOrStr]>;
//===----------------------------------------------------------------------===//
// Tensor and tensor element types
// Any tensor element type allowed in TensorFlow ops
// (see https://www.tensorflow.org/api_docs/python/tf/dtypes/DType)
def TF_ElementType : Type<Or<[TF_Float.predicate,
TF_Complex.predicate,
TF_Int.predicate,
TF_Bool.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;
// Any TensorFlow tensor type
def TF_Tensor : TensorOf<[TF_ElementType]>;
//===----------------------------------------------------------------------===//
// TensorFlow attribute definitions
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Tensorflow devices metadata
// Tensorflow GPU device metadata.
def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
// GPU device compute capability: major:minor.
StructFieldAttr<"cc_major", I32Attr>,
StructFieldAttr<"cc_minor", I32Attr>
]>;
//===----------------------------------------------------------------------===//
// String attribute constraints
// A string attribute whose value are one of the values in `cases`.
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
CPred<!foldl(
"$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
!foreach(case, !tail(cases),
"$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
prev, cur, prev # " || " # cur)>,
"string attribute whose value is " #
!foldl(/*init*/!head(cases), /*list*/!tail(cases),
prev, cur, prev # ", or " # cur)>;
// TODO: Use EnumAttr to define the common attribute cases
def TF_ConvnetDataFormatAttr : StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
"$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
"'NHWC' or 'NCHW' convnet data format">;
//===----------------------------------------------------------------------===//
// Type attributes
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
// operand.
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
"size_t",
"auto range = getODSOperands(" # idx # ");\n"
"return std::distance(range.begin(), range.end());",
[{ $_builder.getI64IntegerAttr($_self) }]>;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// operand. If the `idx`-th operand is a variadic operand, then this attribute
// just returns the element type of its first tensor, which is only meaningful
// when the variadic operand has at least one tensor and the tensors all have
// the same element type.
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
"return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
// A derived attribute that returns the element types of the tensors in the
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
// operand. This returns a list of element types so it is used for variadic
// operands that can have different element types.
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
"mlir::OperandElementTypeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::OperandElementTypeIterator(values.begin()), "
"mlir::OperandElementTypeIterator(values.end())};",
[{
ArrayAttr::get($_ctx,
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}())
}]
>;
// A derived attribute that returns the shapes of the tensors in the actual
// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
// This returns a list of shapes so it is used for variadic operands that
// can have different shapes.
class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
"::mlir::TF::OperandShapeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::TF::OperandShapeIterator(values.begin()), "
"mlir::TF::OperandShapeIterator(values.end())};",
[{
ArrayAttr::get($_ctx,
[&](){
llvm::SmallVector<Attribute, 4> ret;
for (auto shape : $_self)
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
return ret;
}())
}]
>;
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
// result.
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
"size_t",
"auto range = getODSResults(" # idx # ");\n"
"return std::distance(range.begin(), range.end());",
[{ $_builder.getI64IntegerAttr($_self) }]>;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// result. If the `idx`-th result is a variadic result, then this attribute
// just returns the element type of its first tensor, which is only meaningful
// when the variadic result has at least one tensor and the tensors all have
// the same element type.
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
"return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
// A derived attribute that returns the element types of the tensors in the
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
// result. This returns a list of element types so it is used for variadic
// results that can have different element types.
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
"mlir::ResultElementTypeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::ResultElementTypeIterator(values.begin()), "
"mlir::ResultElementTypeIterator(values.end())};",
[{
ArrayAttr::get($_ctx,
[&]() {
llvm::SmallVector<Attribute, 4> ret;
for (auto t : $_self)
ret.push_back(TypeAttr::get(t));
return ret;
}())
}]
>;
// A derived attribute that returns the shapes of the tensors in the actual
// value pack that corresponds to the `idx`-th ODS-declared variadic result.
// This returns a list of shapes so it is used for variadic results that
// can have different shapes.
class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::ResultShapeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::TF::ResultShapeIterator(values.begin()), "
"mlir::TF::ResultShapeIterator(values.end())};",
[{
ArrayAttr::get($_ctx,
[&](){
llvm::SmallVector<Attribute, 4> ret;
for (auto shape : $_self)
ret.push_back(mlir::TF::ShapeAttr::get($_ctx, shape));
return ret;
}())
}]
>;
// A derived attribute that returns the shape of the first result type.
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
"return (*getOperation()->result_type_begin()).cast<ShapedType>();",
[{ mlir::TF::ShapeAttr::get($_ctx, $_self) }]>;
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type";
}
//===----------------------------------------------------------------------===//
// TensorFlow common builders
//===----------------------------------------------------------------------===//
// Mixin class defining a builder for binary ops supporting broadcast
// behavior. The result type has the same element type as both operands.
class WithBroadcastableBinOpBuilder {
list<OpBuilder> builders = [
OpBuilder<(ins "Value":$x, "Value":$y),
[{
auto resultType =
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!resultType)
mlir::emitError($_state.location, "non-broadcastable operands");
return build($_builder, $_state, resultType, x, y);
}]>];
}
// Mixin class defining a builder for comparison ops supporting broadcast
// behavior. The result type has bool element type.
class WithBroadcastableCmpOpBuilder {
list<OpBuilder> builders = [
OpBuilder<(ins "Value":$x, "Value":$y),
[{
Type resultType;
if (x.getType().isa<UnrankedTensorType>() ||
y.getType().isa<UnrankedTensorType>()) {
resultType = UnrankedTensorType::get($_builder.getI1Type());
} else {
SmallVector<int64_t, 4> resultShape;
if (!OpTrait::util::getBroadcastedShape(
x.getType().cast<ShapedType>().getShape(),
y.getType().cast<ShapedType>().getShape(), resultShape)) {
mlir::emitError($_state.location,
"operands have no broadcastable shapes");
}
resultType = RankedTensorType::get(resultShape, $_builder.getI1Type());
}
return build($_builder, $_state, resultType, x, y);
}]>];
}
#endif // TF_OP_BASE