[inductor] don't match indirect indexing in fusion (#96273)
Fixes #96064
When deciding whether to fuse nodes, we match indexing like `c0 + 5 * tmp0`, but `tmp0` in the different nodes can refer to totally different values. Even when `tmp0` is the same (like in the added test) inductor still generates wrongly ordered loads and stores (loads come before stores), so better just disable this fusion altogether. We should fix wrong order also:
```
@pointwise(size_hints=[8], filename=__file__, meta={'signature': {0: '*i64', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': ['out_ptr0'], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]})
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, out_ptr1, xnumel, XBLOCK : tl.constexpr):
xnumel = 5
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0_load = tl.load(in_ptr0 + (0))
tmp0 = tl.broadcast_to(tmp0_load, [XBLOCK])
tmp1 = tl.load(in_ptr1 + (x0), xmask)
tmp2 = tl.load(out_ptr0 + (x0 + (5*tmp0)), xmask)
tl.store(out_ptr0 + (x0 + (5*tmp0) + tl.zeros([XBLOCK], tl.int32)), tmp1, xmask)
tl.store(out_ptr1 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
```
Note: we are loading from `out_ptr0` here (that shouldn't happen), we are loading from it before storing to it.
After this PR, the kernel above is split in 2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96273
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index ad5be4b..a82e96b 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -4461,6 +4461,14 @@
),
)
+ def test_index_put_index(self):
+ def fn(ind, x, src):
+ y = torch.ops.aten.index_put.default(x, [ind], src)
+ return torch.ops.aten.index.Tensor(y, [ind])
+
+ args = [torch.tensor([1], dtype=torch.int64), torch.randn(8, 4), torch.randn(4)]
+ self.common(fn, args)
+
@config.patch(fallback_random=True)
def test_bernoulli1(self):
def fn(a):
diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py
index 531b3ab..52a8ccc 100644
--- a/torch/_inductor/scheduler.py
+++ b/torch/_inductor/scheduler.py
@@ -16,7 +16,7 @@
from . import config, dependencies, ir, metrics
from .dependencies import StarDep, WeakDep
from .sizevars import SimplifyIndexing
-from .utils import cache_on_self, cmp, has_triton
+from .utils import cache_on_self, cmp, free_symbol_has, has_triton
from .virtualized import V
log = logging.getLogger(__name__)
@@ -1001,9 +1001,12 @@
# StarDep doesn't match MemoryDep, different indices don't match
# However, broadcasting sometimes strips dimensions, and if that's the case
# we still can match unmet dep
+ # if there's indirect indexing, don't match it
if (
rd.name == cd.name
and type(rd) == type(cd)
+ and not free_symbol_has(rd.index, "tmp")
+ and not free_symbol_has(cd.index, "tmp")
and rd.index == cd.index
and len(rd.size) >= len(cd.size)
and rd.size[: len(cd.size)] == cd.size
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 51692d5..9956efd 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -324,6 +324,10 @@
return any(v.name.startswith(prefix) for v in index.free_symbols)
+def free_symbol_has(index: sympy.Expr, pattern: str):
+ return any(pattern in v.name for v in index.free_symbols)
+
+
def has_incompatible_cudagraph_ops(gm):
forbidden_list = {
"aten._fused_moving_avg_obs_fq_helper.default",