Reland: Value range refinement using multi-variate expressions (#105491)
Trying to re-land: #97964.
Test strategy:
```
buck2 test '@fbcode//mode/dev-nosan' fbcode//pye/model_inventory/inside_out_tracking_model:inside_out_tracking_model_test -- --exact 'pye/model_inventory/inside_out_tracking_model:inside_out_tracking_model_test - test_executorch_e2e_output_consistency_aten (pye.model_inventory.inside_out_tracking_model.InsideOutTrackingModelTest.InsideOutTrackingModelTest)'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105491
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 1b5f2b1..6c971c1 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1360,6 +1360,28 @@
tensor = make_fx(f, tracing_mode="symbolic")(torch.randn(15))
self.assertExpectedInline(show_guards(tensor), """L['a'].size()[0] < 20""")
+ def test_guard_upperbound_range_refinement_multivariate(self):
+ def f(a):
+ assert a.shape[0] > 5 and a.shape[0] > 12
+ assert a.shape[1] > 5 and a.shape[1] > a.shape[0]
+ return a.cos()
+ tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 20)))
+ self.assertExpectedInline(show_guards(tensor), """\
+L['a'].size()[1] > L['a'].size()[0]
+L['a'].size()[0] > 12""")
+
+ def test_guard_lowerbound_range_refinement_multivariate(self):
+ def f(a):
+ assert a.shape[0] < 20 and a.shape[0] < 30
+ assert a.shape[1] < 30 and a.shape[1] < a.shape[0]
+ return a.cos()
+ tensor = make_fx(f, tracing_mode="symbolic")(torch.randn((15, 5)))
+ self.assertExpectedInline(
+ show_guards(tensor),
+ """\
+L['a'].size()[1] < L['a'].size()[0]
+L['a'].size()[0] < 20""")
+
def test_sym_storage_offset(self):
def f(x, y):
return x + y
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 5a3a23d..4ffac7b 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -3415,10 +3415,6 @@
):
continue
- # Use only univariate functions.
- if len(expr.rhs.free_symbols) > 0:
- continue
-
# Update the value range of the left-hand side, if the
# right-hand side provides a better range.
symbol = expr.lhs