[MLIR][KernelGen] Add `tf.Atanh` kernels
PiperOrigin-RevId: 352393602
Change-Id: I2431e39759a12735241e9efb9ff778bdb287e6d3
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
index 558da58..c633bd2 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td
@@ -397,6 +397,20 @@
}];
}
+def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
+ HLO_FpOrComplexTensor> {
+ let summary = "Atanh operator";
+
+ let description = [{
+ Returns `Atanh(operand)` element-wise.
+
+ $$
+ \atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
+ = nan otherwise
+ $$
+ }];
+}
+
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
HLO_FpOrComplexTensor> {
let summary = "Conj operator";
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
index b8b6abb..a2b97a8 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td
@@ -175,6 +175,29 @@
(HLO_ConstantLike<"1"> $input)
)>;
+// Express `atanh` as follows:
+// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
+// atanh(x) = nan otherwise
+def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input),
+ (HLO_SelectOp
+ (HLO_CompareOp
+ (HLO_AbsOp $input),
+ (HLO_ConstantLike<"1"> $input),
+ HLO_COMPARISON_DIRECTION_GT,
+ (HLO_DEFAULT_COMPARISON_TYPE)
+ ),
+ (HLO_ConstantLike<"NAN"> $input),
+ (HLO_MulOp
+ (HLO_SubOp
+ (HLO_Log1pOp $input),
+ (HLO_Log1pOp
+ (HLO_NegOp $input)
+ )
+ ),
+ (HLO_ConstantLike<"0.5"> $input)
+ )
+ )>;
+
// Express `conj` as
// conj(x) = (re(x), -im(x)).
def : Pat<(HLOClient_ConjOp $v),
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
index bd6d891..70d5d38 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc
@@ -50,9 +50,10 @@
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions.
-#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
- fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
- sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
+#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
+ fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
+ sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
+ sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
index 7707b4f..9603bf0 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td
@@ -588,6 +588,7 @@
[TF_AcosOp, HLOClient_AcosOp],
[TF_AsinOp, HLOClient_AsinOp],
[TF_AtanOp, HLOClient_AtanOp],
+ [TF_AtanhOp, HLOClient_AtanhOp],
[TF_CeilOp, HLO_CeilOp],
[TF_CoshOp, HLOClient_CoshOp],
[TF_ComplexAbsOp, HLO_AbsOp],
diff --git a/tensorflow/core/kernels/cwise_op_atanh.cc b/tensorflow/core/kernels/cwise_op_atanh.cc
index 2404cd1..def2013 100644
--- a/tensorflow/core/kernels/cwise_op_atanh.cc
+++ b/tensorflow/core/kernels/cwise_op_atanh.cc
@@ -20,8 +20,11 @@
REGISTER4(UnaryOp, CPU, "Atanh", functor::atanh, float, double, complex64,
complex128);
-
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
+#if !defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) || \
+ !defined(MLIR_GENERATED_EXPERIMENTAL_GPU_KERNELS_ENABLED)
REGISTER2(UnaryOp, GPU, "Atanh", functor::atanh, float, double);
#endif
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD
index e45360a..e78e00c 100644
--- a/tensorflow/core/kernels/mlir_generated/BUILD
+++ b/tensorflow/core/kernels/mlir_generated/BUILD
@@ -50,6 +50,7 @@
"gpu_op_asin.cc",
"gpu_op_asinh.cc",
"gpu_op_atan.cc",
+ "gpu_op_atanh.cc",
"gpu_op_ceil.cc",
"gpu_op_complex.cc",
"gpu_op_complex_abs.cc",
@@ -118,6 +119,7 @@
":asin_kernels",
":asinh_kernels",
":atan_kernels",
+ ":atanh_kernels",
":ceil_kernels",
":complex_abs_kernels",
":complex_kernels",
@@ -350,6 +352,16 @@
)
gen_kernel_library(
+ name = "atanh",
+ tile_size = "256",
+ types = [
+ "f32",
+ "f64",
+ ],
+ unroll_factors = "4",
+)
+
+gen_kernel_library(
name = "conj",
tile_size = "256",
types = [
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
new file mode 100644
index 0000000..5f16218
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/gpu_op_atanh.cc
@@ -0,0 +1,24 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/mlir_generated/gpu_ops_base.h"
+
+namespace tensorflow {
+
+GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f32, DT_FLOAT, float);
+GENERATE_AND_REGISTER_UNARY_KERNEL(Atanh, f64, DT_DOUBLE, double);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
index fc6f767..7854391 100644
--- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
+++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc
@@ -207,6 +207,16 @@
GENERATE_DEFAULT_TEST(Atan, DT_DOUBLE, DT_DOUBLE, std::atan,
test::GpuOpsTestConfig())
+/// Test `tf.Atanh`.
+
+GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
+ Atanh, DT_FLOAT, DT_FLOAT, test::DefaultInputBetweenZeroAndOne<float>(),
+ std::atanh, test::GpuOpsTestConfig())
+
+GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES(
+ Atanh, DT_DOUBLE, DT_DOUBLE, test::DefaultInputBetweenZeroAndOne<double>(),
+ std::atanh, test::GpuOpsTestConfig())
+
/// Test `tf.Ceil`.
GENERATE_DEFAULT_TEST(Ceil, DT_FLOAT, DT_FLOAT, std::ceil,
diff --git a/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl b/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl
new file mode 100644
index 0000000..5604833
--- /dev/null
+++ b/tensorflow/core/kernels/mlir_generated/op_definitions/atanh.mlir.tmpl
@@ -0,0 +1,5 @@
+func @Atanh_elem_type(%arg0: tensor<*xelem_type>)
+ -> tensor<*xelem_type> attributes {tf_entry, llvm.emit_c_interface} {
+ %0 = "tf.Atanh"(%arg0) : (tensor<*xelem_type>) -> tensor<*xelem_type>
+ return %0 : tensor<*xelem_type>
+}