Add mhlo.round_nearest_afz to linalg.generic with scalar transform

Round nearest afz can be represented with an abs, add, ceil, and copy sign to
handle rounding.

PiperOrigin-RevId: 416975089
Change-Id: I49de5318da00236eb77d6d28ee0f09473d1fde06
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
index 74e744f..413f0e2 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
@@ -795,6 +795,30 @@
 }
 
 template <>
+inline Value MapMhloOpToStdScalarOp<mhlo::RoundOp>(Location loc,
+                                                   ArrayRef<Type> result_types,
+                                                   ArrayRef<Type> arg_types,
+                                                   ValueRange args,
+                                                   OpBuilder* b) {
+  mhlo::RoundOp::Adaptor adaptor(args);
+  auto lb = ImplicitLocOpBuilder(loc, *b);
+  auto operand = adaptor.operand();
+  auto operand_ty = operand.getType();
+  auto element_ty = getElementTypeOrSelf(operand_ty);
+
+  if (auto float_type = element_ty.dyn_cast<FloatType>()) {
+    Value half =
+        b->create<arith::ConstantOp>(loc, b->getFloatAttr(float_type, 0.5));
+    auto abs = lb.create<math::AbsOp>(operand_ty, operand);
+    auto add = lb.create<arith::AddFOp>(abs, half);
+    auto floor = lb.create<math::FloorOp>(add);
+    return lb.create<mlir::math::CopySignOp>(floor, operand);
+  }
+
+  return nullptr;
+}
+
+template <>
 inline Value MapMhloOpToStdScalarOp<mhlo::SelectOp>(Location loc,
                                                     ArrayRef<Type> result_types,
                                                     ArrayRef<Type> arg_types,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index ba9a13f..947d34f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2841,6 +2841,7 @@
       PointwiseToLinalgConverter<mhlo::PowOp>,
       PointwiseToLinalgConverter<mhlo::RealOp>,
       PointwiseToLinalgConverter<mhlo::RemOp>,
+      PointwiseToLinalgConverter<mhlo::RoundOp>,
       PointwiseToLinalgConverter<mhlo::RsqrtOp>,
       PointwiseToLinalgConverter<mhlo::SelectOp>,
       PointwiseToLinalgConverter<mhlo::ShiftLeftOp>,
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index 727f284..46d3e76 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -494,6 +494,20 @@
 
 // -----
 
+// CHECK-LABEL: func @round
+func @round(%val: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  // CHECK: %[[HALF:.+]] = arith.constant 5.000000e-01
+  // CHECK: %[[ABS:.+]] = math.abs %arg1
+  // CHECK: %[[ADD:.+]] = arith.addf %[[ABS]], %[[HALF]]
+  // CHECK: %[[CEIL:.+]] = math.floor %[[ADD]]
+  // CHECK: %[[COPY:.+]] = math.copysign %[[CEIL]], %arg1
+  // CHECK: linalg.yield %[[COPY]]
+  %0 = "mhlo.round_nearest_afz"(%val) : (tensor<2x2xf32>) -> (tensor<2x2xf32>)
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @select
 func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
              %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {