Unify the dtype to VecMask<float, N> in ops.masked (#126662)
**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/126449. For `ops.masked` in CPP backend, when input dtype is `bool`, we actually load it as `VecMask<float, N>`. So, we should unify the type of `other` and `mask` to the same as `VecMask<float, N>` to invoke `blendv` method.
**Test Plan**
```
clear && python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_ops_masked_with_bool_input
clear && PYTORCH_ALL_SAMPLES=1 python -u -m pytest -s -v test/inductor/test_torchinductor_opinfo.py -k test_comprehensive__chunk_cat_cpu_bool
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126662
Approved by: https://github.com/isuruf, https://github.com/jgong5, https://github.com/peterbell10
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 1705cd5..1110fde 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -1873,6 +1873,16 @@
# For forward and backward kernel
check_metrics_vec_kernel_count(2)
+ @requires_vectorization
+ def test_ops_masked_with_bool_input(self):
+ x = torch.zeros(129, dtype=torch.bool)
+ size = [2, 3]
+ res_aten_eager = torch.constant_pad_nd(x, size)
+ cfn = torch.compile(torch.constant_pad_nd)
+ res = cfn(x, size)
+ self.assertEqual(res_aten_eager, res)
+ check_metrics_vec_kernel_count(1)
+
@patch("torch.cuda.is_available", lambda: False)
def test_scatter_using_atomic_add(self):
def fn(a, dim, index, b):
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 9bd873a..2a7995d 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -440,6 +440,8 @@
"triu",
"cummax",
"cummin",
+ "_chunk_cat",
+ "constant_pad_nd",
}
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index a0beddb..d4e314e 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -1412,10 +1412,14 @@
else f"{V.kernel._get_vec_type(dtype)}({body_code})"
)
other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype])
- other_code_vec = f"{V.kernel._get_vec_type(dtype)}({other_code})"
+ # loading bool as VecMask<float, N>
+ other_code_vec = (
+ f"{V.kernel._get_mask_type()}::from({other_code})"
+ if dtype == torch.bool
+ else f"{V.kernel._get_vec_type(dtype)}({other_code})"
+ )
assert isinstance(new_mask, CppCSEVariable), new_mask
if new_mask.is_vec:
- type = f"decltype({body_code_vec})"
code = BracesBuffer()
code.writeline("[&]")
with V.kernel.swap_buffers(code), code.indent():
@@ -1424,8 +1428,21 @@
code.writeline(f"return {other_code_vec};")
code.writeline("else")
with code.indent():
+ # Create cse variable to reuse kernel.overrides.where
+ body_vec_var = V.kernel.cse.generate(
+ V.kernel.compute,
+ body_code_vec,
+ )
+ other_vec_var = V.kernel.cse.generate(
+ V.kernel.compute,
+ other_code_vec,
+ )
+ assert isinstance(body_vec_var, CppCSEVariable), body_vec_var
+ assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
+ body_vec_var.dtype = dtype
+ other_vec_var.dtype = dtype
code.writeline(
- f"return {type}::blendv({other_code_vec}, {body_code_vec}, {V.kernel._get_mask_cast(new_mask, dtype)});"
+ f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};"
)
code.writeline("()")
csevar = V.kernel.cse.generate(