| # mypy: allow-untyped-defs |
| import dataclasses |
| import inspect |
| import sys |
| from collections import defaultdict |
| from enum import auto, Enum |
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union |
| |
| import torch |
| from torch.utils._pytree import ( |
| _get_node_type, |
| BUILTIN_TYPES, |
| keystr, |
| LeafSpec, |
| MappingKey, |
| SequenceKey, |
| SUPPORTED_NODES, |
| tree_flatten, |
| tree_map_with_path, |
| ) |
| |
| from .exported_program import ExportedProgram |
| |
| |
| if TYPE_CHECKING: |
| from sympy import Symbol |
| |
| from torch._guards import Source |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint |
| |
| __all__ = [ |
| "Constraint", |
| "DIM", |
| "Dim", |
| "dims", |
| "refine_dynamic_shapes_from_suggested_fixes", |
| ] |
| |
| |
| class DIM(Enum): |
| """ |
| Enum for automatic/static dynamic shapes. |
| """ |
| |
| STATIC = auto() |
| AUTO = auto() |
| |
| |
| class _Dim(type): |
| """ |
| Metaclass for :func:`Dim` types. |
| """ |
| |
| @staticmethod |
| def readable(name, min_, max_): |
| from torch.utils._sympy.numbers import int_oo |
| |
| if min_ == 2: |
| min_ = None |
| if max_ == int_oo: |
| max_ = None |
| if min_ is None and max_ is None: |
| return f"Dim('{name}')" |
| if min_ is None: |
| return f"Dim('{name}', max={max_})" |
| if max_ is None: |
| return f"Dim('{name}', min={min_})" |
| return f"Dim('{name}', min={min_}, max={max_})" |
| |
| def __add__(cls, other): |
| # e.g., dim + 1 |
| if type(other) is not int: |
| raise NotImplementedError( |
| f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| return cls._derive(lambda x: x + other) |
| |
| def __radd__(cls, other): |
| return cls + other |
| |
| def __sub__(cls, other): |
| # e.g., dim - 1 |
| if type(other) is not int: |
| raise NotImplementedError( |
| f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| return cls._derive(lambda x: x - other) |
| |
| def __rsub__(cls, other): |
| raise NotImplementedError( |
| f"Attempted to negate {cls.__name__}. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| |
| def __mul__(cls, other): |
| # e.g., dim * 2 |
| if type(other) is not int or other <= 0: |
| raise NotImplementedError( |
| f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| return cls._derive(lambda x: x * other) |
| |
| def __rmul__(cls, other): |
| return cls * other |
| |
| def _derived_name(cls, fn): |
| from sympy import sympify |
| |
| return str(fn(sympify(cls.__name__))) |
| |
| def _derive(cls, fn): |
| return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) |
| |
| |
| class _StaticDim(_Dim): |
| """ |
| Meta class for static :func:`Dim` types. |
| |
| This class is only for setting and checking static dim constraints, |
| and the user should never interact with it. |
| """ |
| |
| @property |
| def min(self): |
| return self.value # type: ignore[attr-defined] |
| |
| @property |
| def max(self): |
| return self.value # type: ignore[attr-defined] |
| |
| |
| class _DerivedDim(_Dim): |
| """ |
| Metaclass for derived :func:`Dim` types. |
| |
| Currently we only support increasing linear expressions with integer coefficients. |
| In other words, a derived Dim can always be written in the form Ax + B, where |
| x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. |
| (In particular, the latter ensures that x < y => Ax + B < Ay + B.) |
| These restrictions on the form of derived Dims makes the metatheory simpler: e.g., |
| it simplifies computing ranges for derived Dims, solving for underlying regular Dims, |
| deciding equalities between derived Dims, and so on. |
| |
| The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. |
| The range of a derived Dim is computed by mapping `fn` over the range of its `root`. |
| """ |
| |
| @property |
| def min(self): |
| # assume that self.fn is an increasing function |
| # TODO(avik): use sympy value range analysis instead? |
| from sympy import Integer |
| |
| from torch.utils._sympy.numbers import int_oo |
| |
| if self.root.min is -int_oo: # type: ignore[attr-defined] |
| return -int_oo # fn not needed cuz increasing |
| |
| _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] |
| root = self.root # type: ignore[attr-defined] |
| assert _min_symint >= 0, ( |
| f"Expected derived min value of {self.__name__} to be >= 0. " |
| f"Please specify an appropriate min value for {root.__name__} " |
| f"(currently {root.min})." |
| ) |
| return int(_min_symint) |
| |
| @property |
| def max(self): |
| # assume that self.fn is an increasing function |
| # TODO(avik): use sympy value range analysis instead? |
| from sympy import Integer |
| |
| from torch.utils._sympy.numbers import int_oo |
| |
| if self.root.max is int_oo: # type: ignore[attr-defined] |
| return int_oo # fn not needed cuz increasing |
| |
| _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] |
| root = self.root # type: ignore[attr-defined] |
| assert _max_symint <= sys.maxsize - 1, ( |
| f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " |
| f"Please specify an appropriate max value for {root.__name__} " |
| f"(currently {root.max})." |
| ) |
| return int(_max_symint) |
| |
| def _derive(self, fn): |
| # We support nesting, e.g., 2*dim + 1. |
| # This is implemented by composing operations on the same root. |
| # As a consequence, roots are always regular Dims (i.e., not derived Dims). |
| return _DerivedDim( |
| self._derived_name(fn), |
| (int,), |
| {"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] |
| ) |
| |
| |
| def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): |
| """ |
| :func:`Dim` constructs a type analogous to a named symbolic integer with a range. |
| It can be used to describe multiple possible values of a dynamic tensor dimension. |
| Note that different dynamic dimensions of the same tensor, or of different tensors, |
| can be described by the same type. |
| |
| Args: |
| name (str): Human-readable name for debugging. |
| min (Optional[int]): Minimum possible value of given symbol (inclusive) |
| max (Optional[int]): Maximum possible value of given symbol (inclusive) |
| |
| Returns: |
| A type that can be used in dynamic shape specifications for tensors. |
| """ |
| from torch.utils._sympy.numbers import int_oo |
| |
| _min = 0 if min is None else min |
| _max = int_oo if max is None else max |
| assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" |
| assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}" |
| dim = _Dim(name, (int,), {"min": _min, "max": _max}) |
| dim.__module__ = getattr( |
| inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" |
| ) |
| return dim |
| |
| |
| def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): |
| """ |
| Util to create multiple :func:`Dim` types. |
| """ |
| return tuple(Dim(name, min=min, max=max) for name in names) |
| |
| |
| @dataclasses.dataclass |
| class _ConstraintTarget: |
| """ |
| This represents input tensor dimensions. |
| """ |
| |
| t_id: int |
| dim: int |
| |
| |
| @dataclasses.dataclass |
| class _Constraint(_ConstraintTarget): |
| """ |
| This represents a Dim describing a constraint target. |
| |
| `name` is the name of the Dim. |
| `constraint_range` contains the min/max bounds of the Dim. |
| """ |
| |
| name: str |
| constraint_range: "StrictMinMaxConstraint" |
| |
| def _clone_with_range(self, lower=0, upper=None): |
| # Import sympy locally |
| from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
| from torch.utils._sympy.numbers import int_oo |
| from torch.utils._sympy.value_ranges import ValueRanges |
| |
| if upper is None: |
| upper = int_oo |
| |
| constraint_range = StrictMinMaxConstraint( |
| vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), |
| warn_only=False, |
| ) |
| return _Constraint( |
| self.t_id, |
| self.dim, |
| self.name, |
| constraint_range, |
| ) |
| |
| def __ge__(self, lower): |
| return self._clone_with_range(lower=lower) |
| |
| def __gt__(self, lower): |
| return self._clone_with_range(lower=lower + 1) |
| |
| def __le__(self, upper): |
| return self._clone_with_range(upper=upper) |
| |
| def __lt__(self, upper): |
| return self._clone_with_range(upper=upper - 1) |
| |
| def __bool__(self): |
| # NOTE(avik): We do not support compound expressions like a <= x <= b. |
| # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), |
| # and moreover, enforces that any overload of __bool__ must return True or False. |
| # FWIW, sympy also raises TypeError in this case. |
| raise TypeError( |
| "Cannot determine truth value of _Constraint. " |
| "If you are trying to combine _Constraint's with logical connectives, " |
| "you can specify them separately instead." |
| ) |
| |
| @property |
| def serializable_spec(self): |
| # We need a serialization compatible format of the constraint so that it |
| # can be savedin the graph module w/o breaking the module serialization. |
| # The saved constraints will be used directly for the post-exporting pass |
| # that converts constraints to runtime assertion. The saved constraints |
| # will not be saved in the serialized module. |
| # TODO: A better way is needed. Currently we use 't_id' to map the constraint, |
| # which is not reliable |
| return { |
| "t_id": self.t_id, |
| "dim": self.dim, |
| "min": self.constraint_range.vr.lower, |
| "max": self.constraint_range.vr.upper, |
| } |
| |
| |
| @dataclasses.dataclass |
| class _PhantomRoot: |
| """ |
| This represents the root of a derived Dim where the root does not directly |
| specify the shape of any input dimension, but the derived Dim does. |
| |
| e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. |
| |
| The fields `name`, `constraint_range`, and `val` carried by a phantom root |
| help create a symbol for it. Any derived dims with this phantom root are |
| backed by expressions over this symbol. |
| """ |
| |
| name: str |
| constraint_range: "StrictMinMaxConstraint" |
| val: int |
| |
| |
| @dataclasses.dataclass |
| class _DerivedConstraint(_ConstraintTarget): |
| """ |
| This represents a derived Dim, whose root is either a regular constraint target |
| (which directly specifies the shape of some input dimension) or a phantom root |
| (which does so indirectly). |
| |
| It can be thought of as a subclass of `_Constraint`, except that it does not |
| support <, <=, >, >= operations. |
| """ |
| |
| name: str |
| constraint_range: "StrictMinMaxConstraint" |
| root: Union[_ConstraintTarget, _PhantomRoot] |
| fn: Callable |
| |
| @property |
| def serializable_spec(self): |
| # same as _Constraint.serializable_spec |
| return { |
| "t_id": self.t_id, |
| "dim": self.dim, |
| "min": self.constraint_range.vr.lower, |
| "max": self.constraint_range.vr.upper, |
| } |
| |
| |
| Constraint = Union[_Constraint, _DerivedConstraint] |
| |
| |
| def _process_equalities( |
| constraint: Constraint, |
| get_sources: Callable[[int, int], List["Source"]], |
| shape_env: "ShapeEnv", |
| names: Dict[str, Tuple[int, int]], |
| source_pairs: List[Tuple["Source", "Source"]], |
| derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], |
| phantom_symbols: Dict[str, "Symbol"], |
| ): |
| """ |
| Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become |
| fields of `EqualityConstraint`) based on a given input `constraint`. |
| """ |
| |
| sources = get_sources(constraint.t_id, constraint.dim) |
| if not sources: # empty sources due to unused shapes |
| return |
| |
| source, *other_sources = sources |
| # When t.size()[dim] maps to src0, src1, ..., srcN, we add |
| # constraints that make src0 "equal" to src1, ..., srcN. |
| source_pairs.extend((source, other_source) for other_source in other_sources) |
| if not isinstance(constraint, _DerivedConstraint): |
| if constraint.name in names: |
| shared_t_id, shared_dim = names[constraint.name] |
| other_sources = get_sources(shared_t_id, shared_dim) |
| source_pairs.extend( |
| (source, other_source) for other_source in other_sources |
| ) |
| else: |
| names[constraint.name] = (constraint.t_id, constraint.dim) |
| else: |
| # branch based on the root of the _DerivedConstraint |
| if not isinstance(constraint.root, _PhantomRoot): |
| # either root points to an input source |
| root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] |
| else: |
| # or root points to a phantom symbol |
| if constraint.root.name in phantom_symbols: |
| root = phantom_symbols[constraint.root.name] # type: ignore[assignment] |
| else: |
| # create a phantom symbol in the shape env based on the _PhantomRoot |
| root = shape_env.create_symbol( |
| val=constraint.root.val, |
| source=torch._dynamo.source.ConstantSource(constraint.root.name), |
| dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, |
| constraint_dim=constraint.root.constraint_range, |
| ) |
| phantom_symbols[constraint.root.name] = root # type: ignore[assignment] |
| |
| fn = constraint.fn |
| # A derived equality (source, root, fn) informally corresponds to source = fn(root). |
| # Here source describes an input and root might describe another input or a phantom symbol. |
| derived_equalities.append((source, root, fn)) |
| |
| |
| def _tree_map_with_path( |
| func: Callable[..., Any], |
| tree: Any, |
| *dynamic_shapes: Any, |
| tree_name: Optional[str] = None, |
| ) -> Any: |
| """ |
| Customized tree_map for mapping pytrees to dynamic_shapes. |
| |
| For built-in types (e.g., standard collections) this behaves exactly like tree_map. |
| |
| OTOH for a user-defined class C registered with pytree, we cannot assume that a C |
| containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not |
| be a polymorphic container). In that case we use the flattened form of C instead. |
| Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes). |
| |
| Args: |
| func: function to apply to each (int, float, str, bool, None, torch.Tensor) |
| tree: input pytree |
| dynamic_shapes: zero or more (typically one) dynamic_shapes to match |
| |
| Returns: |
| output pytree mapping func to each (int, float, str, bool, None, torch.Tensor) |
| """ |
| |
| def is_leaf(t): |
| # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types |
| # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types |
| # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES, |
| # as well as user-defined classes registered with pytree, which are. |
| return _get_node_type(t) not in BUILTIN_TYPES |
| |
| def f(path, t, *dynamic_shapes): |
| typ = _get_node_type(t) |
| # typ is not in BUILTIN_TYPES |
| if typ in SUPPORTED_NODES: |
| # thus typ is a user-defined class registered with pytree, |
| # in which case flatten and recurse |
| return tree_map_with_path( |
| f, |
| SUPPORTED_NODES[typ].flatten_fn(t)[0], |
| *dynamic_shapes, |
| is_leaf=is_leaf, |
| ) |
| else: |
| return func(path, t, *dynamic_shapes) |
| |
| try: |
| return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf) |
| except ValueError as e: |
| if "mismatch" in e.args[0]: |
| # When PyTree finds a structural mismatch between tree and dynamic_shapes, |
| # the error message is unfortunately quite horrible. Let's fix that. |
| assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes" |
| assert tree_name, "Must provide a tree_name when there might be a mismatch" |
| |
| def _key(type_, context, i): |
| # derive a PyTree key given the type, context, and child # of a TreeSpec |
| if type_ is dict: |
| return MappingKey(context[i]) |
| if type_ in (list, tuple): |
| assert context is None |
| return SequenceKey(i) |
| raise AssertionError(f"Did not expect type {type_}") |
| |
| def raise_mismatch_error(msg): |
| from torch._dynamo.exc import UserError, UserErrorType |
| |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}", |
| case_name="dynamic_shapes_validation", |
| ) |
| |
| def _compare(tree, dynamic_shapes, path): |
| # raise an error at the point where tree and dynamic_shapes differ, |
| # including the path to that point and the reason for the difference |
| rendered_path = keystr(path) |
| if isinstance(tree, LeafSpec): |
| return |
| if isinstance(dynamic_shapes, LeafSpec): |
| raise_mismatch_error( |
| f"`{tree_name}{rendered_path}` is a {tree.type}, " |
| f"but `dynamic_shapes{rendered_path}` is not" |
| ) |
| if tree.type != dynamic_shapes.type: |
| raise_mismatch_error( |
| f"`{tree_name}{rendered_path}` is a {tree.type}, " |
| f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" |
| ) |
| if len(tree.children_specs) != len(dynamic_shapes.children_specs): |
| raise_mismatch_error( |
| f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " |
| f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" |
| ) |
| if tree.type is dict: |
| # context, children could be out of order |
| if sorted(tree.context) != sorted(dynamic_shapes.context): |
| raise_mismatch_error( |
| f"`{tree_name}{rendered_path}` has keys {tree.context}, " |
| f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" |
| ) |
| _remap = dict( |
| zip(dynamic_shapes.context, dynamic_shapes.children_specs) |
| ) |
| dynamic_shapes_children_specs = [_remap[k] for k in tree.context] |
| else: |
| dynamic_shapes_children_specs = dynamic_shapes.children_specs |
| for i, (tree_, dynamic_shapes_) in enumerate( |
| zip(tree.children_specs, dynamic_shapes_children_specs) |
| ): |
| _compare( |
| tree_, |
| dynamic_shapes_, |
| path + [_key(tree.type, tree.context, i)], |
| ) |
| |
| _, tree_spec = tree_flatten(tree, is_leaf=is_leaf) |
| for other_tree in dynamic_shapes: |
| _, other_tree_spec = tree_flatten(other_tree, is_leaf) |
| _compare(tree_spec, other_tree_spec, []) |
| raise |
| |
| |
| def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> Dict[str, Any]: |
| # combine args and kwargs following the signature of f, as it happens |
| # in the body of f when called with *args, **kwargs |
| if isinstance(f, ExportedProgram): |
| f = f.module() |
| if not _is_torch_jit_trace: |
| signature = ( |
| inspect.signature(f.forward) |
| if isinstance(f, torch.nn.Module) |
| else inspect.signature(f) |
| ) |
| kwargs = kwargs if kwargs is not None else {} |
| return signature.bind(*args, **kwargs).arguments |
| return args |
| |
| |
| class ShapesCollection: |
| """ |
| Builder for dynamic_shapes. |
| Used to assign dynamic shape specifications to tensors that appear in inputs. |
| |
| Example:: |
| args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) |
| |
| dim = torch.export.Dim(...) |
| dynamic_shapes = torch.export.ShapesCollection() |
| dynamic_shapes[tensor_x] = (dim, dim + 1, 8) |
| dynamic_shapes[tensor_y] = {0: dim * 2} |
| # This is equivalent to the following (now auto-generated): |
| # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} |
| |
| torch.export(..., args, dynamic_shapes=dynamic_shapes) |
| """ |
| |
| def __init__(self): |
| self._shapes = {} |
| |
| def __setitem__(self, t, shape): |
| assert isinstance( |
| t, torch.Tensor |
| ), f"Cannot assign shape to non-tensor type {type(t)}" |
| # TODO(avik): check that shape is indeed a Shape |
| t_id = id(t) |
| if t_id in self._shapes: |
| _shape = self._shapes[t_id] |
| assert ( |
| shape == _shape |
| ), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}" |
| else: |
| self._shapes[id(t)] = shape |
| |
| def __getitem__(self, t): |
| t_id = id(t) |
| if t_id in self._shapes: |
| return self._shapes[t_id] |
| else: |
| return None |
| |
| def __len__(self): |
| return len(self._shapes) |
| |
| def dynamic_shapes(self, m, args, kwargs=None): |
| """ |
| Generate dynamic_shapes. |
| """ |
| |
| t_ids = set() |
| |
| def find_shape(path, t): |
| t_id = id(t) |
| if t_id in self._shapes: |
| t_ids.add(t_id) |
| return self._shapes[t_id] |
| else: |
| return None |
| |
| combined_args = _combine_args(m, args, kwargs) |
| dynamic_shapes = _tree_map_with_path(find_shape, combined_args) |
| if any(t_id not in t_ids for t_id in self._shapes): |
| raise ValueError( |
| "Some tensors that were assigned shapes were not found in args. " |
| "Maybe such tensors were copied when passing them as args? " |
| "Maybe such tensors are contained in classes that were not registered with pytree?" |
| ) |
| return dynamic_shapes |
| |
| |
| def _check_dynamic_shapes( |
| combined_args: Dict[str, Any], |
| dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], |
| ): |
| """ |
| Checks the dynamic_shapes specification for correctness, |
| using combined args + kwargs as reference for inputs structure. |
| """ |
| from torch._dynamo.exc import UserError, UserErrorType |
| from torch._export.non_strict_utils import _flatten_dynamic_shapes |
| |
| if dynamic_shapes is None or len(dynamic_shapes) == 0: |
| return |
| if isinstance(dynamic_shapes, (tuple, list)): |
| combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] |
| |
| bounds: Dict[str, Tuple[int, int]] = {} |
| |
| def check_same_bounds(dim): |
| if dim.__name__ in bounds: |
| min_, max_ = bounds[dim.__name__] |
| if dim.min != min_ or dim.max != max_: |
| this_ = _Dim.readable(dim.__name__, min_, max_) |
| that_ = _Dim.readable(dim.__name__, dim.min, dim.max) |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Found different definitions {this_} and {that_} " |
| f"for the same symbolic dimension {dim}!", |
| ) |
| else: |
| bounds[dim.__name__] = (dim.min, dim.max) |
| |
| def check_symbols(path, tensor, shape): |
| if isinstance(shape, dict): |
| for i, dim in shape.items(): |
| if isinstance(dim, _Dim): |
| check_same_bounds(dim) |
| elif not (isinstance(dim, (int, DIM)) or dim is None): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " |
| f"specified at `dynamic_shapes{keystr(path)}` " |
| f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", |
| case_name="dynamic_shapes_validation", |
| ) |
| elif isinstance(shape, (tuple, list)): |
| for i, dim in enumerate(shape): |
| if isinstance(dim, _Dim): |
| check_same_bounds(dim) |
| elif not (isinstance(dim, (int, DIM)) or dim is None): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Unexpected dimension #{i} in input tensor shape {shape} " |
| f"specified at `dynamic_shapes{keystr(path)}` " |
| f"(expected None, an int, a Dim, DIM.AUTO, or DIM.STATIC, but got {dim} instead)", |
| case_name="dynamic_shapes_validation", |
| ) |
| elif shape is not None: |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " |
| f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," |
| f" where each dimension is None, an int, a Dim, DIM.AUTO, or DIM.STATIC)", |
| case_name="dynamic_shapes_validation", |
| ) |
| |
| assert isinstance(dynamic_shapes, (dict, tuple, list)) |
| if isinstance(dynamic_shapes, dict): |
| got_keys = list(dynamic_shapes.keys()) |
| expected_arg_names = list(combined_args.keys()) |
| if sorted(got_keys) != sorted(expected_arg_names): |
| msg = ( |
| f"When `dynamic_shapes` is specified as a dict, its top-level keys " |
| f"must be the arg names {expected_arg_names} of `inputs`, but " |
| f"here they are {got_keys}. " |
| ) |
| if ( |
| len(combined_args) == 1 |
| and expected_arg_names[0] not in got_keys |
| and isinstance(combined_args[expected_arg_names[0]], dict) |
| ): |
| msg += ( |
| "Since here `inputs` is a list/tuple enclosing a single dict, " |
| "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?" |
| ) |
| else: |
| msg += ( |
| "Alternatively, you could also ignore arg names entirely " |
| "and specify `dynamic_shapes` as a list/tuple matching `inputs`." |
| ) |
| raise UserError( |
| UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation" |
| ) |
| |
| def check_shape(path, t, dynamic_shape): |
| if isinstance(t, torch.Tensor): |
| check_symbols(path, t, dynamic_shape) |
| else: |
| if dynamic_shape is not None: |
| rendered_path = keystr(path) |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` " |
| f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)", |
| case_name="dynamic_shapes_validation", |
| ) |
| |
| _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") |
| |
| # raise user warning if both DIM.AUTO & Dims are specified in dynamic_shapes |
| flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) |
| flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) |
| if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( |
| s == DIM.AUTO for s in flatter_dynamic_shapes |
| ): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| "Specifying both `DIM.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " |
| "and can easily lead to constraint violation errors or obscure errors in torch.export. Dim/DerivedDims " |
| "expect all equal or related dimensions to be specified, and does not yet compose well with `DIM.AUTO`. " |
| "We suggest using `DIM.AUTO` mixed with `None` for auto-dynamic + static shapes, plus torch._check(dim >= min), " |
| "torch._check(dim <= max) calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `None` " |
| "if you want to assert on the exact specification of your program's dynamic shapes behavior.", |
| case_name="dynamic_shapes_validation", |
| ) |
| |
| |
| def _transform_shapes_for_default_dynamic( |
| combined_args: Dict[str, Any], |
| dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], |
| ) -> Union[Dict[str, Any], Tuple[Any], List[Any], None]: |
| """ |
| In the long run this might not be needed, but this exists because export.export() and _dynamo.export() |
| historically have different semantics for how dynamic_shapes are specified, but go through the same |
| process of producing constraints, and now both use assume_static_by_default=False. |
| |
| For _dynamo.export(), the semantics for dynamic_shapes are: |
| - None: dynamic, allocated a symbol |
| - Dim/DerivedDim: a strict assertion on the min/max range for this symbol, and require a specification |
| for all dims governed by this symbol (i.e. relations, equality, linear relations, etc.) |
| |
| For export.export(), historically dynamism for unspecified dims has been undesirable, so the semantics are: |
| - DIM.AUTO: dynamic, allocated a symbol |
| - None/unspecified/DIM.STATIC: static |
| - Dim/DerivedDims: also a strict assertion |
| |
| To allow both APIs to follow the same process for producing constraints, this function converts dynamic_shapes |
| for export.export() to be compatible with _process_dynamic_shapes() and assume_static_by_default=False, turning them |
| into essentially what they'd look like for _dynamo.export(). |
| |
| An example conversion might look like, for a 3-d input tensor: |
| |
| input spec: { |
| 0: DIM.AUTO, |
| 1: None, # or DIM.STATIC |
| 2: Dim("dx"), |
| } |
| output spec: { |
| 0: None, # None: dynamic by default |
| 1: 32, # explicitly provide static shape |
| 2: Dim("dx"), # remains the same |
| } |
| """ |
| |
| def _tree_map_helper(tree, val): |
| """ |
| If the user generally specifies dynamic_shapes=None for a pytree input, |
| we'd like to convert this into a tree of Nones following the input spec, |
| so we can explicitly specify static dims for all tensor dimensions. |
| Non-builtin types for pytree (e.g. custom dataclasses) creates some difficulty, |
| in which case the correct format is a list containing specs for each child attribute. |
| """ |
| if (node_type := _get_node_type(tree)) not in SUPPORTED_NODES: # is_leaf |
| return val |
| flatten_fn = SUPPORTED_NODES[node_type].flatten_fn |
| child_pytrees, context = flatten_fn(tree) # flatten from whatever original type |
| unflatten_fn = SUPPORTED_NODES[ |
| node_type if node_type in BUILTIN_TYPES else list |
| ].unflatten_fn |
| children = [_tree_map_helper(child, val) for child in child_pytrees] |
| return unflatten_fn( |
| children, context |
| ) # unflatten into original type, or list if not built-in type |
| |
| if ( |
| dynamic_shapes is None or len(dynamic_shapes) == 0 |
| ): # create pytree structure of static dim |
| dynamic_shapes = _tree_map_helper(combined_args, None) |
| if isinstance(dynamic_shapes, (tuple, list)): |
| combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] |
| |
| def transform_shapes(path, tensor, shape): |
| def _marked_dynamic(tensor, i): |
| # TODO(pianpwk): deprecate mark_dynamic() usage for export |
| return i in getattr(tensor, "_dynamo_dynamic_indices", set()) |
| |
| out: Union[None, List[Any], Dict[int, Any]] = None |
| if isinstance(shape, dict): |
| out = {} |
| for i, val in enumerate(tensor.shape): |
| dim = shape.get(i, None) |
| if _marked_dynamic(tensor, i) or dim == DIM.AUTO: |
| # don't have to specify anything if dynamic |
| # None also works, since assume_static_by_default=False |
| continue |
| elif isinstance(dim, _Dim): |
| out[i] = dim |
| elif isinstance(dim, int): |
| # important that this is dim and not val, |
| # so we can raise error if user-specified dim != val |
| out[i] = dim |
| else: |
| # make explicitly static |
| assert dim is None or dim == DIM.STATIC |
| out[i] = val |
| elif isinstance(shape, (tuple, list)): |
| out = [] |
| for i, val in enumerate(tensor.shape): |
| dim = shape[i] |
| if _marked_dynamic(tensor, i) or dim == DIM.AUTO: |
| out.append(None) |
| elif isinstance(dim, _Dim): |
| out.append(dim) |
| elif isinstance(dim, int): |
| out.append(dim) |
| else: |
| assert dim is None or dim == DIM.STATIC |
| out.append(val) |
| out = type(shape)(out) # type: ignore[assignment] |
| else: |
| assert shape is None |
| if isinstance(tensor, torch.Tensor): |
| out = [] |
| for i, val in enumerate(tensor.shape): |
| out.append(None if _marked_dynamic(tensor, i) else val) |
| out = out or None |
| else: |
| out = None |
| return out |
| |
| def transform_shape(path, t, dynamic_shape): |
| if isinstance(t, torch.Tensor): |
| return transform_shapes(path, t, dynamic_shape) |
| |
| result = _tree_map_with_path( |
| transform_shape, combined_args, dynamic_shapes, tree_name="inputs" |
| ) |
| return result |
| |
| |
| def _process_dynamic_shapes( |
| combined_args: Dict[str, Any], |
| dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], |
| ) -> List[Constraint]: |
| """ |
| Reads the dynamic_shapes specification and produces a list of constraints. |
| """ |
| from torch._dynamo.exc import UserError, UserErrorType |
| |
| if dynamic_shapes is None or len(dynamic_shapes) == 0: |
| # we run with dynamic by default, so no need to produce constraints |
| return [] |
| if isinstance(dynamic_shapes, (tuple, list)): |
| combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] |
| |
| # map of Dim names representing input shape dimensions to constraints on them |
| symbols: Dict[str, List[Constraint]] = defaultdict(list) |
| # track roots that do not directly represent input shape dimensions |
| phantom_roots: Dict[str, _PhantomRoot] = {} |
| derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] |
| |
| def to_constraint(dim, tensor, i): |
| import sympy |
| |
| from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
| from torch.utils._sympy.solve import try_solve |
| from torch.utils._sympy.value_ranges import ValueRanges |
| |
| def root_value(): |
| # given tensor.shape[i] is the value of dim = fn(root), |
| # find the value of root |
| symbol = sympy.Symbol(dim.root.__name__, integer=True) |
| expr = dim.fn(symbol) |
| solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) |
| if solution is not None: |
| return int(solution[1]) # type: ignore[call-overload] |
| else: |
| raise UserError( # noqa: B904 |
| UserErrorType.CONSTRAINT_VIOLATION, |
| f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " |
| f"of the form {expr}, where {symbol} is an integer", |
| ) |
| |
| if isinstance(dim, _DerivedDim): |
| # generate a _DerivedConstraint where the root is: |
| # - either a _ConstraintTarget (if dim.root directly describes an input shape) |
| # - or a _PhantomRoot (otherwise) |
| dim_root = dim.root # type: ignore[attr-defined] |
| if dim_root.__name__ in symbols: |
| # root represents an input shape dimension |
| root_constraint = symbols[dim_root.__name__][0] |
| root = _ConstraintTarget( |
| root_constraint.t_id, |
| root_constraint.dim, |
| ) |
| elif dim_root.__name__ not in phantom_roots: |
| # create a phantom root |
| root = _PhantomRoot( # type: ignore[assignment] |
| name=dim_root.__name__, |
| constraint_range=StrictMinMaxConstraint( |
| vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), |
| warn_only=False, |
| ), |
| val=root_value(), |
| ) |
| phantom_roots[dim_root.__name__] = root # type: ignore[assignment] |
| else: |
| root = phantom_roots[dim_root.__name__] # type: ignore[assignment] |
| constraint = _DerivedConstraint( |
| id(tensor), |
| i, |
| dim.__name__, |
| StrictMinMaxConstraint( |
| vr=ValueRanges(lower=dim.min, upper=dim.max), |
| warn_only=False, |
| ), |
| root, |
| dim.fn, # type: ignore[attr-defined] |
| ) |
| if isinstance(root, _PhantomRoot): |
| # NOTE(avik): since we have not processed all inputs yet, we may replace this |
| # with a root that does represent an input shape dimension later (see below) |
| derived_constraints_with_phantom_root.append(constraint) |
| elif isinstance(dim, _StaticDim): |
| constraint = _Constraint( # type: ignore[assignment] |
| id(tensor), |
| i, |
| dim.__name__, |
| StrictMinMaxConstraint( |
| vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined] |
| ), |
| ) |
| else: |
| constraint = _Constraint( # type: ignore[assignment] |
| id(tensor), |
| i, |
| dim.__name__, |
| StrictMinMaxConstraint( |
| vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined] |
| ), |
| ) |
| return constraint |
| |
| def update_symbols(path, tensor, shape): |
| def _create_static_dim(tensor, i, value): |
| return _StaticDim(str(value), (int,), {"value": value}) |
| |
| if isinstance(shape, dict): |
| for i, dim in shape.items(): |
| if isinstance(dim, (int, _Dim)): |
| if isinstance(dim, int): |
| dim = _create_static_dim(tensor, i, dim) |
| constraint = to_constraint(dim, tensor, i) |
| symbols[dim.__name__].append(constraint) |
| elif isinstance(shape, (tuple, list)): |
| for i, dim in enumerate(shape): |
| if isinstance(dim, (int, _Dim)): |
| if isinstance(dim, int): |
| dim = _create_static_dim(tensor, i, dim) |
| constraint = to_constraint(dim, tensor, i) |
| symbols[dim.__name__].append(constraint) |
| |
| def assoc_shape(path, t, dynamic_shape): |
| if isinstance(t, torch.Tensor): |
| update_symbols(path, t, dynamic_shape) |
| |
| _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs") |
| |
| constraints = [] |
| for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: |
| phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] |
| if phantom_root_name in symbols: |
| # We found an input shape dimension corresponding to this name, so we |
| # do not need a phantom symbol for it after all. |
| # NOTE(avik): Overall we want to maintain the invariant that roots that |
| # are phantom symbols are really "phantom," i.e., they cannot be represented |
| # by any input source. This is important when we are deciding derived equalities, |
| # since we can focus our attention exclusively on input sources: deciding |
| # derived equalities involving phantom symbols are, in comparison, trivial. |
| derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] |
| |
| for dynamic_dims in symbols.values(): |
| constraints.extend(dynamic_dims) |
| |
| return constraints # type: ignore[return-value] |
| |
| |
| def _get_dim_name_mapping( |
| dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None] |
| ): |
| name_to_dim = {} |
| for dim in tree_flatten( |
| dynamic_shapes, |
| is_leaf=lambda x: isinstance(x, _Dim), |
| )[0]: |
| if isinstance(dim, (int, DIM)) or dim is None: |
| continue |
| name_to_dim[dim.__name__] = dim |
| if isinstance(dim, _DerivedDim): |
| name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] |
| return name_to_dim |
| |
| |
| def refine_dynamic_shapes_from_suggested_fixes( |
| msg: str, |
| dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], |
| ) -> Union[Dict[str, Any], Tuple[Any], List[Any]]: |
| """ |
| For working with export's dynamic shapes suggested fixes, and/or automatic dynamic shapes. |
| Refines the given dynamic shapes spec, given a ConstraintViolation error message and the original dynamic shapes. |
| |
| For most cases behavior is straightforward - i.e. for suggested fixes that specialize or refine a Dim's range, |
| or fixes that suggest a derived relation, the new dynamic shapes spec will be updated as such. |
| |
| e.g. |
| Suggested fixes: |
| |
| dim = Dim('dim', min=3, max=6) -> this just refines the dim's range |
| dim = 4 -> this specializes to a constant |
| dy = dx + 1 -> dy was specified as an independent dim, but is actually tied to dx with this relation |
| |
| However, suggested fixes associated with derived dims can be more complicated. |
| For example, if a suggested fix is provided for a root dim, the new derived dim value is evaluated based on the root. |
| |
| e.g. |
| dx = Dim('dx') |
| dy = dx + 2 |
| dynamic_shapes = {"x": (dx,), "y": (dy,)} |
| |
| Suggested fixes: |
| |
| dx = 4 # specialization will lead to dy also specializing = 6 |
| dx = Dim('dx', max=6) # dy now has max = 8 |
| |
| Derived dims suggested fixes can also be used to express divisibility constraints. |
| This involves creating new root dims that aren't tied to a particular input shape. |
| In this case the root dims won't appear directly in the new spec, but as a root of |
| one of the dims. |
| |
| e.g. |
| Suggested fixes: |
| |
| _dx = Dim('_dx', max=1024) # this won't appear in the return result, but dx will |
| dx = 4*_dx # dx is now divisible by 4, with a max value of 4096 |
| """ |
| |
| import re |
| |
| import sympy |
| |
| from torch._dynamo.exc import UserError, UserErrorType |
| from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence |
| |
| try: |
| shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() |
| except Exception as exc: |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", |
| ) from exc |
| |
| # build shape_fixes dictionary |
| shape_fixes = {} |
| for fix in shape_fixes_msg.split("\n"): |
| fix = fix.strip() |
| if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): |
| name = match.group(1) |
| _min, _max = None, None |
| if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): |
| _min = int(match_min.group(1)) |
| if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): |
| _max = int(match_max.group(1)) |
| shape_fixes[name] = Dim(name, min=_min, max=_max) |
| else: |
| name, expr = fix.split(" = ") |
| expr = sympy.sympify(expr) |
| if isinstance(expr, sympy.Number): |
| shape_fixes[name] = int(expr) # static, integer |
| else: |
| shape_fixes[name] = expr # relation or derived dim |
| |
| name_to_dim = _get_dim_name_mapping(dynamic_shapes) |
| |
| # track derived dim roots |
| roots: Set[str] = set() |
| for k, c in shape_fixes.items(): |
| assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr)) |
| if isinstance(c, sympy.Expr): # check dim/derived dim expression |
| assert _is_supported_equivalence(c) |
| shape_fixes[k] = c |
| roots.add(str(next(iter(c.free_symbols)))) |
| if isinstance(c, _DerivedDim): |
| roots.add(c.root.__name__) # type: ignore[attr-defined] |
| |
| # check keys are existing dims or new roots |
| for k, c in shape_fixes.items(): |
| assert k in name_to_dim or k in roots |
| |
| # cache so we don't produce multiple derived dim objects |
| derived_dim_cache: Dict[str, _DerivedDim] = {} |
| |
| def apply_fixes(path, dim, dummy): |
| if dim is None or isinstance(dim, int): # not dynamic |
| return dim |
| elif dim.__name__ in shape_fixes: # directly fix |
| fix = shape_fixes[dim.__name__] |
| if isinstance(fix, sympy.Expr): # now derived or related |
| if str(fix) in derived_dim_cache: |
| return derived_dim_cache[str(fix)] |
| else: |
| symbol = next(iter(fix.free_symbols)) |
| # try to locate symbol |
| if symbol.name in shape_fixes: # type: ignore[attr-defined] |
| root = shape_fixes[symbol.name] # type: ignore[attr-defined] |
| else: |
| assert symbol.name in name_to_dim # type: ignore[attr-defined] |
| root = name_to_dim[symbol.name] # type: ignore[attr-defined] |
| # figure out value of fix |
| modulus, remainder = sympy.polys.polytools.div(fix, symbol) |
| dim = root |
| if modulus != 1: |
| dim = int(modulus) * dim |
| if remainder != 0: |
| dim = dim + int(remainder) |
| derived_dim_cache[str(fix)] = dim |
| return dim |
| else: |
| return fix |
| elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] |
| if dim.__name__ in derived_dim_cache: |
| return derived_dim_cache[dim.__name__] |
| else: # evaluate new derived value based on root |
| _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] |
| derived_dim_cache[dim.__name__] = _dim |
| return _dim |
| return dim # unchanged dim |
| |
| return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes) |