[MLIR][KernelGen] Add `tf.Atanh` kernels

PiperOrigin-RevId: 352393602
Change-Id: I2431e39759a12735241e9efb9ff778bdb287e6d3
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
index 558da58..c633bd2 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
@@ -397,6 +397,20 @@
   }];
 }
 
+def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
+    HLO_FpOrComplexTensor> {
+  let summary = "Atanh operator";
+
+  let description = [{
+    Returns `Atanh(operand)` element-wise.
+
+    $$
+    \atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
+              = nan                          otherwise
+    $$
+  }];
+}
+
 def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
     HLO_FpOrComplexTensor> {
   let summary = "Conj operator";
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
index b8b6abb..a2b97a8 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
@@ -175,6 +175,29 @@
     (HLO_ConstantLike<"1"> $input)
   )>;
 
+// Express `atanh` as follows:
+//   atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
+//   atanh(x) = nan                          otherwise
+def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input),
+  (HLO_SelectOp
+    (HLO_CompareOp
+      (HLO_AbsOp $input),
+      (HLO_ConstantLike<"1"> $input),
+      HLO_COMPARISON_DIRECTION_GT,
+      (HLO_DEFAULT_COMPARISON_TYPE)
+    ),
+    (HLO_ConstantLike<"NAN"> $input),
+    (HLO_MulOp
+      (HLO_SubOp
+        (HLO_Log1pOp $input),
+        (HLO_Log1pOp
+          (HLO_NegOp $input)
+        )
+      ),
+      (HLO_ConstantLike<"0.5"> $input)
+    )
+  )>;
+
 // Express `conj` as
 //   conj(x) = (re(x), -im(x)).
 def : Pat<(HLOClient_ConjOp $v),
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
index bd6d891..70d5d38 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
@@ -50,9 +50,10 @@
               sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
 
 // TODO(herhut): Generate these out of op definitions.
-#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                           \
-  fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
-      sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
+#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                            \
+  fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
+      sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp)           \
+          sep fn(SinhOp) sep fn(TanOp)
 
 template <typename OpTy>
 inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 7707b4f..9603bf0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -588,6 +588,7 @@
                    [TF_AcosOp, HLOClient_AcosOp],
                    [TF_AsinOp, HLOClient_AsinOp],
                    [TF_AtanOp, HLOClient_AtanOp],
+                   [TF_AtanhOp, HLOClient_AtanhOp],
                    [TF_CeilOp, HLO_CeilOp],
                    [TF_CoshOp, HLOClient_CoshOp],
                    [TF_ComplexAbsOp, HLO_AbsOp],
diff --git a/tensorflow/core/kernels/cwise_op_atanh.cc b/tensorflow/core/kernels/cwise_op_atanh.cc
index 2404cd1..def2013 100644
--- a/tensorflow/core/kernels/cwise_op_atanh.cc
+++ b/tensorflow/core/kernels/cwise_op_atanh.cc
@@ -20,8 +20,11 @@
 REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64,
           complex128);
 
-
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+    !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
 REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);
 #endif
+#endif
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD
index e45360a..e78e00c 100644
--- a/tensorflow/core/kernels/mlir_generated/BUILD
+++ b/tensorflow/core/kernels/mlir_generated/BUILD
@@ -50,6 +50,7 @@
         "gpu_op_asin.cc",
         "gpu_op_asinh.cc",
         "gpu_op_atan.cc",
+        "gpu_op_atanh.cc",
         "gpu_op_ceil.cc",
         "gpu_op_complex.cc",
         "gpu_op_complex_abs.cc",
@@ -118,6 +119,7 @@
         ":asin_kernels",
         ":asinh_kernels",
         ":atan_kernels",
+        ":atanh_kernels",
         ":ceil_kernels",
         ":complex_abs_kernels",
         ":complex_kernels",
@@ -350,6 +352,16 @@
 )
 
 gen_kernel_library(
+    name = "atanh",
+    tile_size = "256",
+    types = [
+        "f32",
+        "f64",
+    ],
+    unroll_factors = "4",
+)
+
+gen_kernel_library(
     name = "conj",
     tile_size = "256",
     types = [
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
new file mode 100644
index 0000000..5f16218
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
@@ -0,0 +1,24 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
index fc6f767..7854391 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
@@ -207,6 +207,16 @@
 GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan,
                       test::GpuOpsTestConfig())
 
+/// Test `tf.Atanh`.
+
+GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
+    Atanh, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
+    std::atanh, test::GpuOpsTestConfig())
+
+GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
+    Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
+    std::atanh, test::GpuOpsTestConfig())
+
 /// Test `tf.Ceil`.
 
 GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl
new file mode 100644
index 0000000..5604833
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl
@@ -0,0 +1,5 @@
+func @Atanh_elem_type(%arg0: tensor<*xelem_type>)
+    -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+  %0 = "tf.Atanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
+  return %0 : tensor<*xelem_type>
+}