blob: ba65778e252b94238baa96334a840fdd6c97f625 [file] [log] [blame]
import torch
import torch.utils._pytree as pytree
from typing import Dict, Any
try:
import sympy # type: ignore[import]
HAS_SYMPY = True
except ImportError:
HAS_SYMPY = False
aten = torch.ops.aten
__all__ = ["has_symbolic_sizes_strides", "create_contiguous", "is_symbolic_op", "handle_symbolic_op", "PySymInt", "ShapeEnv"]
def has_symbolic_sizes_strides(elem):
return any([isinstance(i, torch._C.SymbolicIntNode) for i in elem.shape])
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
def is_symbolic_op(func):
return func in [aten.sym_size.default, aten.dim.default, aten.is_contiguous.default, aten.stride]
def handle_symbolic_op(func, args, kwargs):
assert is_symbolic_op(func)
if func == torch.ops.aten.sym_size.default:
return None
if func == torch.ops.aten.dim.default:
return len(args[0].shape)
# TODO: hack, need to make is_contiguous calls symbolic (probably through computing on symbolic strides)
if func == torch.ops.aten.is_contiguous.default:
return True
# TODO: hack, we don't currently support symbolic strides properly
if func == torch.ops.aten.stride:
return create_contiguous(args[0].shape)
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class PySymInt(object):
"""
PySymInt objects are the primary "symbolic shape" objects that flow through
our program. They're what sit under FakeTensor, and contains our primary
implementation of symbolic shapes.
"""
def __init__(self, expr, shape_env):
self.expr = expr
self.shape_env = shape_env
def wrap(self, num):
return PySymInt(sympy.Integer(num), self.shape_env)
def __str__(self):
return f"PySymInt({self.expr})"
# Today we error on calling int on a symbolic shape, as this is a very accessible footgun.
# In the future we'll probably need some explicit way of allowing this
def __int__(self):
raise RuntimeError("Trying to extract a concrete int out of a symbolic int")
def __bool__(self):
return bool(self.shape_env.evaluate_expr(self.expr))
# Methods that have a `__foo__` as well as `__rfoo__`
reflectable_magic_methods = {
'add': lambda a, b: a + b,
'sub': lambda a, b: a - b,
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
}
magic_methods = {
**reflectable_magic_methods,
'eq': lambda a, b: sympy.Eq(a, b),
'gt': lambda a, b: sympy.Gt(a, b),
'lt': lambda a, b: sympy.Lt(a, b),
'le': lambda a, b: sympy.Le(a, b),
'ge': lambda a, b: sympy.Ge(a, b),
}
for method, _func in magic_methods.items():
method_name = f'{method}'
def _create_magic_impl(func):
def magic_impl(self, other):
if isinstance(other, PySymInt):
other = other.expr
return PySymInt(func(self.expr, other), self.shape_env)
return magic_impl
# this should be wrapped transparently into torch._C.SymbolicIntNode
setattr(PySymInt, method_name, _create_magic_impl(_func))
class ShapeEnv(object):
def __init__(self):
self.guards = []
self.shape_env = {}
def create_symint(self, name, val, shape_env=None):
if not HAS_SYMPY:
raise RuntimeError("Need sympy installed to create symbolic shapes")
if shape_env is None:
shape_env = self.shape_env
# Currently we don't put 0/1 specialization in guards but perhaps we should
if val == 0 or val == 1:
return val
sympy_expr = sympy.Symbol(name, positive=True)
py_sym_int = PySymInt(sympy_expr, self)
cpp_sym_int = torch._C.SymbolicIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
shape_env[sympy_expr] = val
return cpp_sym_int
def try_constantify(self, expr):
# Simplifies assuming that shape vars > 1 (since we cache on 0/1 shape values)
new_shape_env = {k: sympy.Symbol(f'shape_{idx}', positive=True) + 1 for idx, k in enumerate(self.shape_env.keys())}
new_expr = expr.subs(new_shape_env)
new_expr = new_expr.simplify()
if len(list(new_expr.free_symbols)) == 0:
return new_expr
return None
def create_shapes_for_args(self, args, shape_env=None):
# Takes pytrees and returns a flat list
arg_cnt = 0
def create_shape(x):
nonlocal arg_cnt
if not isinstance(x, torch.Tensor):
return x
out_shape = [self.create_symint(f"s_{arg_cnt}^{idx}", sz, shape_env) for idx, sz in enumerate(x.shape)]
arg_cnt += 1
return out_shape
return list(map(create_shape, pytree.tree_flatten(args)[0]))
def evaluate_guards_for_args(self, *args):
env: Dict[Any, Any] = {}
_ = self.create_shapes_for_args(args, shape_env=env)
return all(guard.subs(env) == value for guard, value in self.guards)
def evaluate_expr(self, expr):
const_expr = self.try_constantify(expr)
if const_expr is not None:
return const_expr
expr = expr.simplify()
concrete_val = expr.subs(self.shape_env)
self.guards.append((expr, concrete_val))
return concrete_val