Convert mhlo FloorMod pattern with a cst > 0, to a tf.FloorMod.
PiperOrigin-RevId: 373265180
Change-Id: I0617b7679642d9eb03e144b42c6c48676fdbeb2d
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index e079d6d..3d9ca08 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1903,6 +1903,24 @@
return %8 : tensor<192x8xbf16>
}
+// CHECK-LABEL: func @convert_floor_mod_cst
+// CHECK: %[[CST1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<192x8xbf16>} : () -> tensor<192x8xbf16>
+// CHECK: %[[CST2:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<192x8xbf16>} : () -> tensor<192x8xbf16>
+// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %[[CST2]]) : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
+// CHECK: return %[[RESULT]] : tensor<192x8xbf16>
+// CHECK: }
+func @convert_floor_mod_cst(%arg0: tensor<192x8xbf16>) -> tensor<192x8xbf16> {
+ %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xbf16>
+ %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xbf16>
+ %2 = mhlo.remainder %arg0, %1 : tensor<192x8xbf16>
+ %3 = "mhlo.compare"(%2, %0) {comparison_direction = "LT"} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
+ %4 = "mhlo.compare"(%2, %0) {comparison_direction = "NE"} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
+ %5 = mhlo.and %3, %4 : tensor<192x8xi1>
+ %6 = mhlo.add %2, %1 : tensor<192x8xbf16>
+ %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
+ return %7 : tensor<192x8xbf16>
+}
+
// CHECK-LABEL: func @convert_gather(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<147456xf16>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<192x256x1xi32>)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
index 6d35024..8e45f03 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
@@ -229,6 +229,9 @@
class FloatValueEquals<string val> : Constraint<CPred<
"$0.isa<SplatElementsAttr>() && "
"$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isExactlyValue(" # val # ")">>;
+def FloatValueGreaterThanZero : Constraint<CPred<
+ "$0.isa<SplatElementsAttr>() && "
+ "!$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isNegative()">>;
def SameValue : Constraint<CPred<"$0 == $1">>;
def FloatOrDefaultCompare : Constraint<CPred<
"!$0 || $0.getValue() == \"FLOAT\"">>;
@@ -315,3 +318,35 @@
(FloatOrDefaultCompare $compare_type),
(FloatOrDefaultCompare $compare_type1),
(FloatOrDefaultCompare $compare_type2)]>;
+
+// Converts a soup of HLOs representing floor_mod with a constant to
+// tf.FloorMod. The pattern matched executes the following computation:
+//
+// cst = value that is > 0
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg1):
+// if (rem[i] < 0 && rem[i] != 0)
+// rem[i] += cst
+// return rem
+def : Pat<
+(HLO_SelectOp
+ (HLO_AndOp
+ (HLO_CompareOp:$rltz
+ (HLO_RemOp:$rem $arg, (HLO_ConstOp $cst)),
+ (HLO_ConstOp $cst1),
+ HLO_COMPARISON_DIRECTION_LT,
+ $compare_type),
+ (HLO_CompareOp:$rnz $rem1, (HLO_ConstOp $cst2), HLO_COMPARISON_DIRECTION_NE, $compare_type3)),
+ (HLO_AddOp $rem2, (HLO_ConstOp $cst3)),
+ $rem3),
+(TF_FloorModOp $arg, (TF_ConstOp $cst3)),
+[(FloatValueGreaterThanZero $cst),
+(FloatValueEquals<"0.0"> $cst1),
+(FloatValueEquals<"0.0"> $cst2),
+(FloatValueGreaterThanZero $cst3),
+(SameValue $cst, $cst3),
+(SameValue $rem, $rem1),
+(SameValue $rem, $rem2),
+(SameValue $rem, $rem3),
+(FloatOrDefaultCompare $compare_type),
+(FloatOrDefaultCompare $compare_type3)]>;