Convert mhlo FloorMod pattern with a cst > 0, to a tf.FloorMod.

PiperOrigin-RevId: 373265180
Change-Id: I0617b7679642d9eb03e144b42c6c48676fdbeb2d
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index e079d6d..3d9ca08 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -1903,6 +1903,24 @@
   return %8 : tensor<192x8xbf16>
 }
 
+// CHECK-LABEL: func @convert_floor_mod_cst
+// CHECK: %[[CST1:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<192x8xbf16>} : () -> tensor<192x8xbf16>
+// CHECK: %[[CST2:.*]] = "tf.Const"() {value = dense<2.000000e+00> : tensor<192x8xbf16>} : () -> tensor<192x8xbf16>
+// CHECK: %[[RESULT:.*]] = "tf.FloorMod"(%arg0, %[[CST2]]) : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
+// CHECK: return %[[RESULT]] : tensor<192x8xbf16>
+// CHECK: }
+func @convert_floor_mod_cst(%arg0: tensor<192x8xbf16>) -> tensor<192x8xbf16> {
+  %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xbf16>
+  %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xbf16>
+  %2 = mhlo.remainder %arg0, %1 : tensor<192x8xbf16>
+  %3 = "mhlo.compare"(%2, %0) {comparison_direction = "LT"} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
+  %4 = "mhlo.compare"(%2, %0) {comparison_direction = "NE"} : (tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xi1>
+  %5 = mhlo.and %3, %4 : tensor<192x8xi1>
+  %6 = mhlo.add %2, %1 : tensor<192x8xbf16>
+  %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xbf16>, tensor<192x8xbf16>) -> tensor<192x8xbf16>
+  return %7 : tensor<192x8xbf16>
+}
+
 // CHECK-LABEL:   func @convert_gather(
 // CHECK-SAME:                         %[[ARG_0:.*]]: tensor<147456xf16>,
 // CHECK-SAME:                         %[[ARG_1:.*]]: tensor<192x256x1xi32>)
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
index 6d35024..8e45f03 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
@@ -229,6 +229,9 @@
 class FloatValueEquals<string val> : Constraint<CPred<
   "$0.isa<SplatElementsAttr>() && "
   "$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isExactlyValue(" # val # ")">>;
+def FloatValueGreaterThanZero : Constraint<CPred<
+  "$0.isa<SplatElementsAttr>() && "
+  "!$0.cast<SplatElementsAttr>().getSplatValue<APFloat>().isNegative()">>;
 def SameValue : Constraint<CPred<"$0 == $1">>;
 def FloatOrDefaultCompare : Constraint<CPred<
   "!$0 || $0.getValue() == \"FLOAT\"">>;
@@ -315,3 +318,35 @@
 (FloatOrDefaultCompare $compare_type),
 (FloatOrDefaultCompare $compare_type1),
 (FloatOrDefaultCompare $compare_type2)]>;
+
+// Converts a soup of HLOs representing floor_mod with a constant to
+// tf.FloorMod. The pattern matched executes the following computation:
+//
+// cst = value that is > 0
+// rem = remainder(arg0, cst)
+// for i in 0 to len(arg1):
+//    if (rem[i] < 0 && rem[i] != 0)
+//       rem[i] += cst
+// return rem
+def : Pat<
+(HLO_SelectOp
+  (HLO_AndOp
+    (HLO_CompareOp:$rltz
+      (HLO_RemOp:$rem $arg, (HLO_ConstOp $cst)),
+      (HLO_ConstOp $cst1),
+      HLO_COMPARISON_DIRECTION_LT,
+      $compare_type),
+    (HLO_CompareOp:$rnz $rem1, (HLO_ConstOp $cst2), HLO_COMPARISON_DIRECTION_NE, $compare_type3)),
+  (HLO_AddOp $rem2, (HLO_ConstOp $cst3)),
+  $rem3),
+(TF_FloorModOp $arg, (TF_ConstOp $cst3)),
+[(FloatValueGreaterThanZero $cst),
+(FloatValueEquals<"0.0"> $cst1),
+(FloatValueEquals<"0.0"> $cst2),
+(FloatValueGreaterThanZero $cst3),
+(SameValue $cst, $cst3),
+(SameValue $rem, $rem1),
+(SameValue $rem, $rem2),
+(SameValue $rem, $rem3),
+(FloatOrDefaultCompare $compare_type),
+(FloatOrDefaultCompare $compare_type3)]>;