blob: b52f649c1300b5a55f5abf1bce0a8fe8d50b2885 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the optimization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/lite/utils/utils.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// Checks if the param passed is a F32 ElementsAttr.
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isF32()">,
"32 bit float constant tensor">;
// Checks if the param passed is a float ElementsAttr.
def FloatElementsAttr : ElementsAttrBase<
CPred<"$_self.isa<ElementsAttr>() && $_self.cast<ElementsAttr>().getType().getElementType().isa<FloatType>()">,
"float constant tensor">;
// Checks if the param passed is of NoneType.
def IsNoneType : Constraint<CPred<"$0.getType().isa<NoneType>()">>;
def ExtractSingleElementAsFloat : NativeCodeCall<
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
// Checks if the value has rank at most 'n'.
class HasRankAtMost<int n> : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() <= " # n>>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
// Multi-pattern consisting of matching stand-alone convolution op followed by
// activation op.
multiclass FuseActFnIntoConvOpPat<Op ActFnOp, Attr ActFnAttr> {
def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat<
(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor,
$w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w),
[(HasOneUse $conv_out)]>;
def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat<
(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor,
$w_factor, TFL_AF_None, $padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor,
ActFnAttr, $padding, $stride_h, $stride_w, $multiplier),
[(HasOneUse $conv_out)]>;
}
multiclass FuseActFnIntoPoolOpPat<Op ActFnOp, Attr ActFnAttr> {
def FuseActivationFuncWithAvgPool#ActFnOp#ActFnAttr : Pat<
(ActFnOp (TFL_AveragePool2DOp:$pool_out $input, $filter_height,
$filter_width, $padding, $stride_h, $stride_w, TFL_AF_None)),
(TFL_AveragePool2DOp $input, $filter_height, $filter_width, $padding,
$stride_h, $stride_w, ActFnAttr),
[(HasOneUse $pool_out)]>;
def FuseActivationFuncWithMaxPool#ActFnOp#ActFnAttr : Pat<
(ActFnOp (TFL_MaxPool2DOp:$pool_out $input, $padding, $stride_w, $stride_h,
$filter_width, $filter_height, TFL_AF_None)),
(TFL_MaxPool2DOp $input, $padding, $stride_w, $stride_h,
$filter_width, $filter_height, ActFnAttr),
[(HasOneUse $pool_out)]>;
}
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
// activation functions.
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
// those cannot be simply translated into clamping.
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
defm : FuseActFnIntoConvOpPat<!cast<Op>(actFnPair[0]), !cast<Attr>(actFnPair[1])>;
defm : FuseActFnIntoPoolOpPat<!cast<Op>(actFnPair[0]), !cast<Attr>(actFnPair[1])>;
}
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
// constant folding the bias and the binary op's constant operand. The following
// pattern restricts to float constant values for now.
multiclass FuseBinaryOpToPrecedingAffine<Op binaryOp> {
def FuseBinaryOpWithConv#binaryOp : Pat<
(binaryOp (TFL_Conv2DOp:$output $input, $filter,
(ConstantOp FloatElementsAttr:$bias), $h_factor, $w_factor,
TFL_AF_None, $padding, $stride_h, $stride_w),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $output)]>;
def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat<
(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter,
(ConstantOp FloatElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
$stride_w, $multiplier),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input, $filter,
(binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w,
$multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
def FuseBinaryOpWithTransposeConv#binaryOp : Pat<
(binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
(ConstantOp FloatElementsAttr:$bias), $padding,
$stride_h, $stride_w),
(ConstantOp FloatElementsAttr:$value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape, $weights, $inputs,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
(HasOneUse $output)]>;
// Fuse for TransposeConv with no bias
def FuseBinaryOpWithTransposeConvNoneBias#binaryOp : Pat<
(binaryOp (TFL_TransposeConvOp:$output $output_shape, $weights, $inputs,
(ConstantOp $bias), $padding,
$stride_h, $stride_w),
(ConstantOp FloatElementsAttr:$value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape, $weights, $inputs,
(ConstantOp $value),
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
(IsNoneType $bias),
(HasOneUse $output)]>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp]<Op> in
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
def ExpandTo4DForDepthwiseConv: NativeCodeCall<
"ExpandTo4DForDepthwiseConv($0)">;
// If we see a (div or Mul) op (dividing/multiplying) a constant value
// to a convolution op with constant filter and bias, we can fuse the div/mul
// into the convolution op by constant folding
// the filter/bias and the div/mul op's constant operand.
// The following pattern restricts to float constant values for now.
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<Op BinaryOp> {
def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat<
(BinaryOp (TFL_DepthwiseConv2DOp:$output $input,
(ConstantOp FloatElementsAttr:$filter),
(ConstantOp FloatElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None, $padding, $stride_h,
$stride_w, $multiplier),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input,
(BinaryOp
(ConstantOp $filter),
(ConstantOp (ExpandTo4DForDepthwiseConv $value)),
TFL_AF_None),
(BinaryOp
(ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h,
$stride_w, $multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value),
(HasOneUse $output)]>;
def FuseMulOrDivWithConv#BinaryOp : Pat<
(BinaryOp (TFL_Conv2DOp:$conv_output $input,
(ConstantOp FloatElementsAttr:$filter),
(ConstantOp FloatElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp FloatElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp (ExpandTo4DForConv $value)),
TFL_AF_None),
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value),
(HasOneUse $conv_output)]>;
def FuseMulOrDivWithTransposeConv#BinaryOp : Pat<
(BinaryOp (TFL_TransposeConvOp:$output $output_shape,
(ConstantOp FloatElementsAttr:$weights), $input,
(ConstantOp FloatElementsAttr:$bias),
$padding, $stride_h, $stride_w),
(ConstantOp $value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape,
(BinaryOp (ConstantOp $weights),
(ConstantOp (ExpandTo4DForConv $value)),
TFL_AF_None),
$input,
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
(HasOneUse $output)]>;
def FuseMulOrDivWithTransposeConvWithNoneBias#BinaryOp : Pat<
(BinaryOp (TFL_TransposeConvOp:$output $output_shape,
(ConstantOp FloatElementsAttr:$weights), $input,
(ConstantOp $bias),
$padding, $stride_h, $stride_w),
(ConstantOp $value), TFL_AF_None),
(TFL_TransposeConvOp $output_shape,
(BinaryOp (ConstantOp $weights),
(ConstantOp (ExpandTo4DForConv $value)),
TFL_AF_None),
$input,
(ConstantOp $bias),
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $weights, $value),
(IsNoneType $bias),
(HasOneUse $output)]>;
}
foreach BinaryOp = [TFL_DivOp, TFL_MulOp]<Op> in
defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>;
// This pattern applies when the same quantize/dequantize have been used twice
// with the same scale. We want to remove the redundancy.
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
def eliminate_dq_q_pairs : Pat<
(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt),
(replaceWithValue $in),
[(NotFromQuantOpOrSameQuantType $in, $qt)]>;
// Checks if the operand has rank == n
class OperandHasRank<int n> : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
// Matching HardSwish
def MatchHardSwishPattern1 : Pat<
(TFL_MulOp
(TFL_MulOp
$x, (TFL_AddOp
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
TFL_AF_None),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x)>;
def MatchHardSwishPattern2 : Pat<
(TFL_MulOp
$x,
(TFL_MulOp
(TFL_AddOp
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
TFL_AF_None),
(TFL_HardSwishOp $x)>;
def MatchHardSwishPattern3 : Pat<
(TFL_MulOp
(TFL_MulOp
$x,
(TFL_AddOp
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
TFL_AF_None),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x)>;
def MatchHardSwishPattern4 : Pat<
(TFL_MulOp
(TFL_MulOp
(TFL_AddOp
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
$x,
TFL_AF_None),
(TFL_HardSwishOp $x)>;
// Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to
// incorrect placement in the quantization aware training.
def MatchHardSwishQuantized : Pat<
(TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp
(TFL_MulOp
$x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp
$x,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6), $qattr2)),
TFL_AF_None), $qattr1)),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x)>;
// Constraint that the attribute value is less than 'n'
class ConstDoubleValueLessThan<string n> : Constraint<
CPred<"$0.isa<DenseElementsAttr>() && "
"$0.cast<DenseElementsAttr>().getNumElements() == 1 && "
"std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < "
# n>>;
def L2NormValidReduceIndex : Constraint<CPred<
"L2NormalizeReduceAxis($0, $1.cast<DenseElementsAttr>())">>;
// Currently L2Normalization doesn't support activation function
// in TFLite.
// TODO(karimnosseir): Add constraints that the kernel code assumes.
// constraint on axis and depth.
multiclass L2NormalizePatterns<Op FirstOp, Op SecondOp> {
// This pattern constructs L2NormalizationOp from
// Mul->Rsqrt->Sum->Square Or
// Div->sqrt->Sum->Square
def L2NormalizePattern1#FirstOp#SecondOp : Pat<
(FirstOp $x,
(SecondOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $x),
(ConstantOp I32ElementsAttr:$axis),
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $x, TFL_AF_None),
[(L2NormValidReduceIndex $sq_op, $axis)]>;
// Below patterns for L2Normalize when there is an Add or Maximum
// adding or clamping to a small constant scalar.
def L2NormalizePattern2#FirstOp#SecondOp : Pat<
(FirstOp $x,
(SecondOp
(TFL_AddOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $x),
(ConstantOp I32ElementsAttr:$axis),
$keep_dims),
(ConstantOp $epsilon), TFL_AF_None)),
TFL_AF_None),
(TFL_L2NormalizationOp $x, TFL_AF_None),
[(L2NormValidReduceIndex $sq_op, $axis),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
def L2NormalizePattern3#FirstOp#SecondOp : Pat<
(FirstOp $x,
(SecondOp
(TFL_MaximumOp
(TFL_SumOp
(TFL_SquareOp:$sq_op $x),
(ConstantOp I32ElementsAttr:$axis),
$keep_dims),
(ConstantOp $epsilon))),
TFL_AF_None),
(TFL_L2NormalizationOp $x, TFL_AF_None),
[(L2NormValidReduceIndex $sq_op, $axis),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
}
foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
//===----------------------------------------------------------------------===//
// Binary ops patterns.
//===----------------------------------------------------------------------===//
def AreBroadcastableTypes : Constraint<CPred<
"TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
def OperandsBroadcastToOutputType : Constraint<CPred<
"TFL::OperandsBroadcastToOutputType($0.getType(), $1.getType(), "
"$2.getType())">>;
def IsTailOfShape : Constraint<CPred<
"TFL::IsTailOfShape($0.getType(), $1.getType())">>;
def Flatten : NativeCodeCall<
"$0.cast<DenseElementsAttr>()"
".reshape(RankedTensorType::get({$0.getType().cast<ShapedType>().getNumElements()}, "
"$0.getType().cast<ShapedType>().getElementType()))">;
def IsLastDimEqualToNumElements : Constraint<CPred<
"$0.getType().cast<ShapedType>().getRank() >= 1 && "
"$0.getType().cast<ShapedType>().getDimSize($0.getType().cast<ShapedType>().getRank() - 1) == "
"$1.getType().cast<ShapedType>().getNumElements()">>;
def IsDefinedByFullyConnectedOp : Constraint<CPred<
"$0.getDefiningOp<TFL::FullyConnectedOp>() != nullptr">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
multiclass FuseTileBroadcastIntoFollowingBinary<Op BinaryOp> {
def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat<
(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)),
$operand, $act_func),
(BinaryOp $input, $operand, $act_func),
[(OperandsBroadcastToOutputType $input, $operand, $result),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $operand)]>;
def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat<
(BinaryOp:$result $operand,
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
(BinaryOp $operand, $input, $act_func),
[(OperandsBroadcastToOutputType $operand, $input, $result),
(HasRankAtMost<4> $operand),
(HasRankAtMost<4> $input)]>;
}
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
multiclass FusedBinaryActivationFuncOpPat<Op BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat<
(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $binary_out)]>;
}
}
foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
// Instantiated FusedBinary patterns for the from-to pairs of ops.
defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
// Move binary op before reshape: reshape -> binary => binary -> reshape.
// This is valid only when the binary operand is constant and the shape is the
// tail of the other operand and the intermediate result isn't used by other
// ops.
// $rhs is required to be the tail shape of $lhs, so after transformation the
// shape of the binary op result is valid. For example, assume the shapes of
// $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
// transformation, the shape of the binary op result is [40x1600], which
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
// make sure $rhs is the tail shape of $lhs.
def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a), $act_fn),
(TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs)]>;
// Move binary op before reshape:
// binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs))
// This is valid only when both side of the binary operand is reshaped, and
// the sizes are the same both before and after the reshape.
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input1, (ConstantOp:$shape1 $s1)),
(TFL_ReshapeOp:$rhs $input2, (ConstantOp:$shape2 $s2)),
$act_fn),
(TFL_ReshapeOp (BinaryOp $input1, $input2, $act_fn), $shape1),
[(IsTailOfShape $rhs, $lhs),
(IsTailOfShape $lhs, $rhs),
(IsTailOfShape $input1, $input2),
(IsTailOfShape $input2, $input1)]>;
// Move binary op before reshape:
// binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
// This is valid only when the last dimension of lhs is equal to the
// number of elements in constant rhs.
// Therefore, after transformation broadcast of binary op is always
// applied to the last dimension of $input.
def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs ElementsAttr:$rhs_attr), $act_fn),
(TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr)),
$act_fn),
$shape),
[(AnyStaticShapeTensor $input),
(IsTailOfShape $rhs, $lhs),
(IsLastDimEqualToNumElements $input, $rhs),
(HasOneUse $lhs),
// Restrict operands to have at most rank 4 because TFLite binary
// kernel supports up to 4D broadcast.
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs),
(IsDefinedByFullyConnectedOp $input)]>;
}
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
TFL_GreaterEqualOp] in {
// Move binary op before reshape: reshape -> binary => binary -> reshape.
// This is valid only when the binary operand is constant and the shape is the
// tail of the other operand and the intermediate result isn't used by other
// ops.
// $rhs is required to be the tail shape of $lhs, so after transformation the
// shape of the binary op result is valid. For example, assume the shapes of
// $input, $lhs and $rhs are [1600], [1,40,40] and [40x1]. After the
// transformation, the shape of the binary op result is [40x1600], which
// couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to
// make sure $rhs is the tail shape of $lhs.
def MoveBinaryOpConstBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs $a)),
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
// The broadcasting of "BinaryOp" only happens in the lower
// dimensions, and the higher dimensions are same, so we know the
// result and input of the "BinaryOp" in the source pattern have
// the same shape, which is defined by `shape`.
[(IsTailOfShape $rhs, $lhs),
(HasOneUse $lhs),
// The result of the new "BinaryOp" will have the same shape as
// `input`. In other words, the shape of the `Reshape` op are not
// changed after the transformation.
(IsTailOfShape $rhs, $input),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs)]>;
// Move binary op before reshape:
// binary(reshape(lhs), reshape(rhs)) => reshape(binary(lhs, rhs))
// This is valid only when both side of the binary operand is reshaped, and
// the sizes are the same both before and after the reshape.
def MoveBinaryOpBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input1, (ConstantOp:$shape1 $s1)),
(TFL_ReshapeOp:$rhs $input2, (ConstantOp:$shape2 $s2))),
(TFL_ReshapeOp (BinaryOp $input1, $input2), $shape1),
[(IsTailOfShape $rhs, $lhs),
(IsTailOfShape $lhs, $rhs),
(IsTailOfShape $input1, $input2),
(IsTailOfShape $input2, $input1)]>;
// Move binary op before reshape:
// binary(reshape(lhs), rhs) => reshape(binary(lhs, flatten(rhs)))
// This is valid only when the last dimension of lhs is equal to the
// number of elements in constant rhs.
// Therefore, after transformation broadcast of binary op is always
// applied to the last dimension of $input.
def MoveBinaryOpFlattenConstBeforeReshape#BinaryOp : Pat<
(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
(ConstantOp:$rhs ElementsAttr:$rhs_attr)),
(TFL_ReshapeOp (BinaryOp $input, (ConstantOp (Flatten $rhs_attr))),
$shape),
[(AnyStaticShapeTensor $input),
(IsTailOfShape $rhs, $lhs),
(IsLastDimEqualToNumElements $input, $rhs),
(HasOneUse $lhs),
// Restrict operands to have at most rank 4 because TFLite binary
// kernel supports up to 4D broadcast.
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $lhs),
(HasRankAtMost<4> $rhs),
(IsDefinedByFullyConnectedOp $input)]>;
}
// Reorder the element-wise value operations and the element move operations,
// such that the value operation happens before move operation.
foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp,
TFL_ReluOp, TFL_Relu1Op, TFL_Relu6Op, TFL_RoundOp,
TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp, TFL_LogisticOp] in {
foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp,
TFL_ReshapeOp, TFL_TransposeOp] in {
def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat<
(ValueOp:$value (MoveOp:$move $input, $move_def)),
(MoveOp (ValueOp $input), $move_def),
[(HasOneUse $move)]>;
}
}
// Returns shape of a ranked tensor.
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;
// Returns True if the operand type is RankedTensorType and valid.
def HasValidRankedTensor : Constraint<CPred<
"$0.getType().isa<RankedTensorType>() && "
"$0.getType().cast<RankedTensorType>().getNumDynamicDims() <= 1">>;
def ConvertSqueezeToReshape : Pat<
(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
(TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))),
[(HasValidRankedTensor $squeeze_op)]>;
// Convert expand_dims to reshape if possible.
def ConvertExpandDimsToReshape : Pat<
(TFL_ExpandDimsOp:$expand_dims_op $input, $dim),
(TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))),
[(AnyStaticShapeTensor $expand_dims_op)]>;
class FloatValueEquals<string val> : Constraint<CPred<
"FloatValueEquals($0, " # val # ")">>;
// ReLU patterns
def MatchReluPattern : Pat<
(TFL_MaximumOp $input, (ConstantOp $Zero)),
(TFL_ReluOp $input),
[(FloatValueEquals<"0"> $Zero)]>;
def MatchRelu1Pattern1 : Pat<
(TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)),
(ConstantOp $One)),
(TFL_Relu1Op $input),
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def MatchRelu1Pattern2 : Pat<
(TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)),
(ConstantOp $NegOne)),
(TFL_Relu1Op $input),
[(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>;
def MatchLeakyRelu : Pat<
(TFL_MaximumOp
(TFL_MulOp:$mul_out $x,
(ConstantOp F32ElementsAttr:$alpha), TFL_AF_None),
$x),
(TFL_LeakyReluOp $x, ExtractSingleElementAsFloat:$alpha),
[(ConstDoubleValueLessThan<"1"> $alpha),
(HasOneUse $mul_out)]>;
// Returns True if all users of this operation are in TF/TFL and don't need
// shape exact matching. This prevents from removing cast on return values which
// can break the verifier on function type mismatch.
def AllUsersInTF : Constraint<CPred<[{
llvm::all_of($0.getUsers(), [&](Operation *user) {
auto name = user->getName().getDialectNamespace();
return name == "tf" || name == "tfl";
})
}]>, "all users are TF/TFL operations.">;
def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input),
(replaceWithValue $input),
[(SameElementType $input, $output),
(AllUsersInTF $output)]>;
// Checks if the operand0's rank is one less than operand1's rank.
def PReluAlphaRankCheck : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == "
"$1.getType().cast<ShapedType>().getRank() - 1">>;
// PReLU pattern from Keras:
// f(x) = Relu(x) + (-alpha * Relu(-x))
def MatchPRelu : Pat<
(TFL_AddOp
(TFL_ReluOp:$relu_out $x),
(TFL_MulOp:$mul_out
(TFL_ReluOp (TFL_NegOp:$input_neg_out $x)),
$neg_alpha,
TFL_AF_None),
TFL_AF_None),
(TFL_PReluOp $x, (TFL_NegOp $neg_alpha)),
[(PReluAlphaRankCheck $neg_alpha, $x),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)]>;
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def LegalizeConstOp : Pat<
(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
// Reorders adds to allow constant folding.
// Add --> Add $input, $constantA
// \--> $constantB
// To
// Add --> $input
// \--> Add ($constantA, $constantB)
foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat<
(TFL_AddOp
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
(ConstantOp $b), ActFun),
(TFL_AddOp $input,
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
ActFun),
[(HasOneUse $first_output),
(HasRankAtMost<4> $input),
(HasRankAtMost<4> $a),
(HasRankAtMost<4> $b)]>;
}
// We can eliminate Relu from Relu(SquaredDifference(x, y)),
// since the result of SquaredDifference is always non-negative.
// TFLite interpreter doesn't support Relu+int32 for now. So the test cases
// are failing without the following pattern to optimize Relu away fixes
// the problem.
def OptimizeReluSquaredDifference : Pat<
(TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)),
(TFL_SquaredDifferenceOp $l, $r)>;
// Optimize X^1 o X
def OptimizePow1ToIdentity : Pat<
(TFL_PowOp $input,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">)),
(replaceWithValue $input)>;
// Optimize X^2 to X*X
def OptimizePow2ToSquare : Pat<
(TFL_PowOp $input,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "2.0f">)),
(TFL_MulOp $input, $input, TFL_AF_None)>;
// Optimize X^(1/2) to √X
def OptimizePow2ToSqrt : Pat<
(TFL_PowOp $input,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.5f">)),
(TFL_SqrtOp $input)>;
// Optimize X^(-1/2) to 1/√X == rsqrt(x)
def OptimizePow2ToRsqrt : Pat<
(TFL_PowOp $input,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "-0.5f">)),
(TFL_RsqrtOp $input)>;
def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint<CPred<
"TFL::CanOptimizeIdentityGatherNdOrScatterNdOp("
"$0, $1.cast<DenseIntElementsAttr>())">>;
def OptimizeIdentityGatherNdOp : Pat<
(TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)),
(replaceWithValue $params),
[(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>;
def OptimizeIdentityScatterNdOp : Pat<
(TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored),
(replaceWithValue $params),
[(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>;
def ShapeMatchesReduceWithKeepAxes : Constraint<CPred<
"ShapeMatchesReduceWithKeepAxes($0, $1, $2)">>;
// Fold reshapes re-inserting reduced dimensions into the results of a reduction
// with `keep_dims=false` by changing it to one using `keep_dims=true`.
foreach ReduceOp = [TFL_MeanOp, TFL_ReduceMaxOp, TFL_ReduceMinOp,
TFL_ReduceProdOp, TFL_SumOp] in {
def FoldReshapeTo#ReduceOp : Pat<
(TFL_ReshapeOp
(ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes),
ConstBoolAttrFalse),
(ConstantOp I32ElementsAttr: $shape)),
(ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue),
[(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape),
(HasOneUse $reduce)]>;
}
def IsSame : Constraint<CPred<"$0 == $1">>;
def HasTwoUse : Constraint<CPred<
"std::distance($0.use_begin(), $0.use_end()) == 2">>;
def AxesIsLastDimension : Constraint<CPred<
"$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && "
"($0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == "
"$1.getType().cast<ShapedType>().getRank() - 1 || $0.cast<DenseIntElementsAttr>().getValue<int32_t>({0}) == -1)">>;
// Convert exp(x)/sum(exp(x)) into softmax.
def OptimizeToSoftmax : Pat<
(TFL_DivOp (TFL_ExpOp:$exp $input),
(TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes),
ConstBoolAttrTrue), TFL_AF_None),
(TFL_SoftmaxOp $input, ConstF32Attr<"1.0">),
[(IsSame $exp, $sum_input),
(AxesIsLastDimension $axes, $sum_input),
(HasTwoUse $exp),
(HasOneUse $sum)]>;
// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals
// with the max normalization.
def FoldNormalizationIntoSoftmax : Pat<
(TFL_SoftmaxOp
(TFL_SubOp:$sub $input,
(TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes),
ConstBoolAttrTrue),
TFL_AF_None),
$beta),
(TFL_SoftmaxOp $input, $beta),
[(IsSame $input, $max_input),
(AxesIsLastDimension $axes, $max_input),
(HasOneUse $sub),
(HasOneUse $max)]>;
def HaveSameType : Constraint<CPred<"($0.getType() == $1.getType())">>;
class AllElementsAreF32<string val> : Constraint<CPred<
"($0.isa<DenseElementsAttr>() && "
"$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isF32() && "
"std::all_of($0.cast<DenseElementsAttr>().getValues<float>().begin(), "
"$0.cast<DenseElementsAttr>().getValues<float>().end(), "
"[](float v){ return v == " #val# ";}))">>;
// Optimize X*1 to X
def OptimizeMul1ToIdentity : Pat<
(TFL_MulOp:$result $input,
(ConstantOp $constant),
TFL_AF_None),
(replaceWithValue $input),
[(HaveSameType $input, $result),
(AllElementsAreF32<"1.0f"> $constant)]>;
class AllElementsAreBool<string val> : Constraint<CPred<
"($0.isa<DenseElementsAttr>() && "
"$0.cast<DenseElementsAttr>().getType().cast<ShapedType>().getElementType().isInteger(1) && "
"std::all_of($0.cast<DenseElementsAttr>().getValues<bool>().begin(), "
"$0.cast<DenseElementsAttr>().getValues<bool>().end(), "
"[](bool v){ return v == " #val# ";}))">>;
// Remove select operators when the result is known in advance.
foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in {
// select(true_tensor, A, B) -> A
def Optimize#SelectOp#True : Pat<
(SelectOp:$result (ConstantOp $constant),
$input1,
$input2),
(replaceWithValue $input1),
[(HaveSameType $input1, $result),
(AllElementsAreBool<"true"> $constant)]>;
// select(false_tensor, A, B) -> B
def Optimize#SelectOp#False : Pat<
(SelectOp:$result (ConstantOp $constant),
$input1,
$input2),
(replaceWithValue $input2),
[(HaveSameType $input2, $result),
(AllElementsAreBool<"false"> $constant)]>;
}
// Remove (log-)softmax before arg-minmax as (log-)softmax is monotonic.
foreach ArgMinMaxOp = [TFL_ArgMinOp, TFL_ArgMaxOp] in {
def RemoveSoftmaxOpBefore#ArgMinMaxOp : Pat<
(ArgMinMaxOp (TFL_SoftmaxOp:$softmax $logits, TFL_FloatNonNegative:$beta),
(ConstantOp:$const_axes I32ElementsAttr:$axes)),
(ArgMinMaxOp $logits, $const_axes),
[(HasOneUse $softmax),
(AxesIsLastDimension $axes, $logits)]>;
def RemoveLogSoftmaxOpBefore#ArgMinMaxOp : Pat<
(ArgMinMaxOp (TFL_LogSoftmaxOp:$log_softmax $logits),
(ConstantOp:$const_axes I32ElementsAttr:$axes)),
(ArgMinMaxOp $logits, $const_axes),
[(HasOneUse $log_softmax),
(AxesIsLastDimension $axes, $logits)]>;
}
def CanOptimizeIdentitySliceOp : Constraint<CPred<
"TFL::CanOptimizeIdentitySliceOp($0, $1, $2)">>;
// Remove Slice ops slicing the whole input tensor, effectively no-op.
def OptimizeSliceOp : Pat<
(TFL_SliceOp:$output $input, (ConstantOp $begin), (ConstantOp $size)),
(replaceWithValue $input),
[(CanOptimizeIdentitySliceOp $input, $begin, $size)]>;
def GetNumElementsOrOne: NativeCodeCall<"GetNumElementsOrOne($0)">;
def IsLastElementEqualsOne : Constraint<CPred<
"TFL::IsLastElementEqualsOne($0)">>;
def IsOneHotIndexAttribute : Constraint<CPred<
"TFL::IsOneHotIndexAttribute($0)">>;
// Replace
// Equal(Reshape(X, shape), indices)
// With
// OneHot(X, N, true, false, -1)
// where
// - last value in shape is 1
// - indices is a incrementing series from 0 to N-1. (N elements total.)
def ReshapeEqualOpToOneHotOp : Pat<
(TFL_EqualOp (TFL_ReshapeOp $x, (ConstantOp $shape)),
(ConstantOp $series)),
(TFL_OneHotOp $x,
(ConstantOp (GetNumElementsOrOne $series)),
(ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "true">),
(ConstantOp ConstantAttr<RankedSignlessIntElementsAttr<1, []>, "false">),
ConstantAttr<I32Attr, "-1">),
[(IsLastElementEqualsOne $shape),
(IsOneHotIndexAttribute $series)]>;
def F32ElementsVal : Constraint<CPred<
"$0.getType().cast<TensorType>().getElementType().isF32()">,
"32 bit float tensor">;
def ConvertSingleElementAttrToFloatAttr :
NativeCodeCall<"ConvertSingleElementAttrToFloatAttr($0)">;
// Replace
// (float)OneHot(index, depth, on_val, off_val, axis)
// With
// OneHot(index, depth, (float)on_val, (float)off_val, axis)
def FuseOneHotAndCastToFloat : Pat<
(TFL_CastOp:$output (TFL_OneHotOp $indices,
$depth,
(ConstantOp $on_val),
(ConstantOp $off_val),
$axis)),
(TFL_OneHotOp $indices,
$depth,
(ConstantOp (ConvertSingleElementAttrToFloatAttr $on_val)),
(ConstantOp (ConvertSingleElementAttrToFloatAttr $off_val)),
$axis),
[(F32ElementsVal $output)]>;