blob: 8d7e12668d7f178bd4964f4a4390ccde4b3f3356 [file] [log] [blame]
import torch
import torch.utils._pytree as pytree
from typing import Set, Dict, List, Type, Optional, cast
import sys
import operator
import builtins
import math
import functools
import threading
from contextlib import contextmanager
from functools import lru_cache, partial
import traceback
import collections
import textwrap
from torch._subclasses.meta_utils import MetaConverter
from torch import SymInt, SymFloat
try:
import sympy # type: ignore[import]
from sympy.printing.precedence import precedence # type: ignore[import]
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
aten = torch.ops.aten # type: ignore[has-type]
__all__ = [
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv",
"SymDispatchMode", "sym_int", "sym_float", "FloorDiv", "guard_int", "wrap_node",
"sym_sqrt",
]
SYM_FUNCTION_MODE = None
# We don't bother with the metaclass as all of the dispatching logic happens
# entirely from Python
#
# Didn't bother with ancestors for now, unlikely to have multiple modes for
# symints right now
# SymDispatchMode gets invoked whenever an operation is processed on
# a PySymInt. When this occurs, you get called at __sym_dispatch__
# with the operation in question. This is symmetric to TorchDispatchMode
# but with some caveats:
#
# - In TorchDispatchMode, you get the same arguments as what a user
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
# you get (a, b) as args to your call. In SymDispatchMode, if
# you call a + b (where a and b are SymInts), you will get
# (a.get_pyobj(), b.get_pyobj()) as your args (these are PySymInts)
#
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
# So you have to manually call Tracer/create_node to write into
# the graph. See ProxySymDispatchMode for an example
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError()
def __enter__(self):
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
if hasattr(self, "inner"):
raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version")
else:
self.inner = old
SYM_FUNCTION_MODE = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = self.inner
def has_symbolic_sizes_strides(elem):
return elem._has_symbolic_sizes_strides
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
def _handle_sym_dispatch(func, args, kwargs):
global SYM_FUNCTION_MODE
mode = SYM_FUNCTION_MODE
assert mode
SYM_FUNCTION_MODE = mode.inner
try:
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs)
finally:
SYM_FUNCTION_MODE = mode
def guard_int(a):
if isinstance(a, SymInt):
return a.node.guard_int("", 0) # NB: uses Python backtrace
assert isinstance(a, int)
return a
def sym_float(a):
if isinstance(a, SymFloat):
return a
elif hasattr(a, '__sym_float__'):
return a.__sym_float__()
return float(a)
# Drop in replacement for math.sqrt
def sym_sqrt(a):
if hasattr(a, '__sym_sqrt__'):
return a.__sym_sqrt__()
return math.sqrt(a)
# Drop in replacement for math.floor/ceil. Actually, math.floor/ceil
# directly usable, but this has a more relaxed type signature for mypy
# (mypy requires SupportFloat which is too strict)
def sym_floor(a):
return math.floor(a) # type: ignore[type]
def sym_ceil(a):
return math.ceil(a) # type: ignore[type]
def sym_int(a):
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
return sym_floor(a) if a > 0 else sym_ceil(a)
return int(a)
def to_node(self, num):
if isinstance(num, (SymInt, SymFloat)):
return num.node
elif isinstance(num, int):
return self.wrap_int(num)
elif isinstance(num, float):
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
"""
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(self, expr, shape_env, pytype, constant=None, symbol=None):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
self.constant = constant
# Unlike expr, sympy.Symbol is guaranteed to either be a
# symbol or its negation a symbol, and it never gets simplified into a
# constant or another symbol. This only exists for freshly
# create_symint; intermediate values are None. The usage of this
# property is fairly short-lived: it lives long enough so that Dynamo
# can get its hands on symbols and setup Source associations
self.symbol: Optional[sympy.Expr] = symbol
@property
def expr(self):
self._update_expr()
return self._expr
def _update_expr(self):
self._expr = self.shape_env.replace(self._expr)
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def wrap_int(self, num):
assert isinstance(num, int)
return SymNode(sympy.Integer(num), self.shape_env, int, constant=num)
def wrap_float(self, num):
assert isinstance(num, float)
return SymNode(sympy.Float(num), self.shape_env, float, constant=num)
def clone(self):
return self
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
return self.str()
# These methods are metaprogrammed in below
def sym_int(self) -> "SymNode":
...
def sym_float(self) -> "SymNode":
...
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
def int_(self):
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
return int(self.shape_env.evaluate_expr(self.expr))
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
return float(self.shape_env.evaluate_expr(self.expr))
def bool_(self):
return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr)))
if HAS_SYMPY:
class FloorDiv(sympy.Function):
"""
We maintain this so that:
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
"""
nargs = (2,)
def _sympystr(self, printer):
lhs = self.args[0]
rhs = self.args[1]
lhs_str = printer._print(lhs)
rhs_str = printer._print(rhs)
if precedence(lhs) < precedence(sympy.div):
lhs_str = f"({lhs_str})"
if precedence(rhs) < precedence(sympy.div):
rhs_str = f"({rhs_str})"
return f"{lhs_str}//{rhs_str}"
@classmethod
def eval(cls, base, divisor):
if base == 0:
return sympy.Integer(0)
if divisor == 1:
return base
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return base // divisor
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)
# Methods that have a `__foo__` as well as `__rfoo__`
reflectable_magic_methods = {
'add': lambda a, b: a + b,
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
'pow': lambda a, b: a ** b,
'truediv': lambda a, b: a / b,
'floordiv': lambda a, b: FloorDiv(a, b),
}
magic_methods = {
**reflectable_magic_methods,
'eq': lambda a, b: sympy.Eq(a, b),
'gt': lambda a, b: sympy.Gt(a, b),
'lt': lambda a, b: sympy.Lt(a, b),
'le': lambda a, b: sympy.Le(a, b),
'ge': lambda a, b: sympy.Ge(a, b),
'floor': lambda a: sympy.floor(a),
'sym_float': lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals
'ceil': lambda a: sympy.ceiling(a),
'neg': lambda a: -a,
'min': lambda a, b: sympy.Min(a, b),
'max': lambda a, b: sympy.Max(a, b),
'sym_sqrt': lambda a: sympy.sqrt(a),
}
unary_magic_methods = {
'sym_float',
'ceil',
'floor',
'neg',
'sym_sqrt',
}
magic_methods_on_builtins = {"min", "max"}
magic_methods_on_math = {"ceil", "floor"}
magic_methods_on_submodule = {"sym_float", "sym_sqrt"}
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt"}
always_int_magic_methods = {"ceil", "floor"}
always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"}
def wrap_node(x):
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.is_int():
return SymInt(x)
elif x.is_float():
return SymFloat(x)
else:
raise AssertionError(f"unrecognized return type {x}")
def _make_node_magic(method, func):
func = lru_cache(256)(func)
def binary_magic_impl(self, other):
if method in magic_methods_on_builtins:
op = getattr(builtins, method)
else:
op = getattr(operator, method)
if SYM_FUNCTION_MODE:
r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
assert isinstance(other, SymNode)
other_expr = other.expr
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
other_expr = self.shape_env.replace(other_expr)
out = func(expr, other_expr)
out = sympy.expand(out)
pytype: Type
if method in always_float_magic_methods:
pytype = float
else:
pytype = self.pytype
# TODO: relational operators actually technically return a
# PySymBool, this is a type error
return SymNode(out, self.shape_env, pytype)
def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
if method in magic_methods_on_math:
op = getattr(math, method)
elif method in magic_methods_on_submodule:
op = getattr(sys.modules[__name__], method)
else:
op = getattr(operator, method)
r = _handle_sym_dispatch(op, (wrap_node(self),), {})
assert isinstance(r, (SymInt, SymFloat)), type(r)
return r.node
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
out = func(expr)
out = sympy.expand(out)
pytype: Type
if method in always_int_magic_methods:
pytype = int
elif method in always_float_magic_methods:
pytype = float
else:
pytype = self.pytype
return SymNode(out, self.shape_env, pytype)
if method in unary_magic_methods:
setattr(SymNode, method, unary_magic_impl)
else:
setattr(SymNode, method, binary_magic_impl)
for method, func in magic_methods.items():
_make_node_magic(method, func)
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
def unary_magic_impl(self):
return wrap_node(getattr(self.node, method)())
def binary_magic_impl(self, other):
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(self.node, method)(other_node))
def rbinary_magic_impl(self, other):
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
return wrap_node(getattr(other_node, method)(self.node))
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
for method, func in magic_methods.items():
_make_user_magic(method, SymInt)
_make_user_magic(method, SymFloat)
del method
del func
def _lru_cache(fn, maxsize=None):
"""
Wrapper around lru_cache that clears when new info about shapes has been
updated.
Use lru_cache if the output is always the same, regardless of the
constraints we know now (i.e. evaluate_expr)
Use _lru_cache otherwise.
"""
fn_cache = lru_cache(maxsize)(fn)
prior_key = None
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_key
if prior_key != self._get_key():
prior_key = self._get_key()
fn_cache.cache_clear()
return fn_cache(self, *args, **kwargs)
wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
return wrapper
class ShapeEnv(object):
def __init__(self):
self.guards = []
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict["sympy.Symbol", "sympy.Integer"] = {}
# Maps from sympy ints to expressions representing them
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
self.replacements: Dict["sympy.Symbol", "sympy.Expr"] = {} #
# Set holds a % b expressions that evaluate to 0.
self.divisible: Set["sympy.Expr"] = set()
# Duck-shaping says that if two input tensors have the same size,
# they get assigned the same symbolic variable
self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)}
self.tls = threading.local()
# Set holds symbols which definitely are not 0 or 1.
self.definitely_not_01: Set["sympy.Symbol"] = set()
def _suppress_guards_tls(self):
return getattr(self.tls, "suppress_guards", False)
@contextmanager
def suppress_guards(self):
self.tls.suppress_guards = True
try:
yield
finally:
self.tls.suppress_guards = False
def _get_key(self):
"""
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
Determines when we need to invalidate our cache
"""
return (len(self.replacements), len(self.divisible))
def create_symbolic_sizes_strides_storage_offset(self, ex: torch.Tensor):
"""
Returns a list of symbolic sizes and strides for the given tensor.
We try our best to express stride in terms of the sizes, so as to not
introduce new symbolic variables.
"""
size = [self.create_symbol(i) for i in ex.size()]
stride: List[Optional[sympy.Expr]] = [None] * len(size)
for i, val in enumerate(ex.stride()):
if val in (0, 1):
stride[i] = sympy.Integer(val)
while any(x is None for x in stride):
candidates = {
ex.size(i) * ex.stride()[i]: size[i] * stride[i]
for i in range(len(size))
if stride[i] is not None and ex.stride()[i] >= 0
}
# iterate over unbound strides in sorted order
val_list = sorted(
[(ex.stride()[i], i) for i in range(len(stride)) if stride[i] is None]
)
for _, i in val_list:
if stride[i] is None and ex.stride()[i] in candidates:
stride[i] = candidates[ex.stride()[i]]
candidates[ex.size(i) * ex.stride()[i]] = size[i] * stride[i]
if any(x is None for x in stride):
# bind the smallest unbound stride to a new variable
val, i = min(
[
(ex.stride()[i], i)
for i in range(len(stride))
if stride[i] is None
]
)
stride[i] = self.create_symbol(val)
assert all(x is not None for x in stride)
sym_size = [self.create_symintnode(i) for i in size]
sym_stride = []
for stride_expr in stride:
# NB: Don't duck size the stride; instead use the expression
# we computed
# TODO: We actually allocated an unnecessary extra symbol
# here in the smallest unbound stride case, but it's not
# a big deal because the non-0/1 symbol immediately
# evaporates from its duck-sizing simplification
s = self.create_symbol(val, simplify=False)
assert stride_expr is not None
assert isinstance(s, sympy.Symbol)
self.replacements[s] = stride_expr
sym_stride.append(self.create_symintnode(s))
sym_storage_offset = self.create_symintnode(self.create_symbol(ex.storage_offset()))
return sym_size, sym_stride, sym_storage_offset
def create_symintnode(self, sym: "sympy.Expr"):
assert isinstance(sym, sympy.Symbol) or isinstance(-sym, sympy.Symbol)
return SymInt(SymNode(sym.xreplace(self.replacements), self, int, symbol=sym))
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
# but there may be a replacement that allows it to be immediately
# simplified
def create_symbol(self, val: int, *, simplify: bool = True) -> "sympy.Expr":
if not HAS_SYMPY:
raise RuntimeError("Need sympy installed to create symbolic shapes")
if val < 0:
return -self.create_symbol(-val, simplify=simplify)
symbol = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True)
self.var_to_val[symbol] = sympy.Integer(val)
if not simplify:
return symbol
# Now attempt to simplify this symbol
# TODO: Create a guard whenever this happens
# TODO: Do this duck sizing lazily later
# This implements duck-shaping: input sizes that match are assigned
# the same symint
if val not in self.val_to_var:
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=True, integer=True)
self.var_to_val[sympy_expr] = sympy.Integer(val)
self.val_to_var[val] = sympy_expr
self.definitely_not_01.add(sympy_expr)
self.replacements[symbol] = self.val_to_var[val]
# Return the *symbol*; you're expected to apply the replacement to get
# the simplified variable
return symbol
def evaluate_guards_for_args(self, *args):
new_env = ShapeEnv()
# NB: This must be kept in sync with create_aot_dispatcher_function
# and wrap_fake_symbolic
meta_converter = MetaConverter()
pytree.tree_map_only(torch.Tensor, partial(meta_converter, shape_env=new_env), args)
return all(guard.xreplace(new_env.var_to_val) for guard, _ in self.guards)
def get_guard_expr(self):
"""
Returns a sympy expression representing all of the shape env guards.
NOTE: Does not include implicit 0/1 or duck-shaping guards
"""
return sympy.And(*[guard for guard, _ in self.guards])
def get_nontrivial_guards(self):
return [self.simplify(guard) for guard, _ in self.guards if self._maybe_evaluate_static(guard) is None]
def format_guards(self, verbose=False):
def format_tb(tb):
if not verbose:
return ""
return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
return '\n'.join(f" - {guard}{format_tb(tb)}" for guard, tb in self.guards)
def get_shape_groups(self):
shape_groups = collections.defaultdict(list)
for k, v in self.replacements.items():
shape_groups[v].append(k)
return shape_groups
@_lru_cache
def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]":
"""
Tries to evaluate expr without introducing guards
"""
expr = self.simplify(expr)
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
symbols = list(expr.free_symbols)
new_shape_env = {
k: sympy.Symbol(f"shape_{idx}", positive=True, integer=True) + 1
for idx, k in enumerate(symbols)
if k in self.definitely_not_01
}
new_expr = expr.xreplace(new_shape_env)
floor_div_replace = {}
for atom in new_expr.atoms(FloorDiv):
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
new_expr = sympy.expand(new_expr.xreplace(floor_div_replace))
if len(list(new_expr.free_symbols)) == 0:
return new_expr
return None
@_lru_cache
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
return sympy.expand(expr.xreplace(replacements))
@_lru_cache
def _update_divisible(self):
new_divisible = set()
for k in self.divisible:
res = self.replace(k)
if len(res.free_symbols) > 0:
new_divisible.add(k)
self.divisible = new_divisible
@_lru_cache
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
expr = self.replace(expr)
if expr.has(FloorDiv):
self._update_divisible()
div_replacements = {}
for atom in expr.atoms(FloorDiv):
base, divisor = atom.args
if self.replace(base % divisor) in self.divisible:
div_replacements[atom] = base / divisor
expr = expr.xreplace(div_replacements)
expr = sympy.expand(expr)
return expr
@lru_cache(256)
def size_hint(self, expr: "sympy.Expr"):
"""
Gets a size hint for a given expression from the underlying shapes we had.
Does not introduce a guard, so only use this when you can guarantee that
your code is still valid for arbitrary shapes (such as optimization decisions)
"""
result_expr = sympy.expand(expr).xreplace(self.var_to_val)
assert len(result_expr.free_symbols) == 0, "Size hint has variables we don't have underlying values for"
return result_expr
@_lru_cache
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
"""
Implements a DSU-like algorithm to find the variable that represents a
Also handles transitive non-identity replacements.
a: b + c
c: d
"""
if a not in self.replacements:
return a
res = self.replacements[a]
cur_replace = {s: self._find(s) for s in res.free_symbols}
self.replacements[a] = self.replacements[a].xreplace(cur_replace)
return self.replacements[a]
@lru_cache(256)
def _maybe_guard_eq(self, expr: "sympy.Eq") -> None:
"""
Evaluates the result of an eq call. If true, uses information to
simplify shapes (i.e. a == b or a % 5 == 0)
"""
concrete_bool = bool(self.size_hint(expr))
if not concrete_bool:
return
free = list(expr.free_symbols)
assert len(free) > 0, "The expression should not be static by this point"
# In case of really gnarly expression, we don't blow up
if len(free) > 5:
return
free = sorted(free, key=lambda x: (self.size_hint(x), x.name), reverse=True) # type: ignore[attr-defined]
lhs = expr.lhs
rhs = expr.rhs
try:
solutions = sympy.solve(lhs - rhs, free[0], dict=True)
if len(solutions) != 1:
return
solution = solutions[0][free[0]]
if all(t.is_integer for t in sympy.preorder_traversal(solution)):
new_var = self._find(solution)
self.replacements[cast(sympy.Symbol, free[0])] = new_var
except NotImplementedError:
if expr.has(sympy.Mod):
mod_expr = tuple(expr.atoms(sympy.Mod))[0]
try:
solutions = sympy.solve(lhs - rhs, mod_expr, dict=True)
if len(solutions) == 1 and solutions[0][mod_expr] == 0:
self.divisible.add(mod_expr)
except NotImplementedError:
pass
return
@lru_cache(256)
def evaluate_expr(self, expr: "sympy.Expr"):
"""
Given an expression, evaluates it, adding guards if necessary
"""
if len(expr.free_symbols) == 0:
return expr
expr = self.simplify(expr)
static_expr = self._maybe_evaluate_static(expr)
if static_expr is not None:
return static_expr
if isinstance(expr, sympy.Eq):
self._maybe_guard_eq(expr)
concrete_val = self.size_hint(expr)
# TODO: optimize this; avoid formatting traces until we need them
# NB: drop two frames; evaluate_expr and the Sym* function that
# actually called us
if not self._suppress_guards_tls():
stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2]))
if concrete_val is sympy.true:
self.guards.append((expr, stack))
elif concrete_val is sympy.false:
self.guards.append((sympy.Not(expr), stack))
else:
self.guards.append((sympy.Eq(expr, concrete_val), stack))
return concrete_val