Legalize tf.Round to xla_hlo.round
PiperOrigin-RevId: 307841780
Change-Id: Ic2c053eb23e298780c833bd7fa6b7adb38c1875b
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 3c9d0c1..0650e5a 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -2166,6 +2166,12 @@
return %0 : tensor<*xf32>
}
+// CHECK-LABEL: func @round
+func @round(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+ // CHECK: "xla_hlo.round_nearest_afz"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ %0 = "tf.Round"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+ return %0 : tensor<2xf32>
+}
// CHECK-LABEL: func @rsqrt
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 6a36f3e..3e89890 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -541,6 +541,7 @@
[TF_LogicalNotOp, HLO_NotOp],
[TF_NegOp, HLO_NegOp],
[TF_RealOp, HLO_RealOp],
+ [TF_RoundOp, HLO_RoundOp],
[TF_RsqrtOp, HLO_RsqrtOp],
[TF_SinOp, HLO_SinOp],
[TF_SqrtOp, HLO_SqrtOp],