feat: Add min, max ranges to mark_dynamic API (#119737)

Fixes https://github.com/pytorch/pytorch/issues/115137

This PR adds:

- mark_dynamic API will accept `min`, `max` values to create a bounded constraint on the dim.
- test case in test_misc.py which checks if `ConstraintViolationError` is triggered if `torch.compile` gets a input dimension out of bounds.

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119737
Approved by: https://github.com/ezyang, https://github.com/jansel
diff --git a/docs/source/torch.compiler_dynamic_shapes.rst b/docs/source/torch.compiler_dynamic_shapes.rst
index 4c1be52..3325684 100644
--- a/docs/source/torch.compiler_dynamic_shapes.rst
+++ b/docs/source/torch.compiler_dynamic_shapes.rst
@@ -32,7 +32,8 @@
   when guards are added and why.
 
 - If you know ahead of time something will be dynamic, you can skip the first
-  recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``.
+  recompile with ``torch._dynamo.mark_dynamic(tensor, dim)``. If you know ahead of time
+  the ``min`` and ``max`` value this dimension can take, you can specify ``torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)``
 
 - If you say ``torch.compile(dynamic=False)``, we will turn off automatic
   dynamic shapes on recompiles and always recompile for each distinct size.
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 77748f7..a73de8b 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -7123,6 +7123,20 @@
         with self.assertRaises(ConstraintViolationError):
             torch._dynamo.optimize("eager")(my_dyn_fn)(y)
 
+    # Translation validation changes the exception type, don't run with it
+    @torch.fx.experimental._config.patch(translation_validation=False)
+    def test_mark_dynamic_with_ranges(self):
+        y = torch.randn([8, 3, 3])
+
+        def my_dyn_fn(x):
+            if x.shape[0] == 3:
+                return x.sin()
+            return x.cos()
+
+        torch._dynamo.mark_dynamic(y, 0, min=2, max=5)
+        with self.assertRaises(ConstraintViolationError):
+            torch._dynamo.optimize("eager")(my_dyn_fn)(y)
+
     def test_mark_static(self):
         counter = CompileCounter()
 
diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py
index 762fa2a..8c82bf5 100644
--- a/torch/_dynamo/decorators.py
+++ b/torch/_dynamo/decorators.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
 from typing import TYPE_CHECKING
 
 import torch
@@ -168,20 +169,33 @@
 # Helper function to flatten a tensor subclass and apply a function to
 # all inner tensors that match the outer dim. Used to reduce duplication
 # across the various marking APIs.
-def _apply_func_to_inner_tensors_of_same_dim(func, t, *args):
+def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
     assert is_traceable_wrapper_subclass(t)
 
     attrs, ctx = t.__tensor_flatten__()
     for attr in attrs:
         inner = getattr(t, attr)
         if inner.dim() == t.dim():
-            func(inner, *args)
+            func(inner, *args, **kwargs)
+
+
+@dataclass(frozen=True)
+class _DimRange:
+    """
+    This represents an dimension of a tensor and the corresponding
+    min and max values it can take.  Don't create this
+    class directly; instead, use :func:`mark_dynamic`.
+    """
+
+    dim: int
+    min: int
+    max: int
 
 
 @forbid_in_graph
-def mark_dynamic(t, index):
+def mark_dynamic(t, index, *, min=None, max=None):
     """
-    Mark a tensor as having a dynamic dim.
+    Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
 
     [Note - on the state of mark_dynamic]
 
@@ -206,18 +220,22 @@
     if is_traceable_wrapper_subclass(t):
         # default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
         # TODO: Make this configurable via a supported public API
-        _apply_func_to_inner_tensors_of_same_dim(mark_dynamic, t, index)
+        _apply_func_to_inner_tensors_of_same_dim(
+            mark_dynamic, t, index, min=min, max=max
+        )
 
     if isinstance(index, int):
         if not hasattr(t, "_dynamo_dynamic_indices"):
             t._dynamo_dynamic_indices = set()
+            t._dynamo_dynamic_range = set()
         # TODO(voz): Should we bounds check?
         t._dynamo_dynamic_indices.add(index)
+        t._dynamo_dynamic_range.add(_DimRange(index, min, max))
         return
 
     assert isinstance(index, (list, tuple))
     for i in index:
-        mark_dynamic(t, i)
+        mark_dynamic(t, i, min=min, max=max)
 
 
 @forbid_in_graph
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 246ae9e..25407c2 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -15,6 +15,8 @@
 import types
 from typing import List, NamedTuple, Optional, Union
 
+from torch.utils._sympy.value_ranges import ValueRanges
+
 try:
     import numpy as np
 except ModuleNotFoundError:
@@ -1779,7 +1781,24 @@
         constraint = dim2constraint.get(i)
         if constraint is None:
             if marked_dynamic and not config.allow_ignore_mark_dynamic:
-                constraint_dim = RelaxedUnspecConstraint(warn_only=False)
+                if hasattr(e, "_dynamo_dynamic_range"):
+                    dim_range = [
+                        dr for dr in e._dynamo_dynamic_range if dr.dim == i
+                    ].pop()
+                    if dim_range.min is None and dim_range.max is None:
+                        constraint_dim = RelaxedUnspecConstraint(warn_only=False)
+                    else:
+                        from torch.fx.experimental.symbolic_shapes import (
+                            StrictMinMaxConstraint,
+                        )
+
+                        constraint_dim = StrictMinMaxConstraint(
+                            vr=ValueRanges(lower=dim_range.min, upper=dim_range.max),
+                            warn_only=False,
+                        )
+                else:
+                    constraint_dim = RelaxedUnspecConstraint(warn_only=False)
+
             elif not marked_static and automatic_dynamic:
                 constraint_dim = RelaxedUnspecConstraint(warn_only=True)
             else: