blob: 773025c58df31984dd9b325fe7522a1b4c9ff48e [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the base operation definition file for TensorFlow.
//
// This file includes the definition for the TensorFlow dialect, base TensorFlow
// op, and various commonly used TensorFlow traits, types, attributes, and
// builders.
#ifndef TF_OP_BASE
#define TF_OP_BASE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffects.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.td"
//===----------------------------------------------------------------------===//
// TensorFlow dialect definitions
//===----------------------------------------------------------------------===//
def TF_Dialect : Dialect {
let name = "tf";
let description = [{
The TensorFlow dialect.
This dialect maps to TensorFlow operations.
Invariants:
* All values are of Tensor type (in particular, scalars are
represented using zero-dimensional tensors);
TODO: Make invariants more structured so that we can reference them in ops.
}];
let cppNamespace = "TF";
}
//===----------------------------------------------------------------------===//
// TensorFlow traits
//===----------------------------------------------------------------------===//
// Specify this trait if the op requires all outputs to have the same type and
// the inputs either have the same type as result or a ref type corresponding to
// the result type.
def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait<
"TF::OperandsSameAsResultsTypeOrRef">;
// Layout agnostic operations do not depend on the operands data layout (data
// format), as an example all element wise operations are layout agnostic.
def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//
class TF_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TF_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// TensorFlow type definitions
//===----------------------------------------------------------------------===//
// Any tensor element type defined in the TensorFlow dialect
def TF_TFDialectType :
Type<CPred<"$_self.isa<mlir::TF::TensorFlowType>()">, "TensorFlow type">;
// Class for any TensorFlow dialect specific type
class TF_TensorFlowType <string name, string description> :
Type<CPred<"$_self.isa<mlir::TF::" # name # "Type>()">,
"TensorFlow " # description # " type">,
BuildableType<"getType<mlir::TF::" # name # "Type>()">;
// Any tensor element type allowed in TensorFlow ops
def TF_ElementType : Type<Or<[AnyFloat.predicate,
AnySignlessInteger.predicate,
AnyUnsignedInteger.predicate,
AnyComplex.predicate,
TF_TFDialectType.predicate]>,
"tf.dtype">;
// Any TensorFlow tensor type
def TF_Tensor : TensorOf<[TF_ElementType]>;
//===----------------------------------------------------------------------===//
// Integer types
def TF_I32Or64 : SignlessIntOfWidths<[32, 64]>;
def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>;
def TF_Uint8 : UI<8>;
def TF_Uint16 : UI<16>;
def TF_Uint32 : UI<32>;
def TF_Uint64 : UI<64>;
// Any unsigned integer type
def TF_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>;
// Any signed integer type
def TF_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>;
// Any integer type
def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>;
// Any integer tensor types
def TF_IntTensor : TensorOf<[TF_Int]>;
//===----------------------------------------------------------------------===//
// Quantized types
def TF_Qint8 : TF_TensorFlowType<"Qint8", "qint8">;
def TF_Qint16 : TF_TensorFlowType<"Qint16", "qint16">;
def TF_Qint32 : TF_TensorFlowType<"Qint32", "qint32">;
def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">;
def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">;
// Any quantized type
def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8,
TF_Quint16]>;
//===----------------------------------------------------------------------===//
// Floating-point types
def TF_F32Or64 : FloatOfWidths<[32, 64]>;
def TF_F32OrF64Tensor : TensorOf<[TF_F32Or64]>;
// Any floating-point tensor types
def TF_FpTensor : TensorOf<[AnyFloat]>;
//===----------------------------------------------------------------------===//
// Complex types
// TODO(suderman): Remove TF_Complex64 and use a standard ops declaration, along
// with the associated cleanup.
def TF_Complex64 : Complex<F<32>>;
def TF_Complex64Tensor : TensorOf<[TF_Complex64]>;
def TF_Complex128 : Complex<F<64>>;
def TF_Complex128Tensor : TensorOf<[TF_Complex128]>;
def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128],
"64/128-bit complex type">;
def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>;
//===----------------------------------------------------------------------===//
// String/variant/resource types
def TF_Str : TF_TensorFlowType<"String", "string">;
def TF_StrTensor : TensorOf<[TF_Str]>;
def TF_Variant : TF_TensorFlowType<"Variant", "variant">;
def TF_VariantTensor : TensorOf<[TF_Variant]>;
def TF_Resource : TF_TensorFlowType<"Resource", "resource">;
def TF_ResourceTensor : TensorOf<[TF_Resource]>;
//===----------------------------------------------------------------------===//
// Multi-category type constraints
def TF_IntOrF32OrF64Tensor: TensorOf<[TF_Int, TF_F32Or64]>;
def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
// Any integer or floating-point tensor types
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>;
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],
"number">;
def TF_NumberTensor : TensorOf<[TF_AnyNumber]>;
def TF_NumberOrStr : AnyTypeOf<[AnyFloat, TF_SInt, TF_AnyComplex, TF_Uint8, TF_Str]>;
def TF_NumberOrStrTensor : TensorOf<[TF_NumberOrStr]>;
//===----------------------------------------------------------------------===//
// TensorFlow attribute definitions
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Tensorflow devices metadata
// Tensorflow GPU device metadata.
def TF_GpuDeviceMetadata : StructAttr<"GpuDeviceMetadata", TF_Dialect, [
// GPU device compute capability: major:minor.
StructFieldAttr<"cc_major", I32Attr>,
StructFieldAttr<"cc_minor", I32Attr>
]>;
//===----------------------------------------------------------------------===//
// String attribute constraints
// A string attribute whose value are one of the values in `cases`.
class TF_AnyStrAttrOf<list<string> cases> : StringBasedAttr<
CPred<!foldl(
"$_self.cast<StringAttr>().getValue() == \"" # !head(cases) # "\"",
!foreach(case, !tail(cases),
"$_self.cast<StringAttr>().getValue() == \"" # case # "\""),
prev, cur, prev # " || " # cur)>,
"string attribute whose value is " #
!foldl(/*init*/!head(cases), /*list*/!tail(cases),
prev, cur, prev # ", or " # cur)>;
// TODO: Use EnumAttr to define the common attribute cases
def TF_ConvnetDataFormatAttr : StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"NHWC\" || " #
"$_self.cast<StringAttr>().getValue() == \"NCHW\"">,
"'NHWC' or 'NCHW' convnet data format">;
//===----------------------------------------------------------------------===//
// Type attributes
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
// operand.
class TF_DerivedOperandSizeAttr<int idx> : DerivedAttr<
"size_t",
"auto range = getODSOperands(" # idx # ");\n"
"return std::distance(range.begin(), range.end());">;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// operand. If the `idx`-th operand is a variadic operand, then this attribute
// just returns the element type of its first tensor, which is only meaningful
// when the variadic operand has at least one tensor and the tensors all have
// the same element type.
class TF_DerivedOperandTypeAttr<int idx> : DerivedTypeAttr<
"return mlir::getElementTypeOrSelf(*getODSOperands(" # idx # ").begin());">;
// A derived attribute that returns the element types of the tensors in the
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
// operand. This returns a list of element types so it is used for variadic
// operands that can have different element types.
class TF_DerivedOperandTypeListAttr<int idx> : DerivedAttr<
"mlir::OperandElementTypeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::OperandElementTypeIterator(values.begin()), "
"mlir::OperandElementTypeIterator(values.end())};"
>;
// A derived attribute that returns the shapes of the tensors in the actual
// value pack that corresponds to the `idx`-th ODS-declared variadic operand.
// This returns a list of shapes so it is used for variadic operands that
// can have different shapes.
class TF_DerivedOperandShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::OperandShapeRange",
"auto values = getODSOperands(" # idx # ");\n"
"return {mlir::TF::OperandShapeIterator(values.begin()), "
"mlir::TF::OperandShapeIterator(values.end())};"
>;
// A derived attribute that returns the size of `idx`-th ODS-declared variadic
// result.
class TF_DerivedResultSizeAttr<int idx> : DerivedAttr<
"size_t",
"auto range = getODSResults(" # idx # ");\n"
"return std::distance(range.begin(), range.end());">;
// A derived attribute that returns the element type of `idx`-th ODS-declared
// result. If the `idx`-th result is a variadic result, then this attribute
// just returns the element type of its first tensor, which is only meaningful
// when the variadic result has at least one tensor and the tensors all have
// the same element type.
class TF_DerivedResultTypeAttr<int idx> : DerivedTypeAttr<
"return mlir::getElementTypeOrSelf(*getODSResults(" # idx # ").begin());">;
// A derived attribute that returns the element types of the tensors in the
// actual value pack that corresponds to the `idx`-th ODS-declared variadic
// result. This returns a list of element types so it is used for variadic
// results that can have different element types.
class TF_DerivedResultTypeListAttr<int idx> : DerivedAttr<
"mlir::ResultElementTypeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::ResultElementTypeIterator(values.begin()), "
"mlir::ResultElementTypeIterator(values.end())};"
>;
// A derived attribute that returns the shapes of the tensors in the actual
// value pack that corresponds to the `idx`-th ODS-declared variadic result.
// This returns a list of shapes so it is used for variadic results that
// can have different shapes.
class TF_DerivedResultShapeListAttr<int idx> : DerivedAttr<
"mlir::TF::ResultShapeRange",
"auto values = getODSResults(" # idx # ");\n"
"return {mlir::TF::ResultShapeIterator(values.begin()), "
"mlir::TF::ResultShapeIterator(values.end())};"
>;
// A derived attribute that returns the shape of the first result type.
def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType",
"return (*getOperation()->result_type_begin()).cast<ShapedType>();">;
// A derived attribute that returns the element type of the tensor held by a
// named resource-type operand or result.
class TF_DerivedOperandOrResultHandleTypeAttr<string name> : DerivedTypeAttr<
"auto resource_type =\n"
" mlir::getElementTypeOrSelf(this->" # name # "())\n"
" .cast<TF::ResourceType>();\n"
"assert(!resource_type.getSubtypes().empty() && \"unknown type\");\n"
"return mlir::getElementTypeOrSelf(*resource_type.getSubtypes().begin());">;
// A derived attribute that returns the shape of the tensor held by a named
// resource-type operand or result.
class TF_DerivedOperandOrResultHandleShapeAttr<string name> : DerivedAttr<
"ShapedType",
"auto resource_type =\n"
" mlir::getElementTypeOrSelf(this->" # name # "())\n"
" .cast<TF::ResourceType>();\n"
"assert(!resource_type.getSubtypes().empty() && \"unknown shape\");\n"
"return resource_type.getSubtypes().begin()->cast<ShapedType>();">;
def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> {
let returnType = "Type";
}
//===----------------------------------------------------------------------===//
// TensorFlow common builders
//===----------------------------------------------------------------------===//
// Mixin class defining a builder for binary ops supporting broadcast
// behavior. The result type has the same element type as both operands.
class WithBroadcastableBinOpBuilder {
list<OpBuilder> builders = [OpBuilder<
"Builder *builder, OperationState &result, Value x, Value y",
[{
auto resultType =
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands");
return build(builder, result, resultType, x, y);
}]
>];
}
// Mixin class defining a builder for comparison ops supporting broadcast
// behavior. The result type has bool element type.
class WithBroadcastableCmpOpBuilder {
list<OpBuilder> builders = [OpBuilder<
"Builder *builder, OperationState &result, Value x, Value y",
[{
Type resultType;
if (x.getType().isa<UnrankedTensorType>() ||
y.getType().isa<UnrankedTensorType>()) {
resultType = UnrankedTensorType::get(builder->getI1Type());
} else {
SmallVector<int64_t, 4> resultShape;
if (!OpTrait::util::getBroadcastedShape(
x.getType().cast<ShapedType>().getShape(),
y.getType().cast<ShapedType>().getShape(), resultShape)) {
mlir::emitError(result.location,
"operands have no broadcastable shapes");
}
resultType = RankedTensorType::get(resultShape, builder->getI1Type());
}
return build(builder, result, resultType, x, y);
}]
>];
}
#endif // TF_OP_BASE