[cpu] vectorize atanh (#107786)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107786
Approved by: https://github.com/jgong5, https://github.com/sanchitintel, https://github.com/ezyang
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
index c24d673..dc83718 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
@@ -320,6 +320,9 @@
   Vectorized<T> atan() const {
     return map(Sleef_atanf8_u10);
   }
+  Vectorized<T> atanh() const {
+    return map(Sleef_atanhf8_u10);
+  }
   Vectorized<T> atan2(const Vectorized<T> &b) const {
     __m256 lo, hi;
     __m256 b1, b2;
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
index 8d9f1dd..81144f8 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
@@ -214,6 +214,9 @@
     return _mm256_sub_pd(pi_2, asin());
   }
   Vectorized<c10::complex<double>> atan() const;
+  Vectorized<c10::complex<double>> atanh() const {
+    return map(std::atanh);
+  }
   Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>>&) const {
     AT_ERROR("not supported for complex numbers");
   }
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
index 398fc20..18c55a3 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
@@ -247,6 +247,9 @@
     return map(std::acos);
   }
   Vectorized<c10::complex<float>> atan() const;
+  Vectorized<c10::complex<float>> atanh() const {
+    return map(std::atanh);
+  }
   Vectorized<c10::complex<float>> atan2(const Vectorized<c10::complex<float>>& /*b*/) const {
     AT_ERROR("not supported for complex numbers");
   }
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
index 0be6007..a6fb52e 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
@@ -143,6 +143,9 @@
   Vectorized<double> atan() const {
     return Vectorized<double>(Sleef_atand4_u10(values));
   }
+  Vectorized<double> atanh() const {
+    return Vectorized<double>(Sleef_atanhd4_u10(values));
+  }
   Vectorized<double> atan2(const Vectorized<double> &b) const {
     return Vectorized<double>(Sleef_atan2d4_u10(values, b));
   }
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
index 923ffa4..28a58e3 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
@@ -149,6 +149,9 @@
   Vectorized<float> atan() const {
     return Vectorized<float>(Sleef_atanf8_u10(values));
   }
+  Vectorized<float> atanh() const {
+    return Vectorized<float>(Sleef_atanhf8_u10(values));
+  }
   Vectorized<float> atan2(const Vectorized<float> &b) const {
     return Vectorized<float>(Sleef_atan2f8_u10(values, b));
   }
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
index 8719b0d..50a3377 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
@@ -352,6 +352,12 @@
       map(std::atan)
     );
   }
+  Vectorized<float> atanh() const {
+    return USE_SLEEF(
+      Vectorized<float>(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])),
+      map(std::atanh)
+    );
+  }
   Vectorized<float> atan2(const Vectorized<float> &exp) const {
     USE_SLEEF(
       {
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
index b80f7b7..3a4b8a9 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
@@ -319,6 +319,9 @@
     auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
     return ln * vd_imag_half; // i/2*ln()
   }
+  Vectorized<ComplexDbl> atanh() const {
+    return map(std::atanh);
+  }
 
   Vectorized<ComplexDbl> sin() const {
     return map(std::sin);
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
index 7efffa7..0d2aa0c 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
@@ -449,6 +449,9 @@
     auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
     return ln * imag_half; // i/2*ln()
   }
+  Vectorized<ComplexFlt> atanh() const {
+    return map(std::atanh);
+  }
 
   Vectorized<ComplexFlt> acos() const {
     // acos(x) = pi/2 - asin(x)
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
index bd61866..bd1675f 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
@@ -225,6 +225,9 @@
   Vectorized<double> atan() const {
      return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)};
   }
+  Vectorized<double> atanh() const {
+     return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)};
+  }
   Vectorized<double> atan2(const Vectorized<double>& b) const {
      return {Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)};
   }
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
index e72ace0..31ef9ff 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
@@ -264,6 +264,9 @@
   Vectorized<float> atan() const {
     return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)};
   }
+  Vectorized<float> atanh() const {
+    return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)};
+  }
   Vectorized<float> atan2(const Vectorized<float>& b) const {
     return {Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)};
   }
diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
index a93cb8b..c7f5fbe 100644
--- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
+++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
@@ -1045,6 +1045,9 @@
   Vectorized<T> atan() const {
     return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10);
   }
+  Vectorized<T> atanh() const {
+    return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10);
+  }
 
   Vectorized<T> erf() const {
     return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10);
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
index cedb5c6..6efc834 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
@@ -408,6 +408,9 @@
   Vectorized<T> atan() const {
     return map(Sleef_atanf16_u10);
   }
+  Vectorized<T> atanh() const {
+    return map(Sleef_atanhf16_u10);
+  }
   Vectorized<T> atan2(const Vectorized<T> &b) const {
     __m512 lo, hi;
     __m512 b1, b2;
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
index 9644ff3..b015790 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
@@ -275,6 +275,9 @@
     return _mm512_sub_pd(pi_2, asin());
   }
   Vectorized<c10::complex<double>> atan() const;
+  Vectorized<c10::complex<double>> atanh() const {
+    return map(std::atanh);
+  }
   Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>> &b) const {
     AT_ERROR("not supported for complex numbers");
   }
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
index 4898ff7..f9d6040 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
@@ -778,6 +778,9 @@
     return map(std::acos);
   }
   Vectorized<c10::complex<float>> atan() const;
+  Vectorized<c10::complex<float>> atanh() const {
+    return map(std::atanh);
+  }
   Vectorized<c10::complex<float>> atan2(const Vectorized<c10::complex<float>> &b) const {
     AT_ERROR("not supported for complex numbers");
   }
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
index 6adbe10..ee4969b 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_double.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
@@ -163,6 +163,9 @@
   Vectorized<double> atan() const {
     return Vectorized<double>(Sleef_atand8_u10(values));
   }
+  Vectorized<double> atanh() const {
+    return Vectorized<double>(Sleef_atanhd8_u10(values));
+  }
   Vectorized<double> atan2(const Vectorized<double> &b) const {
     return Vectorized<double>(Sleef_atan2d8_u10(values, b));
   }
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
index 3b9a095..d6e6c76 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
@@ -178,6 +178,9 @@
   Vectorized<float> atan() const {
     return Vectorized<float>(Sleef_atanf16_u10(values));
   }
+  Vectorized<float> atanh() const {
+    return Vectorized<float>(Sleef_atanhf16_u10(values));
+  }
   Vectorized<float> atan2(const Vectorized<float> &b) const {
     return Vectorized<float>(Sleef_atan2f16_u10(values, b));
   }
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index f68e9c6..38e8f3d 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -368,6 +368,9 @@
   Vectorized<T> atan() const {
     return map(std::atan);
   }
+  Vectorized<T> atanh() const {
+    return map(std::atanh);
+  }
   Vectorized<T> atan2(const Vectorized<T> &exp) const {
     Vectorized<T> ret;
     for (const auto i : c10::irange(size())) {
diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h
index 587a7a9..bff08aa 100644
--- a/aten/src/ATen/cpu/vml.h
+++ b/aten/src/ATen/cpu/vml.h
@@ -67,6 +67,7 @@
 IMPLEMENT_VML(acos)
 IMPLEMENT_VML(asin)
 IMPLEMENT_VML(atan)
+IMPLEMENT_VML(atanh)
 IMPLEMENT_VML(ceil)
 IMPLEMENT_VML(cos)
 // IMPLEMENT_VML(cosh)
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 17d922a..8dd3ab0 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -399,9 +399,10 @@
 
 static void atanh_kernel(TensorIteratorBase& iter) {
     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "atanh_cpu", [&]() {
-      cpu_kernel(
+      cpu_kernel_vec(
         iter,
-        [=](scalar_t a) -> scalar_t { return std::atanh(a); });
+        [=](scalar_t a) -> scalar_t { return std::atanh(a); },
+        [=](Vectorized<scalar_t> self_vec){return self_vec.atanh();});
     });
 }
 
@@ -846,11 +847,11 @@
 ALSO_REGISTER_AVX512_DISPATCH(logit_stub, &CPU_CAPABILITY::logit_kernel);
 ALSO_REGISTER_AVX512_DISPATCH(sinh_stub, &CPU_CAPABILITY::sinh_kernel);
 ALSO_REGISTER_AVX512_DISPATCH(cosh_stub, &CPU_CAPABILITY::cosh_kernel);
+ALSO_REGISTER_AVX512_DISPATCH(atanh_stub, &CPU_CAPABILITY::atanh_kernel);
 
 // Might enable AVX512 dispatch after enabling explicit vectorization for them
 REGISTER_DISPATCH(acosh_stub, &CPU_CAPABILITY::acosh_kernel);
 REGISTER_DISPATCH(asinh_stub, &CPU_CAPABILITY::asinh_kernel);
-REGISTER_DISPATCH(atanh_stub, &CPU_CAPABILITY::atanh_kernel);
 REGISTER_DISPATCH(digamma_stub, &CPU_CAPABILITY::digamma_kernel);
 REGISTER_DISPATCH(trigamma_stub, &CPU_CAPABILITY::trigamma_kernel);
 REGISTER_DISPATCH(polygamma_stub, &CPU_CAPABILITY::polygamma_kernel);
diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h
index c02472b..5dd349f 100644
--- a/c10/util/BFloat16-math.h
+++ b/c10/util/BFloat16-math.h
@@ -43,6 +43,12 @@
 template <
     typename T,
     typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T atanh(T a) {
+  return std::atanh(float(a));
+}
+template <
+    typename T,
+    typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
 inline T erf(T a) {
   return std::erf(float(a));
 }