[inductor] Do type promotion in pointless cumsum pattern replacement (#109960)
Fixes #109925
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109960
Approved by: https://github.com/Fidget-Spinner, https://github.com/lezcano
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index 066b7bd..aaf2076 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -496,7 +496,16 @@
x = torch.full([100], 0.1, dtype=torch.float32)
return torch.cumsum(x, 0)
- for fn in (fn1, fn2, fn3, fn4):
+ def fn5():
+ t1 = torch.full([2, 4], 1)
+ t2 = t1.to(dtype=torch.bool)
+ return torch.cumsum(t2, 1)
+
+ def fn6():
+ x = torch.full([10, 10], True, dtype=torch.int32)
+ return torch.cumsum(x, 1)
+
+ for fn in (fn1, fn2, fn3, fn4, fn5, fn6):
result, (code,) = run_and_get_code(torch.compile(fn, fullgraph=True))
self.assertNotIn("aten.cumsum", code)
self.assertEqual(result, fn())
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index c751274..6a3d57c 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -10,7 +10,7 @@
import torch
import torch._inductor as inductor
from torch._decomp import register_decomposition
-from torch._prims_common import is_integer_dtype
+from torch._prims_common import is_boolean_dtype, is_integer_dtype
from .. import config, ir, pattern_matcher
from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
@@ -282,6 +282,10 @@
def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
"""Based on a pattern in OPTForCausalLM"""
+ if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
+ # cumsum promotes all integral types to int64
+ dtype = torch.int64
+
def repl(*shape):
dim_size = shape[dim]
idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)