[BE][veclib] Use `is_same_v`/`enable_if_t` (#122533)
`enable_if_t` helper is part of C++14
`is_same_v` helper is part of C++17
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122533
Approved by: https://github.com/Skylion007
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h
index 392a22b..6263efd 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h
@@ -494,7 +494,7 @@
template <typename T>
class Vectorized8 : public Vectorizedi {
static_assert(
- std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
+ std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
"Only int8_t/uint8_t are supported");
protected:
static const Vectorized<T> ones;
@@ -1382,7 +1382,7 @@
return c;
}
-template <bool left_shift, typename T, typename std::enable_if_t<std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, int> = 0>
+template <bool left_shift, typename T, typename std::enable_if_t<std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>, int> = 0>
Vectorized<T> inline shift_256_8(const Vectorized<T>& a, const Vectorized<T>& b) {
// No vector instruction for shifting int8_t/uint8_t, so emulating
// it instead.
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
index 4128841..4d81398 100644
--- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
@@ -97,7 +97,7 @@
}
template <typename T>
-typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, at::vec::Vectorized<float>>::type
+typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<float>>
inline convert_int8_to_float(at::vec::Vectorized<T> src) {
// Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
// Only handle first 8*8 bits
@@ -113,7 +113,7 @@
}
template <typename T>
-typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, at::vec::Vectorized<T>>::type
+typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<T>>
inline convert_float_to_int8(at::vec::Vectorized<float> src) {
// Convert from float32 to int32 with truncation
__m256i x_values_int32 = _mm256_cvttps_epi32(src);
@@ -402,7 +402,7 @@
__m256 multiplier,
__m256i zp) {
static_assert(
- std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
+ std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
"Only int8_t/uint8_t are supported");
constexpr auto min_val = std::numeric_limits<T>::min();
constexpr auto max_val = std::numeric_limits<T>::max();
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h
index 2610d34..381bb95 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h
@@ -540,7 +540,7 @@
template <typename T>
class Vectorized8 : public Vectorizedi {
static_assert(
- std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
+ std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
"Only int8_t/uint8_t are supported");
protected:
static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0};
@@ -1320,7 +1320,7 @@
return (*this <= other) & Vectorized<uint8_t>(1);
}
-template <bool left_shift, typename T, typename std::enable_if_t<std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, int> = 0>
+template <bool left_shift, typename T, typename std::enable_if_t<std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>, int> = 0>
Vectorized<T> inline shift_512_8(const Vectorized<T>& a, const Vectorized<T>& b) {
// No vector instruction for shifting int8_t/uint8_t, so emulating
// it instead.
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
index e0713d0..6584b0b 100644
--- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
@@ -99,7 +99,7 @@
}
template <typename T>
-typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, at::vec::Vectorized<float>>::type
+typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<float>>
inline convert_int8_to_float(at::vec::Vectorized<T> src) {
// Note: this function only convert inputs number of elements equal to at::vec::Vectorized<float>.size()
// Only handle first 16*8 bits
@@ -115,7 +115,7 @@
}
template <typename T>
-typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, at::vec::Vectorized<T>>::type
+typename std::enable_if_t<std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>, at::vec::Vectorized<T>>
inline convert_float_to_int8(at::vec::Vectorized<float> src) {
// Convert from float32 to int32 with truncation
__m512i x_values_int32 = _mm512_cvttps_epi32(src);
@@ -414,7 +414,7 @@
__m512 multiplier,
__m512i zp) {
static_assert(
- std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value,
+ std::is_same_v<T, int8_t> || std::is_same_v<T, uint8_t>,
"Only int8_t/uint8_t are supported");
constexpr auto min_val = std::numeric_limits<T>::min();
constexpr auto max_val = std::numeric_limits<T>::max();
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index adf81dd..ddf5ec6 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -66,9 +66,9 @@
template <typename T>
struct is_floating_point:
std::integral_constant<bool,
- std::is_floating_point<T>::value ||
- std::is_same<T, at::Half>::value ||
- std::is_same<T, at::BFloat16>::value> {
+ std::is_floating_point_v<T> ||
+ std::is_same_v<T, at::Half> ||
+ std::is_same_v<T, at::BFloat16>> {
};
template<typename T>
@@ -77,8 +77,8 @@
template <typename T>
struct is_reduced_floating_point:
std::integral_constant<bool,
- std::is_same<T, at::Half>::value ||
- std::is_same<T, at::BFloat16>::value> {
+ std::is_same_v<T, at::Half> ||
+ std::is_same_v<T, at::BFloat16>> {
};
template <typename T>
@@ -275,90 +275,90 @@
return ret;
}
template <typename other_t_abs = T,
- typename std::enable_if<!is_floating_point_v<other_t_abs> && !c10::is_complex<other_t_abs>::value, int>::type = 0>
+ typename std::enable_if_t<!is_floating_point_v<other_t_abs> && !c10::is_complex<other_t_abs>::value, int> = 0>
Vectorized<T> abs() const {
// other_t_abs is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
+ static_assert(std::is_same_v<other_t_abs, T>, "other_t_abs must be T");
return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
}
template <typename float_t_abs = T,
- typename std::enable_if<is_floating_point_v<float_t_abs>, int>::type = 0>
+ typename std::enable_if_t<is_floating_point_v<float_t_abs>, int> = 0>
Vectorized<T> abs() const {
// float_t_abs is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
+ static_assert(std::is_same_v<float_t_abs, T>, "float_t_abs must be T");
// Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
// 0.0) properly.
return map([](T x) -> T { return std::abs(x); });
}
template <typename complex_t_abs = T,
- typename std::enable_if<c10::is_complex<complex_t_abs>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<complex_t_abs>::value, int> = 0>
Vectorized<T> abs() const {
// complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T");
+ static_assert(std::is_same_v<complex_t_abs, T>, "complex_t_abs must be T");
// Specifically map() does not perform the type conversion needed by abs.
return map([](T x) { return static_cast<T>(std::abs(x)); });
}
template <typename other_t_sgn = T,
- typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<other_t_sgn>::value, int> = 0>
Vectorized<T> sgn() const {
return map(at::native::sgn_impl);
}
template <typename other_t_angle = T,
- typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<other_t_angle>::value, int> = 0>
Vectorized<T> angle() const {
// other_t_angle is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
+ static_assert(std::is_same_v<other_t_angle, T>, "other_t_angle must be T");
return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
}
template <typename complex_t_angle = T,
- typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<complex_t_angle>::value, int> = 0>
Vectorized<T> angle() const {
// complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T");
+ static_assert(std::is_same_v<complex_t_angle, T>, "complex_t_angle must be T");
return map([](T x) { return static_cast<T>(std::arg(x)); });
}
template <typename other_t_real = T,
- typename std::enable_if<!c10::is_complex<other_t_real>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<other_t_real>::value, int> = 0>
Vectorized<T> real() const {
// other_t_real is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T");
+ static_assert(std::is_same_v<other_t_real, T>, "other_t_real must be T");
return *this;
}
template <typename complex_t_real = T,
- typename std::enable_if<c10::is_complex<complex_t_real>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<complex_t_real>::value, int> = 0>
Vectorized<T> real() const {
// complex_t_real is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T");
+ static_assert(std::is_same_v<complex_t_real, T>, "complex_t_real must be T");
return map([](T x) { return static_cast<T>(x.real()); });
}
template <typename other_t_imag = T,
- typename std::enable_if<!c10::is_complex<other_t_imag>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<other_t_imag>::value, int> = 0>
Vectorized<T> imag() const {
// other_t_imag is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T");
+ static_assert(std::is_same_v<other_t_imag, T>, "other_t_imag must be T");
return Vectorized(0);
}
template <typename complex_t_imag = T,
- typename std::enable_if<c10::is_complex<complex_t_imag>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<complex_t_imag>::value, int> = 0>
Vectorized<T> imag() const {
// complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T");
+ static_assert(std::is_same_v<complex_t_imag, T>, "complex_t_imag must be T");
return map([](T x) { return static_cast<T>(x.imag()); });
}
template <typename other_t_conj = T,
- typename std::enable_if<!c10::is_complex<other_t_conj>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<other_t_conj>::value, int> = 0>
Vectorized<T> conj() const {
// other_t_conj is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T");
+ static_assert(std::is_same_v<other_t_conj, T>, "other_t_conj must be T");
return *this;
}
template <typename complex_t_conj = T,
- typename std::enable_if<c10::is_complex<complex_t_conj>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<complex_t_conj>::value, int> = 0>
Vectorized<T> conj() const {
// complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T");
+ static_assert(std::is_same_v<complex_t_conj, T>, "complex_t_conj must be T");
return map([](T x) { return static_cast<T>(std::conj(x)); });
}
Vectorized<T> acos() const {
@@ -422,7 +422,7 @@
typename std::enable_if_t<is_floating_point_v<U>, int> = 0>
Vectorized<T> fmod(const Vectorized<T>& q) const {
// U is for SFINAE purposes only. Make sure it is not changed.
- static_assert(std::is_same<U, T>::value, "U must be T");
+ static_assert(std::is_same_v<U, T>, "U must be T");
Vectorized<T> ret;
for (const auto i : c10::irange(size())) {
ret[i] = std::fmod(values[i], q[i]);
@@ -439,17 +439,17 @@
return map(std::log1p);
}
template <typename other_t_log2 = T,
- typename std::enable_if<!c10::is_complex<other_t_log2>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<other_t_log2>::value, int> = 0>
Vectorized<T> log2() const {
// other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<other_t_log2, T>::value, "other_t_log2 must be T");
+ static_assert(std::is_same_v<other_t_log2, T>, "other_t_log2 must be T");
return map(std::log2);
}
template <typename complex_t_log2 = T,
- typename std::enable_if<c10::is_complex<complex_t_log2>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<complex_t_log2>::value, int> = 0>
Vectorized<T> log2() const {
// complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
- static_assert(std::is_same<complex_t_log2, T>::value, "complex_t_log2 must be T");
+ static_assert(std::is_same_v<complex_t_log2, T>, "complex_t_log2 must be T");
const T log_2 = T(std::log(2.0));
return Vectorized(map(std::log))/Vectorized(log_2);
}
@@ -622,7 +622,7 @@
}
template <class T,
- typename std::enable_if<!is_floating_point_v<T>, int>::type = 0>
+ typename std::enable_if_t<!is_floating_point_v<T>, int> = 0>
Vectorized<T> inline operator%(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
return a - a / b * b;
}
@@ -639,7 +639,7 @@
// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
// either input is a NaN.
template <class T,
- typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
@@ -655,7 +655,7 @@
}
template <class T,
- typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
@@ -673,7 +673,7 @@
// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
// either input is a NaN.
template <class T,
- typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
@@ -689,7 +689,7 @@
}
template <class T,
- typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
@@ -705,7 +705,7 @@
}
template <class T,
- typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
@@ -715,7 +715,7 @@
}
template <class T,
- typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {
@@ -725,7 +725,7 @@
}
template <class T,
- typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
+ typename std::enable_if_t<!c10::is_complex<T>::value, int> = 0>
Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
Vectorized<T> c;
for (int i = 0; i != Vectorized<T>::size(); i++) {