Expose torch.export.dynamic_dim() API (#107635)
With updated doc
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107635
Approved by: https://github.com/avikchaudhuri
diff --git a/docs/source/export.rst b/docs/source/export.rst
index a3a46c7..824fe94 100644
--- a/docs/source/export.rst
+++ b/docs/source/export.rst
@@ -8,6 +8,7 @@
.. automodule:: torch.export
.. autofunction:: export
+.. autofunction:: dynamic_dim
.. toctree::
:glob:
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 0b161fe9..22587ae 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -59,40 +59,7 @@
)
from .wrappers import _wrap_submodules
-# Note - [On Export Dynamic Dimension UX]
-#
-# After a lot of discussion, we have settled on a dynamic marking API
-# for export that meets the following constraints:
-# 1) Stateless
-# 2) Safe for numerous .export calls within a single process
-# 3) Simple to use
-# 4) Can be extended to constraints easily
-#
-# While the underlying API is still torch._dynamo.mark_dynamic, we offer a higher
-# level API that meets the constraints above.
-#
-# This API produces an object that is meant to be passed into torch._dynamo.export
-# constraints field. See docs on torch._dynamo.export for more details.
-#
-# Note - The output type and structure here is NOT BC and NOT A CONTRACT, we reserve
-# the right to change the output here at any time, and will do so as we extend the API.
-#
-# result = torch._dynamo.export(
-# my_model,
-# constraints=[
-# # if you do only dynamic_dim, this is sugar for
-# # -Inf <= dynamic_dim(blah, 0) <= Inf; we don’t otherwise
-# # permit direct int->bool conversion
-# dynamic_dim(blah, 0),
-# # operator overloading because it makes it clear whether
-# # or not you’re inclusive-exclusive range or not
-# 0 <= dynamic_dim(blah, 1) <= 100,
-# # NB: But we actually truncate ranges to be >= 2, because of
-# # 0/1 specialization
-# ]
-# )(
-# *sixtyfour_tensors,
-# )
+
def dynamic_dim(t: torch.Tensor, index: int):
if not isinstance(t, torch.Tensor):
raise UserError(
diff --git a/torch/export/__init__.py b/torch/export/__init__.py
index afa9e68..4ed6f00 100644
--- a/torch/export/__init__.py
+++ b/torch/export/__init__.py
@@ -4,10 +4,82 @@
__all__ = [
+ "dynamic_dim",
"export",
]
+def dynamic_dim(t: torch.Tensor, index: int):
+ """
+ `dynamic_dim` constructs a `Constraint` object that describes the dynamism of
+ a dimension `index` of tensor `t`. `Constraint` objects should be passed to
+ `constraints` argument of `export()`.
+
+ Specifically `dynamic_dim` can be used to express following types of dynamism.
+
+ - Size of a dimension is dynamic and unbounded::
+
+ t0 = torch.rand(2, 3)
+ t1 = torch.rand(3, 4)
+
+ # First dimension of t0 can be dynamic size rather than always being static size 2
+ constraints = [dynamic_dim(t0, 0)]
+ ep = export(fn, (t0, t1), constraints=constraints)
+
+ - Size of a dimension is dynamic with a lower bound::
+
+ t0 = torch.rand(10, 3)
+ t1 = torch.rand(3, 4)
+
+ # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive)
+ # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive)
+ constraints = [
+ dynamic_dim(t0, 0) >= 5,
+ dynamic_dim(t1, 1) > 2,
+ ]
+ ep = export(fn, (t0, t1), constraints=constraints)
+
+ - Size of a dimension is dynamic with an upper bound::
+
+ t0 = torch.rand(10, 3)
+ t1 = torch.rand(3, 4)
+
+ # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive)
+ # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive)
+ constraints = [
+ dynamic_dim(t0, 0) <= 16,
+ dynamic_dim(t1, 1) < 8,
+ ]
+ ep = export(fn, (t0, t1), constraints=constraints)
+
+ - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension::
+
+ t0 = torch.rand(10, 3)
+ t1 = torch.rand(3, 4)
+
+ # Sizes of second dimension of t0 and first dimension are always equal
+ constraints = [
+ dynamic_dim(t0, 1) == dynamic_dim(t1, 0),
+ ]
+ ep = export(fn, (t0, t1), constraints=constraints)
+
+ - Mix and match all types above as long as they do not express conflicting requirements
+
+ Args:
+ t (torch.Tensor): Example input tensor that have dynamic dimension size(s)
+ index (int): Index of dynamic dimension
+
+ Returns:
+ A `Constraint` object that describes shape dynamism. It can be passed to `export()` so
+ that `export()` does not assume static size of specified tensor, i.e. keeping it dynamic
+ as a symbolic size rather than specializing according to size of example tracing input.
+
+ """
+ from torch._export import dynamic_dim
+
+ return dynamic_dim(t, index)
+
+
def export(
f: Callable,
args: Tuple[Any],
@@ -97,7 +169,7 @@
Note:
If you want to preserve dynamic branching behavior based on value or
shape of torch.Tensor in the generated graph, you will need to use
- `torch._export.dynamic_dim` to make a dimension of input tensor to be dynamic
+ `torch.export.dynamic_dim` to make a dimension of input tensor to be dynamic
and rewrite the source code using control flow operations like
`torch.ops.higher_order.cond`.
@@ -114,7 +186,7 @@
Because static shape use cases are more dominant, `export()` chooses to
assume shapes are all static by default unless there are explicit user
instructions that say otherwise. Specifically, users can use
- `torch._export.dynamic_dim` to give a hint to `export()` about dynamism
+ `torch.export.dynamic_dim` to give a hint to `export()` about dynamism
and range of an input tensor dimension.
2. Dynamic Control Flow
@@ -142,7 +214,7 @@
- Assumptions on static shapes of input tensors are automatically validated without additional effort.
- Assumptions on dynamic shape of input tensors require explicit `Input Constraint`
- constructed with `torch._export.dynamic_dim` APIs
+ constructed with `torch.export.dynamic_dim` APIs
- Assumptions on range of intermediate values require explicit `Inline Constraint`,
constructed use `constrain_as_size` and `constraint_as_value` APIs.
@@ -194,9 +266,9 @@
constraints: An optional list of constraints on the dynamic arguments
that specify their possible range of shapes. By default, shapes of
input torch.Tensors are assumed to be static. If an input torch.Tensor
- is expected to have dynamic shapes, please use `torch._export.dynamic_dim()`
+ is expected to have dynamic shapes, please use `torch.export.dynamic_dim()`
to define `Constraint` objects that specify the dynamics and the possible
- range of shapes. See torch._export.dynamic_dim() docstring for examples on
+ range of shapes. See torch.export.dynamic_dim() docstring for examples on
how to use it.
Returns:
diff --git a/torch/overrides.py b/torch/overrides.py
index 7bb7b3f..1fd26d5 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -147,6 +147,7 @@
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
+ torch.export.dynamic_dim,
torch.export.export,
torch.eye,
torch.fft.fftfreq,