blob: 05e65013a14af7f1aed34436b8f2f095c6caae76 [file] [log] [blame]
import collections
import dataclasses
import functools
import inspect
import sys
from typing import Any, Dict, List, Optional
import torch
import torch.fx
from .. import variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..eval_frame import skip_code
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard, make_dupe_guard
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name, istensor, iter_contains
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .tensor import TensorVariable
class ConstDictVariable(VariableTracker):
def __init__(self, items, user_cls, **kwargs):
super().__init__(**kwargs)
# All the keys are constants
assert not any(isinstance(x, VariableTracker) for x in items)
self.items = items
self.user_cls = user_cls
def as_proxy(self):
return {k: v.as_proxy() for k, v in self.items.items()}
def as_python_constant(self):
return {k: v.as_python_constant() for k, v in self.items.items()}
def python_type(self):
return self.user_cls
def reconstruct(self, codegen):
# instructions to load collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
codegen.extend_output(
[
codegen.create_load_python_module(collections, True),
codegen.create_load_attr("OrderedDict"),
]
)
# instructions to build the dict keys and values
for key in self.items.keys():
if istensor(key):
codegen.append_output(
codegen.create_load_global(global_key_name(key), True, add=True)
)
codegen.extend_output(create_call_function(0, False))
else:
codegen.append_output(codegen.create_load_const(key))
codegen(self.items[key])
# BUILD_MAP and calling collections.OrderedDict if necessary
if self.user_cls is collections.OrderedDict:
return [
create_instruction("BUILD_MAP", arg=len(self.items)),
*create_call_function(1, False),
]
# BUILD_MAP only if user_cls is dict
else:
return [create_instruction("BUILD_MAP", arg=len(self.items))]
def getitem_const(self, arg: VariableTracker):
return self.items[ConstDictVariable.get_key(arg)]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import (
ConstantVariable,
ListIteratorVariable,
ListVariable,
TupleVariable,
)
val = self.items
if name == "__getitem__":
assert len(args) == 1
return self.getitem_const(args[0])
elif name == "items":
assert not (args or kwargs)
return TupleVariable(
[
TupleVariable(
items=[
ConstDictVariable._key_to_var(
tx,
k,
),
v,
],
)
for k, v in val.items()
],
)
elif name == "keys":
assert not (args or kwargs)
return SetVariable(
items=[
ConstDictVariable._key_to_var(
tx,
k,
)
for k in val.keys()
],
mutable_local=MutableLocal(),
)
elif name == "values":
assert not (args or kwargs)
return TupleVariable(list(val.values()))
elif name == "copy":
assert not (args or kwargs)
return self.modifed(self.items.copy(), mutable_local=MutableLocal())
elif name == "__len__":
assert not (args or kwargs)
return ConstantVariable.create(len(self.items))
elif (
name == "__setitem__"
and args
and ConstDictVariable.is_valid_key(args[0])
and self.mutable_local
):
assert not kwargs and len(args) == 2
k = ConstDictVariable.get_key(args[0])
if istensor(k):
tx.store_global_weakref(global_key_name(k), k)
newval = dict(val)
newval[k] = args[1]
return tx.replace_all(
self,
self.modifed(newval),
)
elif (
name in ("pop", "get")
and len(args) == 2
and not kwargs
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
):
# missing item, return the default value
return args[1]
elif (
name == "get"
and len(args) == 1
and not kwargs
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) not in self.items
):
return ConstantVariable(None)
elif (
name == "pop"
and args
and ConstDictVariable.is_valid_key(args[0])
and self.mutable_local
):
newval = dict(val)
result = newval.pop(ConstDictVariable.get_key(args[0]))
tx.replace_all(self, self.modifed(newval))
return result
elif (
name == "update"
and len(args) == 1
and isinstance(args[0], ConstDictVariable)
and self.mutable_local
):
newval = dict(val)
newval.update(args[0].items)
newval.update(kwargs) # all keys in kwargs are valid (`str`s)
result = self.modifed(newval)
return tx.replace_all(self, result)
elif (
name == "update"
and len(args) == 1
and isinstance(
args[0],
(
ListVariable,
TupleVariable,
ListIteratorVariable,
),
)
and self.mutable_local
):
newval = dict(val)
for x in args[0].unpack_var_sequence(tx):
k, v = x.unpack_var_sequence(tx)
assert ConstDictVariable.is_valid_key(k)
newval[ConstDictVariable.get_key(k)] = v
newval.update(kwargs) # all keys in kwargs are valid (`str`s)
result = self.modifed(newval)
return tx.replace_all(self, result)
elif (
name in ("get", "__getattr__")
and args
and ConstDictVariable.is_valid_key(args[0])
and ConstDictVariable.get_key(args[0]) in self.items
):
return self.items[ConstDictVariable.get_key(args[0])]
elif (
name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
):
return ConstantVariable.create(
ConstDictVariable.get_key(args[0]) in self.items
)
else:
return super().call_method(tx, name, args, kwargs)
def modifed(self, items, **options):
"""a copy of self with different items"""
return self.clone(items=items, **options)
def unpack_var_sequence(self, tx):
val = self.items
result = [ConstDictVariable._key_to_var(tx, k) for k in val.keys()]
return result
@classmethod
def get_key(cls, arg: VariableTracker):
if isinstance(arg, TensorVariable) and arg.specialized_value is not None:
return arg.specialized_value
else:
return arg.as_python_constant()
@classmethod
def is_valid_key(cls, key):
return (
key.is_python_constant()
or (isinstance(key, TensorVariable) and key.specialized_value is not None)
or (isinstance(key, ConstantVariable) and key.python_type() is torch.dtype)
)
@classmethod
def _key_to_var(cls, tx, key, **options):
from .builder import VariableBuilder
if istensor(key):
return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
else:
assert ConstantVariable.is_literal(key)
return ConstantVariable.create(key, **options)
class DefaultDictVariable(ConstDictVariable):
def __init__(self, items, user_cls, default_factory=None, **kwargs):
super().__init__(items, user_cls, **kwargs)
assert user_cls is collections.defaultdict
self.default_factory = default_factory
def is_python_constant(self):
# Return false for unsupported defaults. This ensures that a bad handler
# path is not taken in BuiltinVariable for getitem.
if self.default_factory not in [list, tuple, dict] and not self.items:
return False
return super().is_python_constant()
@staticmethod
def is_supported_arg(arg):
if isinstance(arg, variables.BuiltinVariable):
return arg.fn in [list, tuple, dict]
else:
return isinstance(arg, variables.functions.BaseUserFunctionVariable)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
k = ConstDictVariable.get_key(args[0])
if k in self.items:
return self.getitem_const(args[0])
else:
if self.default_factory is None:
raise KeyError(f"{k}")
else:
if istensor(k):
tx.store_global_weakref(global_key_name(k), k)
new_val = dict(self.items)
default_var = self.default_factory.call_function(tx, [], {})
new_val[k] = default_var
tx.replace_all(self, self.modifed(new_val))
return default_var
else:
return super().call_method(tx, name, args, kwargs)
class SetVariable(VariableTracker):
@dataclasses.dataclass
class SetElement:
vt: VariableTracker
underlying_value: Any
def __hash__(self) -> int:
return hash(self.underlying_value)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SetVariable.SetElement):
return False
if isinstance(self.vt, variables.TensorVariable):
return self.underlying_value is other.underlying_value
else:
return self.underlying_value == other.underlying_value
def __init__(
self,
items: List[VariableTracker],
**kwargs,
):
super().__init__(**kwargs)
# Note - Set is still backed by a list, because we want set behavior over the contents,
assert isinstance(items, list)
assert all(isinstance(x, VariableTracker) for x in items)
self.items = []
self._add(items)
def as_proxy(self):
return [x.as_proxy() for x in self.items]
def python_type(self):
return set
def reconstruct(self, codegen):
codegen.load_import_from("builtins", "set")
codegen.foreach(self.items)
return [
create_instruction("BUILD_SET", arg=len(self.items))
] + create_call_function(1, True)
# Note - this is only used for producing a set
def _as_set_element(self, vt):
from .base import VariableTracker
from .misc import MethodWrapperVariable
from .tensor import TensorVariable
assert isinstance(vt, VariableTracker)
if isinstance(vt, TensorVariable):
fake_tensor = vt.as_proxy().node.meta.get("example_value")
if fake_tensor is None:
unimplemented(
"Cannot check Tensor object identity without its fake value"
)
return SetVariable.SetElement(vt, fake_tensor)
if isinstance(vt, ConstantVariable):
return SetVariable.SetElement(vt, vt.value)
if isinstance(vt, MethodWrapperVariable):
return SetVariable.SetElement(vt, vt.as_python_constant())
unimplemented(f"Sets with {type(vt)} NYI")
@property
def _underlying_items(self):
underlying_items = set()
for current_item in self.items:
assert (
current_item not in underlying_items
), "Items modeling set invariant violated"
underlying_items.add(self._as_set_element(current_item))
return underlying_items
def _add(self, item):
underlying_items = self._underlying_items
if isinstance(item, (list, set)):
items_to_add = item
else:
items_to_add = [item]
for item_to_add in items_to_add:
set_element = self._as_set_element(item_to_add)
if set_element not in underlying_items:
underlying_items.add(set_element)
self.items.append(set_element.vt)
else:
for e in underlying_items:
if hash(set_element) == hash(e):
alias_guard = make_dupe_guard(
e.vt.source, set_element.vt.source
)
if alias_guard:
install_guard(e.vt.source.make_guard(alias_guard))
return self.items
def call_method(
self,
tx,
name,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
) -> "VariableTracker":
# Somewhat duplicative of CommonListMethodsVariable - but better than to violate substitution
# principles and end up with things like direct item access attempts on a set, or
# getitem sources.
if name == "add" and args and self.mutable_local:
assert not kwargs
item = args[0]
result = SetVariable(
self._add(item),
mutable_local=self.mutable_local,
)
tx.replace_all(self, result)
return ConstantVariable.create(None)
elif name == "pop" and self.mutable_local:
assert not kwargs
assert not args
items = list(self.items)
result = items.pop()
tx.replace_all(
self,
SetVariable(items),
)
return result
elif name == "__len__":
return ConstantVariable.create(len(self.items))
elif name == "__contains__":
assert len(args) == 1
assert not kwargs
return iter_contains(self.items, args[0], tx, check_tensor_identity=True)
else:
return super().call_method(tx, name, args, kwargs)
def getitem_const(self, arg: VariableTracker):
raise RuntimeError("Illegal to getitem on a set")
def as_python_constant(self):
return self.python_type()([x.as_python_constant() for x in self.items])
def unpack_var_sequence(self, tx):
return list(self.items)
def _is_matching_transformers_cls(cls) -> bool:
mod = sys.modules.get("transformers.file_utils")
return mod is not None and issubclass(cls, mod.ModelOutput)
def _is_matching_diffusers_cls(cls) -> bool:
mod = sys.modules.get("diffusers.utils")
return mod is not None and issubclass(cls, mod.BaseOutput)
class DataClassVariable(ConstDictVariable):
"""
This is a bit of a hack to deal with
transformers.file_utils.ModelOutput() from huggingface.
ModelOutput causes trouble because it a a mix of a dataclass and a
OrderedDict and it calls super() methods implemented in C.
"""
# ModelOutput() excludes None, though generic datclasses don't
include_none = False
@staticmethod
@functools.lru_cache(None)
def _patch_once():
try:
from transformers.file_utils import ModelOutput
for obj in ModelOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
try:
from diffusers.utils import BaseOutput
for obj in BaseOutput.__dict__.values():
if callable(obj):
skip_code(obj.__code__)
except ImportError:
pass
@staticmethod
def is_matching_cls(cls):
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
@classmethod
def create(cls, user_cls, args, kwargs, options):
DataClassVariable._patch_once()
skip_code(user_cls.__init__.__code__)
keys = [f.name for f in dataclasses.fields(user_cls)]
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
assert set(bound.arguments.keys()) == set(keys)
items = {}
for key in keys:
val = bound.arguments[key]
if isinstance(val, VariableTracker):
items[key] = val
else:
if cls.include_none:
assert variables.ConstantVariable.is_literal(val)
items[key] = variables.ConstantVariable.create(val)
else:
assert val is None, f"unexpected {val}"
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
unimplemented("DataClassVariable iterator constructor")
# TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
return cls(items, user_cls, **options)
@classmethod
def wrap(cls, builder, obj):
user_cls = type(obj)
keys = [f.name for f in dataclasses.fields(user_cls)]
excluded = []
items = {}
for key in keys:
# __init__ function of a dataclass might not have yet defined the key
if hasattr(obj, key):
val = getattr(obj, key)
var = builder.__class__(
tx=builder.tx, source=AttrSource(builder.source, key)
)(val)
if val is not None or cls.include_none:
items[key] = var
else:
excluded.append(var)
return cls(items, user_cls)
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
keys = tuple(self.items.keys())
for key in keys:
codegen(self.items[key])
return codegen.create_call_function_kw(len(keys), keys, True)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
assert not kwargs and len(args) == 1
index = args[0].as_python_constant()
if isinstance(index, str):
return self.items[index]
else:
return self.call_method(tx, "to_tuple", [], {}).call_method(
tx, "__getitem__", args, kwargs
)
elif name == "to_tuple":
assert not (args or kwargs)
return variables.TupleVariable(list(self.items.values()))
elif name == "__setattr__":
name = "__setitem__"
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name: str) -> "VariableTracker":
if name in self.items:
return self.call_method(
tx, "__getitem__", [variables.ConstantVariable.create(name)], {}
)
elif not self.include_none:
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
if name in defaults:
assert variables.ConstantVariable.is_literal(defaults[name])
return variables.ConstantVariable.create(defaults[name])
super().var_getattr(tx, name)
class CustomizedDictVariable(ConstDictVariable):
@staticmethod
def is_matching_cls(cls):
# True if using default OrderedDict.__init__ and did not implement __post_init__
if (
issubclass(cls, collections.OrderedDict)
and cls.__init__ is collections.OrderedDict.__init__
and not hasattr(cls, "__post_init__")
):
return True
# hack for HF usecase:
# assume dataclass annotation for ModelOutput subclass
# assume self.create is AA to ModelOutput.__post_init__
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls)
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
# called from user_defined.py
# when is_matching_cls(cls) is true
@classmethod
def create(cls, user_cls, args, kwargs, options):
# avoid tracing when returning ModelOutput from forward func
for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"):
if hasattr(user_cls, attr_name):
fn = getattr(user_cls, attr_name)
assert callable(fn), f"expect callable attr {attr_name}"
if hasattr(fn, "__code__"):
skip_code(fn.__code__)
if not args and not kwargs:
# CustomDict() init with empty arguments
raw_items = {}
elif dataclasses.is_dataclass(user_cls):
# @dataclass CustomDict(a=1, b=2)
bound = inspect.signature(user_cls).bind(*args, **kwargs)
bound.apply_defaults()
raw_items = bound.arguments
elif not args:
# CustomDict(a=1, b=2) in the general (non-dataclass) case.
raw_items = dict(kwargs)
elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs:
# CustomDict({'a': 1, 'b': 2})
raw_items = args[0].items
else:
unimplemented("custom dict init with args/kwargs unimplemented")
items = {}
for key in raw_items.keys():
val = raw_items[key]
if isinstance(val, VariableTracker):
items[key] = val
elif variables.ConstantVariable.is_literal(val):
items[key] = variables.ConstantVariable.create(val)
else:
unimplemented("expect VariableTracker or ConstantVariable.is_literal")
return cls(items, user_cls, **options)
# called from builder.py
@classmethod
def wrap(cls, builder, obj):
raise NotImplementedError()
def __init__(self, items, user_cls, **options):
super().__init__(items, user_cls, **options)
assert self.is_matching_cls(user_cls)
def as_proxy(self):
raise NotImplementedError()
# 'RETURN_VALUE triggered compile'
# called from torch/_dynamo/codegen.py
def reconstruct(self, codegen):
codegen.extend_output([codegen._create_load_const(self.user_cls)])
keys = tuple(self.items.keys())
for key in keys:
codegen(self.items[key])
return codegen.create_call_function_kw(len(keys), keys, True)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
fn = getattr(self.user_cls, name)
source = None if self.source is None else AttrSource(self.source, name)
if hasattr(fn, "__objclass__") and fn.__objclass__ in (
dict,
collections.OrderedDict,
):
# for python dict method without overridden
return super().call_method(tx, name, args, kwargs)
elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"):
# for user overridden method
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn, source=source),
[self] + list(args),
kwargs,
)
unimplemented("custom dict: call_method unimplemented name=%s", name)
def var_getattr(self, tx, name: str) -> "VariableTracker":
if name in self.items:
return self.call_method(
tx, "__getitem__", [variables.ConstantVariable.create(name)], {}
)
super().var_getattr(tx, name)
@functools.lru_cache(None)
def _install_PretrainedConfig_patch():
import transformers
# We need to monkeypatch transformers here, sadly.
# TODO(voz): Upstream to transformers lib
def _dynamo_overriden_transformers_eq(self, other):
if not hasattr(other, "__dict__"):
return False
return self.__dict__ == other.__dict__
transformers.configuration_utils.PretrainedConfig.__eq__ = (
_dynamo_overriden_transformers_eq
)
class HFPretrainedConfigVariable(VariableTracker):
"""
Hack for HuggingFace PretrainedConfig
"""
@staticmethod
def is_matching_cls(cls):
mod = sys.modules.get("transformers.configuration_utils")
is_match = mod is not None and issubclass(cls, mod.PretrainedConfig)
# Lazily install monkeypatch the first time we see it in dynamo
if is_match:
_install_PretrainedConfig_patch()
return is_match
@classmethod
def is_matching_object(cls, obj):
return cls.is_matching_cls(type(obj))
def __init__(self, obj, **kwargs):
super().__init__(**kwargs)
self.obj = obj
assert self.is_matching_cls(type(obj))
def var_getattr(self, tx, name: str) -> "VariableTracker":
from . import ConstantVariable
return ConstantVariable.create(getattr(self.obj, name))
def call_hasattr(self, tx, name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(self.obj, name))
class PythonSysModulesVariable(VariableTracker):
"""Special case for sys.modules.
Without this we will guard on the exact set of modules imported in the
lifetime of the python program.
"""
def python_type(self):
return dict
@staticmethod
def reconstruct(self, codegen):
codegen.extend_output(
[
codegen.create_load_python_module(sys, True),
codegen.create_load_attr("modules"),
]
)
def call_method(
self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
):
from .builder import VariableBuilder
if name == "__getitem__":
return self.call_getitem(tx, *args, **kwargs)
elif name == "get":
return self.call_get(tx, *args, **kwargs)
elif name == "__contains__":
return self.call_contains(tx, *args, **kwargs)
# Fallback to dict implementation
real_dict = VariableBuilder(tx, self.source)(sys.modules)
return real_dict.call_method(tx, name, args, kwargs)
def _contains_helper(self, tx, key: VariableTracker):
k = ConstDictVariable.get_key(key)
has_key = k in sys.modules
install_guard(
self.make_guard(
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key)
)
)
return k, has_key
def call_contains(self, tx, key: VariableTracker):
k, has_key = self._contains_helper(tx, key)
return ConstantVariable.create(value=has_key)
def call_get(
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
if has_key:
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])
if default is not None:
return default
return ConstantVariable.create(value=None)
def call_getitem(self, tx, key: VariableTracker):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
return VariableBuilder(
tx,
GetItemSource(self.source, k),
)(sys.modules[k])