[inductor] Do type promotion in pointless cumsum pattern replacement (#109960)

Fixes #109925

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109960
Approved by: https://github.com/Fidget-Spinner, https://github.com/lezcano
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index 066b7bd..aaf2076 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -496,7 +496,16 @@
             x = torch.full([100], 0.1, dtype=torch.float32)
             return torch.cumsum(x, 0)
 
-        for fn in (fn1, fn2, fn3, fn4):
+        def fn5():
+            t1 = torch.full([2, 4], 1)
+            t2 = t1.to(dtype=torch.bool)
+            return torch.cumsum(t2, 1)
+
+        def fn6():
+            x = torch.full([10, 10], True, dtype=torch.int32)
+            return torch.cumsum(x, 1)
+
+        for fn in (fn1, fn2, fn3, fn4, fn5, fn6):
             result, (code,) = run_and_get_code(torch.compile(fn, fullgraph=True))
             self.assertNotIn("aten.cumsum", code)
             self.assertEqual(result, fn())
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index c751274..6a3d57c 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -10,7 +10,7 @@
 import torch
 import torch._inductor as inductor
 from torch._decomp import register_decomposition
-from torch._prims_common import is_integer_dtype
+from torch._prims_common import is_boolean_dtype, is_integer_dtype
 
 from .. import config, ir, pattern_matcher
 from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage
@@ -282,6 +282,10 @@
 def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim):
     """Based on a pattern in OPTForCausalLM"""
 
+    if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
+        # cumsum promotes all integral types to int64
+        dtype = torch.int64
+
     def repl(*shape):
         dim_size = shape[dim]
         idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype)