[Inductor] Fix CPU vectorized implementation of mask calculation that breaks torch.where (#93922)

Fix https://github.com/pytorch/pytorch/issues/93374

The cause of the issue is that the original vectorized float mask calculation doesn't consider the broadcast case. This PR adds the support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93922
Approved by: https://github.com/XiaobingSuper, https://github.com/desertfire, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 897cac1..a5013cd 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -5479,6 +5479,17 @@
             (torch.randn(1, 16, 64, 72).to(memory_format=torch.channels_last),),
         )
 
+    def test_where(self):
+        # https://github.com/pytorch/pytorch/issues/93374
+        def fn(x, p1, p0):
+            o = torch.where(x, p1, p0)
+            return o
+
+        self.common(
+            fn,
+            (torch.tensor([[True]]), torch.rand(13, 7, 3), torch.rand(1, 1)),
+        )
+
 
 test_skips = {
     "test_alexnet_prefix_dynamic_shapes": ("cuda",),
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index f8c48c4..845a8ee 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -931,22 +931,27 @@
         expanded_index = sympy.expand(index)
         new_index = self.scale_index_with_offset(index, self.tiling_factor)
 
-        if expanded_index == new_index:
-            line = f"at::vec::Vectorized<float>({var}[{cexpr(index)}])"
-        else:
-            if V.graph.get_dtype(name) in [torch.bool, torch.uint8]:
-                nelements = codecache.pick_vec_isa().nelements()
-                if var not in self.var_vec_buf_map:
-                    self.var_vec_buf_map[var] = f"g_tmp_buffer_{var}"
-                    self.loads.writeline(
-                        f"float {self.var_vec_buf_map[var]}[{nelements}] = {{0}};"
-                    )
+        is_broadcast = expanded_index == new_index
+
+        var_expr = (
+            f"{var}[{cexpr(index)}]" if is_broadcast else f"{var} + {cexpr(new_index)}"
+        )
+
+        if V.graph.get_dtype(name) in [torch.bool, torch.uint8]:
+            nelements = codecache.pick_vec_isa().nelements()
+            if var not in self.var_vec_buf_map:
+                self.var_vec_buf_map[var] = f"g_tmp_buffer_{var}"
                 self.loads.writeline(
-                    f"flag_to_float({var} + {cexpr(new_index)}, {self.var_vec_buf_map[var]}, {nelements});"
+                    f"float {self.var_vec_buf_map[var]}[{nelements}] = {{0}};"
                 )
-                line = f"at::vec::Vectorized<float>::loadu({self.var_vec_buf_map[var]})"
-            else:
-                line = f"at::vec::Vectorized<float>::loadu({var} + {cexpr(new_index)})"
+            self.loads.writeline(
+                f"flag_to_float({var_expr}, {self.var_vec_buf_map[var]}, {nelements});"
+            )
+            line = f"at::vec::Vectorized<float>::loadu({self.var_vec_buf_map[var]})"
+        elif is_broadcast:
+            line = f"at::vec::Vectorized<float>({var_expr})"
+        else:
+            line = f"at::vec::Vectorized<float>::loadu({var_expr})"
 
         return self.cse.generate(self.loads, line)
 
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
index cfb8ca1..5f3ae07 100644
--- a/torch/_inductor/codegen/cpp_prefix.h
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -70,6 +70,15 @@
   }
 }
 
+template <typename T>
+void flag_to_float(T src, float* dst, int64_t n) {
+#pragma unroll
+  for (int64_t i = 0; i < n; i++) {
+    uint32_t* dst_u32 = (uint32_t*)dst;
+    dst_u32[i] = src ? 0xFFFFFFFF : 0;
+  }
+}
+
 #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
 template <typename SRC>
 inline at::vec::Vectorized<float> to_float_mask(at::vec::Vectorized<SRC>& src) {