blob: 298f962d09617c5cee32f6d1aaa4ae04bf8cd4a2 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for TensorFlow Lite.
#ifdef TFL_OPS
#else
#define TFL_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
include "mlir/Dialect/QuantOps/QuantPredicates.td"
def TFL_Dialect : Dialect {
let name = "tfl";
let description = [{
The TensorFlow Lite dialect.
This dialect maps to TensorFlow Lite operations.
Invariants:
* All values are of Tensor type (in particular, scalars are
represented using zero-dimentional tensors);
}];
let cppNamespace = "TFL";
}
//===----------------------------------------------------------------------===//
// TFLite dialect string type - uses the TF string type as implementation
//===----------------------------------------------------------------------===//
def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
"TFLite string type">,
BuildableType<"getType<mlir::TF::StringType>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
//===----------------------------------------------------------------------===//
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
"TFLite uint8 type">,
BuildableType<"getType<mlir::TF::Uint8Type>()">;
//===----------------------------------------------------------------------===//
// Activation function enum definitions.
//===----------------------------------------------------------------------===//
// Allowed activation function cases
// These should match the ActivationFunctionType enum in TFLite schema.
def TFL_AF_None : StrEnumAttrCase<"NONE">;
def TFL_AF_Relu : StrEnumAttrCase<"RELU">;
def TFL_AF_Relu1 : StrEnumAttrCase<"RELU_N1_TO_1">;
def TFL_AF_Relu6 : StrEnumAttrCase<"RELU6">;
def TFL_AF_Tanh : StrEnumAttrCase<"TANH">;
def TFL_AF_Sign : StrEnumAttrCase<"SIGN_BIT">;
def TFL_AFAttr : StrEnumAttr<
"ActivationFunctionType", "fused activation enum", [
TFL_AF_None, TFL_AF_Relu, TFL_AF_Relu1,
TFL_AF_Relu6, TFL_AF_Tanh, TFL_AF_Sign
]>;
//===----------------------------------------------------------------------===//
// Padding enum definitions.
//===----------------------------------------------------------------------===//
// Allowed padding cases
// These should match the padding enum in TFLite schema.
def TFL_PAD_Same : StrEnumAttrCase<"SAME">;
def TFL_PAD_Valid : StrEnumAttrCase<"VALID">;
def TFL_MIRRORPAD_Reflect : StrEnumAttrCase<"REFLECT">;
def TFL_MIRRORPAD_Symmetric : StrEnumAttrCase<"SYMMETRIC">;
def TFL_PaddingAttr : StrEnumAttr<"Padding", "padding enum", [
TFL_PAD_Same, TFL_PAD_Valid
]>;
def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [
TFL_MIRRORPAD_Reflect, TFL_MIRRORPAD_Symmetric
]>;
//===----------------------------------------------------------------------===//
// Min-max range pair definitions.
//===----------------------------------------------------------------------===//
// A pair of floating point values which defines the min and max of a value
// range for quantization. The attribute is allowed to be empty or
// have 2 elements.
def MinMaxAttr : Attr<Or<[CPred<"$_self.cast<ArrayAttr>().size() == 0">,
CPred<"$_self.cast<ArrayAttr>().size() == 2">]>,
"min-max range pair"> {
let storageType = [{ ArrayAttr }];
let returnType = [{ ArrayRef<Attribute> }];
}
//===----------------------------------------------------------------------===//
// QuantizedType definitions.
//===----------------------------------------------------------------------===//
// The base class of a quantized type.
class TFL_QuantizedType<string n, list<int> params, bit signed>
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
"Q" # !if (signed, "I", "UI") # !head(params) # " type"> {
string name = n;
string asTraitArgsStr =
StrJoinInt<params>.result # !if(signed, ", true", ", false");
}
// Uniform quantized types. Two integers "smantissa" and "sexp" are used to
// express the Mantissa and Exponent components of the floating-point scale so
// the scale of the quantized type is "smantissa * 10 ^ sexp".
class TFL_UInt8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
: TFL_QuantizedType<"Uniform",
[8, zero_pt, smantissa, sexp, 0, 255], 0>;
class TFL_Int8UniformQuantizedType<int zero_pt, int smantissa, int sexp>
: TFL_QuantizedType<"Uniform",
[8, zero_pt, smantissa, sexp, -128, 127], 1>;
// 8-bits quantized types. The definitions can be used to specify tensor types.
def TFL_QUI8 : TFL_QuantizedType<"Uniform", [8], 0>;
def TFL_QI8 : TFL_QuantizedType<"Uniform", [8], 1>;
//===----------------------------------------------------------------------===//
// TensorType attribute definitions.
//===----------------------------------------------------------------------===//
// A type attribute containing the TensorType.
def TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
//===----------------------------------------------------------------------===//
// Derived shape attribute class.
//===----------------------------------------------------------------------===//
class DerivedShapeAttr<code body> : DerivedAttr<"ArrayRef<int64_t>", body>;
class DerivedTFLiteTypeAttr<code body> : DerivedAttr<"tflite::TensorType", body>;
def TFL_Int32Or64 : IntOfWidths<[32, 64]>;
def TFL_FpTensor : TensorOf<[AnyFloat]>;
def TFL_I32OrI64Tensor : TensorOf<[TFL_Int32Or64]>;
def TFL_BoolTensor : TypeAlias<I1Tensor>;
// TODO(jpienaar): Expand to all int types.
def TFL_IntTensor : TypeAlias<I32Tensor, "tensor of any integer type">;
// This is used to represent the type of "ref tensors" or tensors that are
// used as variables to track state.
def TFL_StatefulTensor : TypeAlias<AnyTensor, "stateful tensor">;
// Tensor or None type.
class TFL_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
// Type Constraint operand `idx`'s type is NOT `type`.
// TODO(b/131936589): Once this bug is fixed, we should be able to use
// Neg<TCopVTEtIs<idx, NoneType>>> and can remove this.
class TFL_TCopIsNot<int idx, Type type> :
Neg<CPred<"$_op.getOperand(" # idx # ")->getType().isa<" # type # ">()">>;
def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
//===----------------------------------------------------------------------===//
// Rank/Shape helpers.
//===----------------------------------------------------------------------===//
// TODO: Some of these could be generalized and/or moved to more general
// location.
// Returns true if the n-th operand has unknown rank or has rank m.
class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
TCOpResIsShapedTypePred<i, j>,
Or<[
TCresVTEtIsSameAsOpBase<i, j>,
And<[
SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # j # "))",
quant_QuantizedType.predicate>,
CPred<"quant::QuantizedType::castToStorageType("
"getElementTypeOrSelf($_op.getResult(" # i # "))) == "
"quant::QuantizedType::castToStorageType("
"getElementTypeOrSelf($_op.getOperand(" # j # ")))">]>]>]>;
//===----------------------------------------------------------------------===//
// TFL op common constraints.
//===----------------------------------------------------------------------===//
// This is a constraint for most of the binary ops, e.g., add, mul, div, etc.
// Binary ops lhs & rhs should have the same value type.
def BinaryOpSameElementTypeConstraint :
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<0, 1>>;
//===----------------------------------------------------------------------===//
// TFL common builders.
//===----------------------------------------------------------------------===//
def TFL_BroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState *result, Value *lhs, Value *rhs",
[{
auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
if (!resultType)
mlir::emitError(result->location, "non-broadcastable operands");
result->addOperands({lhs, rhs});
result->types.push_back(resultType);
}]>;
def TFL_FusedBroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState *result, Value *lhs, Value *rhs, "
"StringAttr fusedActivationFunction",
[{
buildFusedBroadcastableBinOp(
builder, result, lhs, rhs, fusedActivationFunction);
}]>;
def TFL_ComparisonBinaryBuilder : OpBuilder<
"Builder *builder, OperationState *result, Value *lhs, Value *rhs",
[{
buildComparisonBinOp(builder, result, lhs, rhs);
}]>;
//===----------------------------------------------------------------------===//
// TFL native op traits (for quantization).
//
// Ops in this link should have those traits specified:
// https://www.tensorflow.org/lite/performance/quantization_spec
//===----------------------------------------------------------------------===//
// Specify this trait if the op has a fixed output value range.
class TFL_FixedResultScale<TFL_QuantizedType qt> : NativeOpTrait<!strconcat(
"TFL::FixedResult", qt.name, "Scale<", qt.asTraitArgsStr, ">::Impl")>;
// Specify this trait if the op requires same inputs and outputs quantization
// scales.
def TFL_SameOperandsAndResultsScale : NativeOpTrait<
"TFL::SameOperandsAndResultsScale">;
// Specify this trait if the b-th input of the op is a bias input, which needs
// a scale based on the scales of op1 and op2.
class TFL_AccumulatorUniformScale<int bias, int op1, int op2> : NativeOpTrait<
!strconcat("TFL::AccumulatorUniformScale<",
StrJoinInt<[bias, op1, op2]>.result,
">::Impl")>;
// Specify this trait if the op doesn't have quantizable ouput. We shouldn't
// apply quantization on this op.
def TFL_NoQuantizableResult : NativeOpTrait<"TFL::NoQuantizableResult">;
//===----------------------------------------------------------------------===//
// TFL native op trait for stateful operands.
class StatefulOperands<list<int> operands>
: ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt<operands>.result>;
//===----------------------------------------------------------------------===//
// TFL op base class.
//===----------------------------------------------------------------------===//
class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TFL_Dialect, mnemonic, traits> {
// FlatBuffer generation specific information.
// -------------------------------------------
// When generating the FlatBuffer output some operations have
// Options (as defined in the schema). These options are effectively
// the attributes of the operations (e.g., what padding is to be used
// for a pooling operator). Not all operations have Options and some
// operations share Options. The following attributes indicate whether
// the operation has Options in the serialized FlatBuffer.
// Whether the TFLite operator has options in the schema representation.
bit hasOptions = 0b0;
// Use to specify a custom options type for TFLite operators where
// the option's name does not match the TFLite operator's name.
// If no customOption is specified then <name>Options is used if the op
// hasOptions.
string customOption = ?;
}
class TFL_ConvOp<string mnemonic, string opSummary> :
TFL_Op<mnemonic, [NoSideEffect, TFL_AccumulatorUniformScale<2, 0, 1>]> {
let summary = opSummary # " operator";
let description = [{
Performs convolution operation on inputs.
Inputs:
`inputs[0]`: required: the input activation tensor
`inputs[1]`: required: the filter weight tensor
`inputs[2]`: optional: the bias tensor
}];
let arguments = (
ins AnyTensor:$input,
AnyTensor:$filter,
AnyTensor:$bias,
I32Attr:$dilation_h_factor,
I32Attr:$dilation_w_factor,
TFL_AFAttr:$fused_activation_function,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w
);
let results = (outs AnyTensor:$output);
let hasOptions = 0b1;
}
//===----------------------------------------------------------------------===//
// TFL op definitions.
//===----------------------------------------------------------------------===//
def TFL_AbsOp : TFL_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Absolute value operator";
let description = [{
Given a tensor `x`, this operation returns a tensor containing the absolute
value of each element in `x`. For example, if x is an input element and y is
an output element, this operation computes \\(y = |x|\\).
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> {
let summary = "Addition operator";
let description = [{
Element-wise addition operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let hasFolder = 1;
let builders = [TFL_FusedBroadcastableBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let hasOptions = 1;
}
// TODO(haoliang): Implement legalization pass after pattern rewrite generator
// supports variadic inputs.
def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> {
let summary = "add_n operator";
let description = [{
Adds all input tensors element-wise.
}];
let arguments = (ins
Variadic<TensorOf<[F32, I32]>>:$inputs
);
let results = (outs
TensorOf<[F32, I32]>:$sum
);
}
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator";
let description = [{
Performs average-pooling operation on input.
}];
let arguments = (
ins AnyTensor:$input,
I32Attr:$filter_height,
I32Attr:$filter_width,
TFL_PaddingAttr:$padding,
I32Attr:$stride_h,
I32Attr:$stride_w,
TFL_AFAttr:$fused_activation_function
);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
let customOption = "Pool2DOptions";
}
def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
let summary = "ArgMax operator";
let description = [{
Returns the index with the largest value across dimensions of a tensor.
}];
let arguments = (
// TODO: Add support for uint8.
ins TensorOf<[F32, I32, I8]>:$input,
TFL_I32OrI64Tensor:$dim
);
let results = (outs
TFL_I32OrI64Tensor:$output
);
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
}
def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let summary = "ArgMin operator";
let description = [{
Returns the index with the smallest value across dimensions of a tensor."
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
b = tf.math.argmin(input = a)
c = tf.keras.backend.eval(b)
}];
let arguments = (
// TODO(pkanwar): Add support for uint8.
ins TensorOf<[F32, I32, I8]>:$input,
TFL_I32OrI64Tensor:$dim
);
let results = (outs
TFL_I32OrI64Tensor:$output
);
let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
}
def TFL_CeilOp: TFL_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Ceil operator";
let description = [{
Returns element-wise ceil value of the input.
}];
let arguments = (ins TFL_FpTensor:$x);
let results = (outs TFL_FpTensor:$y);
}
def TFL_ConcatenationOp : TFL_Op<"concatenation",
[
NoSideEffect,
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_SameOperandsAndResultsScale
]> {
let summary = "Concatenation operator";
let description = [{
Concatenates tensors along one dimension
}];
let arguments = (
ins Variadic<TensorOf<
[F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>>:$values,
I32Attr:$axis,
TFL_AFAttr:$fused_activation_function
);
let results = (outs
TensorOf<
[F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output
);
let hasOptions = 1;
}
def TFL_ConstOp : Op<TFL_Dialect, "pseudo_const", [NoSideEffect,
FirstAttrDerivedResultType]> {
let summary = "Constant pseudo op.";
let description = [{
Represents a constant value in TensorFlow Lite dialect. This is not an
actual operation and it will be lowered to buffer instead.
The op is allowed to have all the same type of attributes as tf.Const does
(e.g., opaque TF attributes are allowed).
}];
let arguments = (ins ElementsAttr:$value);
let results = (outs AnyTensor:$output);
}
def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution">;
def TFL_CosOp: TFL_Op<"cos", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Cosine operator";
let description = [{
Computes element-wise Cosine of input
}];
let arguments = (ins TFL_FpTensor:$x);
let results = (outs TFL_FpTensor:$y);
}
def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution"> {
let arguments = !con(TFL_Conv2DOp.arguments, (ins I32Attr:$depth_multiplier));
}
def TFL_FCWO_Default : StrEnumAttrCase<"DEFAULT">;
def TFL_FCWO_Shuffled4x16i8 : StrEnumAttrCase<"SHUFFLED4x16INT8">;
def TFL_FullyConnectedOptionsWeightFormatAttr :
StrEnumAttr<"FullyConectedOptionsWeightsFormat",
"fully connected options weights format", [
TFL_FCWO_Default, TFL_FCWO_Shuffled4x16i8
]>;
// TODO(jpienaar): Update post discussion on semantics of FC OP.
// TODO(jpienaar): Include more shape verification.
def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, TFL_AccumulatorUniformScale<2, 0, 1>]> {
let summary = "Fully connected op";
let arguments = (ins
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$filter,
TFL_TensorOfOrNone<[F32, TFL_QI8, TFL_QUI8]>:$bias,
TFL_AFAttr:$fused_activation_function,
TFL_FullyConnectedOptionsWeightFormatAttr:$weights_format,
BoolAttr:$keep_num_dims
);
// Depending on the weights format, this op can have one or two outputs.
let results = (outs
Variadic<TensorOf<[F32, TFL_QI8, TFL_QUI8]>>:$output
);
let hasOptions = 1;
}
def TFL_GatherOp : TFL_Op<"gather", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
TFL_OperandHasAtleastRank<0, 1>,
PredOpTrait<"params and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "Gather operator";
let description = [{
Gather slices from `params` axis `axis` according to `indices`.
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_Str, TFL_QI8, TFL_QUI8]>:$params,
TensorOf<[I32, I64]>:$indices,
I32Attr:$axis
);
let builders =
[
OpBuilder<"Builder *builder, OperationState *result, "
"Value *params, Value *indices, IntegerAttr axis",
[{ BuildGatherOp(builder, result, params, indices, axis); }]>
];
let results = (outs
TensorOf<[F32, I16, I32, I64, TFL_Str, TFL_QI8, TFL_QUI8]>:$output
);
let hasOptions = 1;
}
def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> {
let summary = "Gather_nd operator";
let description = [{
Gather slices from `params` into a Tensor with shape specified by `indices`.
}];
// TODO: missing Uint8.
let arguments = (ins
TensorOf<[F32, I8, I64, I32]>:$params,
TFL_I32OrI64Tensor:$indices
);
let results = (outs
TensorOf<[F32, I8, I64, I32]>:$output
);
}
// Same type check of lhs and rhs is handled by the Broadcastable trait.
def TFL_LessEqualOp : TFL_Op<"less_equal", [
Broadcastable, NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Less_equal operator";
let description = [{
Element-wise less_equal operation.
}];
let arguments = (
ins TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$lhs,
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$rhs);
let results = (outs TFL_BoolTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let hasOptions = 0;
}
def TFL_LocalResponseNormalizationOp : TFL_Op<"local_response_normalization",
[NoSideEffect]> {
let summary = "Local Response Normalization.";
let description = [{
The 4-D `input` tensor is treated as a 3-D array of 1-D vectors (along the last
dimension), and each vector is normalized independently. Within a given vector,
each component is divided by the weighted, squared sum of inputs within
`depth_radius`. In detail,
sqr_sum[a, b, c, d] =
sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
output = input / (bias + alpha * sqr_sum) ** beta
For details, see [Krizhevsky et al., ImageNet classification with deep
convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
}];
let arguments = (ins
TensorOf<[F32]>:$input,
I32Attr:$radius,
F32Attr:$bias,
F32Attr:$alpha,
F32Attr:$beta
);
let results = (outs
TensorOf<[F32]>:$output
);
let hasOptions = 1;
}
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
Broadcastable, NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Greater_equal operator";
let description = [{
Element-wise greater_equal operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs TFL_BoolTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let hasOptions = 0;
}
def TFL_NotEqualOp : TFL_Op<"not_equal", [
Broadcastable, Commutative, NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Not_equal operator";
let description = [{
Element-wise not_equal operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs TFL_BoolTensor:$output);
let builders =
[
OpBuilder<
"Builder *builder, OperationState *result, Value *lhs, Value *rhs",
[{
buildComparisonBinOp(builder, result, lhs, rhs);
}]>
];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_DivOp : TFL_Op<"div", [Broadcastable, NoSideEffect]> {
let summary = "Division operator";
let description = [{
Element-wise division operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let builders = [TFL_FusedBroadcastableBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let hasOptions = 1;
}
def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Exponential Linear Unit operator";
let description = [{
Computes the exponential linear
f(x) -> exp(x) - 1 for x < 0, x for x >= 0.
element-wise.
}];
let arguments = (ins TFL_FpTensor:$x);
let results = (outs AnyTensor:$y);
let hasOptions = 0;
}
def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
TFL_NoQuantizableResult,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
let summary = "Equal operator";
let description = [{
Returns the truth element of x == y element-wise
}];
let arguments = (
ins
TensorOf<[I1, F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$x,
TensorOf<[I1, F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$y
);
let results = (outs TFL_BoolTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
}
def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Natural exponentiation operator";
let description = [{
Performs element-wise natural exponentiation operation on input.
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
}
def TFL_ExpandDimsOp: TFL_Op<"expand_dims", [NoSideEffect]> {
let summary = "Inserts a dimension of 1 into a tensor's shape.";
let description = [{
Given a tensor `input`, this operation inserts a dimension of 1 at the
dimension index `axis` of `input`'s shape. The dimension index `axis` starts at
zero; if you specify a negative number for `axis` it is counted backward from
the end.
This operation is useful if you want to add a batch dimension to a single
element. For example, if you have a single image of shape `[height, width,
channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`,
which will make the shape `[1, height, width, channels]`.
Other examples:
```
# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]
# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
```
This operation requires that:
`-1-input.dims() <= dim <= input.dims()`
This operation is related to `squeeze()`, which removes dimensions of
size 1.
}];
// TODO: Restriction on dim's size and valid range are not modeled here.
let arguments = (ins AnyTensor:$input, TFL_IntTensor:$dim);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
}
def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect,
TFL_SameOperandsAndResultsScale]> {
let summary = "Removes dimensions of size 1 from the shape of a tensor.";
let description = [{
Given a tensor `input`, this operation returns a tensor of the same type with
all dimensions of size 1 removed. If you don't want to remove all size 1
dimensions, you can remove specific size 1 dimensions by specifying
`axis`.
For example:
```
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
shape(squeeze(t)) ==> [2, 3]
```
Or, to remove specific size 1 dimensions:
```
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
```
}];
let arguments = (ins
AnyTensor:$input,
DefaultValuedAttr<I64ArrayAttr, "{}">:$squeeze_dims
);
let results = (outs
AnyTensor:$output
);
let hasOptions = 1;
let customOption = "SqueezeOptions";
}
def TFL_FillOp: TFL_Op<"fill", [NoSideEffect]> {
let summary = "Fill the tensor with given value.";
let description = [{
Fill the tensor with given value.
}];
let arguments = (ins TFL_I32OrI64Tensor:$dims,
AnyTensor:$value);
let results = (outs AnyTensor:$res);
let hasOptions = 0;
}
def TFL_FloorOp: TFL_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Floor operator";
let description = [{
Returns element-wise floor value of the input.
}];
let arguments = (ins TFL_FpTensor:$x);
let results = (outs TFL_FpTensor:$y);
}
def TFL_FloorDivOp : TFL_Op<"floor_div", [
Broadcastable, NoSideEffect, BinaryOpSameElementTypeConstraint]> {
let summary = "Floor div operator";
let description = [{
Element-wise floor div operation.
}];
let arguments = (
ins AnyTensor:$lhs, AnyTensor:$rhs);
let results = (outs AnyTensor:$output);
let builders = [TFL_BroadcastableBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
let summary = "Division reminder";
let description = [{
Element-wise division reminder operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs AnyTensor:$output);
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Greater operator";
let description = [{
Element-wise greater operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs AnyTensor:$output);
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
// NoSideEffect trait is not added to the op intentionally to prevent it from
// getting removed if the input is unused. The generated FlatBuffer needs to
// have a tensor along with the metadata for each of the subgraph inputs.
def TFL_InputOp : Op<TFL_Dialect, "pseudo_input", [SameOperandsAndResultType]> {
let summary = "Input pseudo operator";
let description = [{
Takes one of the function arguments as input and returns it as result. This
is a NOP and is used to attach attributes such as tensor name to function
arguments.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Leaky Relu operator";
// TODO(jpienaar): Add type restriction. This op is only defined for
// restricted (floating point) types.
let description = [{
Element-wise Leaky ReLU operator
x -> x >= 0 ? x : (alpha * x)
}];
let arguments = (
ins AnyTensor:$input,
// Slope of the activation function at x < 0.
F32Attr:$alpha
);
let results = (outs AnyTensor:$output);
let hasOptions = 0b1;
}
def TFL_LessOp : TFL_Op<"less", [NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Less operator";
let description = [{
Element-wise less operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs TFL_BoolTensor:$output);
let builders = [TFL_ComparisonBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_LogicalAndOp : TFL_Op<"logical_and", [NoSideEffect]> {
let summary = "Logical AND operator";
let description = [{
Element-wise logical AND operation.
}];
let arguments = (
ins I1Tensor:$lhs,
I1Tensor:$rhs);
let results = (outs I1Tensor:$output);
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_LogicalNotOp : TFL_Op<"logical_not", [NoSideEffect]> {
let summary = "Logical NOT operator";
let description = [{
Element-wise logical NOT operation.
}];
let arguments = (ins I1Tensor:$lhs);
let results = (outs I1Tensor:$output);
}
def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> {
let summary = "Logical OR operator";
let description = [{
Element-wise logical OR operation.
}];
let arguments = (
ins I1Tensor:$lhs,
I1Tensor:$rhs);
let results = (outs I1Tensor:$output);
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Natural logarithm operator";
let description = [{
Performs element-wise natural logarithm operation on input.
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
// TODO(b/130643170): Adds some constraint for the input/output element types.
def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [
NoSideEffect,
SameOperandsAndResultShape,
// zero_point = max_value
// scale = -log_softmax_output_min / (max_value + 1)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<127, 625, -4>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<255, 625, -4>>]> {
let summary = "Log softmax operator";
let description = [{
Computes element-wise log softmax activations with the following formula
input - log(reduce_sum(exp(input), dim))
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
}
// TODO(ashwinm): Revisit the granularity of the PredOpTraits. We could
// break this into smaller PredOpTraits, each with more descriptive messages
// that would make it easier to trace failures OR, need a way to specify desc
// per Predicate inside the trait and get tablegen to use that to emit error
// message.
def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and "
"result types match specified constraints",
And<[
// The input and output tensors should have the same elemental type
// and they should be one of the specified types below.
TCopVTEtIs<0, AnyTypeOf<[F32, TFL_QI8, TFL_QUI8]>>,
TFL_TCresVTEtIsSameAsOp<0, 0>]>>;
def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
NoSideEffect,
MaxPoolOperandAndResultConstraints,
TFL_SameOperandsAndResultsScale]> {
let summary = "Max Pool 2D op";
let description = [{
Performs max pool 2D on input.
Inputs:
`inputs[0]`: required: the input tensor
}];
let arguments = (
ins AnyTensor:$input,
TFL_PaddingAttr:$padding,
I32Attr:$stride_w,
I32Attr:$stride_h,
I32Attr:$filter_width,
I32Attr:$filter_height,
TFL_AFAttr:$fused_activation_function
);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
let customOption = "Pool2DOptions";
}
def TFL_MaximumOp : TFL_Op<"maximum", [
Broadcastable, NoSideEffect, Commutative, TFL_SameOperandsAndResultsScale]> {
let summary = "Max operator";
let description = [{
Element-wise max operation.
}];
let arguments = (
ins TFL_FpOrI32OrI64Tensor:$lhs,
TFL_FpOrI32OrI64Tensor:$rhs
);
let results = (outs TFL_FpOrI32OrI64Tensor:$max);
let builders = [TFL_BroadcastableBinaryBuilder];
let hasOptions = 0;
}
def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "Mean operator";
let description = [{
Computes the mean of elements across dimensions of a tensor.
Reduces input_tensor along the dimensions given in axis.
Unless keepdims is true, the rank of the tensor is reduced by 1 for
each entry in axis. If keepdims is true, the reduced dimensions are retained
with length 1.
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$input,
TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims
);
let results = (outs
TensorOf<[F32, I32, I64, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> {
let summary = "OneHot operator";
let description = [{
Returns a one-hot tensor.The locations represented by indices in `indices`
take value `on_value`, while all other locations take value `off_value`.
If the input `indices` is rank `N`, the output will have rank `N+1`,
The new axis is created at dimension `axis` (default: the new axis is
appended at the end).
}];
let arguments = (ins
TensorOf<[I32, I64]>:$indices,
I32Tensor:$depth,
TensorOf<[F32, I32, I64, I1]>:$on_value,
TensorOf<[F32, I32, I64, I1]>:$off_value,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I32, I64, I1]>:$output
);
let hasOptions = 1;
}
def TFL_SliceOp : TFL_Op<"slice", [
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "Return a slice from 'input'.";
let description = [{
The output tensor is a tensor with dimensions described by 'size'
whose values are extracted from 'input' starting at the offsets in
'begin'.
*Requirements*:
0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)
}];
let arguments = (ins
AnyTensor:$input,
TFL_I32OrI64Tensor:$begin,
TFL_I32OrI64Tensor:$size
);
let results = (outs
AnyTensor:$output
);
}
def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> {
let summary = "Sum operator";
let description = [{
Computes the sum reduction along the specified axes
}];
let arguments = (ins
AnyTensor:$input,
TFL_I32OrI64Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> {
let summary = "Min-reduction operator";
let description = [{
Computes the min reduction along the specified axes
}];
let arguments = (ins
AnyTensor:$input,
TFL_I32OrI64Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> {
let summary = "Max-reduction operator";
let description = [{
Computes the max reduction along the specified axes
}];
let arguments = (ins
AnyTensor:$input,
TFL_I32OrI64Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> {
let summary = "Prod-reduction operator";
let description = [{
Computes the product along the specified axes
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64]>:$input,
TFL_I32OrI64Tensor:$axes,
BoolAttr:$keep_dims
);
let results = (outs AnyTensor);
let hasOptions = 1;
let customOption = "ReducerOptions";
}
def TFL_MinimumOp : TFL_Op<"minimum", [
Broadcastable, NoSideEffect, Commutative, TFL_SameOperandsAndResultsScale]> {
let summary = "Min operator";
let description = [{
Element-wise min operation.
}];
let arguments = (
ins TFL_FpOrI32OrI64Tensor:$lhs,
TFL_FpOrI32OrI64Tensor:$rhs
);
let results = (outs TFL_FpOrI32OrI64Tensor:$min);
let builders = [TFL_BroadcastableBinaryBuilder];
let hasOptions = 0;
}
def TFL_MulOp : TFL_Op<"mul", [Broadcastable, NoSideEffect, Commutative]> {
let summary = "Multiplication operator";
let description = [{
Element-wise multiplication operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let hasFolder = 1;
let builders = [TFL_FusedBroadcastableBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let hasOptions = 1;
}
def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Negation operator";
let description = [{
Computes element-wise negation of input
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> {
let summary = "Packs a list of tensors along a dimension into one tensor";
let description = [{
Packs a list of `values_count` rank-`R` tensors into one rank-`(R+1)`
tensor.
Packs the `values_count` tensors in `values` into a tensor with rank one
higher than each tensor in `values`, by packing them along the `axis`
dimension.
Given a list of tensors of shape `(A, B, C)`;
if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
Etc.
For example:
```
# 'x' is [1, 4]
# 'y' is [2, 5]
# 'z' is [3, 6]
pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
```
This is the opposite of `unpack`.
}];
let arguments = (ins
Variadic<TensorOf<[F32, I8, I16, I32, I64]>>:$values,
I32Attr:$values_count,
I32Attr:$axis
);
let results = (outs
TensorOf<[F32, I8, I16, I32, I64]>:$output
);
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
}
def TFL_PadOp : TFL_Op<"pad", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
TFL_OperandHasRank<1, 2>,
TFL_OperandRankEquals1DimOfOperand<0, 1>]> {
let summary = "Padding operator";
let description = [{
This operation pads a `input` with zeros according to the `paddings` you
specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is
the rank of `input`. For each dimension D of `input`, `paddings[D, 0]`
indicates how many zeros to add before the contents of `input` in that
dimension, and `paddings[D, 1]` indicates how many zeros to add after the
contents of `input` in that dimension.
The padded size of each dimension D of the output is:
`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
For example:
```
# 't' is [[1, 1], [2, 2]]
# 'paddings' is [[1, 1], [2, 2]]
# rank of 't' is 2
pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
[0, 0, 1, 1, 0, 0]
[0, 0, 2, 2, 0, 0]
[0, 0, 0, 0, 0, 0]]
}];
let arguments = (
ins TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
TFL_I32OrI64Tensor:$padding);
let results = (outs TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$output);
let hasOptions = 1;
}
def TFL_PadV2Op : TFL_Op<"padv2", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
TFL_OperandHasRank<1, 2>,
TFL_OperandHasRank<2, 0>,
TFL_OperandRankEquals1DimOfOperand<0, 1>,
PredOpTrait<"input and constant value operands must have same element type",
TCopVTEtAreSameAt<[0, 2]>>]> {
let summary = "Padding operator v2";
let description = [{
This operation pads a `input` according to the `paddings` and
`constant_values` you specify. `paddings` is an integer tensor with shape
`[Dn, 2]`, where n is the rank of `input`. For each dimension D of `input`,
`paddings[D, 0]` indicates how many zeros to add before the contents of
`input` in that dimension, and `paddings[D, 1]` indicates how many zeros to
add after the contents of `input` in that dimension. `constant_values` is a
scalar tensor of the same type as `input` that indicates the value to use
for padding `input`.
The padded size of each dimension D of the output is:
`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
For example:
```
# 't' is [[1, 1], [2, 2]]
# 'paddings' is [[1, 1], [2, 2]]
# rank of 't' is 2
pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
[0, 0, 1, 1, 0, 0]
[0, 0, 2, 2, 0, 0]
[0, 0, 0, 0, 0, 0]]
}];
let arguments = (
ins TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
TFL_I32OrI64Tensor:$padding,
TensorOf<[F32, I8, I32, I64]>:$constant_values);
let results = (outs TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$output);
let hasOptions = 1;
}
def TFL_PowOp : TFL_Op<"pow", [Broadcastable, NoSideEffect]> {
let summary = "Power operator";
let description = [{
Element-wise power operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs AnyTensor:$output);
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> {
let summary = "Rank operator.";
let description = [{
Returns the rank of a tensor.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs TFL_IntTensor:$output);
}
def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Relu operator";
let description = [{
Element-wise Relu operator
x -> max(0, x)
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Relu6 operator";
let description = [{
Element-wise Relu6 operator
x -> max(0, min(6, x))
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
def TFL_ReshapeOp: TFL_Op<"reshape", [
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "Reshape operator";
let description = [{
Produces a tensor with the same values but different static shape defined
by the output type.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
DerivedShapeAttr new_shape = DerivedShapeAttr<[{
return getResult()->getType().cast<ShapedType>().getShape();
}]>;
let hasOptions = 1;
let hasCanonicalizer = 0b1;
let hasFolder = 1;
}
def TFL_ReverseSequenceOp : TFL_Op<"reverse_sequence", [NoSideEffect]> {
let summary = "Reverses variable length slices.";
let description = [{
This op first slices `input` along the dimension `batch_dim`, and for each
slice `i`, reverses the first `seq_lengths[i]` elements along
the dimension `seq_dim`.
The elements of `seq_lengths` must obey `seq_lengths[i] <= input.dims[seq_dim]`,
and `seq_lengths` must be a vector of length `input.dims[batch_dim]`.
The output slice `i` along dimension `batch_dim` is then given by input
slice `i`, with the first `seq_lengths[i]` slices along dimension
`seq_dim` reversed.
}];
// Missing Uint8.
let arguments = (ins
TensorOf<[F32, I16, I32, I64]>:$input,
TFL_I32OrI64Tensor:$seq_lengths,
I32Attr:$seq_dim,
I32Attr:$batch_dim
);
let results = (outs
TensorOf<[F32, I16, I32, I64]>:$output
);
let hasOptions = 1;
}
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Reciprocal of square root operator";
let description = [{
Computes element-wise reverse square root of input
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Shape operator";
let description = [{
Returns the shape of a tensor.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType();
}]>;
let hasOptions = 1;
}
def TFL_LogisticOp: TFL_Op<"logistic", [
NoSideEffect,
SameOperandsAndResultType,
// zero_point = 0
// scale = 1. / (max_value + 1)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
let summary = "Logistic operator";
let description = [{
Computes element-wise Sigmoid of input
}];
let arguments = (ins TFL_FpTensor:$x);
let results = (outs TFL_FpTensor:$y);
}
// TODO(jpienaar): Flesh this out.
def TFL_RangeOp: TFL_Op<"range", [NoSideEffect]> {
let summary = "Range operator";
let description = [{
Returns a 1D tensor defined by a sequence from `start` to `limit` with
a given `delta`.
}];
let arguments = (ins
AnyTensor:$start,
AnyTensor:$limit,
AnyTensor:$delta);
let results = (outs AnyTensor:$result);
}
def TFL_ReverseV2Op: TFL_Op<"reverse_v2",
[NoSideEffect, TFL_OperandHasRank<1,1>]> {
let summary = "ReverseV2 Operator";
let description = [{
Reverses specific dimensions of a tensor.
Given a tensor, and a int32/int64 tensor axis representing the set
of dimensions of tensor to reverse.
This operation reverses each dimension i for
which there exists j s.t. axis[j] == i.
Args:
tensor: A Tensor. Must be one of the following types:
int16, int32, int64, float32 Up to 8-D.
axis: A Tensor. Must be one of the following types: int32, int64.
with only 1 element which is the axis index.
TODO: Add support for multiple elements.
}];
let arguments = (
ins
TensorOf<[F32, I16, I32, I64]>:$input,
TensorOf<[I32, I64]>:$axis
);
let results = (outs
TensorOf<[F32, I16, I32, I64, I8]>:$output
);
}
def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
// TODO(jpienaar): This is too retrictive, rank 1 input is also allowed.
SameOperandsAndResultShape,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TCresVTEtIsSameAsOp<0, 1>>]> {
let summary = "Select operator";
// TODO: missing the shape constraints.
let description = [{
Select values of 'x' if the corresponding value of 'condition' is true or
the value of 'y' if false. There are valid condition input sizes:
1. Either the same shape (in which case the select is elementwise), or
2. condition must be Rank 1 and match over the first dimension.
}];
let arguments = (ins
TFL_BoolTensor:$condition,
// TODO: Missing uint8.
TensorOf<[F32, I1, I8, I16, I32, I64]>:$x,
TensorOf<[F32, I1, I8, I16, I32, I64]>:$y);
let results = (outs AnyTensor:$output);
// TODO(jpienaar): autogenerate this.
let builders = [OpBuilder<"Builder *builder, OperationState *result, "
"Value *condition, Value *x, Value *y",
[{
auto resultType = x->getType();
result->addOperands({condition, x, y});
result->types.push_back(resultType);
}]>];
let hasOptions = 1;
}
def TFL_SinOp: TFL_Op<"sin", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Sine operator";
let description = [{
Computes element-wise Sine of input
}];
let arguments = (ins TFL_FpTensor:$x);
let results = (outs TFL_FpTensor:$y);
}
// TODO(b/130643170): Adds some constraint for the input/output element types.
def TFL_SoftmaxOp : TFL_Op<"softmax", [
NoSideEffect,
SameOperandsAndResultShape,
// zero_point = 0
// scale = 1. / (max_value + 1)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<-128, 390625, -8>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<0, 390625, -8>>]> {
let summary = "Softmax operator";
let description = [{
Computes element-wise softmax activiations with the following formula
exp(input) / tf.reduce_sum(exp(input * beta), dim)
}];
let arguments = (
ins AnyTensor:$input,
F32Attr:$beta
);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
}
def TFL_SqrtOp: TFL_Op<"sqrt", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Square root operator";
let description = [{
Computes element-wise Square root of input
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
}
def TFL_SquareOp: TFL_Op<"square", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Square operator";
let description = [{
Computes element-wise Square of input
}];
let arguments = (ins AnyTensor:$x);
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
}
def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> {
let summary = "Subtraction operator";
let description = [{
Element-wise subtraction operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs AnyTensor:$output);
let hasFolder = 1;
let builders = [TFL_FusedBroadcastableBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
let hasOptions = 1;
}
// TODO(jpienaar): Expand the kernel implementation to support all types besides
// I32 and F32.
def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [Broadcastable, NoSideEffect]> {
let summary = "Squared difference operator";
let description = [{
Element-wise squared difference operation.
}];
let arguments = (
ins AnyTensor:$lhs,
AnyTensor:$rhs);
let results = (outs AnyTensor:$output);
let builders = [TFL_BroadcastableBinaryBuilder];
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
def TFL_TanhOp: TFL_Op<"tanh", [
NoSideEffect,
SameOperandsAndResultType,
// central_value = min_value / 2 + (max_value - 1) / 2 + 1
// zero_point = central_value
// scale = 1. / (central_value - min_value)
TFL_FixedResultScale<TFL_Int8UniformQuantizedType<0, 78125, -7>>,
TFL_FixedResultScale<TFL_UInt8UniformQuantizedType<128, 78125, -7>>]> {
let summary = "Hyperbolic tangent operator";
let description = [{
Computes element-wise Hyperbolic tangent of input
}];
// TODO(haoliang): missing Uint8.
let arguments = (ins TensorOf<[F32, I16, I8]>:$x);
let results = (outs TensorOf<[F32, I16, I8]>:$y);
}
def TFL_TileOp: TFL_Op<"tile", [NoSideEffect,
PredOpTrait<"resultant element type needs to match first operand type",
TCresVTEtIsSameAsOp<0,0>>]> {
let summary = "Tile operator.";
let description = [{
Constructs a tensor by tiling a given tensor.
This operation creates a new tensor by replicating input
multiples times. The output tensor's i'th dimension has
input.dims(i) * multiples[i] elements, and the values of input
are replicated multiples[i] times along the 'i'th dimension.
For example, tiling [a b c d] by [2] produces [a b c d a b c d].
}];
let arguments = (ins AnyTensor:$input, TFL_I32OrI64Tensor:$multiples);
let results = (outs AnyTensor:$output);
let hasOptions = 0;
}
// TODO(jpienaar): Maybe make it accept any single element tensor as `k`.
// TODO(jpienaar): Check that input has one or more dimensions.
// TODO(jpienaar): Check that k is less or equal the internal dimension
def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>,
PredOpTrait<"result and input element type match",
TCresVTEtIsSameAsOp<0,0>>]> {
let summary = "TopK operator";
let description = [{
Returns the top `k` largest element along each last dimensional slice of
`input` and the indices of values within the last dimension of the input
tensor.
}];
let arguments = (ins
// TODO: Missing uint8
TensorOf<[F32, I8, I32, I64]>:$input,
I32Tensor:$k);
let results = (outs
AnyTensor:$values,
I32Tensor:$indices);
let builders = [OpBuilder<"Builder *builder, OperationState *result, "
"Value *input, Value *k",
[{ BuildTopKOp(builder, result, input, k); }]>];
let hasOptions = 1;
}
// TODO: Verify result shape a permutation of the first input shape's
// dimensions.
def TFL_TransposeOp : TFL_Op<"transpose",
[NoSideEffect,
// TODO(jpienaar): these are only true dynamically, change so that it works
// with unknowns.
// TFL_OperandHasRank<1,1>,
// TFL_OperandRankEquals1DimOfOperand<0, 1>,
TFL_SameOperandsAndResultsScale]> {
let summary = "Transpose operator";
let description = [{
Returns the Transpose of x
}];
let arguments = (
ins AnyTensor:$x,
AnyTensor:$perm
);
let results = (outs
AnyTensor:$y
);
}
def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
let summary = "Unpacks a tensor along a dimension into multiple tensors";
let description = [{
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
For example, given a tensor of shape `(A, B, C, D)`;
If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]`
and each tensor in `output` will have shape `(B, C, D)`. (Note that the
dimension unpacked along is gone, unlike `split`).
If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]`
and each tensor in `output` will have shape `(A, C, D)`.
Etc.
This is the opposite of `pack`.
}];
let arguments = (ins
TensorOf<[F32, I8, I32]>:$input,
I32Attr:$num,
I32Attr:$axis
);
let results = (outs
Variadic<TensorOf<[F32, I8, I32]>>:$outputs
);
let verifier = [{ return Verify(*this); }];
let hasOptions = 1;
}
def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [NoSideEffect]> {
let summary = "ZerosLike operator";
let description = [{
Returns a tensor of zeros with the same shape and type as the input tensor.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
}
def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "BatchToSpaceNd operator";
let description = [{
This operation reshapes the "batch" dimension 0 into space dimensions.
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[I32]>:$block_shape,
TensorOf<[I32]>:$indices
);
let results = (outs
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$output
);
}
def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [
NoSideEffect,
TFL_SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "SpaceToBatchNd operator";
let description = [{
This operation reshapes space dimensions into the "batch" dimension 0
}];
let arguments = (ins
TensorOf<[F32, I8, I32, I64, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[I32]>:$block_shape,
TensorOf<[I32]>:$paddings
);
let results = (outs
TensorOf<[F32, I16, I32, I64, TFL_QI8, TFL_QUI8]>:$output
);
}
def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> {
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
let description = [{
Splits the `value` tensor along `split_dim` into a number of sub-tensors
with same shape as the original one, except for `split_dim`. Same as
tf.Split.
}];
let arguments = (ins
I32Tensor:$split_dim,
TensorOf<[F32, I16, I32, I64]>:$value,
I32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
);
let hasOptions = 1;
}
def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> {
let summary = "Splits a tensor into `num_split` tensors along one dimension.";
let description = [{
Splits the `value` tensor along `split_dim` into a number of sub-tensors
with same shape as the original one, except for `split_dim`. The grouping
of the resultant sub-tensors is decided by `size-splits`. Same as tf.SplitV.
}];
let arguments = (ins
TensorOf<[F32, I16, I32, I64]>:$value,
I32Tensor:$size_splits,
I32Tensor:$split_dim,
I32Attr:$num_splits
);
let results = (outs
Variadic<TensorOf<[F32, I16, I32, I64]>>:$outputs
);
let hasOptions = 1;
}
def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [
NoSideEffect, TFL_SameOperandsAndResultsScale]> {
let summary = "ResizeBilinear Op";
let description = [{
Resize `images` to `size` using bilinear interpolation.
}];
let arguments = (ins
// TODO(ycling): Support quantized types.
TensorOf<[F32, I32, TFL_QI8, TFL_QUI8]>:$input,
TensorOf<[I32]>:$size,
BoolAttr:$align_corners);
let results = (outs
TensorOf<[F32, TFL_QI8, TFL_QUI8]>:$output
);
let hasOptions = 1;
}
def TFL_StridedSliceOp: TFL_Op<"strided_slice",
[
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TCresVTEtIsSameAsOp<0, 0>>
]> {
let summary = "StridedSlice Op";
let description = [{
Return a strided slice from `input`.
}];
let arguments = (ins
TensorOf<[F32, I32, I64, I8]>:$input,
TensorOf<[I32]>:$begin,
TensorOf<[I32]>:$end,
TensorOf<[I32]>:$strides,
I32Attr:$begin_mask,
I32Attr:$end_mask,
I32Attr:$ellipsis_mask,
I32Attr:$new_axis_mask,
I32Attr:$shrink_axis_mask
);
let results = (outs
TensorOf<[F32, I32, I64, I8]>:$output
);
let hasOptions = 1;
}
def TFL_CastOp : TFL_Op<"cast", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Cast operator";
let description = [{
Casts input from input type to output type.
}];
// TODO(b/135538711): Add complex types here.
let arguments = (ins
TensorOf<[F32, I1, I32, I64]>:$input
);
let results = (outs TensorOf<[F32, I1, I32, I64]>:$output);
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
let hasOptions = 0;
}
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
let description = [{
This operation pads a input with mirrored values according to the paddings
you specify. paddings is an integer tensor with shape [n, 2],
where n is the rank of input.
For each dimension D of input, paddings[D, 0] indicates how many values
to add before the contents of input in that dimension,
and paddings[D, 1] indicates how many values to add after the contents of
input in that dimension.
Both paddings[D, 0] and paddings[D, 1] must be no greater than
input.dim_size(D) (or input.dim_size(D) - 1)
if copy_border is true (if false, respectively).
The padded size of each dimension D of the output is:
paddings(D, 0) + input.dim_size(D) + paddings(D, 1)
}];
let arguments = (ins
// TODO: add uint8 support when ready.
TensorOf<[F32, I32, I64]>:$input,
TensorOf<[I32, I64]>:$pad,
TFL_MirrorPaddingAttr:$mode
);
let results = (outs
TensorOf<[F32, I32, I64]>:$output
);
let hasOptions = 1;
}
def TFL_UniqueOp: TFL_Op<"unique", [NoSideEffect]> {
let summary = "Unique Op.";
let description = [{
This operation returns a tensor `y` containing all of the unique elements of `x`
sorted in the same order that they occur in `x`. This operation also returns a
tensor `idx` the same size as `x` that contains the index of each value of `x`
in the unique output `y`. In other words:
}];
let arguments = (ins
// TODO: add uint8 support after quantize support.
TensorOf<[I8, I16, I32, I64, F32]>:$input
);
let results = (outs
TensorOf<[I8, I16, I32, I64, F32]>:$output,
TensorOf<[I32, I64]>:$idx
);
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32;
}]>;
let hasOptions = 1;
}
//===----------------------------------------------------------------------===//
// Quantization ops.
//===----------------------------------------------------------------------===//
def TFL_DequantizeOp: TFL_Op<"dequantize", [
NoSideEffect, TFL_NoQuantizableResult]> {
let summary = "Dequantize operator";
let description = [{
Converts quantized array of integers to floating-points according to the
quantization parameters.
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$output);
}
def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> {
let summary = "FakeQuant operator";
let description = [{
Fake-quantize the 'inputs' tensor of type float via float scalars min and
max to 'outputs' tensor of same shape as inputs.
}];
let arguments = (
ins AnyTensor:$input,
// The expected [min, max] range of values.
MinMaxAttr:$minmax,
// The bitwidth of the quantization; between 2 and 16, inclusive.
I32Attr:$num_bits,
// Quantization range starts from 0 or 1; starts from 1 if true.
BoolAttr:$narrow_range);
let results = (outs AnyTensor:$output);
let hasCanonicalizer = 0b1;
}
def TFL_QConstOp : Op<TFL_Dialect, "pseudo_qconst", [
NoSideEffect, FirstAttrDerivedResultType, TFL_NoQuantizableResult]> {
let summary = "Quantized constant pseudo op";
let description = [{
Represents a quantized constant value in TensorFlow Lite dialect. This is
not an actual operation and it will be lowered to buffer instead. The
quantization parameters are stored as a type attribute in this constant.
}];
let arguments = (
ins TensorTypeAttr:$qtype,
ElementsAttr:$value
);
let results = (outs AnyTensor:$output);
}
def TFL_QuantizeOp: TFL_Op<"quantize", [
NoSideEffect, FirstAttrDerivedResultType, TFL_NoQuantizableResult]> {
let summary = "Quantize operator";
let description = [{
Converts floating point tensors to quantized integer tensors according to
the quantization parameters defined in the type attribute.
}];
let arguments = (
ins AnyTensor:$input,
TensorTypeAttr:$qtype
);
let results = (outs AnyTensor:$output);
}
//===----------------------------------------------------------------------===//
// LSTM Ops
//===----------------------------------------------------------------------===//
// LSTM Kernel Type attributes
def TFL_LSTM_KT_FULL : StrEnumAttrCase<"FULL">;
def TFL_LSTM_KT_BASIC : StrEnumAttrCase<"BASIC">;
def TFL_LSTMKernelTypeAttr : StrEnumAttr<"LSTMKernelType", "lstm kernel type enum",
[
TFL_LSTM_KT_FULL, TFL_LSTM_KT_BASIC
]>;
def LstmMandatoryInputsConstraint : PredOpTrait<
"mandatory operands element types should match",
// TODO(ashwinm): Replace the indices with input tensor names when that
// support is available.
TCopVTEtAreSameAt<[0, 2, 3, 4, 6, 7, 8, 13, 14, 15, 18, 19]>>;
def LstmOptionalPeepholeWeightConstraint : PredOpTrait<
"the optional peephole weights should all be specified or none",
TCopVTEtAreSameAt<[9, 10, 11]>>;
def LstmProjectionWeightBiasConstraint : PredOpTrait<
"either projection weight must be specified or both projection weight and "
"projection bias must not be specified",
Or<[
And<[TCopVTEtIs<16, NoneType>, TCopVTEtIs<17, NoneType>]>,
TFL_TCopIsNot<16, NoneType>]>>;
// TODO(b/137798843): Need to add two additional constraints for both LSTM and
// UnidirectionalSequenceLstm
// For coupling of input and forget gates (cifg): if cifg is true,
// tensor {1, 5, 9, 12, 20} are null; if cifg is
// false, tensors {1, 5, 12} are not null; tensor {9} is not null if
// additionally peephole = true; tensor {20} is not null if additionally layer
// norm = true. For layer norm: if layer norm is false, tensor {20, 21, 22, 23}
// are null; if layer norm is true, tensors {21, 22, 23} are not null; tensor
// {20} is not null if additionally cifg = false.
def LstmResultConstraint : PredOpTrait<
"the input and result tensor elemental types must be same",
TCresVTEtIsSameAsOp<0, 0>>;
// This is the FULL kernel type LSTM op.
def TFL_LSTMOp :
TFL_Op<"lstm",
[LstmMandatoryInputsConstraint,
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
StatefulOperands<[18, 19]>]> {
let summary = "The full lstm operator";
let description = [{
Long short-term memory unit (LSTM) recurrent network layer.
The default non-peephole implementation is based on:
http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation,
9(8):1735-1780, 1997.
The peephole implementation is based on:
https://research.google.com/pubs/archive/43905.pdf
Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory
recurrent neural network architectures for large scale acoustic modeling.
INTERSPEECH, 2014.
The coupling of input and forget gate (CIFG) is based on:
http://arxiv.org/pdf/1503.04069.pdf
Greff et al. "LSTM: A Search Space Odyssey"
The layer normalization is based on:
https://arxiv.org/pdf/1607.06450.pdf
Ba et al. “Layer Normalization”
}];
let arguments = (
ins TensorOf<[F32]>:$input,
// Weights
TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
TensorOf<[F32, I8]>:$input_to_forget_weights,
TensorOf<[F32, I8]>:$input_to_cell_weights,
TensorOf<[F32, I8]>:$input_to_output_weights,
// Recurrent weights
TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
TensorOf<[F32, I8]>:$recurrent_to_output_weights,
// Cell weights
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
// Bias
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
TensorOf<[F32]>:$forget_gate_bias,
TensorOf<[F32]>:$cell_bias,
TensorOf<[F32]>:$output_gate_bias,
// Projection weight and bias
TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
// Optional input
TFL_TensorOfOrNone<[F32]>:$projection_bias,
// Stateful activation and cell states.
TFL_StatefulTensor:$input_activation_state,
TFL_StatefulTensor:$input_cell_state,
// Layer norm coefficients
TFL_TensorOfOrNone<[F32]>:$input_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32]>:$forget_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32]>:$cell_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32]>:$output_layer_norm_coefficients,
// Attributes
TFL_AFAttr:$fused_activation_function,
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
// Since this op is the FULL kernel only, constrain it.
Confined<
DefaultValuedAttr<TFL_LSTMKernelTypeAttr, "FULL">,
[TFL_LSTM_KT_FULL]>:$kernel_type
);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
}
// UnidirectionalSequenceLstm op .
// TODO(ashwinm): Add constraint to validate the combination of operands
// that are valid for hybrid vs fully quantized vs float only semantics
def TFL_UnidirectionalSequenceLSTMOp :
TFL_Op<"unidirectional_sequence_lstm",
[LstmMandatoryInputsConstraint,
LstmOptionalPeepholeWeightConstraint,
LstmProjectionWeightBiasConstraint,
LstmResultConstraint,
StatefulOperands<[18, 19]>]> {
let summary = "Unidirectional sequence lstm operator";
let description = [{
A recurrent neural network specified by an LSTM cell. This Op supports
unrolling the input along the time or batch dimensions, and
implements the following operation for
each element in the sequence s = 1...sequence_length:
outputs[s] = state = activation(LSTMOp(inputs[s]))
where LSTMOp is LSTM TF Lite Op and the “activation” is the function passed
as the “fused_activation_function” argument (if not “NONE”).
}];
let arguments = (
ins TensorOf<[F32, I8]>:$input,
// Weights
TFL_TensorOfOrNone<[F32, I8]>:$input_to_input_weights,
TensorOf<[F32, I8]>:$input_to_forget_weights,
TensorOf<[F32, I8]>:$input_to_cell_weights,
TensorOf<[F32, I8]>:$input_to_output_weights,
// Recurrent weights
TFL_TensorOfOrNone<[F32, I8]>:$recurrent_to_input_weights,
TensorOf<[F32, I8]>:$recurrent_to_forget_weights,
TensorOf<[F32, I8]>:$recurrent_to_cell_weights,
TensorOf<[F32, I8]>:$recurrent_to_output_weights,
// Cell weights
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_input_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_forget_weights,
// Optional input
TFL_TensorOfOrNone<[F32, I8]>:$cell_to_output_weights,
// Bias
TFL_TensorOfOrNone<[F32]>:$input_gate_bias,
TensorOf<[F32]>:$forget_gate_bias,
TensorOf<[F32]>:$cell_bias,
TensorOf<[F32]>:$output_gate_bias,
// Projection weight and bias
TFL_TensorOfOrNone<[F32, I8]>:$projection_weights,
// Optional input
TFL_TensorOfOrNone<[F32]>:$projection_bias,
// Stateful activation and cell states.
TFL_StatefulTensor:$input_activation_state,
TFL_StatefulTensor:$input_cell_state,
// Layer norm coefficients
TFL_TensorOfOrNone<[F32, I8]>:$input_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, I8]>:$forget_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, I8]>:$cell_layer_norm_coefficients,
TFL_TensorOfOrNone<[F32, I8]>:$output_layer_norm_coefficients,
// Attributes
TFL_AFAttr:$fused_activation_function,
DefaultValuedAttr<F32Attr, "0.0f">:$cell_clip,
DefaultValuedAttr<F32Attr, "0.0f">:$proj_clip,
BoolAttr:$time_major
);
let results = (outs AnyTensor:$output);
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
}
#endif // TFL_OPS