blob: 1a2af9af2eeb8505cad40ebecd24b188e485a765 [file] [log] [blame]
import collections
import functools
import inspect
import itertools
import sys
import types
from typing import Dict, List
import torch._C
from .. import config, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import unimplemented
from ..source import AttrSource, ODictGetItemSource
from ..utils import (
check_constant_args,
HAS_NUMPY_TORCH_INTEROP,
identity,
proxy_args_kwargs,
)
from .base import MutableLocal, VariableTracker
from .functions import NestedUserFunctionVariable, UserFunctionVariable
from .user_defined import UserDefinedObjectVariable
class SuperVariable(VariableTracker):
def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
super().__init__(**kwargs)
self.typevar = typevar
self.objvar = objvar
self.specialized = specialized # directly get attr from self.typevar if true
def reconstruct(self, codegen):
codegen(variables.BuiltinVariable(super))
codegen(self.typevar)
if self.objvar is not None:
codegen(self.objvar)
return create_call_function(2, True)
else:
return create_call_function(1, True)
def const_getattr(self, tx, name):
assert self.objvar, "1-arg super not implemented"
if self.specialized:
return getattr(self.typevar.as_python_constant(), name)
search_type = self.typevar.as_python_constant()
# We default to the python type of the object. However, if this is
# a `type` or subclass of `type`, then the original object represents
# the user defined type.
type_to_use = self.objvar.python_type()
if issubclass(type_to_use, type):
type_to_use = self.objvar.value
# TODO(jansel): there is a small chance this could trigger user code, prevent that
return getattr(super(search_type, type_to_use), name)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
options = VariableTracker.propagate(
self, args, kwargs.values(), self.objvar, self.typevar
)
inner_fn = self.const_getattr(self, name)
source = None if self.source is None else AttrSource(self.source, name)
if inner_fn is object.__init__:
return LambdaVariable(identity, **options)
elif inner_fn is torch.nn.Module.__init__:
objvar = self.objvar
from ..side_effects import AttributeMutationNew
if (
isinstance(objvar, variables.UserDefinedObjectVariable)
and isinstance(objvar.mutable_local, AttributeMutationNew)
and not (args or kwargs)
):
tx.output.guards.update(options.get("guards", set()))
tx.output.side_effects.store_attr(
objvar, "__call_nn_module_init", variables.ConstantVariable(True)
)
return variables.ConstantVariable(None)
else:
unimplemented("super() nn.Module.__init__")
elif isinstance(inner_fn, types.FunctionType):
return variables.UserFunctionVariable(
inner_fn, source=source, **options
).call_function(tx, [self.objvar] + args, kwargs)
elif isinstance(inner_fn, types.MethodType):
return variables.UserMethodVariable(
inner_fn.__func__, self.objvar, source=source, **options
).call_function(tx, args, kwargs)
elif (
inner_fn is collections.OrderedDict.__getitem__
and isinstance(self.objvar, variables.UserDefinedObjectVariable)
and self.objvar.source
and len(args) == 1
and len(kwargs) == 0
and args[0].is_python_constant()
):
from .builder import VariableBuilder
key = args[0].as_python_constant()
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
collections.OrderedDict.__getitem__(self.objvar.value, key)
)
else:
unimplemented(f"non-function or method super: {inner_fn}")
class UnknownVariable(VariableTracker):
"""
It could be anything!
"""
class DelayGraphBreakVariable(UnknownVariable):
"""
Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
"""
class ComptimeVariable(VariableTracker):
"""
This variable is special, it lets you execute arbitrary code at
Dynamo compile time
"""
def reconstruct(self, codegen):
raise NotImplementedError("comptime is special form")
def var_getattr(self, tx, name: str) -> "VariableTracker":
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
from .functions import UserFunctionVariable
return UserFunctionVariable(
getattr(comptime, name), source=AttrSource(self.source, name)
)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from ..comptime import ComptimeContext
# TODO: support an expression form as well
assert not kwargs
assert len(args) == 1
fn = args[0]
if isinstance(fn, UserFunctionVariable):
fn.get_function()(ComptimeContext(tx))
elif isinstance(fn, NestedUserFunctionVariable):
# We have to manually bind the freevars ourselves
code = fn.get_code()
assert not fn.closure, (
"comptime function must not have free variables, "
f"but these variables were free: {code.co_freevars}"
)
func = types.FunctionType(
code,
fn.f_globals,
fn.fn_name.as_python_constant(),
tuple(fn.defaults.items) if fn.defaults else None,
# We could automatically promote free variables into
# ComptimeVar but this is confusing if you access
# a free variable that we actually DO have the runtime
# value for
# tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
tuple(),
)
func(ComptimeContext(tx))
else:
raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
return variables.ConstantVariable(None)
class ClosureVariable(UnknownVariable):
def __init__(self, name, **kwargs):
super().__init__(**kwargs)
self.name = name
def reconstruct(self, codegen):
return [codegen.create_load_closure(self.name)]
class NewCellVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class NewGlobalVariable(VariableTracker):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class InspectSignatureVariable(VariableTracker):
"""represents inspect.signature(...)"""
@staticmethod
def create(callable, **kwargs):
if kwargs:
unimplemented(f"inspect.signature with {kwargs}")
return InspectSignatureVariable(callable)
def __init__(self, inspected, **kwargs):
super().__init__(**kwargs)
self.inspected = inspected
def produce_trampoline_autograd_fwd(fn_cls):
def trampoline_autograd_fwd(*args, **kwargs):
return fn_cls.forward(*args, **kwargs)
trampoline_autograd_fwd._origin = produce_trampoline_autograd_fwd
return trampoline_autograd_fwd
def produce_trampoline_autograd_bwd(fn_cls):
def trampoline_autograd_bwd(*args, **kwargs):
return fn_cls.backward(*args, **kwargs)
trampoline_autograd_bwd._origin = produce_trampoline_autograd_bwd
return trampoline_autograd_bwd
def produce_trampoline_autograd_apply(fn_cls):
def trampoline_autograd_apply(*args, **kwargs):
return fn_cls.apply(*args, **kwargs)
trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
return trampoline_autograd_apply
class AutogradFunctionVariable(VariableTracker):
"""represents a torch.autograd.Function subclass"""
def __init__(self, fn_cls, **kwargs):
super().__init__(**kwargs)
self.fn_cls = fn_cls
def call_apply(self, tx, args, kwargs):
requires_grad = False
def visit(node):
nonlocal requires_grad
if isinstance(node, variables.TensorVariable):
if node.requires_grad is not False:
requires_grad = True
if isinstance(node, variables.NNModuleVariable):
if node.is_training(tx):
requires_grad = True
return node
VariableTracker.apply(visit, (args, kwargs))
ctx = AutogradFunctionContextVariable.create(tx)
args = [ctx, *args]
if (
requires_grad
and torch.is_grad_enabled()
and torch._dynamo.config.capture_autograd_function
):
# Note - this is the same check used in autograd/function.py, except inverted.
# If we want to support functorch transforms here, we will need to enable this.
if (
self.fn_cls.setup_context
!= torch.autograd.function._SingleLevelFunction.setup_context
):
unimplemented(
"NYI - autograd.Function with custom setup_context method"
)
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
if vjp_fn is not torch.autograd.Function.vjp:
unimplemented("NYI - User defind vjp")
jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
if jvp_fn is not torch.autograd.Function.jvp:
unimplemented("NYI - User defind jvp")
from .torch import (
safe_or_raise_always_restore,
TorchHigherOrderOperatorVariable,
)
trampoline_autograd_apply = produce_trampoline_autograd_apply(self.fn_cls)
trampoline_autograd_fwd = produce_trampoline_autograd_fwd(self.fn_cls)
trampoline_autograd_bwd = produce_trampoline_autograd_bwd(self.fn_cls)
# NOTE [On Tracing autograd.Function w/ grad]
# The complex system described here revolves around the soundness evaluation of an autograd.Function in
# PyTorch. The system follows a well-defined strategy for tracing, which involves three key steps: tracing
# forward, tracing backward, and if both are sound the potential recording of an "apply" operation into the
# graph.We trace forward, and evaluate soundness. Soundness, in this context, refers to the absence of side
# effects, the avoidance of lifting new arguments into the trace, the production of a single tensor output,
# and a limited input scope confined to contexts, tensors, and constants. If the forward trace is sound,
# we install any guards accumulated from tracing. If not, we graph break. We trace backward, and evaluate
# for soundness, same as forward, except with more strictness. We enable a strict mode on the tx, and
# reject certain ops when running under this strict mode. If the backward trace is sound, we discard the
# trace by restoring. Otherwise, we raise.
# if both the forward and backward traces are sound, we write the autograd function’s apply into the graph.
# For tracing forward and backward, we use UserFunctionVariable. Although it does not directly contribute
# to soundness evaluation, it plus a GlobalSource makes sure we can produce valid guards,
# and that we can inline properly here. Inlining is required in order to be able to ensure that the
# soundness evaluation works as described above.
graph_checkpoint, checkpoint = tx.output.graph, tx.copy_graphstate()
module_source = AttrSource(
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
)
higher_order_autograd_fn = TorchHigherOrderOperatorVariable(
trampoline_autograd_fwd, source=AttrSource(module_source, "forward")
)
speculated_fwd_result = higher_order_autograd_fn.call_function(
tx, args, kwargs
)
bwd_args = [ctx, speculated_fwd_result]
safe_or_raise_always_restore(
tx,
graph_checkpoint,
checkpoint,
TorchHigherOrderOperatorVariable(
trampoline_autograd_bwd,
source=AttrSource(module_source, "backward"),
),
bwd_args,
)
# If fwd and backward are sound, we want apply in the graph.
# And we don't want backwards for the obvious reasons.
args = args[1:]
return TorchHigherOrderOperatorVariable(
trampoline_autograd_apply
).call_function(tx, args, kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
options["source"] = AttrSource(AttrSource(self.source, "__class__"), "forward")
fn = self.fn_cls.forward
if isinstance(fn, types.FunctionType):
return variables.UserFunctionVariable(fn, **options).call_function(
tx, args, kwargs
)
elif isinstance(fn, types.MethodType):
return variables.UserMethodVariable(
fn.__func__, variables.UserDefinedClassVariable(self.fn_cls), **options
).call_function(tx, args, kwargs)
else:
unimplemented(
f"non-function or method in subclass of torch.autograd.Function: {fn}"
)
def call_function(self, tx, args, kwargs):
options = VariableTracker.propagate(self, args, kwargs.values())
return AutogradFunctionVariable(self.fn_cls, source=self.source, **options)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
):
if name not in ["backward", "forward"]:
unimplemented(f"Unsupported method: {name}")
if name == "backward":
with tx.strict_translation_mode():
return tx.inline_call(
tx, UserFunctionVariable(self.fn_cls.backward), args, kwargs
)
return tx.inline_call(
tx, UserFunctionVariable(self.fn_cls.forward), args, kwargs
)
class AutogradFunctionContextVariable(UserDefinedObjectVariable):
"""
Tracks an autograd.Function() context using mutation tracking in side_effects.py
"""
def __init__(self, value, value_type=None, inference=False, **kwargs):
super().__init__(value=value, value_type=value_type, **kwargs)
self.inference = inference
@staticmethod
def create(tx):
out = tx.output.side_effects.track_object_new(
None,
torch.autograd.function.FunctionCtx,
functools.partial(AutogradFunctionContextVariable, inference=True),
{},
)
proxy = tx.output.create_proxy(
"call_function", torch.autograd.function.FunctionCtx, tuple(), {}
)
proxy.node.meta["example_value"] = out.value
out.proxy = proxy
return out
def as_proxy(self):
return self.proxy
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name != "save_for_backward":
unimplemented(f"autograd.Function context method: {name}")
if not self.inference:
assert self.source and not kwargs
tx.output.side_effects.track_save_for_backward(self, args)
options = VariableTracker.propagate(self, args, kwargs.values())
if not hasattr(self, "_saved_tensors"):
self._saved_tensors = []
for arg in args:
# as_proxy can return constant values or other non proxy values
if isinstance(arg.as_proxy(), torch.fx.Proxy):
arg.as_proxy().node.meta["saved_tensor_marked"] = True
self._saved_tensors.append(arg)
return variables.ConstantVariable(None, **options)
def var_getattr(self, tx, name):
if name == "save_for_backward":
return LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
).add_options(self)
if name == "saved_tensors":
return variables.TupleVariable(list(self._saved_tensors))
return super().var_getattr(tx, name)
class LambdaVariable(VariableTracker):
def __init__(self, fn, **kwargs):
super().__init__(**kwargs)
self.fn = fn
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return self.fn(*args, **kwargs).add_options(self)
class GetAttrVariable(VariableTracker):
def __init__(self, obj, name, **kwargs):
super().__init__(**kwargs)
assert isinstance(obj, VariableTracker)
assert isinstance(name, str)
self.obj = obj
self.name = name
def __str__(self):
return f"{self.__class__.__name__}({self.obj}, {self.name})"
@staticmethod
def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
return getattr(base_proxy, attr)
def as_proxy(self):
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
def const_getattr(self, tx, name):
if not isinstance(self.obj, variables.NNModuleVariable):
raise NotImplementedError()
step1 = tx.output.get_submodule(self.obj.module_key)
if self.name not in step1.__dict__:
raise NotImplementedError()
step2 = inspect.getattr_static(step1, self.name)
if name not in step2.__dict__:
raise NotImplementedError()
return inspect.getattr_static(step2, name)
def reconstruct(self, codegen):
codegen(self.obj)
return codegen.create_load_attrs(self.name)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
# This variable is True when it corresponds to user code such as
#
# super().__torch_function__(...)
#
# and the super().__torch_function__ attribute resolves
# to torch.Tensor.__torch_function__.
is_original_tensor_torch_function = (
self.name == "__torch_function__"
and isinstance(self.obj, SuperVariable)
# for now, only support one level of inheritance
and len(self.obj.objvar.value.__mro__) > 1
and self.obj.objvar.value.__mro__[1] == torch.Tensor
)
if is_original_tensor_torch_function:
# Instead of tracing inside torch.Tensor.__torch_function__,
# record the `call_function` or `call_method` call into the graph.
from . import TorchVariable
original_torch_or_getattr_variable = args[0]
new_args = args[2].items
new_kwargs = args[3].items
options = VariableTracker.propagate(self, new_args, new_kwargs.values())
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], TorchVariable):
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
original_torch_or_getattr_variable.value,
*proxy_args_kwargs(new_args, new_kwargs),
),
**options,
)
elif isinstance(args[0], GetAttrVariable):
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_method",
original_torch_or_getattr_variable.name,
*proxy_args_kwargs(new_args, new_kwargs),
),
**options,
)
else:
unimplemented(
f"GetAttrVariable.call_function original __torch_function__ {args}"
)
if isinstance(self.obj, AutogradFunctionVariable) and self.name == "apply":
return self.obj.call_apply(tx, args, kwargs).add_options(self)
# calling parent class‘s non classmethod from child class
# https://github.com/pytorch/pytorch/issues/90558
elif (
isinstance(self.obj, variables.UserDefinedClassVariable)
and len(args) > 0
and issubclass(args[0].python_type(), self.obj.value)
):
return SuperVariable(self.obj, args[0], True).call_method(
tx, self.name, args[1:], kwargs
)
return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if (
name == "__len__"
and isinstance(self.obj, InspectSignatureVariable)
and self.name == "parameters"
):
return variables.ConstantVariable(
self.obj.inspected.num_parameters(),
**VariableTracker.propagate(self, self.obj, self.obj.inspected),
)
return super().call_method(tx, name, args, kwargs)
class PythonModuleVariable(VariableTracker):
def __init__(self, value: types.ModuleType, **kwargs):
super().__init__(**kwargs)
self.value = value
def python_type(self):
return types.ModuleType
class SkipFilesVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@staticmethod
@functools.lru_cache(None)
def fold_through_function_to_wrapper():
return {
collections.namedtuple: variables.UserDefinedClassVariable,
}
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builtin import BuiltinVariable
options = VariableTracker.propagate(self, args, kwargs.values())
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
# Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files.
elif self.value is collections.OrderedDict and (
len(args) == 0
or len(args) == 1
and BuiltinVariable.is_supported_call_dict_arg(tx, args[0])
):
return BuiltinVariable.call_dict_helper(
tx,
collections.OrderedDict,
None if len(args) == 0 else args[0],
**options,
)
# Fold through the functions(e.g, collections.namedtuple)
# that inputs & outputs are all python constants
elif (
self.value in self.fold_through_function_to_wrapper().keys()
and check_constant_args(args, kwargs)
):
value = self.value(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
)
return self.fold_through_function_to_wrapper().get(self.value)(
value, mutable_local=MutableLocal(), **options
)
elif (
self.value is itertools.product
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item), **options))
return variables.ListIteratorVariable(
items, mutable_local=MutableLocal(), **options
)
elif (
self.value is functools.wraps
and not kwargs
and len(args) == 1
and args[0].source
):
def wraps(fn):
if isinstance(fn, variables.NestedUserFunctionVariable):
return fn.clone(wraps_source=args[0].source)
unimplemented(f"functools.wraps({fn})")
return variables.LambdaVariable(wraps, **options)
else:
try:
path = inspect.getfile(self.value)
except TypeError:
path = f"Builtin {self.value.__name__}"
unimplemented(
f"call_function {self.value.__qualname__} in skip_files {path}"
)
class TypingVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__" and len(args) == 1:
return variables.ConstantVariable(
self.value[args[0].as_python_constant()],
**VariableTracker.propagate(self, args),
)
unimplemented("typing")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
class NumpyVariable(VariableTracker):
"""
Wrapper around `numpy.*` for better error messages.
"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if not config.numpy_ndarray_as_tensor or not HAS_NUMPY_TORCH_INTEROP:
unimplemented(f"numpy.{self.value}()")
import torch_np
from .builder import wrap_fx_proxy_cls
from .tensor import NumpyNdarrayVariable
options = VariableTracker.propagate([[self]], [args], [list(kwargs.values())])
# lookup method name in torch_np
if hasattr(torch_np, self.value.__name__):
func = getattr(torch_np, self.value.__name__)
return wrap_fx_proxy_cls(
target_cls=NumpyNdarrayVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
func,
*proxy_args_kwargs(args, kwargs),
),
example_value=None,
**options,
)
else:
unimplemented(f"Can't find numpy function {self.value} in torch_np")
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
unimplemented("numpy")
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
class NullVariable(VariableTracker):
def __init__(self, **kwargs):
super(NullVariable, self).__init__(**kwargs)
def __str__(self):
return "NullVariable"
def reconstruct(self, codegen):
if sys.version_info < (3, 11):
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
return [create_instruction("PUSH_NULL")]
class DeletedVariable(VariableTracker):
"""Marker used to implement delattr()"""