Add HLO ops and TF lowering to MLIR for some element-wise unary operations

PiperOrigin-RevId: 269835024
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index a16e5db..6135848 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -112,6 +112,8 @@
 
 def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AbsOp;
 
+def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CeilOp;
+
 def HLO_ConvertOp : HLO_UnaryElementwiseOp<
       "convert", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ConvertOp {
   let hasFolder = 1;
@@ -121,6 +123,8 @@
   let hasCustomHLOConverter = 1;
 }
 
+def HLO_CosOp: HLO_UnaryElementwiseOp<"cos", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CosOp;
+
 def HLO_ExpOp: HLO_UnaryElementwiseOp<"exp", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ExpOp;
 
 def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_FloorOp;
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
index 6623c21..c95e279 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
@@ -61,6 +61,17 @@
   }];
 }
 
+class BASE_HLO_CeilOp {
+  string summary = "Ceil operator";
+
+  string description = [{
+    Returns `Ceil(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
 class BASE_HLO_ConvertOp {
   string summary = "Convert operator";
 
@@ -72,6 +83,17 @@
   }];
 }
 
+class BASE_HLO_CosOp {
+  string summary = "Cos operator";
+
+  string description = [{
+    Returns `Cos(operand)` element-wise.
+
+    See
+    https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
+  }];
+}
+
 class BASE_HLO_ExpOp {
   string summary = "Exponential operator";
 
diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
index 597e5b3..d792271 100644
--- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
@@ -88,8 +88,12 @@
 
 def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
 
+def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp;
+
 def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert">, BASE_HLO_ConvertOp;
 
+def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cos">, BASE_HLO_CosOp;
+
 def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exp">, BASE_HLO_ExpOp;
 
 def LHLO_NegOp: LHLO_UnaryElementwiseOp<"neg">, BASE_HLO_NegOp;
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 42fb73f..a76bb95 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -438,6 +438,55 @@
 // Unary op legalizations.
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @abs
+func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @ceil
+func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @cos
+func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @exp
+func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @floor
+func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: @neg
+func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
+// CHECK-LABEL: tanh
+func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
+  // CHECK:  "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
+  return %0 : tensor<2xf32>
+}
+
 // CHECK-LABEL: reshape
 func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> {
   // CHECK:  %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32>
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index f289350..ce83ed8 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -181,6 +181,21 @@
 // Unary op patterns.
 //===----------------------------------------------------------------------===//
 
+foreach Mapping = [
+                   [TF_AbsOp, HLO_AbsOp],
+                   [TF_CeilOp, HLO_CeilOp],
+                   [TF_CosOp, HLO_CosOp],
+                   [TF_ExpOp, HLO_ExpOp],
+                   [TF_FloorOp, HLO_FloorOp],
+                   [TF_LogOp, HLO_LogOp],
+                   [TF_NegOp, HLO_NegOp],
+                   [TF_TanhOp, HLO_TanhOp],
+                  ] in {
+ def : Pat<(Mapping[0] AnyTensor:$input),
+           (Mapping[1] $input)>;
+}
+
+
 foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp] in {
   def : Pat<(TfOp:$res AnyStaticShapeTensor:$arg, $ignored),
             (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>;