feat: Add min, max ranges to mark_dynamic API (#119737)
Fixes https://github.com/pytorch/pytorch/issues/115137
This PR adds:
- mark_dynamic API will accept `min`, `max` values to create a bounded constraint on the dim.
- test case in test_misc.py which checks if `ConstraintViolationError` is triggered if `torch.compile` gets a input dimension out of bounds.
Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119737
Approved by: https://github.com/ezyang, https://github.com/jansel
diff --git a/docs/source/torch.compiler_dynamic_shapes.rst b/docs/source/torch.compiler_dynamic_shapes.rst
index 4c1be52..3325684 100644
--- a/docs/source/torch.compiler_dynamic_shapes.rst
+++ b/docs/source/torch.compiler_dynamic_shapes.rst
@@ -32,7 +32,8 @@
when guards are added and why.
- If you know ahead of time something will be dynamic, you can skip the first
- recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``.
+ recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``. If you know ahead of time
+ the ``min`` and ``max`` value this dimension can take, you can specify ``torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)``
- If you say ``torch.compile(dynamic=False)``, we will turn off automatic
dynamic shapes on recompiles and always recompile for each distinct size.
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 77748f7..a73de8b 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -7123,6 +7123,20 @@
with self.assertRaises(ConstraintViolationError):
torch._dynamo.optimize("eager")(my_dyn_fn)(y)
+ # Translation validation changes the exception type, don't run with it
+ @torch.fx.experimental._config.patch(translation_validation=False)
+ def test_mark_dynamic_with_ranges(self):
+ y = torch.randn([8, 3, 3])
+
+ def my_dyn_fn(x):
+ if x.shape[0] == 3:
+ return x.sin()
+ return x.cos()
+
+ torch._dynamo.mark_dynamic(y, 0, min=2, max=5)
+ with self.assertRaises(ConstraintViolationError):
+ torch._dynamo.optimize("eager")(my_dyn_fn)(y)
+
def test_mark_static(self):
counter = CompileCounter()
diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py
index 762fa2a..8c82bf5 100644
--- a/torch/_dynamo/decorators.py
+++ b/torch/_dynamo/decorators.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
@@ -168,20 +169,33 @@
# Helper function to flatten a tensor subclass and apply a function to
# all inner tensors that match the outer dim. Used to reduce duplication
# across the various marking APIs.
-def _apply_func_to_inner_tensors_of_same_dim(func, t, *args):
+def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
assert is_traceable_wrapper_subclass(t)
attrs, ctx = t.__tensor_flatten__()
for attr in attrs:
inner = getattr(t, attr)
if inner.dim() == t.dim():
- func(inner, *args)
+ func(inner, *args, **kwargs)
+
+
+@dataclass(frozen=True)
+class _DimRange:
+ """
+ This represents an dimension of a tensor and the corresponding
+ min and max values it can take. Don't create this
+ class directly; instead, use :func:`mark_dynamic`.
+ """
+
+ dim: int
+ min: int
+ max: int
@forbid_in_graph
-def mark_dynamic(t, index):
+def mark_dynamic(t, index, *, min=None, max=None):
"""
- Mark a tensor as having a dynamic dim.
+ Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
[Note - on the state of mark_dynamic]
@@ -206,18 +220,22 @@
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
- _apply_func_to_inner_tensors_of_same_dim(mark_dynamic, t, index)
+ _apply_func_to_inner_tensors_of_same_dim(
+ mark_dynamic, t, index, min=min, max=max
+ )
if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"):
t._dynamo_dynamic_indices = set()
+ t._dynamo_dynamic_range = set()
# TODO(voz): Should we bounds check?
t._dynamo_dynamic_indices.add(index)
+ t._dynamo_dynamic_range.add(_DimRange(index, min, max))
return
assert isinstance(index, (list, tuple))
for i in index:
- mark_dynamic(t, i)
+ mark_dynamic(t, i, min=min, max=max)
@forbid_in_graph
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 246ae9e..25407c2 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -15,6 +15,8 @@
import types
from typing import List, NamedTuple, Optional, Union
+from torch.utils._sympy.value_ranges import ValueRanges
+
try:
import numpy as np
except ModuleNotFoundError:
@@ -1779,7 +1781,24 @@
constraint = dim2constraint.get(i)
if constraint is None:
if marked_dynamic and not config.allow_ignore_mark_dynamic:
- constraint_dim = RelaxedUnspecConstraint(warn_only=False)
+ if hasattr(e, "_dynamo_dynamic_range"):
+ dim_range = [
+ dr for dr in e._dynamo_dynamic_range if dr.dim == i
+ ].pop()
+ if dim_range.min is None and dim_range.max is None:
+ constraint_dim = RelaxedUnspecConstraint(warn_only=False)
+ else:
+ from torch.fx.experimental.symbolic_shapes import (
+ StrictMinMaxConstraint,
+ )
+
+ constraint_dim = StrictMinMaxConstraint(
+ vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
+ warn_only=False,
+ )
+ else:
+ constraint_dim = RelaxedUnspecConstraint(warn_only=False)
+
elif not marked_static and automatic_dynamic:
constraint_dim = RelaxedUnspecConstraint(warn_only=True)
else: