Expose torch.export.dynamic_dim() API (#107635)

With updated doc

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107635
Approved by: https://github.com/avikchaudhuri
diff --git a/docs/source/export.rst b/docs/source/export.rst
index a3a46c7..824fe94 100644
--- a/docs/source/export.rst
+++ b/docs/source/export.rst
@@ -8,6 +8,7 @@
 
 .. automodule:: torch.export
 .. autofunction:: export
+.. autofunction:: dynamic_dim
 
 .. toctree::
    :glob:
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 0b161fe9..22587ae 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -59,40 +59,7 @@
 )
 from .wrappers import _wrap_submodules
 
-# Note - [On Export Dynamic Dimension UX]
-#
-# After a lot of discussion, we have settled on a dynamic marking API
-# for export that meets the following constraints:
-# 1) Stateless
-# 2) Safe for numerous .export calls within a single process
-# 3) Simple to use
-# 4) Can be extended to constraints easily
-#
-# While the underlying API is still torch._dynamo.mark_dynamic, we offer a higher
-# level API that meets the constraints above.
-#
-# This API produces an object that is meant to be passed into torch._dynamo.export
-# constraints field. See docs on torch._dynamo.export for more details.
-#
-# Note - The output type and structure here is NOT BC and NOT A CONTRACT, we reserve
-# the right to change the output here at any time, and will do so as we extend the API.
-#
-# result = torch._dynamo.export(
-#     my_model,
-#     constraints=[
-#         # if you do only dynamic_dim, this is sugar for
-#         # -Inf <= dynamic_dim(blah, 0) <= Inf; we don’t otherwise
-#         # permit direct int->bool conversion
-#         dynamic_dim(blah, 0),
-#         # operator overloading because it makes it clear whether
-#         # or not you’re inclusive-exclusive range or not
-#         0 <= dynamic_dim(blah, 1) <= 100,
-#         # NB: But we actually truncate ranges to be >= 2, because of
-#         # 0/1 specialization
-#     ]
-# )(
-#     *sixtyfour_tensors,
-# )
+
 def dynamic_dim(t: torch.Tensor, index: int):
     if not isinstance(t, torch.Tensor):
         raise UserError(
diff --git a/torch/export/__init__.py b/torch/export/__init__.py
index afa9e68..4ed6f00 100644
--- a/torch/export/__init__.py
+++ b/torch/export/__init__.py
@@ -4,10 +4,82 @@
 
 
 __all__ = [
+    "dynamic_dim",
     "export",
 ]
 
 
+def dynamic_dim(t: torch.Tensor, index: int):
+    """
+    `dynamic_dim` constructs a `Constraint` object that describes the dynamism of
+    a dimension `index` of tensor `t`. `Constraint` objects should be passed to
+    `constraints` argument of `export()`.
+
+    Specifically `dynamic_dim` can be used to express following types of dynamism.
+
+    - Size of a dimension is dynamic and unbounded::
+
+        t0 = torch.rand(2, 3)
+        t1 = torch.rand(3, 4)
+
+        # First dimension of t0 can be dynamic size rather than always being static size 2
+        constraints = [dynamic_dim(t0, 0)]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Size of a dimension is dynamic with a lower bound::
+
+        t0 = torch.rand(10, 3)
+        t1 = torch.rand(3, 4)
+
+        # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
+        # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
+        constraints = [
+            dynamic_dim(t0, 0) >= 5,
+            dynamic_dim(t1, 1) > 2,
+        ]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Size of a dimension is dynamic with an upper bound::
+
+        t0 = torch.rand(10, 3)
+        t1 = torch.rand(3, 4)
+
+        # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
+        # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
+        constraints = [
+            dynamic_dim(t0, 0) <= 16,
+            dynamic_dim(t1, 1) < 8,
+        ]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension::
+
+        t0 = torch.rand(10, 3)
+        t1 = torch.rand(3, 4)
+
+        # Sizes of second dimension of t0 and first dimension are always equal
+        constraints = [
+            dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
+        ]
+        ep = export(fn, (t0, t1), constraints=constraints)
+
+    - Mix and match all types above as long as they do not express conflicting requirements
+
+    Args:
+        t (torch.Tensor): Example input tensor that have dynamic dimension size(s)
+        index (int): Index of dynamic dimension
+
+    Returns:
+        A `Constraint` object that describes shape dynamism. It can be passed to `export()` so
+        that `export()` does not assume static size of specified tensor, i.e. keeping it dynamic
+        as a symbolic size rather than specializing according to size of example tracing input.
+
+    """
+    from torch._export import dynamic_dim
+
+    return dynamic_dim(t, index)
+
+
 def export(
     f: Callable,
     args: Tuple[Any],
@@ -97,7 +169,7 @@
     Note:
     If you want to preserve dynamic branching behavior based on value or
     shape of torch.Tensor in the generated graph, you will need to use
-    `torch._export.dynamic_dim` to make a dimension of input tensor to be dynamic
+    `torch.export.dynamic_dim` to make a dimension of input tensor to be dynamic
     and rewrite the source code using control flow operations like
     `torch.ops.higher_order.cond`.
 
@@ -114,7 +186,7 @@
     Because static shape use cases are more dominant, `export()` chooses to
     assume shapes are all static by default unless there are explicit user
     instructions that say otherwise. Specifically, users can use
-    `torch._export.dynamic_dim` to give a hint to `export()` about dynamism
+    `torch.export.dynamic_dim` to give a hint to `export()` about dynamism
     and range of an input tensor dimension.
 
     2. Dynamic Control Flow
@@ -142,7 +214,7 @@
 
     - Assumptions on static shapes of input tensors are automatically validated without additional effort.
     - Assumptions on dynamic shape of input tensors require explicit `Input Constraint`
-      constructed with `torch._export.dynamic_dim` APIs
+      constructed with `torch.export.dynamic_dim` APIs
     - Assumptions on range of intermediate values require explicit `Inline Constraint`,
       constructed use `constrain_as_size` and `constraint_as_value` APIs.
 
@@ -194,9 +266,9 @@
         constraints: An optional list of constraints on the dynamic arguments
          that specify their possible range of shapes. By default, shapes of
          input torch.Tensors are assumed to be static. If an input torch.Tensor
-         is expected to have dynamic shapes, please use `torch._export.dynamic_dim()`
+         is expected to have dynamic shapes, please use `torch.export.dynamic_dim()`
          to define `Constraint` objects that specify the dynamics and the possible
-         range of shapes. See torch._export.dynamic_dim() docstring for examples on
+         range of shapes. See torch.export.dynamic_dim() docstring for examples on
          how to use it.
 
     Returns:
diff --git a/torch/overrides.py b/torch/overrides.py
index 7bb7b3f..1fd26d5 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -147,6 +147,7 @@
         torch.empty_permuted,
         torch.empty_strided,
         torch.empty_quantized,
+        torch.export.dynamic_dim,
         torch.export.export,
         torch.eye,
         torch.fft.fftfreq,