| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // TFLite legalization patterns |
| |
| include "mlir/IR/OpBase.td" |
| include "mlir/Dialect/StandardOps/IR/Ops.td" |
| include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" |
| include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" |
| |
| def NonOpaqueElementsAttr : ElementsAttrBase< |
| CPred<"!$_self.isa<OpaqueElementsAttr>()">, |
| "non-opaque constant tensor">; |
| |
| def F32ElementsAttr : ElementsAttrBase< |
| CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">; |
| |
| def Int64ElementsAttr : ElementsAttrBase< |
| CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isInteger(64)">, "Int 64 constant tensor">; |
| |
| // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr |
| // with builder. |
| class ExtractI32At<int i> : NativeCodeCall< |
| "$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i # |
| "].cast<IntegerAttr>().getInt())">; |
| |
| // Use the tensor type information from $0 and convert min $1, max $2 and |
| // numBits $3 and narrowRange $4 to a QuantizedType. |
| def ConvertToQuantTypeFromAttrs : NativeCodeCall< |
| "quant::GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">; |
| |
| // Converts an integer attribute $0 to 32-bit with builder. |
| def convertIntAttrTo32Bit : NativeCodeCall< |
| "$_builder.getI32IntegerAttr($0.cast<IntegerAttr>().getInt())">; |
| |
| // Extracts the single integer element from $_self. |
| def ExtractSingleElementAsInteger : NativeCodeCall< |
| "ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">; |
| |
| // Extracts the single int32 element from $_self. |
| def ExtractSingleElementAsInt32 : NativeCodeCall< |
| "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">; |
| |
| // Converts tensor with int64 to int32. |
| def CreateTFCastToInt32Op : NativeCodeCall< |
| "CreateCastToInt32($0, $_loc, $_builder)">; |
| |
| // Checks whether the given operation has static shapes and same shapes of all inputs. |
| def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">; |
| def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">; |
| def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">; |
| |
| def CreateNoneValue : NativeCodeCall< |
| "$_builder.create<mlir::ConstantOp>($0.getLoc(), $_builder.getNoneType(), $_builder.getUnitAttr())">; |
| |
| //===----------------------------------------------------------------------===// |
| // Nullary ops patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def LegalizeTFConstToTFLConst: Pat<(TF_ConstOp ElementsAttr:$value), |
| (TFL_ConstOp $value)>; |
| |
| // Convert to std constant for statically shaped, non-opaque constants. |
| def ConvertTfConstToStdConst : Pat< |
| (TF_ConstOp:$res NonOpaqueElementsAttr:$value), |
| (ConstantOp $value), |
| [(AnyStaticShapeTensor $res)], (addBenefit 10)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Unary ops patterns. |
| //===----------------------------------------------------------------------===// |
| def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">; |
| |
| // Constraint that Attr has values [1, X, Y, 1] |
| def IsIntList1XY1 : AttrConstraint<CPred<"TFIntListIs1XY1($_self)">>; |
| |
| // Constraint that values in list attribute are all ones. |
| def IsAllOnes : AttrConstraint<CPred<"TFIntListIsAllOnes($_self)">>; |
| |
| // Constraint that attribute is string with value either "SAME" or "VALID" |
| def IsSameOrValid : AttrConstraint< |
| CPred<"$_self.cast<StringAttr>().getValue() == \"SAME\" || " # |
| "$_self.cast<StringAttr>().getValue() == \"VALID\"">, |
| "'SAME' or 'VALID' paddings">; |
| |
| def LegalizeAbs : Pat<(TF_AbsOp $arg), (TFL_AbsOp $arg)>; |
| def LegalizeAddN : Pat<(TF_AddNOp $inputs), (TFL_AddNOp $inputs)>; |
| |
| def LegalizeAveragePool : Pat<(TF_AvgPoolOp $value, |
| IsIntList1XY1:$ksize, |
| IsIntList1XY1:$strides, |
| $padding, |
| IsDataFormatNHWC:$format), |
| (TFL_AveragePool2DOp $value, |
| /*filter_height=*/ExtractI32At<1>:$ksize, |
| /*filter_width=*/ExtractI32At<2>:$ksize, |
| /*padding=*/$padding, |
| /*stride_h=*/ExtractI32At<1>:$strides, |
| /*stride_w=*/ExtractI32At<2>:$strides, |
| /*fused_activation_function=*/TFL_AF_None)>; |
| |
| def LegalizeArgMax : Pat<(TF_ArgMaxOp $input, $dim), |
| (TFL_ArgMaxOp $input, $dim)>; |
| def LegalizeArgMin : Pat<(TF_ArgMinOp $input, $dim), |
| (TFL_ArgMinOp $input, $dim)>; |
| |
| def LegalizeBroadcastTo : Pat<(TF_BroadcastToOp $input, $dim), |
| (TFL_BroadcastToOp $input, $dim)>; |
| |
| def LegalizeCeil : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>; |
| |
| def LegalizeCos : Pat<(TF_CosOp $arg), (TFL_CosOp $arg)>; |
| |
| def LegalizeElu : Pat<(TF_EluOp $arg), (TFL_EluOp $arg)>; |
| |
| def LegalizeExpandDims : Pat<(TF_ExpandDimsOp $input, $dim), |
| (TFL_ExpandDimsOp $input, $dim)>; |
| |
| def LegalizeFakeQuantToDequantizeQuantize : Pat< |
| (TF_FakeQuantWithMinMaxArgsOp $inputs, $min, $max, $num_bits, $narrow_range), |
| (TFL_DequantizeOp |
| (TFL_QuantizeOp $inputs, |
| (ConvertToQuantTypeFromAttrs $inputs, $min, $max, |
| $num_bits, $narrow_range)))>; |
| |
| def LegalizeFill : Pat<(TF_FillOp $arg, $value), (TFL_FillOp $arg, $value)>; |
| |
| def LegalizeFloor : Pat<(TF_FloorOp $arg), (TFL_FloorOp $arg)>; |
| |
| def LegalizeLeakyRelu : Pat<(TF_LeakyReluOp $arg, F32Attr:$a), |
| (TFL_LeakyReluOp $arg, $a)>; |
| |
| def LegalizeLog : Pat<(TF_LogOp $arg), (TFL_LogOp $arg)>; |
| |
| def LegalizeLog1p : Pat< |
| (TF_Log1pOp F32Tensor:$arg), |
| (TFL_LogOp (TFL_AddOp $arg, |
| (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), |
| TFL_AF_None))>; |
| |
| def LegalizeNot : Pat<(TF_LogicalNotOp $arg), (TFL_LogicalNotOp $arg)>; |
| |
| def LegalizeLogSoftmax : Pat<(TF_LogSoftmaxOp $arg), (TFL_LogSoftmaxOp $arg)>; |
| |
| def LegalizeMaxPool2D : Pat< |
| (TF_MaxPoolOp $value, |
| IsIntList1XY1:$ksize, |
| IsIntList1XY1:$strides, |
| $padding, |
| $explicit_paddings, |
| IsDataFormatNHWC:$format), |
| (TFL_MaxPool2DOp $value, |
| /*padding=*/$padding, |
| /*stride_w=*/ExtractI32At<2>:$strides, |
| /*stride_h=*/ExtractI32At<1>:$strides, |
| /*filter_width=*/ExtractI32At<2>:$ksize, |
| /*filter_height=*/ExtractI32At<1>:$ksize, |
| /*fused_activation_function=*/TFL_AF_None)>; |
| |
| def LegalizeMaximum : Pat<(TF_MaximumOp $arg1, $arg2), |
| (TFL_MaximumOp $arg1, $arg2)>; |
| |
| def LegalizeMinimum : Pat<(TF_MinimumOp $arg1, $arg2), |
| (TFL_MinimumOp $arg1, $arg2)>; |
| |
| def LegalizeNeg : Pat<(TF_NegOp $arg), (TFL_NegOp $arg)>; |
| def LegalizeOneHot : Pat< |
| (TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis), |
| (TFL_OneHotOp $indices, $depth, $on_value, $off_value, |
| (convertIntAttrTo32Bit $axis))>; |
| def LegalizePow : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>; |
| def LegalizeRange : Pat<(TF_RangeOp $start, $limit, $delta), |
| (TFL_RangeOp $start, $limit, $delta)>; |
| def LegalizeRelu6 : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>; |
| def LegalizeRelu : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>; |
| def LegalizeReverseSequence : Pat< |
| (TF_ReverseSequenceOp $input, $seq_lengths, $seq_dim, $batch_dim), |
| (TFL_ReverseSequenceOp $input, $seq_lengths, |
| (convertIntAttrTo32Bit $seq_dim), (convertIntAttrTo32Bit $batch_dim))>; |
| def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>; |
| def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>; |
| def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>; |
| def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; |
| def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, $segment_ids), |
| (TFL_SegmentSumOp $data, (CreateTFCastToInt32Op $segment_ids))>; |
| def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y), |
| (TFL_SelectOp $cond, $x, $y)>; |
| def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), |
| (TFL_SelectOp $cond, $x, $y), |
| [(HasSameStaticShapes $src_op)]>; |
| def LegalizeSelectV2NotSameStaticShape : Pat< |
| (TF_SelectV2Op:$src_op $cond, $x, $y), |
| (TFL_SelectV2Op $cond, $x, $y), |
| [(HasNotSameStaticShapes $src_op)]>; |
| def LegalizeShape : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; |
| def LegalizeSigmoid : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; |
| def LegalizeSin : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; |
| def LegalizeSlice : Pat<(TF_SliceOp $input, $begin, $size), |
| (TFL_SliceOp $input, $begin, $size)>; |
| def LegalizeSoftmax : Pat<(TF_SoftmaxOp $arg), |
| (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>; |
| def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0), |
| (TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0), |
| (ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">), |
| TFL_AF_None))>; |
| def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims), |
| (TFL_SqueezeOp $arg, $squeeze_dims)>; |
| def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; |
| |
| def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm), |
| (TFL_TransposeOp $arg, |
| (CreateTFCastToInt32Op $perm))>; |
| |
| def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; |
| def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Binary ops patterns. |
| //===----------------------------------------------------------------------===// |
| def LegalizeLess : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>; |
| def LegalizeGreater : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>; |
| |
| def LegalizeLessEqual : Pat<(TF_LessEqualOp $l, $r), (TFL_LessEqualOp $l, $r)>; |
| def LegalizeGreaterEqual : Pat<(TF_GreaterEqualOp $l, $r), |
| (TFL_GreaterEqualOp $l, $r)>; |
| |
| // Gather in TF -> Gather in TFL with axis=0 |
| // The 'validate_indices' attribute is deprecated. |
| def LegalizeGather: Pat< |
| (TF_GatherOp $params, $indices, $ignored_validate_indices), |
| (TFL_GatherOp $params, $indices, ConstantAttr<I32Attr, "0">)>; |
| |
| def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices), |
| (TFL_GatherNdOp $params, $indices)>; |
| |
| def LegalizeGatherV2 : Pat< |
| (TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis), |
| ConstantAttr<I64Attr, "0">:$batch_dims), |
| (TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis)>; |
| |
| def LegalizeFloorDiv : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>; |
| |
| def LegalizeNotEqual : Pat< |
| (TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue), |
| (TFL_NotEqualOp $l, $r)>; |
| |
| def LegalizeLogicalAnd : Pat<(TF_LogicalAndOp $l, $r), |
| (TFL_LogicalAndOp $l, $r)>; |
| |
| def LegalizeLogicalOr : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; |
| |
| def LegalizeAdd : Pat<(TF_AddOp $lhs, $rhs), |
| (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; |
| def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs), |
| (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; |
| def LegalizeBiasAdd : Pat< |
| (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format), |
| (TF_AddV2Op $l, $r)>; |
| def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs), |
| (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; |
| def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs), |
| (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; |
| def LegalizeRealDiv : Pat<(TF_RealDivOp $lhs, $rhs), |
| (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; |
| def LegalizeDiv : Pat<(TF_DivOp $lhs, $rhs), |
| (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; |
| |
| // When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected |
| // with additional ops. In the case of unknown batch size, the match will |
| // fall through to here and convert to TF Lite BatchMatMul. |
| def LegalizeBatchMatMulV2UnknownBatch : Pat< |
| (TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), |
| (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; |
| def LegalizeBatchMatMulUnknownBatch : Pat< |
| (TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), |
| (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; |
| |
| def LegalizeFakeQuantWithMinMaxVars: Pat< |
| (TF_FakeQuantWithMinMaxVarsOp $inputs, (ConstantOp F32ElementsAttr:$min), |
| (ConstantOp F32ElementsAttr:$max), $num_bits, $narrow_range), |
| (TFL_DequantizeOp |
| (TFL_QuantizeOp $inputs, (ConvertToQuantTypeFromAttrs $inputs, $min, $max, |
| $num_bits, $narrow_range)))>; |
| |
| // TODO(rocky): Not all of the attributes are handled correctly. Make this |
| // more general if there is a need. |
| def LegalizeQuantizeAndDequantizeV2 : Pat< |
| (TF_QuantizeAndDequantizeV2Op $inputs, (ConstantOp F32ElementsAttr:$min), |
| (ConstantOp F32ElementsAttr:$max), |
| $signed_input, $num_bits, $range_given, $round_mode, $narrow_range, $axis), |
| (TFL_DequantizeOp |
| (TFL_QuantizeOp $inputs, (ConvertToQuantTypeFromAttrs $inputs, $min, $max, |
| $num_bits, $narrow_range)))>; |
| |
| def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>; |
| |
| def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r), |
| (TFL_SquaredDifferenceOp $l, $r)>; |
| |
| def LegalizeReverseV2 : Pat< |
| (TF_ReverseV2Op $arg0, $axis), |
| (TFL_ReverseV2Op $arg0, (CreateTFCastToInt32Op $axis))>; |
| |
| def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1, |
| /*incompatible_shape_error=*/ConstBoolAttrTrue), |
| (TFL_EqualOp $arg0, $arg1)>; |
| |
| def LegalizePad : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>; |
| |
| def LegalizeTile : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>; |
| |
| def LegalizePadV2 : Pat<(TF_PadV2Op $arg0, $arg1, $cst), |
| (TFL_PadV2Op $arg0, $arg1, $cst)>; |
| |
| def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), |
| (TFL_MeanOp $arg0, $arg1, $arg2)>; |
| |
| def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), |
| (TFL_SumOp $arg, (CreateTFCastToInt32Op $axes), $arg2)>; |
| |
| // TopK in TFL is always sorted so we ignore that attribute here. |
| def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), |
| (TFL_TopKV2Op $input, $k)>; |
| |
| def LegalizeMin : Pat< |
| (TF_MinOp $arg0, $axes, BoolAttr:$arg2), |
| (TFL_ReduceMinOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; |
| |
| def LegalizeMax : Pat< |
| (TF_MaxOp $arg0, $axes, BoolAttr:$arg2), |
| (TFL_ReduceMaxOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; |
| |
| def LegalizeProd : Pat< |
| (TF_ProdOp $arg0, $axes, BoolAttr:$arg2), |
| (TFL_ReduceProdOp $arg0, (CreateTFCastToInt32Op $axes), $arg2)>; |
| |
| def LegalizeAny : Pat< |
| (TF_AnyOp $input, $reduction_indices, $keep_dims), |
| (TFL_ReduceAnyOp $input, (CreateTFCastToInt32Op $reduction_indices), |
| $keep_dims)>; |
| |
| def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; |
| |
| def LegalizeBatchToSpaceND : Pat< |
| (TF_BatchToSpaceNDOp $input, $block_shape, $crops), |
| (TFL_BatchToSpaceNdOp $input, (CreateTFCastToInt32Op $block_shape), |
| (CreateTFCastToInt32Op $crops))>; |
| |
| def LegalizeSpaceToBatchND : Pat< |
| (TF_SpaceToBatchNDOp $input, $block_shape, $paddings), |
| (TFL_SpaceToBatchNdOp $input, (CreateTFCastToInt32Op $block_shape), |
| (CreateTFCastToInt32Op $paddings))>; |
| |
| def LegalizeSpaceToDepth : Pat< |
| (TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format), |
| (TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>; |
| |
| def LegalizeDepthToSpace : Pat< |
| (TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format), |
| (TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>; |
| |
| def LegalizeResizeBilinear : Pat< |
| (TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), |
| (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>; |
| def LegalizeResizeNearestNeighbor : Pat< |
| (TF_ResizeNearestNeighborOp $images, $size, $align_corners, |
| $half_pixel_centers), |
| (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, |
| $half_pixel_centers)>; |
| |
| def LegalizeMirrorPad : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), |
| (TFL_MirrorPadOp $arg0, $arg1, $cst)>; |
| |
| def LegalizeSparseToDense : Pat< |
| (TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, |
| $default_value, $validate_indices), |
| (TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, |
| $default_value)>; |
| |
| def LegalizeUnique : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>; |
| |
| def LegalizeFloorMod : Pat<(TF_FloorModOp $arg0, $arg1), |
| (TFL_FloorModOp $arg0, $arg1)>; |
| def LegalizeExp : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>; |
| |
| def LegalizeLRN : Pat< |
| (TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), |
| (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), |
| $bias, $alpha, $beta)>; |
| |
| def LegalizeNonMaxSuppressionV4 : Pat< |
| (TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, |
| $score_threshold, $pad_to_max_output_size), |
| (TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, |
| $score_threshold)>; |
| |
| def LegalizeNonMaxSuppressionV5 : Pat< |
| (TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, |
| $score_threshold, $soft_nms_sigma, $pad_to_max_output_size), |
| (TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, |
| $score_threshold, $soft_nms_sigma)>; |
| |
| def LegalizeMatrixDiag : Pat<(TF_MatrixDiagOp $diagonal), |
| (TFL_MatrixDiagOp $diagonal)>; |
| |
| class I32VectorElementsAttr<int len> : ElementsAttrBase< |
| CPred<"$_self.isa<DenseIntElementsAttr>() &&" |
| "$_self.cast<DenseIntElementsAttr>().getType()." |
| "getElementType().isSignlessInteger(32)">, |
| "32-bit int elements attribute of shape [" # len # "]"> { |
| |
| let storageType = [{ DenseIntElementsAttr }]; |
| let returnType = [{ DenseIntElementsAttr }]; |
| |
| let constBuilderCall = "DenseElementsAttr::get(" |
| "RankedTensorType::get({" # len # "}, $_builder.getIntegerType(32)), $0)"; |
| } |
| |
| def LegalizeConv2DBackpropInput : Pat< |
| (TF_Conv2DBackpropInputOp $input_sizes, $filter, $out_backprop, |
| IsIntList1XY1:$strides, |
| BoolAttr:$use_cudnn_on_gpu, |
| IsSameOrValid:$padding, |
| I64ArrayAttr:$explicit_paddings, |
| IsDataFormatNHWC:$data_format, |
| IsAllOnes:$dilations), |
| (TFL_TransposeConvOp $input_sizes, |
| (TF_TransposeOp $filter, |
| (ConstantOp ConstantAttr<I32VectorElementsAttr<4>, "{2, 0, 1, 3}">)), |
| $out_backprop, |
| /*bias=*/ (CreateNoneValue $input_sizes), |
| /*padding=*/ $padding, |
| /*stride_h=*/ ExtractI32At<1>:$strides, |
| /*stride_w=*/ ExtractI32At<2>:$strides)>; |
| |
| def IsRankZeroAttr |
| : CPred<"$_self.cast<DenseElementsAttr>().getType().getRank() == 0">; |
| |
| def HasValueZero |
| : CPred<"$_self.cast<DenseElementsAttr>().getSplatValue()." |
| "cast<::mlir::IntegerAttr>().getInt() == 0">; |
| |
| // TFLite only supports MatrixSetDiag ops with scalar zero k attribute. |
| def IsSupportedByTFLiteMatrixSetDiag |
| : ElementsAttrBase<And<[ElementsAttr.predicate, |
| IsRankZeroAttr, HasValueZero]>, |
| "MatrixSetDiag attribute verification">; |
| |
| // Attribute align doesn't matter when k is zero. |
| def LegalizeMatrixSetDiag : Pat< |
| (TF_MatrixSetDiagV3Op $input, $diagonal, |
| (ConstantLikeMatcher IsSupportedByTFLiteMatrixSetDiag:$k), $align), |
| (TFL_MatrixSetDiagOp $input, $diagonal)>; |
| |
| def LegalizeScatterNd : Pat< |
| (TF_ScatterNdOp $indices, $updates, $shape), |
| (TFL_ScatterNdOp (CreateTFCastToInt32Op $indices), $updates, |
| (CreateTFCastToInt32Op $shape))>; |
| |
| def LegalizeCumsum : Pat< |
| (TF_CumsumOp $input, $axis, $exclusive, $reverse), |
| (TFL_CumsumOp $input, (CreateTFCastToInt32Op $axis), $exclusive, $reverse)>; |
| |
| def LegalizeReshape : Pat< |
| (TF_ReshapeOp $input, $shape), |
| (TFL_ReshapeOp $input, (CreateTFCastToInt32Op $shape))>; |
| |
| def ZeroIntAttr |
| : AttrConstraint<CPred<"$_self.cast<::mlir::IntegerAttr>().getInt() == 0">>; |
| |
| // ellipsis_mask and new_axis_mask should be zero. |
| def LegalizeStridedSlice : Pat< |
| (TF_StridedSliceOp |
| $input, $begin, $end, $strides, $begin_mask, $end_mask, |
| ZeroIntAttr:$ellipsis_mask, ZeroIntAttr:$new_axis_mask, $shrink_axis_mask), |
| (TFL_StridedSliceOp $input, |
| (CreateTFCastToInt32Op $begin), (CreateTFCastToInt32Op $end), |
| (CreateTFCastToInt32Op $strides), (convertIntAttrTo32Bit $begin_mask), |
| (convertIntAttrTo32Bit $end_mask), (convertIntAttrTo32Bit $ellipsis_mask), |
| (convertIntAttrTo32Bit $new_axis_mask), |
| (convertIntAttrTo32Bit $shrink_axis_mask))>; |