[CpuInductor] Implement masked_load for integral types (#122608)
Use `if constexpr` to separate float vs integral masked load for avx512
Discovered while looking at `test_comprehensive_fft_ihfft2_cpu_int64` on
non-AVX512 capable CPUs where (5, 6, 7) shape were big enough to start a vectorized loop
Added `test_pad_cast` regression test
Fixes https://github.com/pytorch/pytorch/issues/122606
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122608
Approved by: https://github.com/jansel
ghstack dependencies: #122607
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 303cc44..552763e 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -8497,6 +8497,13 @@
x = torch.rand(48, 3, 512, 512)
self.common(fn, (x,))
+ def test_pad_cast(self):
+ def fn(x):
+ return torch.nn.functional.pad(x.to(torch.float32), (0, 3, 0, 0))
+
+ for dtype in [torch.int32, torch.int64]:
+ self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),))
+
@unittest.skipIf(not HAS_CPU or not RUN_CPU, "requires C++ compiler")
def test_data_type_propogation(self):
from torch._dynamo.utils import detect_fake_mode
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
index 9bec9d5..c922274 100644
--- a/torch/_inductor/codegen/cpp_prefix.h
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -280,19 +280,36 @@
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR)
-inline at::vec::Vectorized<float> masked_load(const float* src, at::vec::Vectorized<float> mask) {
+template<typename T>
+typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>, at::vec::Vectorized<T>>
+inline masked_load(const T* src, at::vec::Vectorized<float> mask) {
# if defined(CPU_CAPABILITY_AVX512)
- at::vec::Vectorized<float> zero_vec(0);
+ at::vec::Vectorized<T> zero_vec(0);
auto all_ones = _mm512_set1_epi32(0xFFFFFFFF);
auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask), all_ones, _MM_CMPINT_EQ);
- return _mm512_mask_loadu_ps(zero_vec, mmask, src);
+
+ if constexpr (std::is_same_v<T, float>) {
+ return _mm512_mask_loadu_ps(zero_vec, mmask, src);
+ } else {
+ return _mm512_mask_loadu_epi32(zero_vec, mmask, src);
+ }
# elif defined(CPU_CAPABILITY_AVX2)
auto all_ones = _mm256_set1_epi32(0xFFFFFFFF);
auto mmask = _mm256_cmpeq_epi32(_mm256_castps_si256(mask), all_ones);
- return _mm256_maskload_ps(src, mmask);
+ if constexpr (std::is_same_v<T, float>) {
+ return _mm256_maskload_ps(src, mmask);
+ } else {
+ return _mm256_maskload_epi32(src, mmask);
+ }
# elif defined(CPU_CAPABILITY_ZVECTOR)
- auto result = at::vec::Vectorized<float>::loadu(src);
- return (result & mask);
+ auto result = at::vec::Vectorized<T>::loadu(src);
+ if constexpr (std::is_same_v<T, float>) {
+ return result & mask;
+ } else {
+ T maskdata[at::vec::Vectorized<T>::size()];
+ mask.store(maskdata);
+ return result & at::vec::Vectorized<T>::loadu(maskdata);
+ }
# else
# error Unsupported vectorization CPU capability
# endif
@@ -326,7 +343,7 @@
maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFFFF: 0;
}
auto maskvector = at::vec::Vectorized<T>::loadu(maskdata_dest);
- return (result & maskvector);
+ return result & maskvector;
# else
# error Unsupported vectorization CPU capability
# endif
@@ -360,13 +377,29 @@
maskdata_dest[i] = (maskdata[i] == 0xFFFFFFFF) ? 0xFF: 0;
}
auto maskvector = at::vec::Vectorized<T>::loadu(maskdata_dest);
- return (result & maskvector);
+ return result & maskvector;
# else
# error Unsupported vectorization CPU capability
# endif
}
template <typename T>
+typename std::enable_if_t<std::is_same_v<T, uint64_t> || std::is_same_v<T, int64_t>, at::vec::VectorizedN<T, 2>>
+inline masked_load(const T* src, at::vec::Vectorized<float> mask) {
+ // TODO: Add vectorized variants for the load
+ constexpr auto vec_size = decltype(mask)::size();
+ __at_align__ uint32_t maskdata[vec_size];
+ __at_align__ uint64_t mask_res[vec_size];
+ mask.store(maskdata);
+ #pragma unroll
+ for(auto i = 0; i < vec_size; ++i) {
+ mask_res[i] = maskdata[i] ? std::numeric_limits<uint64_t>::max() : 0;
+ }
+ auto result = at::vec::VectorizedN<T, 2>::loadu(src);
+ return result & at::vec::VectorizedN<T, 2>::loadu(mask_res);
+}
+
+template <typename T>
inline at::vec::Vectorized<float> flag_to_float_vec(const T* src) {
__at_align__ float dst_tmp[at::vec::Vectorized<float>::size()];
#pragma unroll