[export] refactor _process_dynamic_shapes (#133391)
Sorryyyyy for another refactor. This splits `_process_dynamic_shapes` into 3 parts:
1. `_combine_args` - mostly the same thing
2. `_check_dynamic_shapes`, which is responsible for raising 99% of UserErrors if the dynamic shapes spec is invalid (minus 1 UserError with DerivedDims)
3. `_process_dynamic_shapes`, which for now, is the same thing, minus the stuff in 2.
This refactor is helpful for incoming automatic dynamic shapes work, because, we're switching to `assume_static_by_default=False`, which is what `_dynamo.export` currently does. This means any unspecified dims are allocated a symbol, in contrast to export today which keeps unspecified dims static. Historically this has been desirable - export users don't want too much dynamism. So we want to change how the spec is translated into constraints.
This means when we switch over to automatic dynamic shapes, we want to plug in something in between steps 2. and 3. which patches up the spec for `assume_static_by_default=False`, filling in static shapes for any unspecified dims, and potentially clearing out the auto-dynamic dims (since they're no-ops). We would do this in-between 2. and 3. to keep `_process_dynamic_shapes` semantically the same, since it's used with `_dynamo.export`.
We could do this without a refactor, plugging in this transform before `_process_dynamic_shapes`, but since that function's responsible for both spec checking + constraint production, moving spec checking to before we transform the specs helps guarantee we're raising errors on what the user's specified, and not an internal export bug.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133391
Approved by: https://github.com/avikchaudhuri
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 815cb21..fc07797 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -2756,7 +2756,7 @@
x = torch.tensor([3])
y = torch.randn([8, 8, 6])
- example_inputs = [x, y]
+ example_inputs = (x, y)
dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)})
gm, _ = torch._dynamo.export(
f,
@@ -2766,7 +2766,7 @@
)(*example_inputs)
constraints = torch.export.dynamic_shapes._process_dynamic_shapes(
- f, example_inputs, dynamic_shapes=dynamic_shapes
+ {"x": x, "y": y}, dynamic_shapes=dynamic_shapes
)
self.assertEqual(
gm.meta["input_shape_constraints"],
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 0c0062d..28e5c16 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -55,7 +55,11 @@
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._utils_internal import justknobs_check, log_export_usage
-from torch.export.dynamic_shapes import _process_dynamic_shapes
+from torch.export.dynamic_shapes import (
+ _check_dynamic_shapes,
+ _combine_args,
+ _process_dynamic_shapes,
+)
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
@@ -1302,7 +1306,9 @@
_assume_static_by_default = assume_static_by_default
def inner(*args, **kwargs):
- constraints = _process_dynamic_shapes(_f, args, kwargs, dynamic_shapes)
+ combined_args = _combine_args(_f, args, kwargs)
+ _check_dynamic_shapes(combined_args, dynamic_shapes)
+ constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
f = _f
assume_static_by_default = _assume_static_by_default
check_if_dynamo_supported()
diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py
index 9470d67..37a93b7 100644
--- a/torch/_export/non_strict_utils.py
+++ b/torch/_export/non_strict_utils.py
@@ -20,7 +20,12 @@
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Constraint
-from torch.export.dynamic_shapes import _tree_map_with_path
+from torch.export.dynamic_shapes import (
+ _check_dynamic_shapes,
+ _combine_args,
+ _process_dynamic_shapes,
+ _tree_map_with_path,
+)
from torch.export.graph_signature import CustomObjArgument
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
@@ -120,10 +125,11 @@
# - output_graph.py fakifies inputs.
# - [post-tracing] guards.py processes input shape equalities.
- constraints = torch.export.dynamic_shapes._process_dynamic_shapes(
- nn_module, args, kwargs, dynamic_shapes, _is_torch_jit_trace=_is_torch_jit_trace
+ combined_args = _combine_args(
+ nn_module, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace
)
- constraints = constraints or []
+ _check_dynamic_shapes(combined_args, dynamic_shapes)
+ constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
t_constraints: Dict[int, Dict[int, Constraint]] = defaultdict(dict)
for constraint in constraints:
t_constraints[constraint.t_id][constraint.dim] = constraint
diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py
index 8bb0369..58d2470 100644
--- a/torch/export/dynamic_shapes.py
+++ b/torch/export/dynamic_shapes.py
@@ -796,17 +796,128 @@
return dynamic_shapes
-def _process_dynamic_shapes(
- f: Callable,
- args: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]] = None,
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
- _is_torch_jit_trace=False,
-) -> Optional[List[Constraint]]:
+def _check_dynamic_shapes(
+ combined_args: Dict[str, Any],
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
+):
+ """
+ Checks the dynamic_shapes specification for correctness,
+ using combined args + kwargs as reference for inputs structure.
+ """
from torch._dynamo.exc import UserError, UserErrorType
if dynamic_shapes is None or len(dynamic_shapes) == 0:
- return None
+ return
+ if isinstance(dynamic_shapes, (tuple, list)):
+ combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
+
+ bounds: Dict[str, Tuple[int, int]] = {}
+
+ def check_same_bounds(dim):
+ if dim.__name__ in bounds:
+ min_, max_ = bounds[dim.__name__]
+ if dim.min != min_ or dim.max != max_:
+ this_ = _Dim.readable(dim.__name__, min_, max_)
+ that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
+ raise UserError(
+ UserErrorType.INVALID_INPUT,
+ f"Found different definitions {this_} and {that_} "
+ f"for the same symbolic dimension {dim}!",
+ )
+ else:
+ bounds[dim.__name__] = (dim.min, dim.max)
+
+ def check_symbols(path, tensor, shape):
+ if isinstance(shape, dict):
+ for i, dim in shape.items():
+ if isinstance(dim, _Dim):
+ check_same_bounds(dim)
+ elif not isinstance(dim, int) and dim is not None:
+ raise UserError(
+ UserErrorType.INVALID_INPUT,
+ f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "
+ f"specified at `dynamic_shapes{keystr(path)}` "
+ f"(expected None, an int, or a Dim, but got {dim} instead)",
+ case_name="dynamic_shapes_validation",
+ )
+ elif isinstance(shape, (tuple, list)):
+ for i, dim in enumerate(shape):
+ if isinstance(dim, _Dim):
+ check_same_bounds(dim)
+ elif not isinstance(dim, int) and dim is not None:
+ raise UserError(
+ UserErrorType.INVALID_INPUT,
+ f"Unexpected dimension #{i} in input tensor shape {shape} "
+ f"specified at `dynamic_shapes{keystr(path)}` "
+ f"(expected None, an int, or a Dim, but got {dim} instead)",
+ case_name="dynamic_shapes_validation",
+ )
+ elif shape is not None:
+ raise UserError(
+ UserErrorType.INVALID_INPUT,
+ f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "
+ f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
+ f" where each dimension is None, an int, or a Dim)",
+ case_name="dynamic_shapes_validation",
+ )
+
+ assert isinstance(dynamic_shapes, (dict, tuple, list))
+ if isinstance(dynamic_shapes, dict):
+ got_keys = list(dynamic_shapes.keys())
+ expected_arg_names = list(combined_args.keys())
+ if sorted(got_keys) != sorted(expected_arg_names):
+ msg = (
+ f"When `dynamic_shapes` is specified as a dict, its top-level keys "
+ f"must be the arg names {expected_arg_names} of `inputs`, but "
+ f"here they are {got_keys}. "
+ )
+ if (
+ len(combined_args) == 1
+ and expected_arg_names[0] not in got_keys
+ and isinstance(combined_args[expected_arg_names[0]], dict)
+ ):
+ msg += (
+ "Since here `inputs` is a list/tuple enclosing a single dict, "
+ "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
+ )
+ else:
+ msg += (
+ "Alternatively, you could also ignore arg names entirely "
+ "and specify `dynamic_shapes` as a list/tuple matching `inputs`."
+ )
+ raise UserError(
+ UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
+ )
+
+ def check_shape(path, t, dynamic_shape):
+ if isinstance(t, torch.Tensor):
+ check_symbols(path, t, dynamic_shape)
+ else:
+ if dynamic_shape is not None:
+ rendered_path = keystr(path)
+ raise UserError(
+ UserErrorType.INVALID_INPUT,
+ f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` "
+ f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)",
+ case_name="dynamic_shapes_validation",
+ )
+
+ _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
+
+
+def _process_dynamic_shapes(
+ combined_args: Dict[str, Any],
+ dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
+) -> List[Constraint]:
+ """
+ Reads the dynamic_shapes specification and produces a list of constraints.
+ """
+ from torch._dynamo.exc import UserError, UserErrorType
+
+ if dynamic_shapes is None or len(dynamic_shapes) == 0:
+ return []
+ if isinstance(dynamic_shapes, (tuple, list)):
+ combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
# map of Dim names representing input shape dimensions to constraints on them
symbols: Dict[str, List[Constraint]] = defaultdict(list)
@@ -897,23 +1008,6 @@
constraint = constraint <= dim.max
return constraint
- bounds: Dict[str, Tuple[int, int]] = {}
-
- def check_same_bounds(dim):
- if dim.__name__ in symbols:
- min_, max_ = bounds[dim.__name__]
- if dim.min != min_ or dim.max != max_:
- this_ = _Dim.readable(dim.__name__, min_, max_)
- that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Found different definitions {this_} and {that_} "
- f"for the same symbolic dimension {dim}!",
- )
-
- else:
- bounds[dim.__name__] = (dim.min, dim.max)
-
def update_symbols(path, tensor, shape):
def _create_static_dim(tensor, i, value):
return _StaticDim(str(value), (int,), {"value": value})
@@ -923,98 +1017,21 @@
if isinstance(dim, (int, _Dim)):
if isinstance(dim, int):
dim = _create_static_dim(tensor, i, dim)
- check_same_bounds(dim)
constraint = to_constraint(dim, tensor, i)
symbols[dim.__name__].append(constraint)
- else:
- if dim is not None:
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Unexpected dimension mapped to index {i} in input tensor shape {shape} "
- f"specified at `dynamic_shapes{keystr(path)}` "
- f"(expected None, an int, or a Dim, but got {dim} instead)",
- case_name="dynamic_shapes_validation",
- )
elif isinstance(shape, (tuple, list)):
for i, dim in enumerate(shape):
if isinstance(dim, (int, _Dim)):
if isinstance(dim, int):
dim = _create_static_dim(tensor, i, dim)
- check_same_bounds(dim)
constraint = to_constraint(dim, tensor, i)
symbols[dim.__name__].append(constraint)
- else:
- if dim is not None:
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Unexpected dimension #{i} in input tensor shape {shape} "
- f"specified at `dynamic_shapes{keystr(path)}` "
- f"(expected None, an int, or a Dim, but got {dim} instead)",
- case_name="dynamic_shapes_validation",
- )
- else:
- if shape is not None:
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` "
- f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions,"
- f" where each dimension is None, an int, or a Dim)",
- case_name="dynamic_shapes_validation",
- )
- def assoc_shapes(combined_args, dynamic_shapes):
- def assoc_shape(path, t, dynamic_shape):
- if isinstance(t, torch.Tensor):
- update_symbols(path, t, dynamic_shape)
- else:
- if dynamic_shape is not None:
- rendered_path = keystr(path)
- raise UserError(
- UserErrorType.INVALID_INPUT,
- f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` "
- f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)",
- case_name="dynamic_shapes_validation",
- )
+ def assoc_shape(path, t, dynamic_shape):
+ if isinstance(t, torch.Tensor):
+ update_symbols(path, t, dynamic_shape)
- _tree_map_with_path(
- assoc_shape, combined_args, dynamic_shapes, tree_name="inputs"
- )
-
- combined_args = _combine_args(
- f, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace
- )
- if isinstance(dynamic_shapes, dict):
- got_keys = list(dynamic_shapes.keys())
- expected_arg_names = list(combined_args.keys())
- if sorted(got_keys) != sorted(expected_arg_names):
- # This error would be caught by `assoc_shapes` below, but we can give
- # a more helpful error message here.
- msg = (
- f"When `dynamic_shapes` is specified as a dict, its top-level keys "
- f"must be the arg names {expected_arg_names} of `inputs`, but "
- f"here they are {got_keys}. "
- )
- if (
- len(combined_args) == 1
- and expected_arg_names[0] not in got_keys
- and isinstance(combined_args[expected_arg_names[0]], dict)
- ):
- msg += (
- "Since here `inputs` is a list/tuple enclosing a single dict, "
- "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?"
- )
- else:
- msg += (
- "Alternatively, you could also ignore arg names entirely "
- "and specify `dynamic_shapes` as a list/tuple matching `inputs`."
- )
- raise UserError(
- UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation"
- )
- else:
- assert isinstance(dynamic_shapes, (tuple, list))
- combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
- assoc_shapes(combined_args, dynamic_shapes)
+ _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs")
constraints = []
for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: