[cpu] vectorize atanh (#107786)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107786
Approved by: https://github.com/jgong5, https://github.com/sanchitintel, https://github.com/ezyang
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
index c24d673..dc83718 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
@@ -320,6 +320,9 @@
Vectorized<T> atan() const {
return map(Sleef_atanf8_u10);
}
+ Vectorized<T> atanh() const {
+ return map(Sleef_atanhf8_u10);
+ }
Vectorized<T> atan2(const Vectorized<T> &b) const {
__m256 lo, hi;
__m256 b1, b2;
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
index 8d9f1dd..81144f8 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
@@ -214,6 +214,9 @@
return _mm256_sub_pd(pi_2, asin());
}
Vectorized<c10::complex<double>> atan() const;
+ Vectorized<c10::complex<double>> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>>&) const {
AT_ERROR("not supported for complex numbers");
}
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
index 398fc20..18c55a3 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
@@ -247,6 +247,9 @@
return map(std::acos);
}
Vectorized<c10::complex<float>> atan() const;
+ Vectorized<c10::complex<float>> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<c10::complex<float>> atan2(const Vectorized<c10::complex<float>>& /*b*/) const {
AT_ERROR("not supported for complex numbers");
}
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
index 0be6007..a6fb52e 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
@@ -143,6 +143,9 @@
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand4_u10(values));
}
+ Vectorized<double> atanh() const {
+ return Vectorized<double>(Sleef_atanhd4_u10(values));
+ }
Vectorized<double> atan2(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_atan2d4_u10(values, b));
}
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
index 923ffa4..28a58e3 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
@@ -149,6 +149,9 @@
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf8_u10(values));
}
+ Vectorized<float> atanh() const {
+ return Vectorized<float>(Sleef_atanhf8_u10(values));
+ }
Vectorized<float> atan2(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_atan2f8_u10(values, b));
}
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
index 8719b0d..50a3377 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
@@ -352,6 +352,12 @@
map(std::atan)
);
}
+ Vectorized<float> atanh() const {
+ return USE_SLEEF(
+ Vectorized<float>(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])),
+ map(std::atanh)
+ );
+ }
Vectorized<float> atan2(const Vectorized<float> &exp) const {
USE_SLEEF(
{
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
index b80f7b7..3a4b8a9 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
@@ -319,6 +319,9 @@
auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
return ln * vd_imag_half; // i/2*ln()
}
+ Vectorized<ComplexDbl> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<ComplexDbl> sin() const {
return map(std::sin);
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
index 7efffa7..0d2aa0c 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
@@ -449,6 +449,9 @@
auto ln = (sum / sub).log(); // ln((i + z)/(i - z))
return ln * imag_half; // i/2*ln()
}
+ Vectorized<ComplexFlt> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<ComplexFlt> acos() const {
// acos(x) = pi/2 - asin(x)
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
index bd61866..bd1675f 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h
@@ -225,6 +225,9 @@
Vectorized<double> atan() const {
return {Sleef_atand2_u10(_vec0), Sleef_atand2_u10(_vec1)};
}
+ Vectorized<double> atanh() const {
+ return {Sleef_atanhd2_u10(_vec0), Sleef_atanhd2_u10(_vec1)};
+ }
Vectorized<double> atan2(const Vectorized<double>& b) const {
return {Sleef_atan2d2_u10(_vec0, b._vec0), Sleef_atan2d2_u10(_vec1, b._vec1)};
}
diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
index e72ace0..31ef9ff 100644
--- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
+++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h
@@ -264,6 +264,9 @@
Vectorized<float> atan() const {
return {Sleef_atanf4_u10(_vec0), Sleef_atanf4_u10(_vec1)};
}
+ Vectorized<float> atanh() const {
+ return {Sleef_atanhf4_u10(_vec0), Sleef_atanhf4_u10(_vec1)};
+ }
Vectorized<float> atan2(const Vectorized<float>& b) const {
return {Sleef_atan2f4_u10(_vec0, b._vec0), Sleef_atan2f4_u10(_vec1, b._vec1)};
}
diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
index a93cb8b..c7f5fbe 100644
--- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
+++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
@@ -1045,6 +1045,9 @@
Vectorized<T> atan() const {
return mapSleef(Sleef_atanf4_u10, Sleef_atand2_u10);
}
+ Vectorized<T> atanh() const {
+ return mapSleef(Sleef_atanhf4_u10, Sleef_atanhd2_u10);
+ }
Vectorized<T> erf() const {
return mapSleef(Sleef_erff4_u10, Sleef_erfd2_u10);
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
index cedb5c6..6efc834 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
@@ -408,6 +408,9 @@
Vectorized<T> atan() const {
return map(Sleef_atanf16_u10);
}
+ Vectorized<T> atanh() const {
+ return map(Sleef_atanhf16_u10);
+ }
Vectorized<T> atan2(const Vectorized<T> &b) const {
__m512 lo, hi;
__m512 b1, b2;
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
index 9644ff3..b015790 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
@@ -275,6 +275,9 @@
return _mm512_sub_pd(pi_2, asin());
}
Vectorized<c10::complex<double>> atan() const;
+ Vectorized<c10::complex<double>> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<c10::complex<double>> atan2(const Vectorized<c10::complex<double>> &b) const {
AT_ERROR("not supported for complex numbers");
}
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
index 4898ff7..f9d6040 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
@@ -778,6 +778,9 @@
return map(std::acos);
}
Vectorized<c10::complex<float>> atan() const;
+ Vectorized<c10::complex<float>> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<c10::complex<float>> atan2(const Vectorized<c10::complex<float>> &b) const {
AT_ERROR("not supported for complex numbers");
}
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
index 6adbe10..ee4969b 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_double.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
@@ -163,6 +163,9 @@
Vectorized<double> atan() const {
return Vectorized<double>(Sleef_atand8_u10(values));
}
+ Vectorized<double> atanh() const {
+ return Vectorized<double>(Sleef_atanhd8_u10(values));
+ }
Vectorized<double> atan2(const Vectorized<double> &b) const {
return Vectorized<double>(Sleef_atan2d8_u10(values, b));
}
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
index 3b9a095..d6e6c76 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
@@ -178,6 +178,9 @@
Vectorized<float> atan() const {
return Vectorized<float>(Sleef_atanf16_u10(values));
}
+ Vectorized<float> atanh() const {
+ return Vectorized<float>(Sleef_atanhf16_u10(values));
+ }
Vectorized<float> atan2(const Vectorized<float> &b) const {
return Vectorized<float>(Sleef_atan2f16_u10(values, b));
}
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index f68e9c6..38e8f3d 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -368,6 +368,9 @@
Vectorized<T> atan() const {
return map(std::atan);
}
+ Vectorized<T> atanh() const {
+ return map(std::atanh);
+ }
Vectorized<T> atan2(const Vectorized<T> &exp) const {
Vectorized<T> ret;
for (const auto i : c10::irange(size())) {
diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h
index 587a7a9..bff08aa 100644
--- a/aten/src/ATen/cpu/vml.h
+++ b/aten/src/ATen/cpu/vml.h
@@ -67,6 +67,7 @@
IMPLEMENT_VML(acos)
IMPLEMENT_VML(asin)
IMPLEMENT_VML(atan)
+IMPLEMENT_VML(atanh)
IMPLEMENT_VML(ceil)
IMPLEMENT_VML(cos)
// IMPLEMENT_VML(cosh)
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 17d922a..8dd3ab0 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -399,9 +399,10 @@
static void atanh_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "atanh_cpu", [&]() {
- cpu_kernel(
+ cpu_kernel_vec(
iter,
- [=](scalar_t a) -> scalar_t { return std::atanh(a); });
+ [=](scalar_t a) -> scalar_t { return std::atanh(a); },
+ [=](Vectorized<scalar_t> self_vec){return self_vec.atanh();});
});
}
@@ -846,11 +847,11 @@
ALSO_REGISTER_AVX512_DISPATCH(logit_stub, &CPU_CAPABILITY::logit_kernel);
ALSO_REGISTER_AVX512_DISPATCH(sinh_stub, &CPU_CAPABILITY::sinh_kernel);
ALSO_REGISTER_AVX512_DISPATCH(cosh_stub, &CPU_CAPABILITY::cosh_kernel);
+ALSO_REGISTER_AVX512_DISPATCH(atanh_stub, &CPU_CAPABILITY::atanh_kernel);
// Might enable AVX512 dispatch after enabling explicit vectorization for them
REGISTER_DISPATCH(acosh_stub, &CPU_CAPABILITY::acosh_kernel);
REGISTER_DISPATCH(asinh_stub, &CPU_CAPABILITY::asinh_kernel);
-REGISTER_DISPATCH(atanh_stub, &CPU_CAPABILITY::atanh_kernel);
REGISTER_DISPATCH(digamma_stub, &CPU_CAPABILITY::digamma_kernel);
REGISTER_DISPATCH(trigamma_stub, &CPU_CAPABILITY::trigamma_kernel);
REGISTER_DISPATCH(polygamma_stub, &CPU_CAPABILITY::polygamma_kernel);
diff --git a/c10/util/BFloat16-math.h b/c10/util/BFloat16-math.h
index c02472b..5dd349f 100644
--- a/c10/util/BFloat16-math.h
+++ b/c10/util/BFloat16-math.h
@@ -43,6 +43,12 @@
template <
typename T,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+inline T atanh(T a) {
+ return std::atanh(float(a));
+}
+template <
+ typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
inline T erf(T a) {
return std::erf(float(a));
}