Add mhlo.round_nearest_afz to linalg.generic with scalar transform
Round nearest afz can be represented with an abs, add, ceil, and copy sign to
handle rounding.
PiperOrigin-RevId: 416975089
Change-Id: I49de5318da00236eb77d6d28ee0f09473d1fde06
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
index 74e744f..413f0e2 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h
@@ -795,6 +795,30 @@
}
template <>
+inline Value MapMhloOpToStdScalarOp<mhlo::RoundOp>(Location loc,
+ ArrayRef<Type> result_types,
+ ArrayRef<Type> arg_types,
+ ValueRange args,
+ OpBuilder* b) {
+ mhlo::RoundOp::Adaptor adaptor(args);
+ auto lb = ImplicitLocOpBuilder(loc, *b);
+ auto operand = adaptor.operand();
+ auto operand_ty = operand.getType();
+ auto element_ty = getElementTypeOrSelf(operand_ty);
+
+ if (auto float_type = element_ty.dyn_cast<FloatType>()) {
+ Value half =
+ b->create<arith::ConstantOp>(loc, b->getFloatAttr(float_type, 0.5));
+ auto abs = lb.create<math::AbsOp>(operand_ty, operand);
+ auto add = lb.create<arith::AddFOp>(abs, half);
+ auto floor = lb.create<math::FloorOp>(add);
+ return lb.create<mlir::math::CopySignOp>(floor, operand);
+ }
+
+ return nullptr;
+}
+
+template <>
inline Value MapMhloOpToStdScalarOp<mhlo::SelectOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Type> arg_types,
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index ba9a13f..947d34f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2841,6 +2841,7 @@
PointwiseToLinalgConverter<mhlo::PowOp>,
PointwiseToLinalgConverter<mhlo::RealOp>,
PointwiseToLinalgConverter<mhlo::RemOp>,
+ PointwiseToLinalgConverter<mhlo::RoundOp>,
PointwiseToLinalgConverter<mhlo::RsqrtOp>,
PointwiseToLinalgConverter<mhlo::SelectOp>,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp>,
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index 727f284..46d3e76 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -494,6 +494,20 @@
// -----
+// CHECK-LABEL: func @round
+func @round(%val: tensor<2x2xf32>) -> tensor<2x2xf32> {
+ // CHECK: %[[HALF:.+]] = arith.constant 5.000000e-01
+ // CHECK: %[[ABS:.+]] = math.abs %arg1
+ // CHECK: %[[ADD:.+]] = arith.addf %[[ABS]], %[[HALF]]
+ // CHECK: %[[CEIL:.+]] = math.floor %[[ADD]]
+ // CHECK: %[[COPY:.+]] = math.copysign %[[CEIL]], %arg1
+ // CHECK: linalg.yield %[[COPY]]
+ %0 = "mhlo.round_nearest_afz"(%val) : (tensor<2x2xf32>) -> (tensor<2x2xf32>)
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {