blob: 3ca32349efef33a6919286a102d5f09a4b4b8e85 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern definition file for HLO to TF.
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Func/IR/FuncOps.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td"
include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
// Check if broadcasting is compatible with TF ops.
def IsLegalNumpyRankedBroadcast :
Constraint<CPred<"hlo::isLegalNumpyRankedBroadcast($0, $1, $2)">,
"broadcasting should be compatible with TF ops">;
// Return a constant op that carries the shape of the given value.
def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">;
// Check if broadcast dimensions match Tensorflow convention.
def IsTFStyleBroadcast : Constraint<CPred<"IsTFStyleBroadcast($0, $1)">,
"new dimensions are added as prefix">;
// Check if broadcast dimensions do not match Tensorflow convention.
def IsNotTFStyleBroadcast : Constraint<Neg<CPred<"IsTFStyleBroadcast($0, $1)">>,
"new dimensions are inserted in intermediate positions">;
// Return intermediate shape before broadcasting, wrapped in a constant op.
def ExpandedShape : NativeCodeCall<"ExpandedShape($_builder, $0, $1, $2)">;
def : Pat<(HLO_ConstantOp:$output $value), (TF_ConstOp $value),
[(TF_Tensor $output)]>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
// semantically compatible with the corresponding TF ops. Depending on
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
foreach fromToBinPair = [[HLO_AddOp, CHLO_BroadcastAddOp, TF_AddV2Op],
[HLO_DivOp, CHLO_BroadcastDivOp, TF_DivOp],
[HLO_ShiftLeftOp, CHLO_BroadcastShiftLeftOp, TF_LeftShiftOp],
[HLO_MaxOp, CHLO_BroadcastMaxOp, TF_MaximumOp],
[HLO_MinOp, CHLO_BroadcastMinOp, TF_MinimumOp],
[HLO_MulOp, CHLO_BroadcastMulOp, TF_MulOp],
[HLO_PowOp, CHLO_BroadcastPowOp, TF_PowOp],
[HLO_SubtractOp, CHLO_BroadcastSubOp, TF_SubOp],
[HLO_Atan2Op, CHLO_BroadcastAtan2Op, TF_Atan2Op]] in {
def : Pat<(fromToBinPair[0] $l, $r), (fromToBinPair[2] $l, $r)>;
def : Pat<(fromToBinPair[1] $l, $r, $broadcast_dimensions),
(fromToBinPair[2] $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
}
foreach pair = [[HLO_AndOp, CHLO_BroadcastAndOp, TF_BitwiseAndOp],
[HLO_OrOp, CHLO_BroadcastOrOp, TF_BitwiseOrOp],
[HLO_XorOp, CHLO_BroadcastXorOp, TF_BitwiseXorOp]] in {
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r), (pair[2] $l, $r)>;
def : Pat<(pair[1] TF_IntTensor:$l, TF_IntTensor:$r, $broadcast_dimensions),
(pair[2] $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
}
foreach pair = [[HLO_AndOp, CHLO_BroadcastAndOp, TF_LogicalAndOp],
[HLO_OrOp, CHLO_BroadcastOrOp, TF_LogicalOrOp]] in {
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r), (pair[2] $l, $r)>;
def : Pat<(pair[1] I1Tensor:$l, I1Tensor:$r, $broadcast_dimensions),
(pair[2] $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
}
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r), (TF_RightShiftOp $l, $r)>;
def : Pat<(CHLO_BroadcastShiftRightArithmeticOp $l, $r,
$broadcast_dimensions),
(TF_RightShiftOp $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
def : Pat<(HLO_ShiftRightLogicalOp $l, $r), (TF_RightShiftOp $l, $r)>;
def : Pat<(CHLO_BroadcastShiftRightLogicalOp $l, $r,
$broadcast_dimensions),
(TF_RightShiftOp $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r)), (TF_FloorDivOp $l, $r)>;
def : Pat<(HLO_FloorOp (CHLO_BroadcastDivOp $l, $r,
$broadcast_dimensions)),
(TF_FloorDivOp $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
def : Pat<(HLO_ComplexOp $r, $i), (TF_ComplexOp $r, $i)>;
def : Pat<(HLO_RemOp TF_FpOrI32OrI64Tensor:$l, TF_FpOrI32OrI64Tensor:$r), (TF_ModOp $l, $r)>;
def : Pat<(CHLO_BroadcastRemOp TF_FpOrI32OrI64Tensor:$l, TF_FpOrI32OrI64Tensor:$r, $broadcast_dimensions),
(TF_ModOp $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ConvertOp HLO_Tensor:$operand),
(TF_CastOp $operand, ConstBoolAttrFalse)>;
foreach Mapping = [[HLO_AbsOp, TF_AbsOp],
[HLO_BitcastConvertOp, TF_BitcastOp],
[HLO_CeilOp, TF_CeilOp],
[HLO_CosineOp, TF_CosOp],
[HLO_ExpOp, TF_ExpOp],
[HLO_Expm1Op, TF_Expm1Op],
[HLO_FloorOp, TF_FloorOp],
[HLO_ImagOp, TF_ImagOp],
[HLO_IsFiniteOp, TF_IsFiniteOp],
[HLO_LogOp, TF_LogOp],
[HLO_Log1pOp, TF_Log1pOp],
[HLO_LogisticOp, TF_SigmoidOp],
[HLO_NegOp, TF_NegOp],
[HLO_RealOp, TF_RealOp],
[HLO_RsqrtOp, TF_RsqrtOp],
[HLO_SineOp, TF_SinOp],
[HLO_SignOp, TF_SignOp],
[HLO_SqrtOp, TF_SqrtOp],
[HLO_TanhOp, TF_TanhOp]] in
def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>;
def : Pat<(HLO_NotOp TF_BoolTensor:$input), (TF_LogicalNotOp $input)>;
def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>;
def : Pat<(HLO_BroadcastOp $arg, $shape),
(TF_BroadcastToOp $arg, (TF_ConstOp $shape))>;
def : Pat<(HLO_BroadcastInDimOp:$output $input, $broadcast_dimensions),
(TF_BroadcastToOp $input, (ShapeToConst $output)),
[(IsTFStyleBroadcast $broadcast_dimensions, $output)]>;
def : Pat<(HLO_BroadcastInDimOp:$output $input, $broadcast_dimensions),
(TF_BroadcastToOp
(TF_ReshapeOp
$input,
(ExpandedShape $input, $broadcast_dimensions, $output)),
(ShapeToConst $output)),
[(IsNotTFStyleBroadcast $broadcast_dimensions, $output)]>;
def : Pat<(HLO_TransposeOp $arg, $permutation),
(TF_TransposeOp $arg, (TF_ConstOp $permutation))>;
def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>;
def : Pat<(HLO_ReshapeOp:$output $input),
(TF_ReshapeOp $input, (ShapeToConst $output))>;
//===----------------------------------------------------------------------===//
// Ternary op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ClampOp $min, $arg, $max),
(TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>;
def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>;
//===----------------------------------------------------------------------===//
// Variadic op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ConcatenateOp $inputs, $dim),
(TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>;
class HasCompareType<string value> :
CPred<"$_self.cast<::mlir::mhlo::ComparisonTypeAttr>().getValue() == " # value>;
// Attribute value should be such that it matches the comparison used by
// TensorFlow, if the attribute is present.
def IsTFCompareType : AttrConstraint<
Or<[CPred<"!$_self">, HasCompareType<"::mlir::mhlo::ComparisonType::FLOAT">,
HasCompareType<"::mlir::mhlo::ComparisonType::SIGNED">,
HasCompareType<"::mlir::mhlo::ComparisonType::UNSIGNED">,
HasCompareType<"::mlir::mhlo::ComparisonType::NOTYPE">]>,
"compare type supported by TensorFlow">;
//===----------------------------------------------------------------------===//
// Compare op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
// semantically compatible with the corresponding TF ops. Depending on
// context, getting to these ops may require some raising.
//===----------------------------------------------------------------------===//
class HLO_ComparisonDirectionValue<string enumStr> :
ConstantAttr<HLO_ComparisonDirectionAttr, "::mlir::mhlo::ComparisonDirection::" # enumStr>;
foreach p = [[TF_EqualOp, HLO_ComparisonDirectionValue<"EQ">],
[TF_NotEqualOp, HLO_ComparisonDirectionValue<"NE">]] in {
def : Pat<(CHLO_BroadcastCompareOp $l, $r, $broadcast_dimensions, p[1],
IsTFCompareType:$type),
(p[0] $l, $r, ConstBoolAttrTrue),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
def : Pat<(HLO_CompareOp $l, $r, p[1], IsTFCompareType:$type),
(p[0] $l, $r, ConstBoolAttrTrue)>;
}
foreach pair = [[TF_GreaterEqualOp, HLO_ComparisonDirectionValue<"GE">],
[TF_GreaterOp, HLO_ComparisonDirectionValue<"GT">],
[TF_LessEqualOp, HLO_ComparisonDirectionValue<"LE">],
[TF_LessOp, HLO_ComparisonDirectionValue<"LT">]] in {
def : Pat<(CHLO_BroadcastCompareOp $l, $r, $broadcast_dimensions,
pair[1], IsTFCompareType:$type),
(pair[0] $l, $r),
[(IsLegalNumpyRankedBroadcast $l, $r, $broadcast_dimensions)]>;
def : Pat<(HLO_CompareOp $l, $r, pair[1], IsTFCompareType:$type),
(pair[0] $l, $r)>;
}
def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "
"$0.getDefiningOp())">;
def : Pat<(HLO_DotOp:$old_value StaticShapeTensorOf<[TF_ElementType]>:$lhs,
StaticShapeTensorOf<[TF_ElementType]>:$rhs, $precision_config),
(ConvertDotOp $old_value)>;
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
"$0.getDefiningOp())">;
def : Pat<(HLO_DotGeneralOp:$old_value AnyStaticShapeTensor:$lhs,
AnyStaticShapeTensor:$rhs, $dot_dimension_numbers,
$precision_config),
(ConvertDotGeneralOp $old_value)>;
def IsZero : Constraint<CPred<
"$0.isSplat() && $0.getSplatValue<APInt>() == 0">>;
def ConvertPadOp : NativeCodeCall<
"ConvertPadOp($_builder, $0.getDefiningOp())">;
def : Pat<(HLO_PadOp:$old_value $input, $pad_value, $pad_low, $pad_high,
$pad_interior),
(ConvertPadOp $old_value),
[(IsZero $pad_interior)]>;
class FloatValueEquals<string val> : Constraint<CPred<
"$0.isa<SplatElementsAttr>() && "
"$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isExactlyValue(" # val # ")">>;
def FloatValueGreaterThanZero : Constraint<CPred<
"$0.isa<SplatElementsAttr>() && "
"!$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isNegative() && "
"!$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isZero()">>;
def FloatValueIsReciprocal : Constraint<CPred<
"$0.isa<SplatElementsAttr>() && $1.isa<SplatElementsAttr>() &&"
"($0.cast<SplatElementsAttr>().getSplatValue<APFloat>() * "
"$1.cast<SplatElementsAttr>().getSplatValue<APFloat>()).isExactlyValue(1.0)">>;
def FloatTensorIsSign : Constraint<CPred<"FloatTensorIsSign($_builder, $0, $1)">>;
def SameValue : Constraint<CPred<"$0 == $1">>;
def FloatOrDefaultCompare : Constraint<CPred<
"!$0 || $0.getValue() == ::mlir::mhlo::ComparisonType::FLOAT">>;
// Converts a dag of HLOs representing banker rounding (round x.5 to nearest
// even) to tf.round.
// The pattern matched executes the following computation:
// frac = x - floor(x)
// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1
// if frac > 0.5 || (frac == 0.5 && to_even)
// return floor(x) + 1
// else
// return floor(x)
def : Pat<(HLO_SelectOp
(HLO_OrOp
(HLO_CompareOp (HLO_SubtractOp:$frac
$input,
(HLO_FloorOp:$floor $input)),
(HLO_ConstantOp $half),
HLO_ComparisonDirectionValue<"GT">,
$compare_type0),
(HLO_AndOp
(HLO_CompareOp
$frac1,
(HLO_ConstantOp $half1),
HLO_ComparisonDirectionValue<"EQ">,
$compare_type1),
(HLO_CompareOp
(HLO_SubtractOp
$floor1,
(HLO_MulOp
(HLO_FloorOp (HLO_MulOp $input, (HLO_ConstantOp $half2))),
(HLO_ConstantOp $two))),
(HLO_ConstantOp $one1),
HLO_ComparisonDirectionValue<"EQ">,
$compare_type2))),
(HLO_AddOp $floor2, (HLO_ConstantOp $one)),
$floor3),
(TF_RoundOp $input),
[(FloatValueEquals<"1.0"> $one),
(FloatValueEquals<"1.0"> $one1),
(FloatValueEquals<"2.0"> $two),
(FloatValueEquals<"0.5"> $half),
(FloatValueEquals<"0.5"> $half1),
(FloatValueEquals<"0.5"> $half2),
(SameValue $floor, $floor1),
(SameValue $floor, $floor2),
(SameValue $floor, $floor3),
(SameValue $frac, $frac1),
(FloatOrDefaultCompare $compare_type0),
(FloatOrDefaultCompare $compare_type1),
(FloatOrDefaultCompare $compare_type2)]>;
// Converts a dag of HLOs representing floor_mod to tf.FloorMod.
// The pattern matched executes the following computation:
//
// rem = remainder(arg0, arg1)
// for i in 0 to len(arg1):
// if ((rem[i] < 0) != (arg0[i] < 0) && arg0[i] != 0)
// rem[i] += arg1[i]
// return rem
def : Pat<(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLO_CompareOp:$rltz
(HLO_RemOp:$rem $arg, $arg1),
(HLO_ConstantOp $cst),
HLO_ComparisonDirectionValue<"LT">,
$compare_type),
(HLO_CompareOp:$arg1ltz $arg1, (HLO_ConstantOp $cst1), HLO_ComparisonDirectionValue<"LT">, $compare_type1),
HLO_ComparisonDirectionValue<"NE">,
$compare_type2),
(HLO_CompareOp:$rnz $rem1, (HLO_ConstantOp $cst2), HLO_ComparisonDirectionValue<"NE">, $compare_type3)),
(HLO_AddOp $rem2, $arg1),
$rem3),
(TF_FloorModOp $arg, $arg1),
[(FloatValueEquals<"0.0"> $cst),
(FloatValueEquals<"0.0"> $cst1),
(FloatValueEquals<"0.0"> $cst2),
(SameValue $rem, $rem1),
(SameValue $rem, $rem2),
(SameValue $rem, $rem3),
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type1),
(FloatOrDefaultCompare $compare_type2)]>;
// Converts a dag of HLOs representing floor_mod with a constant to
// tf.FloorMod. The pattern matched executes the following computation:
//
// cst = value that is > 0
// rem = remainder(arg0, cst)
// for i in 0 to len(arg1):
// if (rem[i] < 0 && rem[i] != 0)
// rem[i] += cst
// return rem
def : Pat<(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp:$rltz
(HLO_RemOp:$rem $arg, (HLO_ConstantOp $cst)),
(HLO_ConstantOp $cst1),
HLO_ComparisonDirectionValue<"LT">,
$compare_type),
(HLO_CompareOp:$rnz $rem1, (HLO_ConstantOp $cst2), HLO_ComparisonDirectionValue<"NE">, $compare_type3)),
(HLO_AddOp $rem2, (HLO_ConstantOp $cst3)),
$rem3),
(TF_FloorModOp $arg, (TF_ConstOp $cst3)),
[(FloatValueGreaterThanZero $cst),
(FloatValueEquals<"0.0"> $cst1),
(FloatValueEquals<"0.0"> $cst2),
(SameValue $cst, $cst3),
(SameValue $rem, $rem1),
(SameValue $rem, $rem2),
(SameValue $rem, $rem3),
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type3)]>;
// Converts a dag of HLOs representing floor_div to tf.FloorDiv.
// The pattern matched executes the following computation:
//
// rem = remainder(arg0, arg1)
// for i in 0 to len(arg1):
// rem[i] = arg0[i] - rem[i] / arg1[i]
// if (rem[i] != 0 && sign(arg1[i]) != sign(rem[i]))
// rem[i] -= 1.0
// return round_nearest_afz(rem)
// As a dag this looks like the following:
// round
// |
// -------- select
// | | \
// && + div
// / | / \
// != != div -1
// / | / | / |
// rem 0.0 sn sn1 - $1
// / | | | / |
// $0 $1 $1 rem $0 rem
// Note that named operators like 'sn' and 'sn1' are different values produced by
// the same function in this case the sign function. Named values like 'div'
// refer to the same value produced by the same function, in this case division.
// Mathematical symbols do not indicate a re-use of the value.
def : Pat<(HLO_RoundOp
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLO_RemOp:$rem $arg0, $arg1),
(HLO_ConstantOp $cst),
HLO_ComparisonDirectionValue<"NE">,
$compare_type),
(HLO_CompareOp
(HLO_SignOp $arg1),
(HLO_SignOp $rem1),
HLO_ComparisonDirectionValue<"NE">,
$compare_type1)),
(HLO_AddOp
(HLO_DivOp:$div
(HLO_SubtractOp $arg0, $rem2),
$arg1b),
(HLO_ConstantOp $cst_neg1)),
$div1)),
(TF_FloorDivOp $arg0, $arg1),
[(FloatValueEquals<"0.0"> $cst),
(FloatValueEquals<"-1.0"> $cst_neg1),
(SameValue $div, $div1),
(SameValue $rem, $rem1),
(SameValue $rem, $rem2),
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type1)]>;
// Converts a dag of HLOs representing floor_div with a splat constant to
// tf.FloorDiv. The pattern matched executes the following computation:
// This particular pattern matches multiplication with the reciprocal of the
// constant instead of dividing by the constant.
// rem = remainder(arg0, cst)
// for i in 0 to len(arg0):
// rem[i] = (arg0[i] - rem[i]) * 1 / cst
// if (rem[i] != 0 && sign(cst) != sign(rem[i]))
// rem[i] += -1.0
// return round_nearest_afz(rem)
// As a dag this looks like the following:
// round
// |
// -------- select
// | | \
// && + mul
// / | / \
// != != mul -1
// / | / | / |
// rem 0.0 cs1 sn1 - cs2
// / | | / |
// $0 cst rem $0 rem
// cs1 == sign(cst)
// cs2 = 1 / cst i.e. the reciprocal
// Note that named operators like 'sn' and 'sn1' are different values produced by
// the same function in this case the sign function. Named values like 'div'
// refer to the same value produced by the same function, in this case division.
// Mathematical symbols do not indicate a re-use of the value.
def : Pat<(HLO_RoundOp
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLO_RemOp:$rem $arg0, (HLO_ConstantOp:$cst $cstv)),
(HLO_ConstantOp $cst_zero),
HLO_ComparisonDirectionValue<"NE">,
$compare_type),
(HLO_CompareOp
(HLO_ConstantOp $cst_sgn),
(HLO_SignOp $rem1),
HLO_ComparisonDirectionValue<"NE">,
$compare_type1)),
(HLO_AddOp
(HLO_MulOp:$mul
(HLO_SubtractOp $arg0, $rem2),
(HLO_ConstantOp $cst_recip)),
(HLO_ConstantOp $cst_neg1)),
$mul1)),
(TF_FloorDivOp $arg0, $cst),
[(FloatValueEquals<"0.0"> $cst_zero),
(FloatValueEquals<"-1.0"> $cst_neg1),
(FloatTensorIsSign $cstv, $cst_sgn),
(FloatValueIsReciprocal $cstv, $cst_recip),
(SameValue $mul, $mul1),
(SameValue $rem, $rem1),
(SameValue $rem, $rem2),
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type1)]>;
// Converts a dag of HLOs representing floor_div with a splat constant to
// tf.FloorDiv. The pattern matched executes the following computation:
// This particular pattern matches division with the constant.
// .
// rem = remainder(arg0, cst)
// for i in 0 to len(arg0):
// rem[i] = (arg0[i] - rem[i]) / cst
// if (rem[i] != 0 && sign(cst) != sign(rem[i]))
// rem[i] -= 1.0
// return round_nearest_afz(rem)
// As a dag this looks like the following:
// round
// |
// -------- select
// | | \
// && + div
// / | / \
// != != div -1
// / | / | / |
// rem 0.0 cs1 sn1 - cs2
// / | | / |
// $0 cst rem $0 rem
// cs1 == sign(cst)
// cs2 = 1 / cst i.e. the reciprocal
// Note that named operators like 'sn' and 'sn1' are different values produced by
// the same function in this case the sign function. Named values like 'div'
// refer to the same value produced by the same function, in this case division.
// Mathematical symbols do not indicate a re-use of the value.
def : Pat<(HLO_RoundOp
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLO_RemOp:$rem $arg0, (HLO_ConstantOp:$cst $cstv)),
(HLO_ConstantOp $cst_zero),
HLO_ComparisonDirectionValue<"NE">,
$compare_type),
(HLO_CompareOp
(HLO_ConstantOp $cst_sgn),
(HLO_SignOp $rem1),
HLO_ComparisonDirectionValue<"NE">,
$compare_type1)),
(HLO_AddOp
(HLO_DivOp:$div
(HLO_SubtractOp $arg0, $rem2),
(HLO_ConstantOp $cstv1)),
(HLO_ConstantOp $cst_neg1)),
$div1)),
(TF_FloorDivOp $arg0, $cst),
[(FloatValueEquals<"0.0"> $cst_zero),
(FloatValueEquals<"-1.0"> $cst_neg1),
(FloatTensorIsSign $cstv, $cst_sgn),
(SameValue $div, $div1),
(SameValue $rem, $rem1),
(SameValue $rem, $rem2),
(SameValue $cstv1, $cstv),
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type1)]>;
// Converts a dag of HLOs representing floor_div with a broadcasted vector
// constant to tf.FloorDiv. The pattern matched executes the following
// computation:
// scs = sign(cst)
// bcst = broadcast(cst)
// rem = remainder(arg0, bcst)
// for i in 0 to len(arg0):
// rem[i] = arg0[i] - rem[i] * / bcst
// if (rem[i] != 0 && scs != sign(rem[i]))
// rem[i] -= 1.0
// return round_nearest_afz(rem)
// Where scs is a splat constant folded sign on the unbroadcasted tensor.
//
// As a dag this looks like the following:
// round
// |
// -------- select
// | | \
// && + div
// / | / \
// != != div -1
// / | / | / |
// rem 0.0 scs sn1 - bcst
// / | | / |
// $0 bcst rem $0 rem
// |
// cst
// scs == sign(cst) == sign(bcst)
// Note that named operators like 'sn' and 'sn1' are different values produced by
// the same function in this case the sign function. Named values like 'div'
// refer to the same value produced by the same function, in this case division.
// Mathematical symbols do not indicate a re-use of the value.
def : Pat<(HLO_RoundOp
(HLO_SelectOp
(HLO_AndOp
(HLO_CompareOp
(HLO_RemOp:$rem $arg0,
(HLO_BroadcastInDimOp:$bcst
(HLO_ConstantOp $cstv),
$broadcast_dimension)),
(HLO_ConstantOp $cst_zero),
HLO_ComparisonDirectionValue<"NE">,
$compare_type),
(HLO_CompareOp
(HLO_ConstantOp $cst_sgn),
(HLO_SignOp $rem1),
HLO_ComparisonDirectionValue<"NE">,
$compare_type1)),
(HLO_AddOp
(HLO_DivOp:$div
(HLO_SubtractOp $arg0, $rem2),
$bcst1),
(HLO_ConstantOp $cst_neg1)),
$div1)),
(TF_FloorDivOp $arg0, $bcst),
[(FloatValueEquals<"0.0"> $cst_zero),
(FloatValueEquals<"-1.0"> $cst_neg1),
(FloatTensorIsSign $cstv, $cst_sgn),
(SameValue $bcst, $bcst1),
(SameValue $div, $div1),
(SameValue $rem, $rem1),
(SameValue $rem, $rem2),
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type1)]>;
//===----------------------------------------------------------------------===//
// TorchIndexSelect op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_TorchIndexSelectOp $params, $indices, $axis, $batch_dims),
(TF_GatherV2Op $params, $indices, (TF_ConstOp $axis), $batch_dims)>;