blob: bb00a7e852aadadb253c9a80a69fac3d03ae59f1 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the optimization pattern definition file for TensorFlow Lite.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
// Multi-pattern consisting of matching stand-alone convolution op followed by
// activation op.
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w)>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w,
$multiplier)>;
}
// TODO(hinsu): Also fuse ops corresponding to RELU_N1_TO_1 and SIGN_BIT fused
// activation functions.
// Currently we're not fusing tanh, sigmoid, hard_swish and other activations
// those cannot be simply translated into clamping.
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6]] in
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
// constant folding the bias and the binary op's constant operand. The following
// pattern restricts to float constant values for now.
multiclass FuseBinaryOpToPrecedingAffine<dag binaryOp> {
def : Pat<(binaryOp (TFL_Conv2DOp $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value), TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w)>;
def : Pat<(binaryOp (TFL_DepthwiseConv2DOp $input, $filter,
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input, $filter,
(binaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier)>;
}
foreach binaryOp = [TFL_AddOp, TFL_SubOp] in
defm : FuseBinaryOpToPrecedingAffine<binaryOp>;
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
def ExpandTo4DForConv: NativeCodeCall<"ExpandTo4DForConv($0)">;
def ExpandTo4DForDepthwiseConv: NativeCodeCall<
"ExpandTo4DForDepthwiseConv($0)">;
// If we see a (div or Mul) op (dividing/multiplying) a constant value
// to a convolution op with constant filter and bias, we can fuse the div/mul
// into the convolution op by constant folding
// the filter/bias and the div/mul op's constant operand.
// The following pattern restricts to float constant values for now.
multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d<dag BinaryOp> {
def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_DepthwiseConv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp
(ExpandTo4DForDepthwiseConv $value)),
TFL_AF_None),
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w,
$multiplier),
[(CanFuseConvOrDepthwiseConv<"true"> $filter, $value)]>;
def : Pat<(BinaryOp (TFL_Conv2DOp $input,
(ConstantOp F32ElementsAttr:$filter),
(ConstantOp F32ElementsAttr:$bias),
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w),
(ConstantOp F32ElementsAttr:$value), $act_fn),
(TFL_Conv2DOp $input,
(BinaryOp (ConstantOp $filter),
(ConstantOp (ExpandTo4DForConv $value)),
TFL_AF_None),
(BinaryOp (ConstantOp $bias),
(ConstantOp $value),
TFL_AF_None),
$h_factor, $w_factor, $act_fn,
$padding, $stride_h, $stride_w),
[(CanFuseConvOrDepthwiseConv<"false"> $filter, $value)]>;
}
foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in
defm : FuseMulOrDivWithConv2dOrDepthwiseConv2d<BinaryOp>;
// This pattern applies when the same quantize/dequantize have been used twice
// with the same scale. We want to remove the redundancy.
// TODO(fengliuai): move this to the sanity check of pre-quantize pass.
def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>;
// Constraint that makes sure both operands are the same operands.
def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand has rank == n
class OperandHasRank<int n> : Constraint<
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
// Matching HardSwish
def : Pat<
(TFL_MulOp
(TFL_MulOp
$x, (TFL_AddOp
$y,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
TFL_AF_None),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
def : Pat<
(TFL_MulOp
$x,
(TFL_MulOp
(TFL_AddOp
$y,
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
TFL_AF_Relu6),
(ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "0.166666666f">),
TFL_AF_None),
TFL_AF_None),
(TFL_HardSwishOp $x),
[(EqualOperands $x, $y)]>;
// Constraint that the attribute value is less than 'n'
class ConstDoubleValueLessThan<string n> : Constraint<
CPred<"$0.isa<DenseElementsAttr>() && "
"$0.cast<DenseElementsAttr>().getNumElements() == 1 && "
"std::abs(*$0.cast<DenseElementsAttr>().getValues<float>().begin()) < "
# n>>;
// Currently L2Normalization doesn't support activation function
// in TFLite.
// TODO(karimnosseir): Add constraints that the kernel code assumes.
// constraint on axis and depth.
multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
// This pattern constructs L2NormalizationOp from
// Mul->Rsqrt->Sum->Square Or
// Div->sqrt->Sum->Square
def : Pat<(FirstOp $operand1,
(SecondOp
(TFL_SumOp
(TFL_SquareOp $square_operand),
(ConstantOp I32ElementsAttr:$constant),
$keep_dims)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand)]>;
// Below patterns for L2Normalize when there is an Add or Maximum
// adding or clamping to a small constant scalar.
def : Pat<(FirstOp $operand1,
(SecondOp
(TFL_AddOp
(TFL_SumOp
(TFL_SquareOp $square_operand),
(ConstantOp I32ElementsAttr:$constant),
$keep_dims),
(ConstantOp $epsilon), TFL_AF_None)),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
def : Pat<(FirstOp $operand1,
(SecondOp
(TFL_MaximumOp
(TFL_SumOp
(TFL_SquareOp $square_operand),
(ConstantOp I32ElementsAttr:$constant),
$keep_dims),
(ConstantOp $epsilon))),
TFL_AF_None),
(TFL_L2NormalizationOp $operand1, TFL_AF_None),
[(EqualOperands $operand1, $square_operand),
(ConstDoubleValueLessThan<"1e-3"> $epsilon)]>;
}
foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
def AreBroadcastableTypes : Constraint<CPred<
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting.
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
def : Pat<(BinaryOp (TFL_TileOp $input, (ConstantOp $tile)),
$operand, $act_func),
(BinaryOp $input, $operand, $act_func),
[(AreBroadcastableTypes $input, $operand)]>;
def : Pat<(BinaryOp $operand,
(TFL_TileOp $input, (ConstantOp $tile)), $act_func),
(BinaryOp $operand, $input, $act_func),
[(AreBroadcastableTypes $operand, $input)]>;
}
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;