| import builtins |
| import dataclasses |
| import inspect |
| import math |
| import sys |
| import weakref |
| from collections import defaultdict |
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union |
| |
| import torch |
| from torch._subclasses.fake_tensor import FakeTensor |
| from torch.utils._pytree import SUPPORTED_NODES |
| |
| from .exported_program import ExportedProgram |
| |
| if TYPE_CHECKING: |
| from sympy import Symbol |
| |
| from torch._guards import Source |
| |
| from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint |
| |
| __all__ = ["Constraint", "Dim", "dims", "dynamic_dim"] |
| |
| |
| class _Dim(type): |
| """ |
| Metaclass for :func:`Dim` types. |
| """ |
| |
| @staticmethod |
| def readable(name, min_, max_): |
| if min_ == 2: |
| min_ = None |
| if max_ == sys.maxsize - 1: |
| max_ = None |
| if min_ is None and max_ is None: |
| return f"Dim('{name}')" |
| if min_ is None: |
| return f"Dim('{name}', max={max_})" |
| if max_ is None: |
| return f"Dim('{name}', min={min_})" |
| return f"Dim('{name}', min={min_}, max={max_})" |
| |
| def __add__(cls, other): |
| # e.g., dim + 1 |
| if type(other) is not int: |
| raise NotImplementedError( |
| f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| return cls._derive(lambda x: x + other) |
| |
| def __radd__(cls, other): |
| return cls + other |
| |
| def __sub__(cls, other): |
| # e.g., dim - 1 |
| if type(other) is not int: |
| raise NotImplementedError( |
| f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| return cls._derive(lambda x: x - other) |
| |
| def __rsub__(cls, other): |
| raise NotImplementedError( |
| f"Attempted to negate {cls.__name__}. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| |
| def __mul__(cls, other): |
| # e.g., dim * 2 |
| if type(other) is not int or other <= 0: |
| raise NotImplementedError( |
| f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " |
| "(Only increasing linear operations with integer coefficients are supported.)" |
| ) |
| return cls._derive(lambda x: x * other) |
| |
| def __rmul__(cls, other): |
| return cls * other |
| |
| def _derived_name(cls, fn): |
| from sympy import sympify |
| |
| return str(fn(sympify(cls.__name__))) |
| |
| def _derive(cls, fn): |
| return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) |
| |
| |
| class _DerivedDim(_Dim): |
| """ |
| Metaclass for derived :func:`Dim` types. |
| |
| Currently we only support increasing linear expressions with integer coefficients. |
| In other words, a derived Dim can always be written in the form Ax + B, where |
| x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. |
| (In particular, the latter ensures that x < y => Ax + B < Ay + B.) |
| These restrictions on the form of derived Dims makes the metatheory simpler: e.g., |
| it simplifies computing ranges for derived Dims, solving for underlying regular Dims, |
| deciding equalities between derived Dims, and so on. |
| |
| The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. |
| The range of a derived Dim is computed by mapping `fn` over the range of its `root`. |
| """ |
| |
| @property |
| def min(self): |
| # assume that self.fn is an increasing function |
| # TODO(avik): use sympy value range analysis instead? |
| from sympy import Integer |
| |
| _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] |
| assert _min_symint >= 0, ( |
| f"Expected derived min value of {self.__name__} to be >= 0. " |
| f"Please specify an appropriate min value for {self.root.__name__} " # type: ignore[attr-defined] |
| f"(currently {self.root.min})." # type: ignore[attr-defined] |
| ) |
| return int(_min_symint) |
| |
| @property |
| def max(self): |
| # assume that self.fn is an increasing function |
| # TODO(avik): use sympy value range analysis instead? |
| from sympy import Integer |
| |
| _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] |
| assert _max_symint <= sys.maxsize - 1, ( |
| f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " |
| f"Please specify an appropriate max value for {self.root.__name__} " # type: ignore[attr-defined] |
| f"(currently {self.root.max})." # type: ignore[attr-defined] |
| ) |
| return int(_max_symint) |
| |
| def _derive(self, fn): |
| # We support nesting, e.g., 2*dim + 1. |
| # This is implemented by composing operations on the same root. |
| # As a consequence, roots are always regular Dims (i.e., not derived Dims). |
| return _DerivedDim( |
| self._derived_name(fn), |
| (int,), |
| {"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] |
| ) |
| |
| |
| def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): |
| """ |
| :func:`Dim` constructs a type analogous to a named symbolic integer with a range. |
| It can be used to describe multiple possible values of a dynamic tensor dimension. |
| Note that different dynamic dimensions of the same tensor, or of different tensors, |
| can be described by the same type. |
| |
| Args: |
| name (str): Human-readable name for debugging. |
| min (Optional[int]): Minimum possible value of given symbol (inclusive) |
| max (Optional[int]): Maximum possible value of given symbol (inclusive) |
| |
| Returns: |
| A type that can be used in dynamic shape specifications for tensors. |
| """ |
| _min = 0 if min is None else min |
| _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) |
| assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" |
| dim = _Dim(name, (int,), {"min": _min, "max": _max}) |
| dim.__module__ = getattr( |
| inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" |
| ) |
| return dim |
| |
| |
| def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): |
| """ |
| Util to create multiple :func:`Dim` types. |
| """ |
| return tuple(Dim(name, min=min, max=max) for name in names) |
| |
| |
| @dataclasses.dataclass |
| class _ConstraintTarget: |
| """ |
| This represents input tensor dimensions. Don't create this |
| class directly; instead, use :func:`dynamic_dim`. |
| """ |
| |
| w_tensor: Any # weakref to torch.Tensor |
| # TODO: We don't need t_id; we can get it off of w_tensor |
| t_id: int |
| dim: int |
| |
| |
| class _ConstraintFactory(type): |
| """ |
| Metaclass that ensures a private constructor for :class:`_Constraint` |
| """ |
| |
| def __call__(cls, *args, **kwargs): |
| raise TypeError( |
| f"{cls.__module__}.{cls.__qualname__} has no public constructor. " |
| f"Please use torch.export.dynamic_dim() to create one" |
| ) |
| |
| def _create( |
| cls, w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None |
| ): |
| return super().__call__( |
| w_tensor, t_id, dim, constraint_range, shared, debug_name |
| ) |
| |
| |
| def _create_constraint( |
| w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None |
| ): |
| return _Constraint._create( |
| w_tensor, t_id, dim, constraint_range, shared, debug_name |
| ) |
| |
| |
| @dataclasses.dataclass |
| class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): |
| """ |
| |
| .. warning:: |
| Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead. |
| |
| This represents constraints on input tensor dimensions, e.g., requiring |
| them to be fully polymorphic or within some range. |
| |
| """ |
| |
| # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>] |
| constraint_range: "StrictMinMaxConstraint" |
| # Represent that `constraint_range` is shared with another _ConstraintTarget, which |
| # typically arises because of a specified equality with another dynamic dimension. |
| shared: Optional[_ConstraintTarget] = None |
| debug_name: Optional[str] = None |
| |
| def _clone_with_range(self, lower=0, upper=math.inf): |
| # Import sympy locally |
| from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
| from torch.utils._sympy.value_ranges import ValueRanges |
| |
| constraint_range = StrictMinMaxConstraint( |
| vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), |
| warn_only=False, |
| ) |
| return _create_constraint( |
| self.w_tensor, |
| self.t_id, |
| self.dim, |
| constraint_range, |
| self.shared, |
| self.debug_name, |
| ) |
| |
| def __ge__(self, lower): |
| return self._clone_with_range(lower=lower) |
| |
| def __gt__(self, lower): |
| return self._clone_with_range(lower=lower + 1) |
| |
| def __le__(self, upper): |
| return self._clone_with_range(upper=upper) |
| |
| def __lt__(self, upper): |
| return self._clone_with_range(upper=upper - 1) |
| |
| def __bool__(self): |
| # NOTE(avik): We do not support compound expressions like a <= x <= b. |
| # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), |
| # and moreover, enforces that any overload of __bool__ must return True or False. |
| # FWIW, sympy also raises TypeError in this case. |
| raise TypeError( |
| "Cannot determine truth value of _Constraint. " |
| "If you are trying to combine _Constraint's with logical connectives, " |
| "you can specify them separately instead." |
| ) |
| |
| @property |
| def serializable_spec(self): |
| # We need a serialization compatible format of the constraint so that it |
| # can be savedin the graph module w/o breaking the module serialization. |
| # The saved constraints will be used directly for the post-exporting pass |
| # that converts constraints to runtime assertion. The saved constraints |
| # will not be saved in the serialized module. |
| # TODO: A better way is needed. Currently we use 't_id' to map the constraint, |
| # which is not reliable |
| return { |
| "t_id": self.t_id, |
| "dim": self.dim, |
| "min": self.constraint_range.vr.lower, |
| "max": self.constraint_range.vr.upper, |
| } |
| |
| def __eq__(self, other): |
| if not isinstance(other, _Constraint): |
| raise TypeError( |
| "A dynamic dim can be specified equal only to another dynamic dim. " |
| f"Equality with {type(other)} is not supported." |
| ) |
| |
| # import sympy locally |
| from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
| |
| constraint_range = StrictMinMaxConstraint( |
| vr=self.constraint_range.vr & other.constraint_range.vr, |
| warn_only=False, |
| ) |
| if self.debug_name is None: |
| debug_name = other.debug_name |
| else: |
| assert other.debug_name is None or self.debug_name == other.debug_name |
| debug_name = self.debug_name |
| return _create_constraint( |
| self.w_tensor, |
| self.t_id, |
| self.dim, |
| constraint_range, |
| shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim), |
| debug_name=debug_name, |
| ) |
| |
| |
| @dataclasses.dataclass |
| class _PhantomRoot: |
| """ |
| This represents the root of a derived Dim where the root does not directly |
| specify the shape of any input dimension, but the derived Dim does. |
| |
| e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. |
| |
| The fields `name`, `constraint_range`, and `val` carried by a phantom root |
| help create a symbol for it. Any derived dims with this phantom root are |
| backed by expressions over this symbol. |
| """ |
| |
| name: str |
| constraint_range: "StrictMinMaxConstraint" |
| val: int |
| |
| |
| @dataclasses.dataclass |
| class _DerivedConstraint(_ConstraintTarget): |
| """ |
| This represents a derived Dim, whose root is either a regular constraint target |
| (which directly specifies the shape of some input dimension) or a phantom root |
| (which does so indirectly). |
| """ |
| |
| # NOTE: This is not currently a subclass of _Constraint because we do not support |
| # `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for |
| # legacy constraints based on `dynamic_dim`: equality can be expressed simply by |
| # reusing the same (derived or normal) `Dim`. |
| root: Union[_ConstraintTarget, _PhantomRoot] |
| fn: Callable |
| constraint_range: "StrictMinMaxConstraint" |
| debug_name: Optional[str] = None |
| |
| @property |
| def shared(self): |
| # Some code paths expect a union of _Constraint and _DerivedConstraint. |
| # Thus we expose a `shared` field that is always None. |
| # TODO(avik): clean this up |
| return None |
| |
| @property |
| def serializable_spec(self): |
| # same as _Constraint.serializable_spec |
| return { |
| "t_id": self.t_id, |
| "dim": self.dim, |
| "min": self.constraint_range.vr.lower, |
| "max": self.constraint_range.vr.upper, |
| } |
| |
| |
| Constraint = Union[_Constraint, _DerivedConstraint] |
| |
| |
| def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): |
| """ |
| .. warning:: |
| (This feature is DEPRECATED. See :func:`Dim` instead.) |
| |
| :func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of |
| a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to |
| ``constraints`` argument of :func:`export`. |
| |
| Args: |
| t (torch.Tensor): Example input tensor that have dynamic dimension size(s) |
| index (int): Index of dynamic dimension |
| |
| Returns: |
| A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so |
| that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic |
| as a symbolic size rather than specializing according to size of example tracing input. |
| |
| Specifically :func:`dynamic_dim` can be used to express following types of dynamism. |
| |
| - Size of a dimension is dynamic and unbounded:: |
| |
| t0 = torch.rand(2, 3) |
| t1 = torch.rand(3, 4) |
| |
| # First dimension of t0 can be dynamic size rather than always being static size 2 |
| constraints = [dynamic_dim(t0, 0)] |
| ep = export(fn, (t0, t1), constraints=constraints) |
| |
| - Size of a dimension is dynamic with a lower bound:: |
| |
| t0 = torch.rand(10, 3) |
| t1 = torch.rand(3, 4) |
| |
| # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) |
| # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) |
| constraints = [ |
| dynamic_dim(t0, 0) >= 5, |
| dynamic_dim(t1, 1) > 2, |
| ] |
| ep = export(fn, (t0, t1), constraints=constraints) |
| |
| - Size of a dimension is dynamic with an upper bound:: |
| |
| t0 = torch.rand(10, 3) |
| t1 = torch.rand(3, 4) |
| |
| # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) |
| # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) |
| constraints = [ |
| dynamic_dim(t0, 0) <= 16, |
| dynamic_dim(t1, 1) < 8, |
| ] |
| ep = export(fn, (t0, t1), constraints=constraints) |
| |
| - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: |
| |
| t0 = torch.rand(10, 3) |
| t1 = torch.rand(3, 4) |
| |
| # Sizes of second dimension of t0 and first dimension are always equal |
| constraints = [ |
| dynamic_dim(t0, 1) == dynamic_dim(t1, 0), |
| ] |
| ep = export(fn, (t0, t1), constraints=constraints) |
| |
| - Mix and match all types above as long as they do not express conflicting requirements |
| |
| """ |
| from torch._dynamo.exc import UserError, UserErrorType |
| |
| if not isinstance(t, torch.Tensor): |
| raise UserError( |
| UserErrorType.DYNAMIC_DIM, |
| f"Expected tensor as input to dynamic_dim but got {type(t)}", |
| ) |
| |
| if t.dim() < 1: |
| raise UserError( |
| UserErrorType.DYNAMIC_DIM, "Cannot mark 0-dimension tensors to be dynamic" |
| ) |
| |
| if index >= t.dim(): |
| raise UserError( |
| UserErrorType.DYNAMIC_DIM, |
| f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]" |
| f" but got {index}, which is out of bounds for the given tensor.", |
| ) |
| |
| # Import sympy locally |
| import sympy |
| |
| from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
| from torch.utils._sympy.value_ranges import ValueRanges |
| |
| return _create_constraint( |
| weakref.ref(t), |
| id(t), |
| index, |
| StrictMinMaxConstraint( |
| vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False |
| ), |
| debug_name=debug_name, |
| ) |
| |
| |
| def _process_equalities( |
| constraint: Constraint, |
| get_sources: Callable[[int, int], List["Source"]], |
| shape_env: "ShapeEnv", |
| source_pairs: List[Tuple["Source", "Source"]], |
| derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], |
| phantom_symbols: Dict[str, "Symbol"], |
| ): |
| """ |
| Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become |
| fields of `EqualityConstraint`) based on a given input `constraint`. |
| """ |
| |
| source, *other_sources = get_sources(constraint.t_id, constraint.dim) |
| # When t.size()[dim] maps to src0, src1, ..., srcN, we add |
| # constraints that make src0 "equal" to src1, ..., srcN. |
| source_pairs.extend((source, other_source) for other_source in other_sources) |
| if not isinstance(constraint, _DerivedConstraint): |
| if constraint.shared is not None: |
| # Moreover, when t.size()[dim] is specified equal to t'.size()[dim'] |
| # and t'.size()[dim'] maps to src1', ..., srcN', we add |
| # constraints that also make src0 "equal" to src1', ..., srcN'. |
| other_sources = get_sources(constraint.shared.t_id, constraint.shared.dim) |
| source_pairs.extend( |
| (source, other_source) for other_source in other_sources |
| ) |
| else: |
| # branch based on the root of the _DerivedConstraint |
| if not isinstance(constraint.root, _PhantomRoot): |
| # either root points to an input source |
| root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] |
| else: |
| # or root points to a phantom symbol |
| if constraint.root.name in phantom_symbols: |
| root = phantom_symbols[constraint.root.name] # type: ignore[assignment] |
| else: |
| # create a phantom symbol in the shape env based on the _PhantomRoot |
| root = shape_env.create_symbol( |
| val=constraint.root.val, |
| source=torch._dynamo.source.ConstantSource(constraint.root.name), |
| dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, |
| constraint_dim=constraint.root.constraint_range, |
| ) |
| phantom_symbols[constraint.root.name] = root # type: ignore[assignment] |
| |
| fn = constraint.fn |
| # A derived equality (source, root, fn) informally corresponds to source = fn(root). |
| # Here source describes an input and root might describe another input or a phantom symbol. |
| derived_equalities.append((source, root, fn)) |
| |
| |
| def _process_dynamic_shapes( |
| f: Callable, |
| args: Tuple[Any, ...], |
| kwargs: Optional[Dict[str, Any]] = None, |
| dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, |
| ) -> Optional[List[Constraint]]: |
| from collections import defaultdict |
| from collections.abc import Mapping, Sequence |
| |
| from torch._dynamo.exc import UserError, UserErrorType |
| |
| if dynamic_shapes is None or len(dynamic_shapes) == 0: |
| return None |
| |
| kwargs = kwargs if kwargs is not None else {} |
| |
| def tree_zip(combined_args, dynamic_shapes): |
| if isinstance(combined_args, (tuple, list)): |
| if not isinstance(dynamic_shapes, Sequence): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, " |
| f"got {dynamic_shapes} instead", |
| ) |
| if len(combined_args) != len(dynamic_shapes): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Expected {dynamic_shapes} to have {len(combined_args)} items", |
| ) |
| for i, shape in enumerate(dynamic_shapes): |
| yield from tree_zip(combined_args[i], shape) |
| elif isinstance(combined_args, dict): |
| if not isinstance(dynamic_shapes, Mapping): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, " |
| f"got {dynamic_shapes} instead", |
| ) |
| if len(combined_args) != len(dynamic_shapes): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Expected {dynamic_shapes} to have {len(combined_args)} items", |
| ) |
| for k, shape in dynamic_shapes.items(): |
| yield from tree_zip(combined_args[k], shape) |
| elif type(combined_args) in SUPPORTED_NODES: |
| if not isinstance(dynamic_shapes, Sequence): |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Expected dynamic_shapes of a user-registered class (e.g., " |
| f"{type(combined_args)}) to be a Sequence that matches the " |
| f"flattened structure, but got {dynamic_shapes} instead", |
| ) |
| yield from tree_zip( |
| SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0], |
| dynamic_shapes, |
| ) |
| elif isinstance(combined_args, torch.Tensor): |
| yield (combined_args, dynamic_shapes) |
| else: |
| if dynamic_shapes is not None: |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Expected dynamic_shapes of a {type(combined_args)} to be None, " |
| f"got {dynamic_shapes} instead", |
| ) |
| |
| # map of Dim names representing input shape dimensions to constraints on them |
| symbols: Dict[str, List[Constraint]] = defaultdict(list) |
| # track roots that do not directly represent input shape dimensions |
| phantom_roots: Dict[str, _PhantomRoot] = {} |
| derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] |
| |
| def to_constraint(dim, tensor, i): |
| import sympy |
| |
| from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint |
| from torch.utils._sympy.solve import try_solve |
| from torch.utils._sympy.value_ranges import ValueRanges |
| |
| def root_value(): |
| # given tensor.shape[i] is the value of dim = fn(root), |
| # find the value of root |
| symbol = sympy.Symbol(dim.root.__name__, integer=True) |
| expr = dim.fn(symbol) |
| solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) |
| if solution is not None: |
| return int(solution[1]) # type: ignore[call-overload] |
| else: |
| raise UserError( # noqa: TRY200 |
| UserErrorType.CONSTRAINT_VIOLATION, |
| f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " |
| f"of the form {expr}, where {symbol} is an integer", |
| ) |
| |
| if isinstance(dim, _DerivedDim): |
| # generate a _DerivedConstraint where the root is: |
| # - either a _ConstraintTarget (if dim.root directly describes an input shape) |
| # - or a _PhantomRoot (otherwise) |
| dim_root = dim.root # type: ignore[attr-defined] |
| if dim_root.__name__ in symbols: |
| # root represents an input shape dimension |
| root_constraint = symbols[dim_root.__name__][0] |
| root = _ConstraintTarget( |
| root_constraint.w_tensor, |
| root_constraint.t_id, |
| root_constraint.dim, |
| ) |
| elif dim_root.__name__ not in phantom_roots: |
| # create a phantom root |
| root = _PhantomRoot( # type: ignore[assignment] |
| name=dim_root.__name__, |
| constraint_range=StrictMinMaxConstraint( |
| vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), |
| warn_only=False, |
| ), |
| val=root_value(), |
| ) |
| phantom_roots[dim_root.__name__] = root # type: ignore[assignment] |
| else: |
| root = phantom_roots[dim_root.__name__] # type: ignore[assignment] |
| constraint = _DerivedConstraint( |
| weakref.ref(tensor), |
| id(tensor), |
| i, |
| root, |
| dim.fn, # type: ignore[attr-defined] |
| StrictMinMaxConstraint( |
| vr=ValueRanges(lower=dim.min, upper=dim.max), |
| warn_only=False, |
| ), |
| debug_name=dim.__name__, |
| ) |
| if isinstance(root, _PhantomRoot): |
| # NOTE(avik): since we have not processed all inputs yet, we may replace this |
| # with a root that does represent an input shape dimension later (see below) |
| derived_constraints_with_phantom_root.append(constraint) |
| else: |
| constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) |
| if dim.min != 0: |
| constraint = constraint >= dim.min |
| if dim.max != sys.maxsize - 1: |
| constraint = constraint <= dim.max |
| return constraint |
| |
| bounds: Dict[str, Tuple[int, int]] = {} |
| |
| def check_same_bounds(dim): |
| if dim.__name__ in symbols: |
| min_, max_ = bounds[dim.__name__] |
| if dim.min != min_ or dim.max != max_: |
| this_ = _Dim.readable(dim.__name__, min_, max_) |
| that_ = _Dim.readable(dim.__name__, dim.min, dim.max) |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Found different definitions {this_} and {that_} " |
| f"for the same symbolic dimension {dim}!", |
| ) |
| |
| else: |
| bounds[dim.__name__] = (dim.min, dim.max) |
| |
| def update_symbols(tensor, shape): |
| if isinstance(shape, dict): |
| for i, dim in shape.items(): |
| if isinstance(dim, _Dim): |
| check_same_bounds(dim) |
| constraint = to_constraint(dim, tensor, i) |
| symbols[dim.__name__].append(constraint) |
| else: |
| if dim is not None: |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " |
| "try None instead", |
| ) |
| elif isinstance(shape, (tuple, list)): |
| for i, dim in enumerate(shape): |
| if isinstance(dim, _Dim): |
| check_same_bounds(dim) |
| constraint = to_constraint(dim, tensor, i) |
| symbols[dim.__name__].append(constraint) |
| else: |
| if dim is not None: |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " |
| "try None instead", |
| ) |
| else: |
| if shape is not None: |
| raise UserError( |
| UserErrorType.INVALID_INPUT, |
| f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead", |
| ) |
| |
| import inspect |
| |
| if isinstance(f, ExportedProgram): |
| f = f.module() |
| signature = ( |
| inspect.signature(f.forward) |
| if isinstance(f, torch.nn.Module) |
| else inspect.signature(f) |
| ) |
| combined_args = signature.bind(*args, **kwargs).arguments |
| |
| # This means user didn't specify dynamic shapes with argument names. |
| combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] |
| for tensor, shape in tree_zip(combined_args, dynamic_shapes): |
| update_symbols(tensor, shape) |
| |
| constraints = [] |
| for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: |
| phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] |
| if phantom_root_name in symbols: |
| # We found an input shape dimension corresponding to this name, so we |
| # do not need a phantom symbol for it after all. |
| # NOTE(avik): Overall we want to maintain the invariant that roots that |
| # are phantom symbols are really "phantom," i.e., they cannot be represented |
| # by any input source. This is important when we are deciding derived equalities, |
| # since we can focus our attention exclusively on input sources: deciding |
| # derived equalities involving phantom symbols are, in comparison, trivial. |
| derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] |
| |
| for dynamic_dims in symbols.values(): |
| if all( |
| isinstance(dynamic_dim, _DerivedConstraint) for dynamic_dim in dynamic_dims |
| ): |
| constraints.extend(dynamic_dims) |
| else: |
| primary, *others = dynamic_dims |
| if others: |
| for other in others: |
| constraints.append(primary == other) # type: ignore[arg-type] |
| else: |
| constraints.append(primary) |
| |
| return constraints # type: ignore[return-value] |
| |
| |
| def _process_constraints( |
| fake_mode, |
| graph_module: torch.fx.GraphModule, |
| num_lifted_params_buffers: int, |
| example_inputs: List[torch.Tensor], |
| ) -> Dict: |
| """ |
| Process the constraints stored in the graph module to return something more readable. |
| |
| Args: |
| graph_module (torch.fx.GraphModule): GraphModule returned from |
| dynamo.export, which contains the "input_shape_constraints" and |
| "inline_constraints" metadata |
| |
| example_inputs: Flattened list of example inputs used to export the graph module |
| |
| Returns: |
| range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of |
| symbols (from SymInts) appearing in the fake tensors in |
| node.meta["val"] to their range constraints, which are a tuple |
| containing (lower, upper) constraints. |
| """ |
| from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( |
| InputDim, |
| ) |
| |
| # Import sympy locally |
| from torch.fx.experimental.symbolic_shapes import SymInt |
| from torch.utils._sympy.value_ranges import ValueRanges |
| |
| input_shape_constraints = graph_module.meta.get("input_shape_constraints", []) |
| inline_constraints = graph_module.meta.get("inline_constraints", []) |
| |
| # Create dict mapping tensor_id to node names |
| tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list) |
| # Create dict mapping placeholder node names to their nodes |
| placeholder_nodes: Dict[str, torch.fx.Node] = {} |
| for i, node in enumerate(graph_module.graph.nodes): |
| if node.op != "placeholder": |
| # All placeholder nodes should be together in the beginning of the |
| # graph |
| break |
| if i >= num_lifted_params_buffers: |
| example_input = example_inputs[i - num_lifted_params_buffers] |
| tensor_id_to_nodes[id(example_input)].append(node.name) |
| placeholder_nodes[node.name] = node |
| |
| # Create dict mapping (node name, dim) a list of range (lower, upper) |
| # constraints |
| multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list) |
| for constraint in input_shape_constraints: |
| for node in tensor_id_to_nodes[constraint["t_id"]]: |
| node_dim = InputDim(node, constraint["dim"]) |
| |
| # Accumulate range constraints |
| multi_range_constraints[node_dim].append( |
| ValueRanges(constraint["min"], constraint["max"]) |
| ) |
| |
| # Create dict mapping symbol to a singular range (lower, upper) |
| range_constraints: Dict[Any, ValueRanges] = {} |
| |
| # Add inline constraints to range_constraints |
| range_constraints = { |
| symbol: inline_constraints[symbol] for symbol in inline_constraints |
| } |
| |
| free_symbols: Set["Symbol"] = set() |
| # Add input range constraints to range_constraints |
| for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment] |
| # Simplify the range constraints into a single range constraint |
| # Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10] |
| min_vals = [rc.lower for rc in multi_range_constraint] |
| max_vals = [rc.upper for rc in multi_range_constraint] |
| min_val = max(min_vals) # type: ignore[type-var] |
| max_val = min(max_vals) # type: ignore[type-var] |
| assert min_val <= max_val # type: ignore[operator] |
| |
| # Add input node range constraints |
| val = placeholder_nodes[input_dim.input_name].meta["val"] |
| assert isinstance(val, FakeTensor) |
| symint = val.shape[input_dim.dim] |
| assert isinstance( |
| symint, SymInt |
| ), f"Expected SymInt but got {symint}: {type(symint)}" |
| symbol = symint.node.expr |
| range_constraints[symbol] = ValueRanges(min_val, max_val) |
| free_symbols.update(symbol.free_symbols) |
| |
| for symbol in free_symbols: |
| if symbol not in range_constraints: |
| # Placeholders can have symbolic shapes that are derived expressions. |
| # The above code will record direct range constraints for them |
| # so that we can do runtime assertions. In addition, for serde checks |
| # we want to record range constraints for their root symbols. |
| range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol] |
| |
| return range_constraints |