| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This is the canonicalize pattern definition file. |
| |
| include "mlir/IR/OpBase.td" |
| include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" |
| |
| /// TODO(b/130756570): Support OpBase constraints in PatternRewrites. |
| def SingleResultAndOperandHaveSameElementType : Constraint< |
| CPred<"$0->getResult(0)->getType().cast<ShapedType>()" |
| ".getElementType() == " |
| "$1->getType().cast<ShapedType>().getElementType()">>; |
| |
| //===----------------------------------------------------------------------===// |
| // Add op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def AddToAddV2 : Pat<(TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1), |
| (TF_AddV2Op $arg0, $arg1)>; |
| |
| //===----------------------------------------------------------------------===// |
| // AddV2 op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def AddV2OfNegLeft : Pat<(TF_AddV2Op (TF_NegOp $arg0), $arg1), |
| (TF_SubOp $arg1, $arg0)>; |
| |
| def AddV2OfNegRight : Pat<(TF_AddV2Op $arg0, (TF_NegOp $arg1)), |
| (TF_SubOp $arg0, $arg1)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Bitcast op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg), |
| [(SingleResultAndOperandHaveSameElementType $res, |
| $arg)]>; |
| |
| def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)), |
| (TF_BitcastOp $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Cast op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def CastSameType : Pat<(TF_CastOp:$res $arg, $truncate), |
| (replaceWithValue $arg), |
| [(SingleResultAndOperandHaveSameElementType $res, |
| $arg)]>; |
| |
| //===----------------------------------------------------------------------===// |
| // Conj op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def ConjNested : Pat<(TF_ConjOp (TF_ConjOp $arg)), (replaceWithValue $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Div op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| /// Favor Mul over Div. |
| def DivWithSqrtDivisor : Pat<(TF_DivOp $arg0, (TF_SqrtOp $arg1)), |
| (TF_MulOp $arg0, (TF_RsqrtOp $arg1))>; |
| |
| //===----------------------------------------------------------------------===// |
| // Invert op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def InvertNested : Pat<(TF_InvertOp (TF_InvertOp $arg)), |
| (replaceWithValue $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Log op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def LogOfSoftmax : Pat<(TF_LogOp (TF_SoftmaxOp $arg)), (TF_LogSoftmaxOp $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // LogicalNot op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def LogicalNotNested : Pat<(TF_LogicalNotOp (TF_LogicalNotOp $arg)), |
| (replaceWithValue $arg)>; |
| |
| def LogicalNotOfEqual : Pat<(TF_LogicalNotOp (TF_EqualOp $arg0, $arg1)), |
| (TF_NotEqualOp $arg0, $arg1)>; |
| |
| def LogicalNotOfNotEqual : Pat<(TF_LogicalNotOp (TF_NotEqualOp $arg0, $arg1)), |
| (TF_EqualOp $arg0, $arg1)>; |
| |
| def LogicalNotOfGreater : Pat<(TF_LogicalNotOp (TF_GreaterOp $arg0, $arg1)), |
| (TF_LessEqualOp $arg0, $arg1)>; |
| |
| def LogicalNotOfGreaterEqual : Pat<(TF_LogicalNotOp (TF_GreaterEqualOp $arg0, |
| $arg1)), |
| (TF_LessOp $arg0, $arg1)>; |
| |
| def LogicalNotOfLess : Pat<(TF_LogicalNotOp (TF_LessOp $arg0, $arg1)), |
| (TF_GreaterEqualOp $arg0, $arg1)>; |
| |
| def LogicalNotOfLessEqual : Pat<(TF_LogicalNotOp (TF_LessEqualOp $arg0, $arg1)), |
| (TF_GreaterOp $arg0, $arg1)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Neg op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def NegNested : Pat<(TF_NegOp (TF_NegOp $arg)), (replaceWithValue $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // RealDiv op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), |
| (TF_MulOp $arg0, (TF_RsqrtOp $arg1))>; |
| |
| //===----------------------------------------------------------------------===// |
| // Reciprocal op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), |
| (replaceWithValue $arg)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Square op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def SquareOfSub : Pat<(TF_SquareOp (TF_SubOp $arg0, $arg1)), |
| (TF_SquaredDifferenceOp $arg0, $arg1)>; |
| |
| //===----------------------------------------------------------------------===// |
| // Sub op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def SubOfNeg : Pat<(TF_SubOp $arg0, (TF_NegOp $arg1)), |
| (TF_AddV2Op $arg0, $arg1)>; |
| |
| //===----------------------------------------------------------------------===// |
| // TruncateDiv op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def TruncateDivWithSqrtDivisor : Pat<(TF_TruncateDivOp $arg0, |
| (TF_SqrtOp $arg1)), |
| (TF_MulOp $arg0, (TF_RsqrtOp $arg1))>; |
| |
| //===----------------------------------------------------------------------===// |
| // Xdivy op patterns. |
| //===----------------------------------------------------------------------===// |
| |
| def XdivyWithSqrtDivisor : Pat<(TF_XdivyOp $arg0, (TF_SqrtOp $arg1)), |
| (TF_MulNoNanOp (TF_RsqrtOp $arg1), $arg0)>; |