Remove pattern to map xla_hlo.divide to tf.RealDiv.
xla_hlo.divide is already lowered to tf.Div. tf.RealDiv expects its operands to be real and does not have the same semantics as xla_hlo.divide.
PiperOrigin-RevId: 308842783
Change-Id: I88313f2644aacf0216366b2a1994fa9a94d20558
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
index 1cc2f2f..ddbd6ec 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir
@@ -728,13 +728,13 @@
// CHECK-LABEL: func @div(
// CHECK-SAME: [[VAL_18:%.*]]: tensor<2xi32>) -> tensor<2xi32> {
-// CHECK: [[VAL_19:%.*]] = "tf.RealDiv"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+// CHECK: [[VAL_19:%.*]] = "tf.Div"([[VAL_18]], [[VAL_18]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK: return [[VAL_19]] : tensor<2xi32>
// CHECK: }
// CHECK-LABEL: func @broadcast_div(
// CHECK-SAME: [[VAL_20:%.*]]: tensor<1xi32>, [[VAL_21:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
-// CHECK: [[VAL_22:%.*]] = "tf.RealDiv"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
+// CHECK: [[VAL_22:%.*]] = "tf.Div"([[VAL_20]], [[VAL_21]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK: return [[VAL_22]] : tensor<1x2xi32>
// CHECK: }
@@ -746,7 +746,7 @@
// CHECK-LABEL: func @div_dynamic(
// CHECK-SAME: [[VAL_26:%.*]]: tensor<?xi32>, [[VAL_27:%.*]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
-// CHECK: [[VAL_28:%.*]] = "tf.RealDiv"([[VAL_26]], [[VAL_27]]) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
+// CHECK: [[VAL_28:%.*]] = "tf.Div"([[VAL_26]], [[VAL_27]]) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: return [[VAL_28]] : tensor<?x?xi32>
// CHECK: }
@@ -776,13 +776,13 @@
// CHECK-LABEL: func @real_div(
// CHECK-SAME: [[VAL_40:%.*]]: tensor<2xi32>) -> tensor<2xi32> {
-// CHECK: [[VAL_41:%.*]] = "tf.RealDiv"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+// CHECK: [[VAL_41:%.*]] = "tf.Div"([[VAL_40]], [[VAL_40]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
// CHECK: return [[VAL_41]] : tensor<2xi32>
// CHECK: }
// CHECK-LABEL: func @broadcast_real_div(
// CHECK-SAME: [[VAL_42:%.*]]: tensor<1xi32>, [[VAL_43:%.*]]: tensor<1x2xi32>) -> tensor<1x2xi32> {
-// CHECK: [[VAL_44:%.*]] = "tf.RealDiv"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
+// CHECK: [[VAL_44:%.*]] = "tf.Div"([[VAL_42]], [[VAL_43]]) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
// CHECK: return [[VAL_44]] : tensor<1x2xi32>
// CHECK: }
@@ -901,7 +901,7 @@
// CHECK: [[VAL_98:%.*]] = "tf.Const"() {value = dense<0> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK: [[VAL_99:%.*]] = "tf.Less"([[VAL_95]], [[VAL_98]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
// CHECK: [[VAL_100:%.*]] = "tf.Equal"([[VAL_97]], [[VAL_99]]) {incompatible_shape_error = true} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
-// CHECK: [[VAL_101:%.*]] = "tf.RealDiv"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
+// CHECK: [[VAL_101:%.*]] = "tf.Div"([[VAL_94]], [[VAL_95]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_102:%.*]] = "tf.Abs"([[VAL_94]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_103:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32>
// CHECK: [[VAL_104:%.*]] = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32>
@@ -909,7 +909,7 @@
// CHECK: [[VAL_106:%.*]] = "tf.AddV2"([[VAL_102]], [[VAL_105]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_107:%.*]] = "tf.Neg"([[VAL_106]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_108:%.*]] = "tf.Abs"([[VAL_95]]) : (tensor<3xi32>) -> tensor<3xi32>
-// CHECK: [[VAL_109:%.*]] = "tf.RealDiv"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
+// CHECK: [[VAL_109:%.*]] = "tf.Div"([[VAL_107]], [[VAL_108]]) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_110:%.*]] = "tf.Select"([[VAL_100]], [[VAL_101]], [[VAL_109]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: return [[VAL_110]] : tensor<2x3xi32>
// CHECK: }
@@ -921,7 +921,7 @@
// CHECK: [[VAL_115:%.*]] = "tf.Const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
// CHECK: [[VAL_116:%.*]] = "tf.Less"([[VAL_112]], [[VAL_115]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
// CHECK: [[VAL_117:%.*]] = "tf.Equal"([[VAL_114]], [[VAL_116]]) {incompatible_shape_error = true} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
-// CHECK: [[VAL_118:%.*]] = "tf.RealDiv"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+// CHECK: [[VAL_118:%.*]] = "tf.Div"([[VAL_111]], [[VAL_112]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_119:%.*]] = "tf.Abs"([[VAL_111]]) : (tensor<3xi32>) -> tensor<3xi32>
// CHECK: [[VAL_120:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_121:%.*]] = "tf.Const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
@@ -929,23 +929,23 @@
// CHECK: [[VAL_123:%.*]] = "tf.AddV2"([[VAL_119]], [[VAL_122]]) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_124:%.*]] = "tf.Neg"([[VAL_123]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_125:%.*]] = "tf.Abs"([[VAL_112]]) : (tensor<2x3xi32>) -> tensor<2x3xi32>
-// CHECK: [[VAL_126:%.*]] = "tf.RealDiv"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+// CHECK: [[VAL_126:%.*]] = "tf.Div"([[VAL_124]], [[VAL_125]]) : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: [[VAL_127:%.*]] = "tf.Select"([[VAL_117]], [[VAL_118]], [[VAL_126]]) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: return [[VAL_127]] : tensor<2x3xi32>
// CHECK: }
// CHECK-LABEL: func @floordiv_f32(
// CHECK-SAME: [[VAL_128:%.*]]: tensor<2xf32>) -> tensor<2xf32> {
-// CHECK: [[VAL_129:%.*]] = "tf.RealDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
-// CHECK: [[VAL_130:%.*]] = "tf.RealDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+// CHECK: [[VAL_129:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
+// CHECK: [[VAL_130:%.*]] = "tf.Div"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: [[VAL_131:%.*]] = "tf.FloorDiv"([[VAL_128]], [[VAL_128]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: return [[VAL_131]] : tensor<2xf32>
// CHECK: }
// CHECK-LABEL: func @floordiv_f16_broadcast(
// CHECK-SAME: [[VAL_132:%.*]]: tensor<2x3xf16>, [[VAL_133:%.*]]: tensor<3xf16>) -> tensor<2x3xf16> {
-// CHECK: [[VAL_134:%.*]] = "tf.RealDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
-// CHECK: [[VAL_135:%.*]] = "tf.RealDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
+// CHECK: [[VAL_134:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
+// CHECK: [[VAL_135:%.*]] = "tf.Div"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
// CHECK: [[VAL_136:%.*]] = "tf.FloorDiv"([[VAL_132]], [[VAL_133]]) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
// CHECK: return [[VAL_136]] : tensor<2x3xf16>
// CHECK: }
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
index 6abceed..f337198 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td
@@ -38,7 +38,6 @@
[HLO_MinOp, TF_MinimumOp],
[HLO_MulOp, TF_MulOp],
[HLO_PowOp, TF_PowOp],
- [HLO_DivOp, TF_RealDivOp],
[HLO_SubOp, TF_SubOp],
[HLO_Atan2Op, TF_Atan2Op],
[HLO_RemOp, TF_ModOp]] in