| # mypy: allow-untyped-defs |
| """ |
| This file does three things: |
| - Contains the definition of SymNode |
| - Installs all the magic methods into SymBool, SymFloat, SymFloat at import time |
| - Does not depend on sympy at import time |
| |
| As this file is imported from within torch/__init__.py we do not want it to depend on SymPy |
| to avoid having to load SymPy at import time, as doing so is *very* slow. |
| """ |
| |
| import builtins |
| import itertools |
| import logging |
| import math |
| import operator |
| import sys |
| from functools import lru_cache, update_wrapper |
| from typing import Optional, Type, TYPE_CHECKING, Union |
| |
| import torch |
| |
| # NB: The sym_* functions are used via getattr() and must be imported here. |
| from torch import ( # noqa: F401 |
| sym_float, |
| sym_ite, |
| sym_max, |
| sym_min, |
| sym_not, |
| SymBool, |
| SymFloat, |
| SymInt, |
| ) |
| |
| from torch.fx.experimental._sym_dispatch_mode import ( |
| handle_sym_dispatch, |
| sym_function_mode, |
| ) |
| |
| if TYPE_CHECKING: |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv |
| |
| log = logging.getLogger(__name__) |
| sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") |
| |
| |
| __all__ = ["SymNode", "method_to_operator", "magic_methods"] |
| |
| |
| from torch.types import py_sym_types as SymTypes |
| |
| |
| def _to_symtype(t): |
| if t is bool: |
| return SymBool |
| if t is int: |
| return SymInt |
| if t is float: |
| return SymFloat |
| return t |
| |
| |
| # TODO: An incomplete list |
| # 1. Set variables to be equal when we do equality |
| # 2. Specialize on 0/1 when we do subtraction |
| class SymNode: |
| """ |
| This is a type erased SymInt/SymFloat which we use to do actual operations. |
| End users don't touch this. Magic methods are NOT defined on this object. |
| """ |
| |
| def __init__( |
| self, |
| expr, |
| shape_env, |
| pytype, |
| hint: Optional[Union[int, float, bool]], |
| constant=None, |
| fx_node=None, |
| ): |
| self._expr = expr |
| self.shape_env = shape_env |
| self.pytype = pytype |
| |
| # What's the difference between hint and constant? |
| # |
| # - A constant is known to be invariant across invocations of the model; |
| # it will always be this value. We only really know this when we |
| # encounter an honest-to-goodness literal (when wrapping it into |
| # a SymNode, we set constant.) Most of the time, constant is None |
| # |
| # - A hint is a *particular* value from the particular run we are |
| # tracing, but it may vary the next time around. It's useful to |
| # keep this around, as if we need a concrete value from a SymNode, |
| # we will return the hint and guard on the expression that produced |
| # it giving the same hint next time around. The hint is not |
| # guaranteed to be set either: if you have an unbacked SymNode, |
| # there won't be any hint; it was the result of some tensor-dependent |
| # computation, but we don't know what it actually is because we |
| # haven't actually run the tensor computation. |
| # |
| # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) |
| # in hopes that we've learned enough about the unbacked symints to |
| # discharge the hint; otherwise, you're likely to just error out. |
| # |
| # (A previous version of this system had some optimizations to only |
| # recompute when it was possible we had learned enough about the |
| # unbacked symint that a hint was now possible, but as we added more |
| # potential refinements to unbacked symints this got harder to keep |
| # in sync, so we've deleted it for now.) |
| |
| def compute_hint(): |
| # This occasionally gets exercised by, e.g., |
| # convert_shape_to_symint. It's just a nicety so you don't HAVE |
| # to have a correct hint on hand when making a SymNode. |
| hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) |
| if hint is not None: |
| hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint |
| return hint |
| |
| if hint is not None: |
| assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( |
| "Cannot create SymNode of type " |
| f"{pytype} with incompatible hint of type {type(hint)}" |
| ) |
| if self.shape_env and self.shape_env._translation_validation_enabled: |
| # This is technically not TV, but this assert is expensive so |
| # let's only do it when we're already doing expensive things |
| computed_hint = compute_hint() |
| assert ( |
| hint == computed_hint |
| ), f"{hint} != {computed_hint} (for {self.expr})" |
| else: |
| hint = compute_hint() |
| self._hint = hint |
| self.constant: Optional[Union[int, float, bool]] = constant |
| |
| # Record the FX node of the current node if we are doing translation |
| # validation. They will be used for building the input assertions for |
| # the translation validation problem. |
| tx_validation_en = ( |
| self.shape_env and self.shape_env._translation_validation_enabled |
| ) |
| self.fx_node = tx_validation_en and fx_node |
| |
| def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": |
| return SymNode( |
| self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node |
| ) |
| |
| def _value_eq(self, other: "SymNode") -> bool: |
| # Purposely don't include the shape_env in the eq. |
| return ( |
| self._expr == other._expr |
| and self.pytype == other.pytype |
| and self._hint == other._hint |
| and self.constant == other.constant |
| and self.fx_node == other.fx_node |
| ) |
| |
| def _value_hash(self) -> int: |
| # Purposely don't include the shape_env in the hash. |
| return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) |
| |
| @property |
| def expr(self): |
| return self.shape_env.replace(self._expr) |
| |
| @property |
| def hint(self): |
| return self._hint |
| |
| def has_hint(self): |
| return self._hint is not None |
| |
| def require_hint(self, fallback=None): |
| if self._hint is None: |
| if fallback is not None: |
| return fallback |
| # NB: we expect this to raise |
| return self.shape_env.size_hint(self.expr) |
| return self._hint |
| |
| def maybe_as_int(self): |
| if self.expr.is_number: |
| return int(self.expr) |
| else: |
| return None |
| |
| # NB: This does conversions, not sure if this is good or not |
| def maybe_as_float(self): |
| import sympy |
| |
| if isinstance(self.expr, sympy.Float): |
| return float(self.expr) |
| else: |
| return None |
| |
| def maybe_as_bool(self): |
| import sympy |
| |
| if self.expr is sympy.true: |
| return True |
| elif self.expr is sympy.false: |
| return False |
| else: |
| return None |
| |
| def is_int(self): |
| return self.pytype is int |
| |
| def is_float(self): |
| return self.pytype is float |
| |
| def is_bool(self): |
| return self.pytype is bool |
| |
| def is_nested_int(self): |
| # Unbacked SymInts cannot be nested int today |
| return ( |
| self._hint is not None |
| and isinstance(self._hint, SymInt) |
| and self._hint.node.is_nested_int() |
| ) |
| |
| def wrap_int(self, num): |
| assert type(num) is int |
| import sympy |
| |
| return SymNode( |
| sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num |
| ) |
| |
| def wrap_float(self, num): |
| assert type(num) is float |
| import sympy |
| |
| return SymNode( |
| sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num |
| ) |
| |
| def wrap_bool(self, num): |
| assert type(num) is bool |
| import sympy |
| |
| return SymNode( |
| sympy.true if num else sympy.false, |
| self.shape_env, |
| bool, |
| num, |
| constant=num, |
| fx_node=num, |
| ) |
| |
| def clone(self): |
| return self |
| |
| def str(self): |
| return f"{self.expr}" |
| |
| def __str__(self): |
| return self.str() |
| |
| def __repr__(self): |
| rep = [ |
| f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", |
| ] |
| if self._hint is not None: |
| rep.append(f"hint={self._hint}") |
| if self.constant is not None: |
| rep.append(f"constant={self.constant}") |
| if self.fx_node is not None: |
| rep.append(f"fx_node={self.fx_node}") |
| return ", ".join(rep) + ")" |
| |
| def _graph_repr(self) -> builtins.str: |
| # Representation used by GraphModule to create a pythonic version of a graph |
| return self.str() |
| |
| # These methods call the metaprogrammed methods, they're hand written |
| # here so we get good stack traces |
| def abs(self) -> "SymNode": |
| return self._abs() # type: ignore[attr-defined] |
| |
| def pos(self) -> "SymNode": |
| return self._pos() # type: ignore[attr-defined] |
| |
| def round(self, ndigits=None) -> "SymNode": |
| return self._round(ndigits) # type: ignore[attr-defined] |
| |
| def trunc(self) -> "SymNode": |
| return self._trunc() # type: ignore[attr-defined] |
| |
| def add(self, other) -> "SymNode": |
| return self._add(other) # type: ignore[attr-defined] |
| |
| def sub(self, other) -> "SymNode": |
| return self._sub(other) # type: ignore[attr-defined] |
| |
| def mul(self, other) -> "SymNode": |
| return self._mul(other) # type: ignore[attr-defined] |
| |
| def mod(self, other) -> "SymNode": |
| return self._mod(other) # type: ignore[attr-defined] |
| |
| def float_pow(self, other) -> "SymNode": |
| return self._float_pow(other) # type: ignore[attr-defined] |
| |
| def pow_by_natural(self, other) -> "SymNode": |
| return self._pow_by_natural(other) # type: ignore[attr-defined] |
| |
| def and_(self, other) -> "SymNode": |
| return self._and_(other) # type: ignore[attr-defined] |
| |
| def or_(self, other) -> "SymNode": |
| return self._or_(other) # type: ignore[attr-defined] |
| |
| def float_truediv(self, other) -> "SymNode": |
| return self._float_truediv(other) # type: ignore[attr-defined] |
| |
| def int_truediv(self, other) -> "SymNode": |
| return self._int_truediv(other) # type: ignore[attr-defined] |
| |
| def int_floordiv(self, other) -> "SymNode": |
| return self._int_floordiv(other) # type: ignore[attr-defined] |
| |
| def lshift(self, other) -> "SymNode": |
| return self._lshift(other) # type: ignore[attr-defined] |
| |
| def rshift(self, other) -> "SymNode": |
| return self._rshift(other) # type: ignore[attr-defined] |
| |
| def sym_not(self) -> "SymNode": # noqa: F811 |
| return self._sym_not() # type: ignore[attr-defined] |
| |
| def eq(self, other) -> "SymNode": |
| return self._eq(other) # type: ignore[attr-defined] |
| |
| def ne(self, other) -> "SymNode": |
| return self._ne(other) # type: ignore[attr-defined] |
| |
| def gt(self, other) -> "SymNode": |
| return self._gt(other) # type: ignore[attr-defined] |
| |
| def lt(self, other) -> "SymNode": |
| return self._lt(other) # type: ignore[attr-defined] |
| |
| def le(self, other) -> "SymNode": |
| return self._le(other) # type: ignore[attr-defined] |
| |
| def ge(self, other) -> "SymNode": |
| return self._ge(other) # type: ignore[attr-defined] |
| |
| def floor(self) -> "SymNode": |
| return self._floor() # type: ignore[attr-defined] |
| |
| def is_integer(self) -> "SymNode": |
| return self._is_integer() # type: ignore[attr-defined] |
| |
| def sym_float(self) -> "SymNode": # noqa: F811 |
| return self._sym_float() # type: ignore[attr-defined] |
| |
| def sym_int(self) -> "SymNode": |
| return self._sym_int() # type: ignore[attr-defined] |
| |
| def ceil(self) -> "SymNode": |
| return self._ceil() # type: ignore[attr-defined] |
| |
| def neg(self) -> "SymNode": |
| return self._neg() # type: ignore[attr-defined] |
| |
| def sym_min(self, other) -> "SymNode": # noqa: F811 |
| return self._sym_min(other) # type: ignore[attr-defined] |
| |
| def sym_max(self, other) -> "SymNode": # noqa: F811 |
| return self._sym_max(other) # type: ignore[attr-defined] |
| |
| def sym_ite(self, then_val, else_val) -> "SymNode": |
| return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] |
| |
| def is_contiguous(self, sizes, strides) -> "SymNode": |
| return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] |
| |
| def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": |
| return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] |
| |
| def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": |
| return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] |
| |
| def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": |
| return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] |
| |
| def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": |
| return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] |
| |
| def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": |
| return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] |
| |
| # Make C++ happy |
| def sym_or(self, other): |
| return self.or_(other) |
| |
| def sym_and(self, other): |
| return self.and_(other) |
| |
| # There is no int_truediv available from C++ |
| def truediv(self, other): |
| return self.float_truediv(other) |
| |
| def floordiv(self, other) -> "SymNode": |
| return self.int_floordiv(other) |
| |
| # We didn't bind integer pow in C++ |
| def pow(self, other): |
| return self.float_pow(other) |
| |
| def is_non_overlapping_and_dense(self, sizes, strides): |
| return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] |
| |
| def int_(self): |
| return self.guard_int("", 0) # NB: uses Python backtrace |
| |
| # You can manually trigger a guard with this function |
| def guard_int(self, file, line): |
| # TODO: use the file/line for some useful diagnostic on why a |
| # guard occurred |
| r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) |
| try: |
| return int(r) |
| except Exception: |
| log.warning("Failed to convert to int: %s", r) |
| raise |
| |
| def guard_float(self, file, line): |
| # TODO: use the file/line for some useful diagnostic on why a |
| # guard occurred |
| r = self.shape_env.evaluate_expr( |
| self.expr, self.hint, fx_node=self.fx_node, expect_rational=False |
| ) |
| try: |
| return float(r) |
| except Exception: |
| log.warning("Failed to convert to float: %s", r) |
| raise |
| |
| def guard_bool(self, file, line): |
| # TODO: use the file/line for some useful diagnostic on why a |
| # guard occurred |
| r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) |
| try: |
| return bool(r) |
| except Exception: |
| log.warning("Failed to convert to bool: %s", r) |
| raise |
| |
| def expect_true(self, file, line): |
| from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols |
| |
| if ( |
| self.has_hint() |
| and not free_unbacked_symbols(self.expr) |
| and not self.shape_env.prefer_deferred_runtime_asserts_over_guards |
| ): |
| # OK to generate guards |
| return self.guard_bool(file, line) |
| # Generate a deferred runtime assert (this might actually end up doing |
| # a regular guard if we can!) |
| # TODO: file/line here is very important, because the assert has been |
| # deferred so you can't backtrace easily |
| return self.shape_env.defer_runtime_assert( |
| self.expr, f"{file}:{line}", fx_node=self.fx_node |
| ) |
| |
| def expect_size(self, file, line): |
| from torch.fx.experimental.symbolic_shapes import _advise_is_size |
| |
| b = self.ge(self.wrap_int(0)) |
| # Generate a deferred runtime assert |
| r = b.expect_true(file, line) |
| # Refine compile time range, but only if it's unbacked. |
| # If you refine range for hinted variables, you can end up making |
| # improper deductions since compile time reasoning may be |
| # incompatible with runtime reasoning. |
| if r and not self.has_hint(): |
| _advise_is_size(SymInt(self)) |
| return r |
| |
| def guard_size_oblivious(self, file, line): |
| """ |
| Like guard_bool, but if we encounter unbacked symbols, if those symbols |
| are size-like, we will treat them as >= 2 for the purposes of the analysis. |
| |
| This CHANGES the runtime semantics, but all size-oblivious sites have been |
| audited to ensure that the runtime semantics don't change in a material way. |
| Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping |
| an unbacked one size, or a tensor reporting as non-contiguous even if it's |
| contiguous if it would have been reported contiguous due to being empty. |
| """ |
| # TODO: use the file/line for some useful diagnostic on why a |
| # guard occurred |
| r = self.shape_env.evaluate_expr( |
| self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True |
| ) |
| try: |
| return bool(r) |
| except Exception: |
| log.warning("Failed to convert to bool: %s", r) |
| raise |
| |
| def bool_(self): |
| return self.guard_bool("", 0) |
| |
| def is_symbolic(self): |
| return True |
| |
| def nested_int(self): |
| return None |
| |
| def is_constant(self): |
| return False |
| |
| |
| # TODO: this probably needs the sizes-strides eval functions |
| METHOD_TO_OPERATOR = { |
| "pos": operator.pos, |
| "abs": operator.abs, |
| "add": operator.add, |
| "and": operator.and_, |
| "ceil": math.ceil, |
| "eq": operator.eq, |
| "floor": math.floor, |
| "trunc": math.trunc, |
| "int_floordiv": operator.floordiv, |
| "ge": operator.ge, |
| "gt": operator.gt, |
| "is_integer": lambda x: x.is_integer(), |
| "le": operator.le, |
| "lshift": operator.lshift, |
| "lt": operator.lt, |
| "mod": operator.mod, |
| "mul": operator.mul, |
| "ne": operator.ne, |
| "neg": operator.neg, |
| "or": operator.or_, |
| "float_pow": operator.pow, |
| "pow_by_natural": operator.pow, |
| "round": builtins.round, |
| "rshift": operator.rshift, |
| "sub": operator.sub, |
| "sym_float": sym_float, |
| "sym_ite": sym_ite, |
| "sym_max": sym_max, |
| "sym_min": sym_min, |
| "sym_not": sym_not, |
| "float_truediv": operator.truediv, |
| "int_truediv": operator.truediv, |
| } |
| |
| unary_magic_methods = { |
| "abs", |
| "sym_float", |
| "sym_int", |
| "ceil", |
| "floor", |
| "neg", |
| "sym_not", |
| "pos", |
| "trunc", |
| } |
| |
| |
| # Adding math ops: sqrt, cos, sin, ... |
| def _get_sym_node_fn(name): |
| def fn(self): |
| return getattr(self, f"_sym_{name}")() |
| |
| return fn |
| |
| |
| math_op_names = ( |
| "sqrt", |
| "cos", |
| "cosh", |
| "sin", |
| "sinh", |
| "tan", |
| "tanh", |
| "asin", |
| "acos", |
| "atan", |
| ) |
| for name in math_op_names: |
| sym_name = f"sym_{name}" |
| priv_sym_name = f"_{sym_name}" |
| setattr(SymNode, sym_name, _get_sym_node_fn(name)) |
| METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) |
| unary_magic_methods.add(sym_name) |
| __all__.append(sym_name) |
| |
| |
| # Unary methods that are not magic methods |
| unary_nonmagic_methods = { |
| "is_integer", |
| } |
| |
| unary_methods = unary_magic_methods | unary_nonmagic_methods |
| |
| # Most methods are only registered on SymInt and SymFloat |
| # Some methods are only be registered on SymBool |
| only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} |
| # Methods that implicitly convert SymBool into SymInt |
| bool_becomes_int_magic_methods = {"add", "sub", "mul"} |
| # Methods that are also on SymBool, in addition to on SymInt and SymFloat |
| also_bool_magic_methods = {"eq"} |
| bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods |
| |
| # Methods that are only for float |
| only_float_magic_methods = {"is_integer", "round", "sym_int"} |
| |
| |
| magic_methods_on_operator_with_trailing_underscore = {"and", "or"} |
| |
| |
| always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} |
| |
| for name in math_op_names: |
| sym_name = f"sym_{name}" |
| always_float_magic_methods.add(sym_name) |
| |
| |
| always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} |
| always_bool_magic_methods = { |
| "eq", |
| "ne", |
| "gt", |
| "lt", |
| "le", |
| "ge", |
| "and", |
| "or", |
| "sym_not", |
| "is_non_overlapping_and_dense", |
| "is_integer", |
| } |
| |
| # Methods that have a `__foo__` as well as `__rfoo__` |
| |
| |
| def _sympy_float_truediv(a, b): |
| from torch.utils._sympy.functions import FloatTrueDiv |
| |
| return FloatTrueDiv(a, b) |
| |
| |
| def _sympy_int_truediv(a, b): |
| from torch.utils._sympy.functions import IntTrueDiv |
| |
| return IntTrueDiv(a, b) |
| |
| |
| def _sympy_floordiv(a, b): |
| from torch.utils._sympy.functions import FloorDiv |
| |
| return FloorDiv(a, b) |
| |
| |
| def _sympy_mod(a, b): |
| from torch.utils._sympy.functions import Mod, PythonMod |
| |
| if a.is_nonnegative and b.is_nonnegative: |
| return Mod(a, b) |
| else: |
| return PythonMod(a, b) |
| |
| |
| def _sympy_pow_by_natural(a, b): |
| from torch.utils._sympy.functions import PowByNatural |
| |
| return PowByNatural(a, b) |
| |
| |
| def _sympy_float_pow(a, b): |
| from torch.utils._sympy.functions import FloatPow |
| |
| return FloatPow(a, b) |
| |
| |
| def _sympy_and(a, b): |
| import sympy |
| |
| return sympy.And(a, b) |
| |
| |
| def _sympy_or(a, b): |
| import sympy |
| |
| return sympy.Or(a, b) |
| |
| |
| def _sympy_lshift(a, b): |
| from torch.utils._sympy.functions import LShift |
| |
| return LShift(a, b) |
| |
| |
| def _sympy_rshift(a, b): |
| from torch.utils._sympy.functions import RShift |
| |
| return RShift(a, b) |
| |
| |
| reflectable_magic_methods = { |
| "add": operator.add, |
| "sub": operator.sub, |
| "mul": operator.mul, |
| "mod": _sympy_mod, |
| "pow_by_natural": _sympy_pow_by_natural, |
| "float_pow": _sympy_float_pow, |
| "and": _sympy_and, |
| "or": _sympy_or, |
| "float_truediv": _sympy_float_truediv, |
| "int_truediv": _sympy_int_truediv, |
| "int_floordiv": _sympy_floordiv, |
| "lshift": _sympy_lshift, |
| "rshift": _sympy_rshift, |
| } |
| |
| |
| def _floor_ceil_helper(a, fn): |
| import sympy |
| |
| if isinstance(a, sympy.Mul): |
| aa = a.args |
| if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: |
| coef = sympy.Integer(aa[0]) |
| if aa[0] == coef: # structural equality test |
| return coef * aa[1] |
| if ( |
| isinstance(a, sympy.Float) |
| and a == sympy.Integer(a) |
| or isinstance(a, sympy.Integer) |
| ): |
| return sympy.Integer(a) |
| return fn(a) |
| |
| |
| def _sympy_floor(a): |
| from torch.utils._sympy.functions import FloorToInt |
| |
| return FloorToInt(a) |
| |
| |
| # NB: this is Python trunc semantics which returns an int. Do NOT use this to |
| # represent torch.trunc (which is float to float) |
| def _sympy_trunc(a): |
| from torch.utils._sympy.functions import TruncToInt |
| |
| return TruncToInt(a) |
| |
| |
| def _sympy_ceil(a): |
| from torch.utils._sympy.functions import CeilToInt |
| |
| return CeilToInt(a) |
| |
| |
| def _sympy_eq(a, b): |
| import sympy |
| |
| return sympy.Eq(a, b) |
| |
| |
| def _sympy_ne(a, b): |
| import sympy |
| |
| return sympy.Ne(a, b) |
| |
| |
| def _sympy_gt(a, b): |
| import sympy |
| |
| return sympy.Gt(a, b) |
| |
| |
| def _sympy_lt(a, b): |
| import sympy |
| |
| return sympy.Lt(a, b) |
| |
| |
| def _sympy_le(a, b): |
| import sympy |
| |
| return sympy.Le(a, b) |
| |
| |
| def _sympy_ge(a, b): |
| import sympy |
| |
| return sympy.Ge(a, b) |
| |
| |
| def _sympy_min(a, b): |
| import sympy |
| |
| return sympy.Min(a, b) |
| |
| |
| def _sympy_max(a, b): |
| import sympy |
| |
| return sympy.Max(a, b) |
| |
| |
| def _sympy_ite(a, t, f): |
| import sympy |
| |
| return sympy.Piecewise((t, a), (f, True)) |
| |
| |
| current_module = sys.modules[__name__] |
| |
| |
| def _get_sym_math_fn(name): |
| def fn(a): |
| import torch.utils._sympy.functions |
| |
| return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) |
| |
| return fn |
| |
| |
| for name in math_op_names: |
| priv_sympy_name = f"_sympy_{name}" |
| fn = _get_sym_math_fn(name) |
| fn.__qualname__ = fn.__name__ = priv_sympy_name |
| setattr(current_module, priv_sympy_name, fn) |
| |
| del fn, name, priv_sympy_name # type: ignore[possibly-undefined] |
| |
| |
| def _sympy_abs(a): |
| import sympy |
| |
| return sympy.Abs(a) |
| |
| |
| def _sympy_round(number, ndigits=None): |
| from torch.utils._sympy.functions import RoundDecimal, RoundToInt |
| |
| if ndigits is None: |
| return RoundToInt(number) |
| else: |
| return RoundDecimal(number, ndigits) |
| |
| |
| def _sympy_sym_float(a): |
| from torch.utils._sympy.functions import ToFloat |
| |
| # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly |
| # reports that it is an integer |
| return ToFloat(a) |
| |
| |
| def _sympy_is_integer(a): |
| import sympy |
| |
| from torch.utils._sympy.functions import ToFloat |
| |
| return sympy.Eq(ToFloat(sympy.floor(a)), a) |
| |
| |
| magic_methods = { |
| **reflectable_magic_methods, |
| "sym_not": operator.invert, |
| "pos": operator.pos, |
| "eq": _sympy_eq, |
| "ne": _sympy_ne, |
| "gt": _sympy_gt, |
| "lt": _sympy_lt, |
| "le": _sympy_le, |
| "ge": _sympy_ge, |
| "floor": _sympy_floor, |
| "trunc": _sympy_trunc, |
| "sym_float": _sympy_sym_float, |
| "ceil": _sympy_ceil, |
| "neg": operator.neg, |
| "sym_min": _sympy_min, |
| "sym_max": _sympy_max, |
| "sym_ite": _sympy_ite, |
| "abs": _sympy_abs, |
| "round": _sympy_round, |
| "is_integer": _sympy_is_integer, |
| } |
| |
| |
| for name in math_op_names: |
| sym_name = f"sym_{name}" |
| magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") |
| |
| del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] |
| |
| |
| def sympy_is_contiguous(sizes, strides): |
| dim = len(sizes) |
| return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) |
| |
| |
| def sympy_is_contiguous_generic(sizes, strides, dim_order): |
| import sympy |
| |
| dim = len(sizes) |
| |
| if len(dim_order) != dim: |
| return sympy.false |
| |
| is_contiguous = sympy.true |
| z = sympy.Integer(1) |
| # Contiguous if the strides make sense (or the dim is size 1) |
| for d in dim_order: |
| is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) |
| z *= sizes[d] |
| # OR if any size is zero |
| for d in range(dim): |
| is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) |
| return is_contiguous |
| |
| |
| # NB: There is a TODO in C++ to allow omitting the batch dim. If that |
| # happens you will need to refactor this |
| |
| |
| def sympy_is_channels_last_contiguous_2d(sizes, strides): |
| return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) |
| |
| |
| def sympy_is_channels_last_contiguous_3d(sizes, strides): |
| return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) |
| |
| |
| def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): |
| import sympy |
| |
| dim = len(sizes) |
| |
| if dim != len(dim_order): |
| return sympy.false |
| |
| m = sympy.Integer(0) |
| r = sympy.true |
| |
| # special case for trivial C dimension. default to NCHW |
| r &= sympy.Ne(strides[1], 0) |
| |
| for d in dim_order: |
| r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) |
| # Fallback to NCHW as default layout for ambiguous cases |
| # This is the flaw of implicit memory_format from strides. |
| # N111 tensor with identical strides for size 1 dimension; |
| # Two cases could lead us here: |
| # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) |
| # b. N11W contiguous Tensor sliced on the W-dimension. |
| # ([N,1,1,1]@[W,W,W,W]) |
| if d == 0: |
| r &= sympy.Ne(m, strides[1]) |
| # This is necessary to: |
| # 1. distinguish the memory_format of N1H1; |
| # [H, 1, 1, 1] channels_last stride |
| # [H, H, 1, 1] contiguous stride |
| # 2. permutation of 1C1W: |
| # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) |
| # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as |
| # channels_last |
| m = strides[d] * sympy.Max(sizes[d], 1) |
| |
| return r |
| |
| |
| def sympy_is_channels_last_strides_2d(sizes, strides): |
| return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) |
| |
| |
| def sympy_is_channels_last_strides_3d(sizes, strides): |
| return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) |
| |
| |
| def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): |
| from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator |
| |
| return IsNonOverlappingAndDenseIndicator(*sizes, *strides) |
| |
| |
| sizes_strides_methods = { |
| # TODO: These could also be done with indicators, maybe it is better |
| # for reasoning to do it that way |
| "is_contiguous": sympy_is_contiguous, |
| "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, |
| "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, |
| "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, |
| "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, |
| "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, |
| } |
| |
| alternate_impl_if_hinted_methods = { |
| "sym_min": builtins.min, |
| "sym_max": builtins.max, |
| } |
| |
| |
| def to_node(self, num): |
| if isinstance(num, SymTypes): |
| return num.node |
| elif type(num) is bool: |
| return self.wrap_bool(num) |
| elif type(num) is int: |
| return self.wrap_int(num) |
| elif type(num) is float: |
| return self.wrap_float(num) |
| else: |
| # NotImplemented is important so that Python tries the |
| # other magic method |
| return NotImplemented |
| |
| |
| def wrap_node(x): |
| # TODO: let C++ also take advantage of this |
| if isinstance(x, SymNode) and x.constant is not None: |
| return x.constant |
| if x.is_int(): |
| return SymInt(x) |
| elif x.is_float(): |
| return SymFloat(x) |
| elif x.is_bool(): |
| return SymBool(x) |
| else: |
| raise AssertionError(f"unrecognized return type {x}") |
| |
| |
| def method_to_operator(method): |
| return METHOD_TO_OPERATOR[method] |
| |
| |
| def _make_node_magic(method, func): |
| func = lru_cache(256)(func) |
| |
| if method in magic_methods_on_operator_with_trailing_underscore: |
| method_attr = f"{method}_" |
| else: |
| method_attr = method |
| |
| def binary_magic_impl(self, other): |
| from torch.fx.experimental.symbolic_shapes import safe_expand |
| |
| op = method_to_operator(method) |
| |
| out_hint = None |
| if self.hint is not None and other.hint is not None: |
| out_hint = op(self.hint, other.hint) |
| |
| alternate_impl = alternate_impl_if_hinted_methods.get(method) |
| if alternate_impl and out_hint is not None: |
| return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) |
| |
| if sym_function_mode(): |
| return to_node( |
| self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) |
| ) |
| assert isinstance(other, SymNode) |
| try: |
| if method == "mod": |
| from torch.utils._sympy.functions import Mod, PythonMod |
| |
| # Special handling for mod that requires access to the value |
| # ranges |
| shape_env = self.shape_env |
| if ( |
| self.expr.is_nonnegative |
| or shape_env.bound_sympy(self.expr).lower >= 0 |
| ) and ( |
| other.expr.is_nonnegative |
| or shape_env.bound_sympy(other.expr).lower >= 0 |
| ): |
| out = Mod(self.expr, other.expr) |
| else: |
| out = PythonMod(self.expr, other.expr) |
| else: |
| # TODO: consider constant prop here |
| out = func(self.expr, other.expr) |
| except Exception: |
| log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) |
| raise |
| out = safe_expand(out) |
| sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) |
| pytype: Type |
| # This is not strictly correct. In Python, a**b may return complex when |
| # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This |
| # returns a float while both arguments are ints: 2**(-1). Also, max and |
| # min do not type promote. To avoid having data-dependent control flow |
| # here, we just set the type to float if one of the args is a float. In |
| # case of a type mismatch, we assume that it will be detected during |
| # evaluation. |
| if method in always_float_magic_methods: |
| pytype = float |
| elif method in always_bool_magic_methods: |
| pytype = bool |
| elif self.pytype is float or other.pytype is float: |
| pytype = float |
| else: |
| pytype = self.pytype |
| |
| if ( |
| pytype is not None |
| and out_hint is not None |
| and not isinstance(out_hint, SymTypes) |
| ): |
| out_hint = pytype(out_hint) |
| |
| # Create a FX node that corresponds to the operation being applied to |
| # this node. |
| fx_node, _ = self.shape_env._create_fx_call_function( |
| op, (self.fx_node, other.fx_node) |
| ) |
| return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) |
| |
| def unary_magic_impl(self): |
| from torch.fx.experimental.symbolic_shapes import safe_expand |
| |
| op = method_to_operator(method) |
| if sym_function_mode(): |
| return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) |
| # TODO: consider constant prop here |
| expr = self.expr |
| if method == "floor" or method == "ceiling": |
| expr = self.shape_env._simplify_floor_div(expr) |
| |
| try: |
| out = func(expr) |
| except Exception: |
| log.warning("failed to eval %s(%s)", method, expr) |
| raise |
| sym_node_log.debug("%s %s -> %s", func, expr, out) |
| out_hint = None |
| if self.hint is not None: |
| out_hint = op(self.hint) |
| out = safe_expand(out) |
| pytype: Type |
| if method in always_int_magic_methods: |
| pytype = int |
| elif method in always_bool_magic_methods: |
| pytype = bool |
| elif method in always_float_magic_methods: |
| pytype = float |
| else: |
| pytype = self.pytype |
| |
| fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) |
| return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) |
| |
| if method in unary_methods: |
| setattr(SymNode, f"_{method_attr}", unary_magic_impl) |
| elif method == "sym_ite": |
| |
| def sym_ite_impl(pred_node, then_node, else_node): |
| from torch.fx.experimental.symbolic_shapes import safe_expand |
| |
| out_hint = then_node.hint if pred_node.hint else else_node.hint |
| if sym_function_mode(): |
| return to_node( |
| pred_node, |
| handle_sym_dispatch( |
| sym_ite, |
| ( |
| wrap_node(pred_node), |
| wrap_node(then_node), |
| wrap_node(else_node), |
| ), |
| {}, |
| ), |
| ) |
| |
| try: |
| out = func(pred_node.expr, then_node.expr, else_node.expr) |
| except Exception: |
| log.warning( |
| "failed to eval %s(%s, %s, %s)", |
| method, |
| pred_node.expr, |
| then_node.expr, |
| else_node.expr, |
| ) |
| raise |
| |
| out = safe_expand(out) |
| fx_node, _ = pred_node.shape_env._create_fx_call_function( |
| sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) |
| ) |
| return SymNode( |
| out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node |
| ) |
| |
| setattr(SymNode, f"_{method_attr}", sym_ite_impl) |
| elif method == "round": |
| |
| def round_impl(self, ndigits=None): |
| from torch.fx.experimental.symbolic_shapes import safe_expand |
| |
| op = builtins.round |
| if sym_function_mode(): |
| return to_node( |
| self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) |
| ) |
| |
| expr = self.expr |
| try: |
| out = func(expr, ndigits) |
| except Exception: |
| log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) |
| raise |
| |
| out = safe_expand(out) |
| |
| if ndigits is None: |
| pytype = int |
| else: |
| pytype = self.pytype |
| |
| out_hint = None |
| if self.hint is not None: |
| out_hint = op(self.hint, ndigits) |
| |
| # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the |
| # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here |
| # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The |
| # hack down below works, because all round function down the line all take ndigits=None as default in their |
| # signature. |
| # TODO: Remove the args construction below if a different sentinel is used by FX. |
| # ezyang(May 2024): LOL |
| args = [self.fx_node] |
| if ndigits is not None: |
| args.append(ndigits) |
| fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) |
| return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) |
| |
| setattr(SymNode, f"_{method_attr}", round_impl) |
| else: |
| setattr(SymNode, f"_{method_attr}", binary_magic_impl) |
| |
| |
| def _make_node_sizes_strides(method, func): |
| # NB: don't LRU cache, lots of arguments |
| |
| def sizes_strides_impl(self, sizes, strides): |
| op = getattr(sys.modules[__name__], method) |
| if sym_function_mode(): |
| return to_node( |
| self, |
| handle_sym_dispatch( |
| op, |
| ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), |
| {}, |
| ), |
| ) |
| size_exprs = [s.expr for s in sizes] |
| stride_exprs = [s.expr for s in strides] |
| try: |
| out = func(size_exprs, stride_exprs) |
| except Exception: |
| log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) |
| raise |
| # bool is never expandable |
| |
| size_hints = [] |
| out_hint = None |
| for s in sizes: |
| if s.hint is None: |
| break |
| size_hints.append(s.hint) |
| else: |
| stride_hints = [] |
| for s in strides: |
| if s.hint is None: |
| break |
| stride_hints.append(s.hint) |
| else: |
| out_hint = op(size_hints, stride_hints) |
| |
| # NB: This is the indicator function, not the actual bool! |
| pytype: Type |
| if method.endswith("_indicator"): |
| pytype = int |
| else: |
| pytype = bool |
| return SymNode(out, self.shape_env, pytype, out_hint) |
| |
| setattr(SymNode, f"_{method}", sizes_strides_impl) |
| |
| # TODO: This is technically hotpath, but in the ideal end state |
| # guards on this will resolve at a higher level so you never |
| # spend time in this code |
| def sizes_strides_user(sizes, strides): |
| import sympy |
| |
| from torch.fx.experimental.symbolic_shapes import ( |
| eval_is_non_overlapping_and_dense, |
| ) |
| |
| for a in itertools.chain(sizes, strides): |
| if isinstance(a, SymInt): |
| return wrap_node( |
| getattr(a.node, method)( |
| [to_node(a.node, b) for b in sizes], |
| [to_node(a.node, b) for b in strides], |
| ) |
| ) |
| if method == "is_non_overlapping_and_dense_indicator": |
| return eval_is_non_overlapping_and_dense(sizes, strides) |
| else: |
| # TODO: this is an awful implementation |
| return bool( |
| func( |
| [sympy.sympify(a) for a in sizes], |
| [sympy.sympify(a) for a in strides], |
| ) |
| ) |
| |
| # Skip for is_non_overlapping_and_dense_indicator |
| if not hasattr(sys.modules[__name__], method): |
| setattr(sys.modules[__name__], method, sizes_strides_user) |
| |
| |
| for method, func in magic_methods.items(): |
| _make_node_magic(method, func) |
| |
| for method, func in sizes_strides_methods.items(): |
| _make_node_sizes_strides(method, func) |
| |
| |
| def _make_user_magic(method, user_type): |
| # User magic takes care of wrapping the other operand into a node, |
| # so that our internal logic can assume everything is nodes |
| |
| if method in magic_methods_on_operator_with_trailing_underscore: |
| method_attr = f"sym_{method}" |
| else: |
| method_attr = method |
| |
| def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): |
| if isinstance(x, (int, float, bool)): |
| return x |
| if isinstance(x, SymBool): |
| return x.node.guard_bool("", 0) |
| raise AssertionError("expect to be called with constant SymBools") |
| |
| def is_constant(x): |
| if isinstance(x, (int, float, bool)): |
| return True |
| if isinstance(x, (SymInt, SymFloat, SymBool)): |
| return x.node.is_constant() |
| return False |
| |
| # Promotion rules for binary operations. NB: we preserve PYTHON semantics |
| # - if args are same type, do nothing |
| # - if one arg is float, promote other arg to float |
| # - nb: this applies to floordiv, even though output is integral |
| # (it's still float) |
| # - pow is funny business |
| # - if both ints |
| # - trigger a guard on exponent >= 0 |
| # - if non-negative, output is int |
| # - otherwise, output is float |
| # - otherwise, promote other arg to float |
| # - nb: complex is impossible to handle correctly lol, with |
| # negative base and integral float need to diverge semantics and |
| # just always return complex. Neener neener pretend this problem |
| # doesn't exist |
| # - equality is pain: Python does the fancy thing where it unpacks the |
| # mantissa from the float and then compares that against the int. |
| # Which means it is able to tell that |
| # 9007199254740993 != 9007199254740992. (rather than if the LHS was |
| # promoted to float, in which case it would have truncated to the RHS |
| # and subsequently been equal). We'll model this exactly by having |
| # special mixed type equality operations. Unfortunately, we need to |
| # do this for all comparison operations (maybe I'll only implement |
| # compare) |
| # - sym_ite mumble mumble really shouldn't allow mixed but whatever |
| |
| if method in bool_becomes_int_magic_methods: |
| |
| def promote(x): |
| """Implements True+True=2, which works in python but not sympy""" |
| if isinstance(x, SymBool): |
| return SymInt(x.node.wrap_int(int(x))) |
| return x |
| |
| else: |
| |
| def promote(x): |
| return x |
| |
| def promote2(self, other): |
| # TODO: Remove eq and other relations from this list. |
| # CPython has fancy implementations for these to get as much precision |
| # as possible instead of just promoting to float64 and praying, so we |
| # need to handle them specially too. |
| # Also, note that int_truediv doesn't go through this path: both |
| # arguments are "int" so there isn't any promotion |
| if method not in [ |
| "add", |
| "sub", |
| "mul", |
| "mod", |
| "float_pow", |
| "float_truediv", |
| "int_floordiv", |
| "sym_min", |
| "sym_max", |
| # TODO: remove these |
| "eq", |
| "ne", |
| "gt", |
| "lt", |
| "le", |
| "ge", |
| ]: |
| return self, other |
| f_self = isinstance(self, (float, torch.SymFloat)) |
| f_other = isinstance(other, (float, torch.SymFloat)) |
| if f_self or f_other: |
| if not f_self: |
| self = torch.sym_float(self) |
| if not f_other: |
| other = torch.sym_float(other) |
| return self, other |
| |
| # Before and after performing the operation, check if any operands are constant. |
| # If so, extract out the constant values first. If `self` itself is a |
| # constant, then "redispatch" by calling back into the operator. Sometimes |
| # this means that operations involving SymBool return plain bools. |
| # Alternatively, we could also rewrap into constant Symbool (i.e. by |
| # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that |
| # today for no particular reason. |
| def unary_magic_impl(self): |
| self = promote(self) |
| if is_constant(self): |
| return (method_to_operator(method))(get_constant(self)) |
| return wrap_node(getattr(self.node, method_attr)()) |
| |
| def binary_magic_impl(self, other): |
| if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): |
| return NotImplemented |
| sym_node_log.debug("MAGIC %s %s %s", method, self, other) |
| self = promote(self) |
| other = promote(other) |
| self, other = promote2(self, other) |
| if is_constant(self): |
| return (method_to_operator(method))(get_constant(self), other) |
| if is_constant(other): |
| other = get_constant(other) |
| other_node = to_node(self.node, other) |
| if other_node is NotImplemented: |
| return NotImplemented |
| ret = wrap_node(getattr(self.node, method_attr)(other_node)) |
| return get_constant(ret) if is_constant(ret) else ret |
| |
| def rbinary_magic_impl(self, other): |
| if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): |
| return NotImplemented |
| self = promote(self) |
| other = promote(other) |
| self, other = promote2(self, other) |
| if is_constant(self): |
| return (method_to_operator(method))(get_constant(self), other) |
| if is_constant(other): |
| other = get_constant(other) |
| other_node = to_node(self.node, other) |
| if other_node is NotImplemented: |
| return NotImplemented |
| ret = wrap_node(getattr(other_node, method_attr)(self.node)) |
| return get_constant(ret) if is_constant(ret) else ret |
| |
| if method in unary_magic_methods: |
| setattr(user_type, f"__{method}__", unary_magic_impl) |
| elif method in unary_nonmagic_methods: |
| orig = getattr(user_type, method) |
| setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) |
| elif method == "sym_ite": |
| |
| def sym_ite_magic_impl(pred, then_val, else_val): |
| pred_node = pred.node |
| then_node = to_node(pred_node, then_val) |
| else_node = to_node(pred_node, else_val) |
| if then_node is NotImplemented or else_node is NotImplemented: |
| return NotImplemented |
| assert ( |
| isinstance(then_node, SymNode) |
| and isinstance(else_node, SymNode) |
| and then_node.pytype == else_node.pytype |
| ) |
| ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) |
| return get_constant(ret) if ret.node.is_constant() else ret |
| |
| setattr(user_type, f"__{method}__", sym_ite_magic_impl) |
| elif method == "round": |
| |
| def round_magic_impl(self, ndigits=None): |
| if is_constant(self): |
| return builtins.round(get_constant(self), ndigits) |
| |
| return wrap_node(getattr(self.node, method)(ndigits)) |
| |
| setattr(user_type, f"__{method}__", round_magic_impl) |
| else: |
| setattr(user_type, f"__{method}__", binary_magic_impl) |
| if method in reflectable_magic_methods: |
| setattr(user_type, f"__r{method}__", rbinary_magic_impl) |
| |
| |
| for method, func in magic_methods.items(): # type: ignore[assignment] |
| if method in only_bool_magic_methods: |
| _make_user_magic(method, SymBool) |
| continue |
| if method in only_float_magic_methods: |
| _make_user_magic(method, SymFloat) |
| continue |
| if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: |
| _make_user_magic(method, SymBool) |
| _make_user_magic(method, SymInt) |
| _make_user_magic(method, SymFloat) |
| |
| del method |
| del func |