[export] Error when constraining on static values (#101655)
Fixes https://github.com/pytorch/pytorch/issues/100415
Results in the following error:
```
Traceback (most recent call last):
File "/scratch/angelayi/work/pytorch/test/export/test_export.py", line 572, in test_export_constrain_static
export(f, example_inputs, constraints)
File "/scratch/angelayi/work/pytorch/torch/_export/__init__.py", line 348, in export
method_name_to_graph_module[compile_spec.method_name] = _export(
File "/scratch/angelayi/work/pytorch/torch/_export/__init__.py", line 119, in _export
raise UserError(UserErrorType.CONSTRAIN_VIOLATION, str(e))
torch._dynamo.exc.UserError: File "/scratch/angelayi/work/pytorch/test/export/test_export.py", line 561, in f
constrain_as_value(c, min=1, max=3)
It appears that you're trying to set a constraint on a value which we evaluated to have a static value of 3. Scroll up to see where this constraint was set.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101655
Approved by: https://github.com/avikchaudhuri
diff --git a/test/export/test_export.py b/test/export/test_export.py
index efbf89f..83065b8 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -5,12 +5,13 @@
import torch._dynamo as torchdynamo
from torch._export import _export, export, dynamic_dim
from torch._export.trace import do_not_use_experimental_export
-from torch._export.constraints import constrain_as_size
+from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.graph_module import get_export_meta
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
+@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExperimentalExport(TestCase):
@unittest.skip("TypeError: <lambda>() missing 1 required positional argument")
def test_export_simple_model_with_attr(self):
@@ -29,7 +30,6 @@
exported_program = do_not_use_experimental_export(mod, inp)
self.assertEqual(exported_program.fw_module(*inp)[0], mod(*inp))
- @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
def test_export_simple_model(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
@@ -551,6 +551,25 @@
# There should be nonzero view nodes in the graph
self.assertTrue(view_count > 0)
+ def test_export_constrain_static(self):
+ def f(x, y):
+ b = x.item()
+ constrain_as_size(b, min=2, max=5)
+ c = y.dim()
+ constrain_as_value(c, min=1, max=3)
+ z = y[0:c]
+ return torch.empty((b, y.shape[0])), z
+
+ x = torch.tensor([3])
+ y = torch.randn([8, 8, 6])
+ example_inputs = (x, y)
+ constraints = [dynamic_dim(y, 0) >= 6, dynamic_dim(y, 0) <= 10]
+ with self.assertRaisesRegex(
+ torchdynamo.exc.UserError, "It appears that you're trying to set a constraint " +
+ "on a value which we evaluated to have a static value of 3. "
+ ):
+ export(f, example_inputs, constraints)
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 587f1bf..ee9fc1c 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -897,6 +897,16 @@
"Summary of dimension constraints:%s",
msg,
)
+
+ # Error if we have any constraints on static values
+ for k in shape_env.var_to_range.keys():
+ if isinstance(k, sympy.Integer):
+ constraint_violation_error = ConstraintViolationError(
+ f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
+ "It appears that you're trying to set a constraint on a "
+ f"value which we evaluated to have a static value of {k}. "
+ "Scroll up to see where this constraint was set."
+ )
if constraint_violation_error:
raise constraint_violation_error
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 9580484..98bcf54 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -31,7 +31,7 @@
SymFloat,
SymInt,
)
-from torch._guards import ShapeGuard, Source, TracingContext
+from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges, ValueRangeError
@@ -323,12 +323,26 @@
if not isinstance(a, SymInt):
if not (min <= a <= max):
raise ValueRangeError(f"Invalid value {a} for range [{min}:{max}]")
+
+ if (
+ (fake_mode := detect_fake_mode()) is not None and
+ getattr(fake_mode, "shape_env", None) is not None
+ ):
+ # If we are tracing with a fake mode then add this integer to the
+ # shape_env's var_to_range
+ sym_integer = sympy.Integer(a)
+ shape_env = fake_mode.shape_env
+ shape_env.var_to_range[sym_integer] = ValueRanges(min, max)
+ shape_env.var_to_stack[sym_integer] = TracingContext(fake_mode).extract_stack()
+
return
+
if isinstance(a.node.expr, sympy.Integer):
if not (min <= int(a.node.expr) <= max):
raise ValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]")
return
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
+
# TODO: Shouldn't we install a guard if the symbol is backed? Or is the
# semantics that this is an "unchecked" assert (but it this actually
# something useful? Might be better to restrict only for unbacked
@@ -1824,7 +1838,7 @@
# practice
self.var_to_range: Dict["sympy.Symbol", ValueRanges] = {}
self.var_to_sources: Dict["sympy.Symbol", List[Source]] = {}
- self.var_to_stack: Dict["sympy.Symbol", str] = {}
+ self.var_to_stack: Dict["sympy.Symbol", traceback.StackSummary] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
@@ -1987,19 +2001,19 @@
def create_unbacked_symfloat(self):
symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
- self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
+ self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_range[symbol] = ValueRanges.unknown()
return SymFloat(SymNode(symbol, self, float, None))
def create_unbacked_symint(self):
symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
- self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
+ self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_range[symbol] = ValueRanges(-sys.maxsize - 1, sys.maxsize)
return SymInt(SymNode(symbol, self, int, None))
def create_unbacked_symbool(self):
symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
- self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
+ self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
self.var_to_range[symbol] = ValueRanges(0, 1)
return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None))
@@ -2634,7 +2648,8 @@
# TODO: in a Dynamo context, having user code, and having the
# name of the local, will be much better
for s in expr.free_symbols:
- self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, self.var_to_stack[s])
+ stacktrace = ''.join(traceback.format_list(self.var_to_stack[s]))
+ self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
return GuardOnDataDependentSymNode(
"It appears that you're trying to get a value out of symbolic int/float "
"whose value is data-dependent (and thus we do not know the true value.) "