blob: 73334a759e0fb47ab726460844f0fc4062de7bb8 [file] [log] [blame]
import dataclasses
import enum
import logging
import weakref
from typing import Callable, List, NamedTuple, Optional
# TODO(voz): Stolen pattern, not sure why this is the case,
# but mypy complains.
try:
import sympy # type: ignore[import]
except ImportError:
logging.warning("No sympy found")
"""
torch._guards is the definitional source of truth for general purpose guard structures.
An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
and no guard installation notions here.
"""
class GuardSource(enum.Enum):
LOCAL = 0
GLOBAL = 1
LOCAL_NN_MODULE = 2
GLOBAL_NN_MODULE = 3
CONSTANT = 4
RANDOM_VALUE = 5
SHAPE_ENV = 6
def select(self, locals_, globals_):
if self in (GuardSource.LOCAL, GuardSource.LOCAL_NN_MODULE):
return locals_
if self in (GuardSource.GLOBAL, GuardSource.GLOBAL_NN_MODULE):
return globals_
raise NotImplementedError()
def is_nn_module(self) -> bool:
return self in (GuardSource.GLOBAL_NN_MODULE, GuardSource.LOCAL_NN_MODULE)
def is_local(self):
return self in (GuardSource.LOCAL, GuardSource.LOCAL_NN_MODULE)
"""
Base class for a "GuardBuilder" role.
The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
to torchdynamo's GuardBuilder.
Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
on GuardSource's select function.
There is value in keeping this GuardBuilderBase empty to keep layering clean.
"""
class GuardBuilderBase:
pass
class ShapeGuard(NamedTuple):
expr: sympy.Expr
stack: str
@dataclasses.dataclass
class Guard:
# The name of a Guard specifies what exactly it is the guard is guarding
# on. The meaning of the name is dependent on the create_fn; you must
# look at the use-site inside create_fn to know what name means.
#
# That being said, although you might think this is just a "name", name is
# usually an arbitrary Python expression that will be evaluated with all
# globals (and locals, if you create a LOCAL guard) to extract the Python
# object that we want to perform guard tests on. This evaluation
# typically happens in GuardBuilder.eval. In these cases, name is
# typically produced by Source.name() (not to be confused with
# GuardSource)--morally, we could have stored a Source here.
#
# Occasionally, name is not a valid Python expression; sometimes
# it is meaningless. Example create_fns that are like this include
# GRAD_MODE and SYMBOL_MATCH.
name: str
source: GuardSource
create_fn: Callable[[GuardBuilderBase, "Guard"], None]
is_volatile: bool = False
# Export only. These values are written to at time of guard check_fn creation.
guard_types: Optional[List[str]] = None
code_list: Optional[List[str]] = None
obj_weakref: Optional[object] = None
guarded_class_weakref: Optional[type] = None
def __hash__(self):
return hash((self.name, self.source, id(self.create_fn)))
def sort_key(self):
return (
self.source.value if self.source else -1,
len(self.name),
self.name,
self.create_fn.__code__.co_firstlineno,
)
def __lt__(self, other):
return self.sort_key() < other.sort_key()
@staticmethod
def weakref_to_str(obj_weakref):
"""
This is a workaround of a Python weakref bug.
`obj_weakref` is instance returned by `weakref.ref`,
`str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
class MyConfig(dict):
def __getattr__(self, x):
return self[x]
obj = MyConfig(offset=5)
obj_weakref = weakref.ref(obj)
str(obj_weakref) # raise error: KeyError: '__name__'
"""
if isinstance(obj_weakref, weakref.ReferenceType):
obj = obj_weakref()
if obj is not None:
return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
else:
return f"<weakref at {hex(id(obj_weakref))}; dead>"
else:
return str(obj_weakref)
def __str__(self):
s = f"""
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.create_fn.__name__}
{{
'guard_types': {self.guard_types},
'code': {self.code_list},
'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
'guarded_class': {self.guarded_class_weakref}
}}
"""
return s
def create(self, local_builder: GuardBuilderBase, global_builder: GuardBuilderBase):
return self.create_fn(self.source.select(local_builder, global_builder), self)
def is_nn_module(self):
return self.source.is_nn_module()
def is_local(self):
return self.source.is_local()
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
if not self.guard_types:
self.guard_types = list()
self.guard_types.append(guard_type)
assert self.guarded_class_weakref in (
guarded_class,
None,
), "Guarded class id must be identical, or None"
self.guarded_class_weakref = guarded_class
if not self.code_list:
self.code_list = code_list
else:
self.code_list.extend(code_list)
assert self.obj_weakref in (
obj_weakref,
None,
), "Guarded object must be identical, or None"
self.obj_weakref = obj_weakref