[Inductor] Fix CPU vectorized implementation of mask calculation that breaks torch.where (#93922)
Fix https://github.com/pytorch/pytorch/issues/93374
The cause of the issue is that the original vectorized float mask calculation doesn't consider the broadcast case. This PR adds the support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93922
Approved by: https://github.com/XiaobingSuper, https://github.com/desertfire, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 897cac1..a5013cd 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -5479,6 +5479,17 @@
(torch.randn(1, 16, 64, 72).to(memory_format=torch.channels_last),),
)
+ def test_where(self):
+ # https://github.com/pytorch/pytorch/issues/93374
+ def fn(x, p1, p0):
+ o = torch.where(x, p1, p0)
+ return o
+
+ self.common(
+ fn,
+ (torch.tensor([[True]]), torch.rand(13, 7, 3), torch.rand(1, 1)),
+ )
+
test_skips = {
"test_alexnet_prefix_dynamic_shapes": ("cuda",),
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index f8c48c4..845a8ee 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -931,22 +931,27 @@
expanded_index = sympy.expand(index)
new_index = self.scale_index_with_offset(index, self.tiling_factor)
- if expanded_index == new_index:
- line = f"at::vec::Vectorized<float>({var}[{cexpr(index)}])"
- else:
- if V.graph.get_dtype(name) in [torch.bool, torch.uint8]:
- nelements = codecache.pick_vec_isa().nelements()
- if var not in self.var_vec_buf_map:
- self.var_vec_buf_map[var] = f"g_tmp_buffer_{var}"
- self.loads.writeline(
- f"float {self.var_vec_buf_map[var]}[{nelements}] = {{0}};"
- )
+ is_broadcast = expanded_index == new_index
+
+ var_expr = (
+ f"{var}[{cexpr(index)}]" if is_broadcast else f"{var} + {cexpr(new_index)}"
+ )
+
+ if V.graph.get_dtype(name) in [torch.bool, torch.uint8]:
+ nelements = codecache.pick_vec_isa().nelements()
+ if var not in self.var_vec_buf_map:
+ self.var_vec_buf_map[var] = f"g_tmp_buffer_{var}"
self.loads.writeline(
- f"flag_to_float({var} + {cexpr(new_index)}, {self.var_vec_buf_map[var]}, {nelements});"
+ f"float {self.var_vec_buf_map[var]}[{nelements}] = {{0}};"
)
- line = f"at::vec::Vectorized<float>::loadu({self.var_vec_buf_map[var]})"
- else:
- line = f"at::vec::Vectorized<float>::loadu({var} + {cexpr(new_index)})"
+ self.loads.writeline(
+ f"flag_to_float({var_expr}, {self.var_vec_buf_map[var]}, {nelements});"
+ )
+ line = f"at::vec::Vectorized<float>::loadu({self.var_vec_buf_map[var]})"
+ elif is_broadcast:
+ line = f"at::vec::Vectorized<float>({var_expr})"
+ else:
+ line = f"at::vec::Vectorized<float>::loadu({var_expr})"
return self.cse.generate(self.loads, line)
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
index cfb8ca1..5f3ae07 100644
--- a/torch/_inductor/codegen/cpp_prefix.h
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -70,6 +70,15 @@
}
}
+template <typename T>
+void flag_to_float(T src, float* dst, int64_t n) {
+#pragma unroll
+ for (int64_t i = 0; i < n; i++) {
+ uint32_t* dst_u32 = (uint32_t*)dst;
+ dst_u32[i] = src ? 0xFFFFFFFF : 0;
+ }
+}
+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
template <typename SRC>
inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<SRC>& src) {