Don't attempt to compute hints for unbacked expressions (#132060)
This breaks the inference we made that if you cat an N-D tensor with a 1-D tensor of size (u0,), the u0 must be zero, but no one really wanted that anyway...
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132060
Approved by: https://github.com/Skylion007
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index dda8e47..0c66e3a 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -8755,7 +8755,9 @@
z = y.item()
return torch.cat([x, torch.ones(z)])
- fn(torch.randn(2, 3), torch.tensor([0]))
+ self.assertRaises(
+ RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([0]))
+ )
self.assertRaises(
RuntimeError, lambda: fn(torch.randn(2, 3), torch.tensor([1]))
)
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 76a70da..127dbfd 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -5411,13 +5411,17 @@
z = y.item()
return torch.cat([x, x.new_ones(z)])
- self.common(
- fn,
- (
- torch.randn([2, 3]),
- torch.tensor([0]),
- ),
- )
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "Expected 2-D tensors, but got 1-D for tensor number 1 in the list",
+ ):
+ self.common(
+ fn,
+ (
+ torch.randn([2, 3]),
+ torch.tensor([0]),
+ ),
+ )
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_cat_unbacked_empty_1d(self):
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 6161f4c..d2eedc5 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -2780,7 +2780,17 @@
assert tensor.ndim == 1 # we've already checked this above
# Don't suggest the legacy behavior in the error message
torch._check(
- tensor.shape[0] == 0,
+ # NB: it is not enough to simply assert that tensor.shape[0] == 0;
+ # this MUST be true even under guard size oblivious.
+ # Effectively, we must actually know that the shape is zero,
+ # passing an unbacked SymInt which we will defer a runtime
+ # assert on won't cut it. This is a policy decision (size
+ # oblivious semantics say that u0 tensors never are inferred
+ # to be zero size, even if they must be that for the cat to go
+ # through), and is load bearing for our Inductor lowerings
+ # (which assume that size oblivious tests are OK to determine
+ # if a shape is permissibly zero.)
+ guard_size_oblivious(tensor.shape[0] == 0),
lambda: f"Number of dimensions of tensors must match. "
f"Expected {example.ndim}-D tensors, but got 1-D for "
f"tensor number {tensor_idx} in the list",
diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py
index 370185b..b097036 100644
--- a/torch/fx/experimental/sym_node.py
+++ b/torch/fx/experimental/sym_node.py
@@ -110,9 +110,15 @@
# in sync, so we've deleted it for now.)
def compute_hint():
+ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
+
# This occasionally gets exercised by, e.g.,
# convert_shape_to_symint. It's just a nicety so you don't HAVE
# to have a correct hint on hand when making a SymNode.
+ # Don't attempt to compute for unbacked, this can be quite
+ # expensive.
+ if free_unbacked_symbols(self.expr):
+ return None
hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
if hint is not None:
hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint