[inductor] Use dense masks for indirect indexing (#89524)
Fixes https://github.com/pytorch/torchdynamo/issues/1654
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89524
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 4f672af..89b94ce 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -5377,6 +5377,24 @@
fn_optimized = torch._dynamo.optimize("inductor")(fn)
assert same(fn(a), fn_optimized(a))
+ @requires_cuda()
+ def test_indirect_indexing_dense_mask(self):
+ def fn(x, y):
+ ne = torch.ops.aten.ne.Scalar(x, 1)
+ sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1])
+ sub = torch.ops.aten.sub.Tensor(sum_1, 1)
+ unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1)
+ gather = torch.ops.aten.gather.default(x, 1, unsqueeze)
+ squeeze = torch.ops.aten.squeeze.default(gather)
+ out = torch.ops.aten.multiply(y, squeeze)
+ return (out,)
+
+ a = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
+ b = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
+
+ fn_optimized = torch._dynamo.optimize("inductor")(fn)
+ assert same(fn(a, b), fn_optimized(a, b))
+
class TritonCodeGenTests(TestCase):
from torch._inductor.triton_ops.autotune import CachingAutotuner
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 2504bd2..e14b417 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -756,6 +756,12 @@
mask = dense_mask
index_str = f"{index_str} + tl.zeros({copy_shape}.shape, tl.int32)"
elif indirect_indexing:
+ # Use dense mask for indirect_indexing
+ # See https://github.com/pytorch/torchdynamo/issues/1654
+ # TODO - An optimization could be to hoist this load outside of
+ # reduction loop, if it is independent of rmask. Such example can be found in
+ # https://github.com/pytorch/torchdynamo/issues/1654
+ index_str = f"{index_str} + tl.zeros({self.dense_size_str()}, tl.int32)"
mask = dense_mask
if self._load_mask: