blob: 60b2acd46b479666269ace7104ce35a874bf20be [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "tensorflow/compiler/mlir/tensorflow/transforms/optimize.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def HasNoUse: Constraint<
CPred<"$0->use_begin() == $0->use_end()">, "has no use">;
// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
// operations. Specifically, performs the following calculation:
//
// (x - mean) * scale / sqrt(variance + epsilon) + offset
//
// Let multiplier = scale / sqrt(variance + epsilon),
// to compute
// (x - mean) * scale / sqrt(variance + epsilon) + offset,
// is then to compute
// (x * multiplier) + (offset - mean * multiplier).
def : Pattern<
(TF_FusedBatchNormOp:$root
$x, $scale, $offset, $mean, $variance,
F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training),
[(TF_AddOp
(TF_MulOp
$x,
(TF_MulOp:$multiplier
$scale,
(TF_RsqrtOp
(TF_AddOp $variance,
(TF_ConstOp $epsilon))))),
(TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
// We already guaranteed that the last four results has no use so it does
// not matter what value we provide here for replacement.
/*batch_mean=*/(replaceWithValue $x),
/*batch_variance=*/(replaceWithValue $x),
/*reserve_space_1=*/(replaceWithValue $x),
/*reserve_space_2=*/(replaceWithValue $x)],
[(HasNoUse $root__1), (HasNoUse $root__2),
(HasNoUse $root__3), (HasNoUse $root__4)]>;
def : Pattern<
(TF_FusedBatchNormV3Op:$root
$x, $scale, $offset, $mean, $variance,
F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training),
[(TF_AddOp
(TF_MulOp
$x,
(TF_MulOp:$multiplier
$scale,
(TF_RsqrtOp
(TF_AddOp $variance,
(TF_ConstOp $epsilon))))),
(TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
// We already guaranteed that the last five results have no use so it does
// not matter what value we provide here for replacement.
/*batch_mean=*/(replaceWithValue $x),
/*batch_variance=*/(replaceWithValue $x),
/*reserve_space_1=*/(replaceWithValue $x),
/*reserve_space_2=*/(replaceWithValue $x),
/*reserve_space_3=*/(replaceWithValue $x)],
[(HasNoUse $root__1), (HasNoUse $root__2),
(HasNoUse $root__3), (HasNoUse $root__4),
(HasNoUse $root__5)]>;
// TODO(jpienaar): Move to opbase something more general.
def TFi32ElementsAttr : Attr<CPred<"$_self.isa<DenseIntElementsAttr>">,
"scalar int attribute"> {
let storageType = [{ DenseIntElementAttr }];
let constBuilderCall = "$_builder.getDenseElementsAttr("
"$_builder.getTensorType({}, $_builder.getIntegerType(32)), "
"{$_builder.getI32IntegerAttr($0)})";
}
class TFi32<int v> : ConstantAttr<TFi32ElementsAttr, !cast<string>(v)>;
// Matmul without transpose on b to matmul with explicit transpose op and
// transposed b.
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrFalse:$at, ConstBoolAttrFalse),
(TF_MatMulOp $a, (TF_TransposeOp $b, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $b),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))),
$at, ConstBoolAttrTrue)>;
// Matmul with transpose on a to matmul with explicit transpose op and a not
// transposed.
def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
(TF_MatMulOp (TF_TransposeOp $a, (TF_SubOp (TF_RangeOp
/*start=*/(TF_RankOp $a),
/*limit=*/(ConstantOp TFi32<0>),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
ConstBoolAttrFalse, $bt)>;
def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>;
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;
//===----------------------------------------------------------------------===//
// Op removal patterns.
//===----------------------------------------------------------------------===//
def : Pat<(TF_IdentityOp $arg), (replaceWithValue $arg)>;
//===----------------------------------------------------------------------===//
// Op quantization pass-through patterns.
//===----------------------------------------------------------------------===//
// TODO(fengliuai): Implement similar rule in the QuantizePass if the constant
// folding hook of tfl.transpose and tfl.reshape are implemented.
def : Pat<(TF_TransposeOp
(TF_FakeQuantWithMinMaxVarsOp
$input, $min, $max, $num_bits, $narrow_range),
$perm),
(TF_FakeQuantWithMinMaxVarsOp (TF_TransposeOp $input, $perm),
$min, $max, $num_bits, $narrow_range)>;
def : Pat<(TF_ReshapeOp
(TF_FakeQuantWithMinMaxVarsOp
$input, $min, $max, $num_bits, $narrow_range),
$shape),
(TF_FakeQuantWithMinMaxVarsOp (TF_ReshapeOp $input, $shape),
$min, $max, $num_bits, $narrow_range)>;
// Casts result type of $1 to a quantized type by using the quantization
// parameters from the type in $0.
def UpdateShape : NativeCodeCall<
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, GetFirstResultType($1))">;
// When the op is passing-through, the output types of the quantized ops need
// to be updated as well. Since the quantize op manages its own type by the
// "qtype" attribute, we should update the type shape in this attribute.
def : Pat<(TF_TransposeOp:$op
(TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $perm),
(TFL_DequantizeOp (TFL_QuantizeOp (TF_TransposeOp $input, $perm),
(UpdateShape $qtype, $op)))>;
def : Pat<(TF_ReshapeOp:$op
(TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $shape),
(TFL_DequantizeOp
(TFL_QuantizeOp (TF_ReshapeOp $input, $shape),
(UpdateShape $qtype, $op)))>;