[inductor] simplify expr when looking up size hint (#123140)
## Context
Suppose we have two symbols: `u0` and `s0` where we know that `u0 = s0`. Now, let's say we tried to look up the size hint for `u0 + 1`.
* Before this PR, we would use a fallback hint if one was provided.
https://github.com/pytorch/pytorch/blob/3f6acf65fd9b6094513cf28898a42b90dd1169a0/torch/_inductor/sizevars.py#L406-L407
* With this PR, we would try to replace `u0` with `s0` via `simplify()` before using a fallback hint. https://github.com/pytorch/pytorch/blob/3f6acf65fd9b6094513cf28898a42b90dd1169a0/torch/_inductor/sizevars.py#L46-L47
## Concrete Example
A scenario where this is useful is when we're running autotuning benchmarking on bmm with two input nodes: one who has `s0` as the batch size and one who has `u0` as the batch size. During benchmarking, we'll create two example input tensors where the input with `u0` has to use a fallback hint for batch size. This will lead to a mismatch.
https://github.com/pytorch/pytorch/blob/e3d80f2fa98d7ab02f88023d381b2e5981dd99ff/torch/_inductor/select_algorithm.py#L991-L997
Using the fallback hint (i.e. 8192) leads to a batch size mismatch.
```
# Note: s0 = 7 and u0 = 7 and fallback hint is 8192.
LoweringException: ErrorFromChoice: Expected size for first two dimensions of batch2 tensor to be: [7, 30] but got: [8192, 30].
From choice ExternKernelCaller(extern_kernels.bmm)
```
Differential Revision: D55619331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123140
Approved by: https://github.com/aakhundov
diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py
index d3a0163..8d7228e 100644
--- a/test/inductor/test_unbacked_symints.py
+++ b/test/inductor/test_unbacked_symints.py
@@ -156,6 +156,38 @@
torch.testing.assert_close(actual, expected)
+ @inductor_config.patch({"max_autotune": True})
+ @dynamo_config.patch({"capture_scalar_outputs": True})
+ def test_equivalent_backed_unbacked(self, device):
+ # Tests scenario when there are two equivalent backed & unbacked symints,
+ # but when we look-up a size hint on the unbacked symint, we ignorantly
+ # use the default fallback hint.
+
+ def fn(x, w, a, b):
+ # Make tensors where 1st dim is unbacked/backed.
+ u0, s0 = a.item(), b.size(0)
+ unbacked = x.expand(u0, *x.shape)
+ backed = x.expand(s0, *x.shape)
+
+ # The cat unifies u0 and s0 -- i.e. u0 == s0.
+ cat = torch.cat([backed, unbacked, unbacked], dim=1) # [s0, 30, 16]
+ mat1 = torch.permute(cat, [0, 2, 1]) # [s0, 16, 30]
+ mat2 = w.expand(u0, *w.shape) # [u0, 30, 32]
+ bmm = torch.ops.aten.bmm(mat1, mat2)
+ return bmm
+
+ example_inputs = (
+ torch.randn((10, 16), dtype=torch.float32, device=device),
+ torch.randn((30, 32), dtype=torch.float32, device=device),
+ torch.tensor(7, device=device),
+ backed := torch.randn((7,), device=device),
+ )
+ torch._dynamo.mark_dynamic(backed, 0) # create backed symint
+
+ actual = torch.compile(fn, fullgraph=True)(*example_inputs)
+ expected = fn(*example_inputs)
+ torch.testing.assert_close(actual, expected)
+
instantiate_device_type_tests(
TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu")
diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py
index ceff1bd..7da7172 100644
--- a/torch/_inductor/sizevars.py
+++ b/torch/_inductor/sizevars.py
@@ -415,6 +415,7 @@
return sympy_subs(expr, self.var_to_val)
def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int:
+ expr = self.simplify(expr)
out = self.symbolic_hint(expr)
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
# Use the provided heuristic fallback hint