Revert "Annotate all InstructionTranslator (#131509)"
This reverts commit eafbd20f23746aa6b9090d989a4ccb059f45297e.
Reverted https://github.com/pytorch/pytorch/pull/131509 on behalf of https://github.com/clee2000 due to sorry need to revert this to revert something else, I think you only need to rebase and remerge ([comment](https://github.com/pytorch/pytorch/pull/131509#issuecomment-2249000843))
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 2c2e824..83c6400 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -244,7 +244,7 @@
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
- def impl(self: "InstructionTranslator", inst: Instruction):
+ def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
@@ -441,18 +441,18 @@
self.jump(inst)
elif isinstance(value, UserDefinedObjectVariable):
try:
- x = value.var_getattr(self, "__bool__") # type: ignore[arg-type]
+ x = value.var_getattr(self, "__bool__")
except exc.ObservedException:
# if __bool__ is missing, trying __len__ to infer a truth value.
- x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
+ x = value.var_getattr(self, "__len__")
else:
if isinstance(x, GetAttrVariable):
# if __bool__ is missing, trying __len__ to infer a truth value.
- x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
+ x = value.var_getattr(self, "__len__")
# __bool__ or __len__ is function
if isinstance(x, UserMethodVariable):
- result = x.call_function(self, [], {}) # type: ignore[arg-type]
+ result = x.call_function(self, [], {})
if isinstance(result, ConstantVariable) and isinstance(
result.value, (bool, int)
):
@@ -771,7 +771,7 @@
inner_fn = fn.fn
if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
- self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
+ self.push(fn.call_function(self, args, kwargs))
def inline_user_function_return(self, fn, args, kwargs):
"""
@@ -1310,7 +1310,7 @@
if isinstance(val, variables.BuiltinVariable):
# Create the instance of the exception type
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
- val = val.call_function(self, [], {}) # type: ignore[arg-type]
+ val = val.call_function(self, [], {})
# Save the exception in a global data structure
self.exn_vt_stack.append(val)
@@ -1629,7 +1629,7 @@
def _load_attr(self, inst):
obj = self.pop()
result = BuiltinVariable(getattr).call_function(
- self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type]
+ self, [obj, ConstantVariable.create(inst.argval)], {}
)
self.push(result)
@@ -1655,7 +1655,7 @@
try:
BuiltinVariable(setattr).call_function(
- self, [obj, ConstantVariable.create(inst.argval), val], {} # type: ignore[arg-type]
+ self, [obj, ConstantVariable.create(inst.argval), val], {}
)
return
except Unsupported as e:
@@ -1681,7 +1681,7 @@
def DELETE_ATTR(self, inst):
obj = self.pop()
BuiltinVariable(delattr).call_function(
- self, [obj, ConstantVariable.create(inst.argval)], {} # type: ignore[arg-type]
+ self, [obj, ConstantVariable.create(inst.argval)], {}
)
def create_call_resume_at(self, offset):
@@ -1748,7 +1748,7 @@
def BUILD_MAP_UNPACK(self, inst):
items = self.popn(inst.argval)
# ensure everything is a dict
- items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type]
+ items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items]
result = {}
for x in items:
assert isinstance(x, ConstDictVariable)
@@ -1853,7 +1853,7 @@
def UNPACK_SEQUENCE(self, inst):
seq = self.pop()
if isinstance(seq, TensorVariable):
- val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type]
+ val = seq.unpack_var_sequence(self, idxes=range(inst.argval))
elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
# x, y = a.shape
proxy = getattr(seq.obj.as_proxy(), seq.name)
@@ -1940,11 +1940,11 @@
if isinstance(value, SymNodeVariable):
value = ConstantVariable.create(str(value.sym_num))
if (flags & 0x03) == 0x01:
- value = BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type]
+ value = BuiltinVariable(str).call_function(self, [value], {})
elif (flags & 0x03) == 0x02:
- value = BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type]
+ value = BuiltinVariable(repr).call_function(self, [value], {})
elif (flags & 0x03) == 0x03:
- value = BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type]
+ value = BuiltinVariable(ascii).call_function(self, [value], {})
fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}")
@@ -2000,7 +2000,7 @@
obj.call_method(self, "extend", [v], {})
def LIST_TO_TUPLE(self, inst):
- self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
+ self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {}))
def DICT_MERGE(self, inst):
v = self.pop()
@@ -3184,7 +3184,7 @@
tos = self.stack[-1]
if not isinstance(tos, ListIteratorVariable):
self.pop()
- res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type]
+ res = BuiltinVariable(iter).call_function(self, [tos], {})
self.push(res)
def YIELD_FROM(self, inst):
diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py
index f8ea7e2..472b98c 100644
--- a/torch/_dynamo/variables/base.py
+++ b/torch/_dynamo/variables/base.py
@@ -2,10 +2,7 @@
import collections
from enum import Enum
-from typing import Any, Callable, Dict, List, TYPE_CHECKING
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
+from typing import Any, Callable, Dict, List
from .. import variables
from ..current_scope_id import current_scope_id
@@ -231,11 +228,11 @@
return self.source.make_guard(fn)
raise NotImplementedError
- def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
+ def const_getattr(self, tx, name: str) -> Any:
"""getattr(self, name) returning a python constant"""
raise NotImplementedError
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
"""getattr(self, name) returning a new variable"""
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
@@ -295,14 +292,11 @@
def inspect_parameter_names(self) -> List[str]:
unimplemented(f"inspect_parameter_names: {self}")
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
unimplemented(f"hasattr {self.__class__.__name__} {name}")
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
unimplemented(f"call_function {self} {args} {kwargs}")
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index a8001f5..da72a4c 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -10,13 +10,10 @@
import types
from collections import defaultdict, OrderedDict
from collections.abc import KeysView, MutableMapping
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
import torch
from torch import sym_float, sym_int
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, polyfill, variables
@@ -93,7 +90,7 @@
def _polyfill_call_impl(name):
"""Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}"""
- def call_fn(self, tx: "InstructionTranslator", *args, **kwargs):
+ def call_fn(self, tx, *args, **kwargs):
return tx.inline_user_function_return(
variables.UserFunctionVariable(fn), args, kwargs
)
@@ -863,7 +860,7 @@
return builtin_dipatch
- def _handle_insert_op_in_graph(self, tx: "InstructionTranslator", args, kwargs):
+ def _handle_insert_op_in_graph(self, tx, args, kwargs):
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
if kwargs and not self.tensor_args(*args, *kwargs.values()):
@@ -959,10 +956,7 @@
call_function_handler_cache = {}
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if kwargs:
kwargs = {k: v.realize() for k, v in kwargs.items()}
@@ -1008,7 +1002,7 @@
return super().call_method(tx, name, args, kwargs)
- def _call_int_float(self, tx: "InstructionTranslator", arg):
+ def _call_int_float(self, tx, arg):
# Handle cases like int(torch.seed())
# Also handle sym_float to sym_int cases
if isinstance(arg, (SymNodeVariable, variables.TensorVariable)):
@@ -1032,12 +1026,12 @@
call_int = _call_int_float
call_float = _call_int_float
- def call_str(self, tx: "InstructionTranslator", arg):
+ def call_str(self, tx, arg):
# Handle `str` on a user defined function
if isinstance(arg, (variables.UserFunctionVariable)):
return variables.ConstantVariable.create(value=str(arg.fn))
- def _call_min_max(self, tx: "InstructionTranslator", *args):
+ def _call_min_max(self, tx, *args):
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
# expand iterable
items = args[0].unpack_var_sequence(tx)
@@ -1047,14 +1041,14 @@
elif len(args) > 2:
return self._call_min_max_seq(tx, args)
- def _call_min_max_seq(self, tx: "InstructionTranslator", items):
+ def _call_min_max_seq(self, tx, items):
assert len(items) > 0
if len(items) == 1:
return items[0]
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
- def _call_min_max_binary(self, tx: "InstructionTranslator", a, b):
+ def _call_min_max_binary(self, tx, a, b):
if self.tensor_args(a, b):
if not isinstance(a, variables.TensorVariable):
a, b = b, a
@@ -1145,21 +1139,21 @@
call_min = _call_min_max
call_max = _call_min_max
- def call_abs(self, tx: "InstructionTranslator", arg: "VariableTracker"):
+ def call_abs(self, tx, arg: "VariableTracker"):
# Call arg.__abs__()
abs_method = BuiltinVariable(getattr).call_function(
tx, [arg, ConstantVariable.create("__abs__")], {}
)
return abs_method.call_function(tx, [], {})
- def call_pos(self, tx: "InstructionTranslator", arg: "VariableTracker"):
+ def call_pos(self, tx, arg: "VariableTracker"):
# Call arg.__pos__()
pos_method = BuiltinVariable(getattr).call_function(
tx, [arg, ConstantVariable.create("__pos__")], {}
)
return pos_method.call_function(tx, [], {})
- def call_index(self, tx: "InstructionTranslator", arg: "VariableTracker"):
+ def call_index(self, tx, arg: "VariableTracker"):
if isinstance(arg, variables.TensorVariable):
unimplemented("unsupported index(tensor)")
@@ -1167,14 +1161,14 @@
constant_value = operator.index(arg)
return variables.ConstantVariable.create(constant_value)
- def call_round(self, tx: "InstructionTranslator", arg, *args, **kwargs):
+ def call_round(self, tx, arg, *args, **kwargs):
# Call arg.__round__()
round_method = BuiltinVariable(getattr).call_function(
tx, [arg, ConstantVariable.create("__round__")], {}
)
return round_method.call_function(tx, args, kwargs)
- def call_range(self, tx: "InstructionTranslator", *args):
+ def call_range(self, tx, *args):
if check_unspec_or_constant_args(args, {}):
return variables.RangeVariable(args)
elif self._dynamic_args(*args):
@@ -1190,10 +1184,10 @@
isinstance(x, SymNodeVariable) for x in kwargs.values()
)
- def call_slice(self, tx: "InstructionTranslator", *args):
+ def call_slice(self, tx, *args):
return variables.SliceVariable(args)
- def _dyn_proxy(self, tx: "InstructionTranslator", *args, **kwargs):
+ def _dyn_proxy(self, tx, *args, **kwargs):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
@@ -1203,9 +1197,7 @@
),
)
- def _call_iter_tuple_list(
- self, tx: "InstructionTranslator", obj=None, *args, **kwargs
- ):
+ def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
if self._dynamic_args(*args, **kwargs):
return self._dyn_proxy(tx, *args, **kwargs)
@@ -1241,7 +1233,7 @@
mutable_local=MutableLocal(),
)
- def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs):
+ def call_iter(self, tx, obj, *args, **kwargs):
# Handle the case where we are iterating over a tuple, list or iterator
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs)
@@ -1255,7 +1247,7 @@
call_tuple = _call_iter_tuple_list
call_list = _call_iter_tuple_list
- def call_callable(self, tx: "InstructionTranslator", arg):
+ def call_callable(self, tx, arg):
from .functions import BaseUserFunctionVariable
from .nn_module import NNModuleVariable
@@ -1289,7 +1281,7 @@
unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}")
- def call_dict(self, tx: "InstructionTranslator", *args, **kwargs):
+ def call_dict(self, tx, *args, **kwargs):
return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs)
@staticmethod
@@ -1368,7 +1360,7 @@
)
unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}")
- def call_set(self, tx: "InstructionTranslator", *args, **kwargs):
+ def call_set(self, tx, *args, **kwargs):
# Can we merge this implementation and call_dict's one?
assert not kwargs
if not args:
@@ -1394,7 +1386,7 @@
else:
unimplemented(f"set(): {args} {kwargs}")
- def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
+ def call_zip(self, tx, *args, **kwargs):
if kwargs:
assert len(kwargs) == 1 and "strict" in kwargs
if all(x.has_unpack_var_sequence(tx) for x in args):
@@ -1408,7 +1400,7 @@
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)]
return variables.TupleVariable(items)
- def call_enumerate(self, tx: "InstructionTranslator", *args):
+ def call_enumerate(self, tx, *args):
if len(args) == 1:
start = 0
else:
@@ -1424,13 +1416,13 @@
]
return variables.TupleVariable(items)
- def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
+ def call_len(self, tx, *args, **kwargs):
return args[0].call_method(tx, "__len__", args[1:], kwargs)
- def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs):
+ def call_getitem(self, tx, *args, **kwargs):
return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
- def call_isinstance(self, tx: "InstructionTranslator", arg, isinstance_type):
+ def call_isinstance(self, tx, arg, isinstance_type):
try:
arg_type = arg.python_type()
except NotImplementedError:
@@ -1491,7 +1483,7 @@
val = arg_type is isinstance_type
return variables.ConstantVariable.create(val)
- def call_issubclass(self, tx: "InstructionTranslator", left_ty, right_ty):
+ def call_issubclass(self, tx, left_ty, right_ty):
"""Checks if first arg is subclass of right arg"""
try:
left_ty_py = left_ty.as_python_constant()
@@ -1503,10 +1495,10 @@
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
- def call_super(self, tx: "InstructionTranslator", a, b):
+ def call_super(self, tx, a, b):
return variables.SuperVariable(a, b)
- def call_next(self, tx: "InstructionTranslator", arg: VariableTracker):
+ def call_next(self, tx, arg: VariableTracker):
try:
return arg.next_variable(tx)
except Unsupported as ex:
@@ -1515,20 +1507,20 @@
return arg.items[0]
raise
- def call_hasattr(self, tx: "InstructionTranslator", obj, attr):
+ def call_hasattr(self, tx, obj, attr):
if attr.is_python_constant():
name = attr.as_python_constant()
if isinstance(obj, variables.BuiltinVariable):
return variables.ConstantVariable(hasattr(obj.fn, name))
return obj.call_hasattr(tx, name)
- def call_map(self, tx: "InstructionTranslator", fn, *seqs):
+ def call_map(self, tx, fn, *seqs):
if all(seq.has_unpack_var_sequence(tx) for seq in seqs):
unpacked = [seq.unpack_var_sequence(tx) for seq in seqs]
items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)]
return variables.TupleVariable(items)
- def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL):
+ def call_sum(self, tx, seq, start=_SENTINEL):
# Special case for sum on tuple of floats and ints
if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all(
isinstance(x, variables.ConstantVariable)
@@ -1559,9 +1551,7 @@
{},
)
- def call_reduce(
- self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL
- ):
+ def call_reduce(self, tx, function, iterable, initial=_SENTINEL):
if iterable.has_unpack_var_sequence(tx):
items = iterable.unpack_var_sequence(tx)
if initial is self._SENTINEL:
@@ -1573,11 +1563,7 @@
return value
def call_getattr(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- name_var: VariableTracker,
- default=None,
+ self, tx, obj: VariableTracker, name_var: VariableTracker, default=None
):
from .. import trace_rules
from . import (
@@ -1692,11 +1678,7 @@
return GetAttrVariable(obj, name, **options)
def call_setattr(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- name_var: VariableTracker,
- val: VariableTracker,
+ self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
):
if isinstance(
obj,
@@ -1819,15 +1801,10 @@
)
return ConstantVariable(None)
- def call_delattr(
- self,
- tx: "InstructionTranslator",
- obj: VariableTracker,
- name_var: VariableTracker,
- ):
+ def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker):
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable())
- def call_type(self, tx: "InstructionTranslator", obj: VariableTracker):
+ def call_type(self, tx, obj: VariableTracker):
from .builder import SourcelessBuilder, VariableBuilder
try:
@@ -1844,12 +1821,12 @@
else:
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
- def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker):
+ def call_reversed(self, tx, obj: VariableTracker):
if obj.has_unpack_var_sequence(tx):
items = list(reversed(obj.unpack_var_sequence(tx)))
return variables.TupleVariable(items)
- def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs):
+ def call_sorted(self, tx, obj: VariableTracker, **kwargs):
if (
obj.has_unpack_var_sequence(tx)
and not isinstance(obj, variables.TensorVariable)
@@ -1876,14 +1853,14 @@
)
return variables.ListVariable(items)
- def call_chain(self, tx: "InstructionTranslator", *args):
+ def call_chain(self, tx, *args):
if all(obj.has_unpack_var_sequence(tx) for obj in args):
items = []
for obj in args:
items.extend(obj.unpack_var_sequence(tx))
return variables.TupleVariable(items)
- def call_islice(self, tx: "InstructionTranslator", iterable, *args):
+ def call_islice(self, tx, iterable, *args):
if iterable.has_unpack_var_sequence(tx) and all(
x.is_python_constant() for x in args
):
@@ -1893,7 +1870,7 @@
return variables.TupleVariable(items)
# neg is a constant fold function, so we only get here if constant fold is not valid
- def call_neg(self, tx: "InstructionTranslator", a):
+ def call_neg(self, tx, a):
if isinstance(a, SymNodeVariable):
return SymNodeVariable.create(
tx,
@@ -1903,11 +1880,11 @@
# None no-ops this handler and lets the driving function proceed
return None
- def call_format(self, tx: "InstructionTranslator", _format_string, *args, **kwargs):
+ def call_format(self, tx, _format_string, *args, **kwargs):
format_string = _format_string.as_python_constant()
return variables.StringFormatVariable.create(format_string, args, kwargs)
- def call_id(self, tx: "InstructionTranslator", *args):
+ def call_id(self, tx, *args):
if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
nn_mod_variable = args[0]
mod = tx.output.get_submodule(nn_mod_variable.module_key)
@@ -1921,10 +1898,10 @@
else:
unimplemented(f"call_id with args {args}")
- def call_deepcopy(self, tx: "InstructionTranslator", x):
+ def call_deepcopy(self, tx, x):
unimplemented(f"copy.deepcopy {repr(x)}")
- def _comparison_with_tensor(self, tx: "InstructionTranslator", left, right):
+ def _comparison_with_tensor(self, tx, left, right):
from .builder import wrap_fx_proxy_cls
from .tensor import supported_tensor_comparison_op_values
@@ -1965,7 +1942,7 @@
proxy,
)
- def _comparison_with_symnode(self, tx: "InstructionTranslator", left, right):
+ def _comparison_with_symnode(self, tx, left, right):
from .tensor import supported_tensor_comparison_op_values
op = self.fn
@@ -1982,7 +1959,7 @@
sym_num=None,
)
- def call_and_(self, tx: "InstructionTranslator", a, b):
+ def call_and_(self, tx, a, b):
# Rely on constant_handler
if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
return None
@@ -2000,7 +1977,7 @@
return SetVariable(list(a.set_items & b.set_items))
# None no-ops this handler and lets the driving function proceed
- def call_or_(self, tx: "InstructionTranslator", a, b):
+ def call_or_(self, tx, a, b):
# Rely on constant_handler
if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable):
return None
@@ -2019,7 +1996,7 @@
# None no-ops this handler and lets the driving function proceed
return None
- def call_not_(self, tx: "InstructionTranslator", a):
+ def call_not_(self, tx, a):
if isinstance(a, SymNodeVariable):
return SymNodeVariable.create(
tx,
@@ -2037,9 +2014,7 @@
return None
- def call_contains(
- self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker
- ):
+ def call_contains(self, tx, a: VariableTracker, b: VariableTracker):
return a.call_method(tx, "__contains__", [b], {})
call_all = _polyfill_call_impl("all")
diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py
index 11832a4..9f95afb 100644
--- a/torch/_dynamo/variables/constant.py
+++ b/torch/_dynamo/variables/constant.py
@@ -1,14 +1,11 @@
# mypy: ignore-errors
import operator
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
import torch
from torch._dynamo.source import GetItemSource
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
-
from .. import variables
from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder, install_guard
@@ -119,7 +116,7 @@
except TypeError as e:
raise NotImplementedError from e
- def const_getattr(self, tx: "InstructionTranslator", name):
+ def const_getattr(self, tx, name):
if isinstance(self.value, type):
raise UserError(
UserErrorType.ANTI_PATTERN,
@@ -192,7 +189,7 @@
unimplemented(f"const method call {typestr(self.value)}.{name}")
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
result = hasattr(self.value, name)
return variables.ConstantVariable.create(result)
@@ -222,7 +219,7 @@
def as_python_constant(self):
return self.value
- def const_getattr(self, tx: "InstructionTranslator", name):
+ def const_getattr(self, tx, name):
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError
diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py
index c6d828e..9a378cf 100644
--- a/torch/_dynamo/variables/ctx_manager.py
+++ b/torch/_dynamo/variables/ctx_manager.py
@@ -3,12 +3,9 @@
import inspect
import sys
import warnings
-from typing import Callable, Dict, List, Optional, TYPE_CHECKING
+from typing import Callable, Dict, List, Optional
import torch._C
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._guards import Guard
from .. import variables
@@ -72,7 +69,7 @@
self.set_cleanup_hook(tx)
return variables.ConstantVariable.create(None)
- def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None):
+ def set_cleanup_hook(self, tx, fn=None):
if fn is None:
def fn():
@@ -81,7 +78,7 @@
self.state.cleanup_fn = fn
tx.output.add_cleanup_hook(self.state.cleanup)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup_assert()
return variables.ConstantVariable.create(None)
@@ -105,10 +102,7 @@
raise NotImplementedError("fn_name called on base")
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
assert len(args) == 1
if isinstance(args[0], NestedUserFunctionVariable):
@@ -144,7 +138,7 @@
from_exc=e,
)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
source = None if self.source is None else AttrSource(self.source, "__exit__")
try:
x = variables.UserMethodVariable(
@@ -199,7 +193,7 @@
)
return variables.ConstantVariable.create(None)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function",
@@ -243,7 +237,7 @@
)
return variables.ConstantVariable.create(jvp_level)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._jvp_decrement_nesting, (), {}
@@ -278,7 +272,7 @@
)
return variables.ConstantVariable.create(None)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function",
@@ -316,7 +310,7 @@
)
return variables.ConstantVariable.create(self.new_level)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function",
@@ -358,7 +352,7 @@
)
return variables.ConstantVariable.create(grad_level)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
@@ -429,7 +423,7 @@
)
return variables.ConstantVariable.create(vmap_level)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._vmap_decrement_nesting, (), {}
@@ -463,20 +457,17 @@
self._call_func(tx, self.target_values)
return variables.ConstantVariable.create(None)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self._call_func(tx, self.initial_values)
return variables.ConstantVariable.create(None)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
):
self._call_func(tx, self.initial_values) # undo eager initialization
return super().call_function(tx, args, kwargs)
- def _call_func(self, tx: "InstructionTranslator", values):
+ def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
# Coalesce grad mode mutations
@@ -515,7 +506,7 @@
)
self.target_values = target_values
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup_assert()
tx.output.create_node(
"call_function",
@@ -569,7 +560,7 @@
def enter(self, tx):
return variables.ConstantVariable.create(None)
- def _call_func(self, tx: "InstructionTranslator", values):
+ def _call_func(self, tx, values):
assert len(values) == 1
tx.output.set_torch_function_state(values[0])
@@ -601,7 +592,7 @@
def enter(self, tx):
return variables.ConstantVariable.create(None)
- def _call_func(self, tx: "InstructionTranslator", values):
+ def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
tx.output.create_node(
@@ -640,7 +631,7 @@
def enter(self, tx):
return variables.ConstantVariable.create(None)
- def _call_func(self, tx: "InstructionTranslator", values):
+ def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
if value is not None:
@@ -707,7 +698,7 @@
)
self.target_values = target_values
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self.state.cleanup_assert()
tx.output.create_node(
"call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
@@ -740,7 +731,7 @@
def enter(self, tx):
return variables.ConstantVariable.create(None)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
return variables.ConstantVariable.create(None)
def module_name(self):
@@ -804,7 +795,7 @@
self.set_stream(self.target_values[0].value)
self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
tx.output.create_proxy(
"call_function",
self.set_stream,
@@ -837,7 +828,7 @@
def enter(self, tx):
pass
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
from ..tensor_version_op import _unsafe_set_version_counter
return variables.TorchInGraphFunctionVariable(
@@ -874,20 +865,17 @@
self._call_func(tx, self.target_values)
return variables.ConstantVariable.create(None)
- def exit(self, tx: "InstructionTranslator", *args):
+ def exit(self, tx, *args):
self._call_func(tx, self.initial_values)
return variables.ConstantVariable.create(None)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
):
self._call_func(tx, self.initial_values) # undo eager initialization
return super().call_function(tx, args, kwargs)
- def _call_func(self, tx: "InstructionTranslator", values):
+ def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
if self.param_group_var.value._training_state != value:
@@ -1033,10 +1021,7 @@
self.target = target
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
assert not kwargs
return self.ctx.exit(tx, *args)
diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py
index b2a838d..e06dca8 100644
--- a/torch/_dynamo/variables/dicts.py
+++ b/torch/_dynamo/variables/dicts.py
@@ -5,10 +5,7 @@
import functools
import inspect
import sys
-from typing import Dict, List, Optional, TYPE_CHECKING
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
+from typing import Dict, List, Optional
from torch._subclasses.fake_tensor import is_fake
@@ -588,9 +585,7 @@
return mod is not None and issubclass(cls, mod.BaseOutput)
-def _call_hasattr_customobj(
- self, tx: "InstructionTranslator", name: str
-) -> "VariableTracker":
+def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker":
"""Shared method between DataClassVariable and CustomizedDictVariable where items are attrs"""
if tx.output.side_effects.is_attribute_mutation(self):
try:
@@ -809,7 +804,7 @@
unimplemented(f"custom dict: call_method unimplemented name={name}")
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
name_vt = ConstantVariable.create(name)
if name_vt in self:
return self.call_method(tx, "__getitem__", [name_vt], {})
@@ -864,12 +859,12 @@
self.obj = obj
assert self.is_matching_cls(type(obj))
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
from . import ConstantVariable
return ConstantVariable.create(getattr(self.obj, name))
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(self.obj, name))
@@ -894,11 +889,7 @@
)
def call_method(
- self,
- tx: "InstructionTranslator",
- name,
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
):
if name == "__getitem__":
return self.call_getitem(tx, *args, **kwargs)
@@ -908,7 +899,7 @@
return self.call_contains(tx, *args, **kwargs)
unimplemented(f"sys.modules.{name}(*{args}, **{kwargs})")
- def _contains_helper(self, tx: "InstructionTranslator", key: VariableTracker):
+ def _contains_helper(self, tx, key: VariableTracker):
k = key.as_python_constant()
has_key = k in sys.modules
install_guard(
@@ -918,15 +909,12 @@
)
return k, has_key
- def call_contains(self, tx: "InstructionTranslator", key: VariableTracker):
+ def call_contains(self, tx, key: VariableTracker):
k, has_key = self._contains_helper(tx, key)
return ConstantVariable.create(value=has_key)
def call_get(
- self,
- tx: "InstructionTranslator",
- key: VariableTracker,
- default: Optional[VariableTracker] = None,
+ self, tx, key: VariableTracker, default: Optional[VariableTracker] = None
):
from .builder import VariableBuilder
@@ -943,7 +931,7 @@
return ConstantVariable.create(value=None)
- def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker):
+ def call_getitem(self, tx, key: VariableTracker):
from .builder import VariableBuilder
k, has_key = self._contains_helper(tx, key)
diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py
index acd92c7..6816ea9 100644
--- a/torch/_dynamo/variables/distributed.py
+++ b/torch/_dynamo/variables/distributed.py
@@ -1,12 +1,9 @@
# mypy: ignore-errors
import functools
import inspect
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
import torch
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from ...fx.experimental._backward_state import BackwardState
from .. import compiled_autograd, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
@@ -91,7 +88,7 @@
return type(value) is _WorldMeta
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "WORLD":
source = AttrSource(base=self.source, member="WORLD")
install_guard(source.make_guard(GuardBuilder.ID_MATCH))
@@ -114,10 +111,7 @@
return self.value
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if (
inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
@@ -148,7 +142,7 @@
def as_python_constant(self):
return self.value
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "dim":
return ConstantVariable.create(self.value.dim)
return super().var_getattr(tx, name)
@@ -212,7 +206,7 @@
def as_python_constant(self):
return self.value
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def var_getattr(self, tx, name: str) -> VariableTracker:
if name == "ndim":
return ConstantVariable.create(self.value.ndim)
if name == "device_type":
@@ -275,7 +269,7 @@
return super().call_method(tx, name, args, kwargs)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
if name == "group_name":
return variables.ConstantVariable.create(self.value.group_name)
if name in ["rank", "size"]:
@@ -381,7 +375,7 @@
return self._setup_hook(tx, name, *args, **kwargs)
return super().call_method(tx, name, args, kwargs)
- def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args):
+ def _setup_hook(self, tx, hook_method_name, args):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py
index 5e6ac82..c1be7d8 100644
--- a/torch/_dynamo/variables/functions.py
+++ b/torch/_dynamo/variables/functions.py
@@ -9,9 +9,6 @@
import torch
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
-
from .. import polyfill, variables
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented, Unsupported
@@ -98,14 +95,11 @@
return self.get_code().co_name
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def call_hasattr(self, tx, name: str) -> VariableTracker:
result = False
try:
@@ -291,15 +285,12 @@
def export_freevars(self, parent, child):
pass
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def call_hasattr(self, tx, name: str) -> VariableTracker:
result = hasattr(self.fn, name)
return variables.ConstantVariable.create(result)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if self.is_constant:
return invoke_and_store_as_constant(
@@ -326,10 +317,7 @@
return types.MethodType
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
@@ -382,10 +370,7 @@
self.context = context
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
@@ -402,10 +387,7 @@
self.context = context
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
@@ -637,10 +619,7 @@
}
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
@@ -720,7 +699,7 @@
self.wrapper_obj = wrapper_obj
self.attr_to_trace = attr_to_trace
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
if name == self.attr_to_trace:
val = getattr(self.wrapper_obj, self.attr_to_trace)
if self.source:
@@ -735,10 +714,7 @@
return super().var_getattr(tx, name)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return variables.UserFunctionVariable(polyfill.getattr_and_trace).call_function(
tx, [self, variables.ConstantVariable(self.attr_to_trace), *args], kwargs
@@ -803,10 +779,7 @@
return new_fn, _traceable_collectives_source(tx, new_fn)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# call_function must check any unsupported arguments and graph-break.
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
@@ -872,16 +845,13 @@
return self.as_python_constant()
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
merged_args = self.args + args
merged_kwargs = {**self.keywords, **kwargs}
return self.func.call_function(tx, merged_args, merged_kwargs)
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def call_hasattr(self, tx, name: str) -> VariableTracker:
# functools.partial uses slots, so attributes are constant
return variables.ConstantVariable.create(
hasattr(functools.partial(identity), name)
@@ -984,10 +954,7 @@
dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return dynamo_triton_hopifier_singleton.call_triton_kernel(
self, args, kwargs, tx
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 7d108d0..c3a7a967 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -13,9 +13,6 @@
import torch.fx
import torch.nn
import torch.onnx.operators
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._dynamo.utils import get_fake_value
from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.base import VariableTracker
@@ -35,6 +32,9 @@
from .lazy import LazyVariableTracker
from .lists import ListVariable, TupleVariable
+if TYPE_CHECKING:
+ from torch._dynamo.symbolic_convert import InstructionTranslator
+
log = logging.getLogger(__name__)
@@ -601,10 +601,7 @@
unimplemented(f"HigherOrderOperator {value.__name__}")
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
unimplemented(f"HigherOrderOperator {self.value.__name__}")
@@ -614,10 +611,7 @@
reason="Cond doesn't work unless it is captured completely with torch.compile."
)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ListVariable, TensorVariable
@@ -816,10 +810,7 @@
self.method_name = method_name
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from .builder import wrap_fx_proxy
@@ -845,10 +836,7 @@
reason="while_loop doesn't work unless it is captured completely with torch.compile."
)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from . import TensorVariable
@@ -1002,10 +990,7 @@
reason="associative_scan must be captured completely with torch.compile."
)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from .builder import SourcelessBuilder, wrap_fx_proxy
@@ -1108,10 +1093,7 @@
class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from . import TensorVariable
from .builder import wrap_fx_proxy_cls
@@ -1198,10 +1180,7 @@
class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
@@ -1249,10 +1228,7 @@
class FunctorchHigherOrderVariable(UserFunctionVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if not torch._dynamo.config.capture_func_transforms:
name = self.get_name()
@@ -1276,9 +1252,7 @@
class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
- def create_wrapped_node(
- self, tx: "InstructionTranslator", args, kwargs, description
- ):
+ def create_wrapped_node(self, tx, args, kwargs, description):
# See NOTE [HigherOrderOperator tracing design] for more details
(
@@ -1318,10 +1292,7 @@
return proxy_args, {}, example_value, body_r, treespec, body_gmod
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# This flattens the kwargs into lifted args
p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node(
@@ -1344,10 +1315,7 @@
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
@@ -1382,10 +1350,7 @@
reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile."
)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
callable = args[0]
@@ -1446,10 +1411,7 @@
class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
from torch._higher_order_ops.wrap import TagActivationCheckpoint
from torch.utils.checkpoint import noop_context_fn
@@ -1512,10 +1474,7 @@
class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
@@ -1535,10 +1494,7 @@
class RunWithRNGStateHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
@@ -1565,10 +1521,7 @@
"""
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
kwargs = dict(kwargs)
fn = kwargs.pop("fn")
@@ -1588,11 +1541,7 @@
return all_args
def create_wrapped_node(
- self,
- tx: "InstructionTranslator",
- query: "VariableTracker",
- fn: "VariableTracker",
- fn_name: str,
+ self, tx, query: "VariableTracker", fn: "VariableTracker", fn_name: str
):
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
from .builder import SourcelessBuilder
@@ -1657,10 +1606,7 @@
return proxy_args
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
@@ -1738,10 +1684,7 @@
self.parent_source = parent_source
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import (
AutogradFunctionContextVariable,
diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py
index 5c312b6..cdf2b37 100644
--- a/torch/_dynamo/variables/iter.py
+++ b/torch/_dynamo/variables/iter.py
@@ -5,10 +5,7 @@
import itertools
import operator
-from typing import Dict, List, Optional, TYPE_CHECKING
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
+from typing import Dict, List, Optional
from .. import polyfill, variables
from ..exc import ObservedUserStopIteration, unimplemented
@@ -32,10 +29,7 @@
return self.value
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if (
self.value is itertools.product
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index 0e3c32e..718f4c5 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -5,14 +5,11 @@
import inspect
import operator
import types
-from typing import Dict, List, Optional, TYPE_CHECKING
+from typing import Dict, List, Optional
import torch
import torch.fx
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
-
from ..._guards import Source
from .. import polyfill, variables
@@ -308,7 +305,7 @@
codegen.foreach(self.items)
codegen.extend_output(create_call_function(3, False))
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
fields = ["start", "stop", "step"]
if name not in fields:
unimplemented(f"range.{name}")
@@ -438,7 +435,7 @@
else:
return super().call_method(tx, name, args, kwargs)
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
if self.python_type() is not list:
return super().call_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr([], name))
@@ -527,7 +524,7 @@
) -> "VariableTracker":
return super().call_method(tx, name, args, kwargs)
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
if self.python_type() is not tuple:
return super().call_hasattr(tx, name)
return variables.ConstantVariable.create(hasattr((), name))
@@ -659,7 +656,7 @@
return super().call_method(tx, name, args, kwargs)
- def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker):
+ def get_item_dyn(self, tx, arg: VariableTracker):
from .tensor import SymNodeVariable
if isinstance(arg, SymNodeVariable):
@@ -672,7 +669,7 @@
assert isinstance(index, (int, torch.SymInt))
return self.items[index]
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(torch.Size, name))
@@ -712,7 +709,7 @@
+ create_call_function(1, False)
)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
def check_and_create_method():
method = inspect.getattr_static(self.tuple_cls, name, None)
if isinstance(method, classmethod):
@@ -736,7 +733,7 @@
return method
return self.items[fields.index(name)]
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
return variables.ConstantVariable.create(hasattr(self.tuple_cls, name))
@@ -777,7 +774,7 @@
codegen.foreach(self.items)
codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items)))
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
fields = ["start", "stop", "step"]
if name not in fields:
unimplemented(f"slice.{name}")
@@ -973,9 +970,6 @@
return super().call_method(tx, name, args, kwargs)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return self.call_method(tx, "__call__", args, kwargs)
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index 4834f83..8f201f9 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -7,14 +7,11 @@
import re
import sys
import types
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
import torch._C
import torch._numpy as tnp
import torch.utils._pytree as pytree
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import config, variables
from ..bytecode_transformation import (
add_push_null_call_function_ex,
@@ -66,7 +63,7 @@
else:
codegen.extend_output(create_call_function(1, False))
- def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
+ def _resolved_getattr_and_source(self, tx, name):
assert self.objvar, "1-arg super not implemented"
if self.specialized:
return getattr(self.typevar.as_python_constant(), name)
@@ -105,7 +102,7 @@
# TODO(jansel): there is a small chance this could trigger user code, prevent that
return getattr(super(search_type, type_to_use), name), source
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
# Check if getattr is a constant. If not, delay the actual work by
# wrapping the result in GetAttrVariable. Mostly super is called with a
# method, so most of the work is delayed to call_function.
@@ -239,7 +236,7 @@
def reconstruct(self, codegen):
raise NotImplementedError("comptime is special form")
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
@@ -250,10 +247,7 @@
)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from ..comptime import ComptimeContext
@@ -343,7 +337,7 @@
super().__init__(**kwargs)
self.inspected = inspected
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
if name == "parameters":
return variables.ConstDictVariable(
{
@@ -381,7 +375,7 @@
super().__init__(**kwargs)
self.fn_cls = fn_cls
- def call_apply(self, tx: "InstructionTranslator", args, kwargs):
+ def call_apply(self, tx, args, kwargs):
requires_grad = False
def visit(node):
@@ -474,7 +468,7 @@
f"non-function or method in subclass of torch.autograd.Function: {fn}"
)
- def call_backward(self, tx: "InstructionTranslator", args, kwargs):
+ def call_backward(self, tx, args, kwargs):
fn = self.fn_cls.backward
self.source = AttrSource(self.source, "backward")
assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
@@ -484,7 +478,7 @@
tx, args, kwargs
)
- def call_function(self, tx: "InstructionTranslator", args, kwargs):
+ def call_function(self, tx, args, kwargs):
return AutogradFunctionVariable(self.fn_cls)
def call_method(
@@ -636,7 +630,7 @@
self.saved_tensors.tensors.append(arg)
return variables.ConstantVariable.create(None)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
if name == "save_for_backward":
return LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
@@ -702,10 +696,7 @@
self.fn = fn
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return self.fn(*args, **kwargs)
@@ -733,7 +724,7 @@
def as_proxy(self):
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
- def const_getattr(self, tx: "InstructionTranslator", name):
+ def const_getattr(self, tx, name):
if not isinstance(self.obj, variables.NNModuleVariable):
raise NotImplementedError
step1 = tx.output.get_submodule(self.obj.module_key)
@@ -749,10 +740,7 @@
codegen.extend_output(codegen.create_load_attrs(self.name))
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return self.obj.call_method(tx, self.name, args, kwargs)
@@ -821,10 +809,7 @@
self.method_wrapper = method_wrapper
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
args[0], variables.TensorVariable
@@ -847,7 +832,7 @@
super().__init__(**kwargs)
self.desc = desc
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
if name == "__get__" and self.source:
from .builder import VariableBuilder
@@ -885,13 +870,13 @@
def __repr__(self):
return f"PythonModuleVariable({self.value})"
- def call_hasattr(self, tx: "InstructionTranslator", name):
+ def call_hasattr(self, tx, name):
if self.is_torch:
result = hasattr(self.value, name)
return variables.ConstantVariable.create(result)
return super().call_hasattr(tx, name)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
return tx.output.side_effects.load_attr(self, name)
@@ -972,10 +957,7 @@
return np_constant_collections_map.get(fn, None)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if not config.trace_numpy:
unimplemented(f"numpy.{self.value}()")
@@ -1144,7 +1126,7 @@
and obj in torch._dynamo.config.reorderable_logging_functions
)
- def call_function(self, tx: "InstructionTranslator", args, kwargs):
+ def call_function(self, tx, args, kwargs):
if tx.export:
# For export cases, we can just make debugging functions no-ops
return
@@ -1249,7 +1231,7 @@
unimplemented(f"{self._error_prefix}.{name}() -> {result}")
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def var_getattr(self, tx, name: str) -> VariableTracker:
result = getattr(self.value, name)
if isinstance(result, self.np_floating):
result = float(result)
diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py
index 40af1d4..fdc4a7a 100644
--- a/torch/_dynamo/variables/nn_module.py
+++ b/torch/_dynamo/variables/nn_module.py
@@ -5,13 +5,10 @@
import itertools
import types
from contextlib import contextmanager, nullcontext
-from typing import Any, Dict, List, TYPE_CHECKING
+from typing import Any, Dict, List
import torch.nn
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
-
from .. import trace_rules, variables
from ..exc import (
ObservedException,
@@ -146,9 +143,7 @@
def python_type(self):
return self.module_type
- def _wrap_submodule(
- self, tx: "InstructionTranslator", source, submod, *key_extra, **options
- ):
+ def _wrap_submodule(self, tx, source, submod, *key_extra, **options):
return
def unpack_var_sequence(self, tx):
@@ -183,7 +178,7 @@
)
return result
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
mod = tx.output.get_submodule(self.module_key)
result = hasattr(mod, name)
install_guard(
@@ -207,7 +202,7 @@
GenerationTracker.mark_class_dynamic(type(mod))
raise UnspecializeRestartAnalysis
- def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
+ def has_key_in_generic_dict(self, tx, key):
base = tx.output.get_submodule(self.module_key)
if object_has_getattribute(base):
@@ -236,7 +231,7 @@
tx, [variables.ConstantVariable.create(name)], {}
)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
from .builder import VariableBuilder
if self.source:
@@ -839,10 +834,7 @@
return super().unpack_var_sequence(tx)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
mod = self.value
# see comment on lazy module handling in NNModuleVariable.call_function for context
@@ -891,9 +883,7 @@
tx, [self] + list(args), kwargs
)
- def trace_supported_methods(
- self, tx: "InstructionTranslator", method, name, args, kwargs
- ):
+ def trace_supported_methods(self, tx, method, name, args, kwargs):
def get_kwargs(*names):
fn = getattr(self.value, name)
bound_args = inspect.signature(fn).bind(
@@ -1041,13 +1031,13 @@
return super().call_method(tx, name, args, kwargs)
- def getattr_helper(self, tx: "InstructionTranslator", field, name_vt):
+ def getattr_helper(self, tx, field, name_vt):
dict_vt = self.var_getattr(tx, field)
if isinstance(dict_vt, variables.ConstDictVariable):
return dict_vt.maybe_getitem_const(name_vt)
return None
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
# Allow skipping of empty hook dict guards on inbuilt nn modules
if name in (
"_backward_hooks",
@@ -1070,7 +1060,7 @@
return variables.ConstDictVariable({})
return super().var_getattr(tx, name)
- def manually_trace_nn_module_getattr(self, tx: "InstructionTranslator", name):
+ def manually_trace_nn_module_getattr(self, tx, name):
"""
Dynamo tracing of nn.Module __getattr__ can be expensive if the model
has deep submodule hierarchy. Since the __getattr__ is stable, we can
diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py
index 0b42162..62f09f9 100644
--- a/torch/_dynamo/variables/optimizer.py
+++ b/torch/_dynamo/variables/optimizer.py
@@ -4,9 +4,6 @@
from typing import Dict, List, TYPE_CHECKING
import torch
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from torch.utils._pytree import tree_map_only
from ..guards import GuardBuilder, install_guard
@@ -91,7 +88,7 @@
return super().call_method(tx, name, args, kwargs)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
# Note: this allows us to intercept the call in call_method
# in the typical case, we return a UserMethodVariable
# which will directly inline
@@ -287,7 +284,7 @@
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
- def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
+ def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address
from .builder import VariableBuilder
@@ -315,9 +312,7 @@
result = builder(tensor_value)
return result
- def update_list_args(
- self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
- ):
+ def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable):
diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py
index 2e9414d..2146329 100644
--- a/torch/_dynamo/variables/sdpa.py
+++ b/torch/_dynamo/variables/sdpa.py
@@ -2,11 +2,6 @@
from inspect import getattr_static
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
-
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from .base import VariableTracker
@@ -62,7 +57,7 @@
def as_proxy(self):
return self.proxy
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
+ def var_getattr(self, tx, name: str) -> VariableTracker:
import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 9a0a244..467e78d 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -7,7 +7,7 @@
import textwrap
import types
import unittest
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
import sympy
@@ -15,9 +15,6 @@
import torch.fx
import torch.random
from torch._dynamo import compiled_autograd
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import (
guard_scalar,
@@ -212,7 +209,7 @@
)
return props
- def dynamic_getattr(self, tx: "InstructionTranslator", name):
+ def dynamic_getattr(self, tx, name):
fake_val = self.proxy.node.meta["example_value"]
# For getattrs on tensors without sources,
# we can do better than the default (creating a GetAttrVariable)
@@ -336,7 +333,7 @@
tx, [self], {}
)
- def call_hasattr(self, tx: "InstructionTranslator", name):
+ def call_hasattr(self, tx, name):
from . import GetAttrVariable
from .builtin import BuiltinVariable
@@ -357,7 +354,7 @@
return ConstantVariable(ret_val)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
from . import UserDefinedClassVariable
if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops():
@@ -440,7 +437,7 @@
def has_unpack_var_sequence(self, tx):
return self.ndim > 0
- def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None):
+ def unpack_var_sequence(self, tx, idxes=None):
from .builder import wrap_fx_proxy_cls
if self.size:
@@ -1113,7 +1110,7 @@
**options,
)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
# NB: This INTENTIONALLY does not call super(), because there is
# no intrinsic reason ndarray properties are related to Tensor
# properties. The inheritance here is for implementation sharing.
@@ -1261,10 +1258,7 @@
super().__init__(*args, **kwargs)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
+ self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
if len(args) == 1 and isinstance(args[0], TensorVariable):
from .builder import VariableBuilder
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index 2175da6..c2deef3 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -1,4 +1,3 @@
-# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
import inspect
@@ -6,16 +5,13 @@
import math
import re
-from typing import Dict, List, TYPE_CHECKING
+from typing import Dict, List
import torch._C
import torch._refs
import torch.fx
import torch.nn
import torch.onnx.operators
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._logging import warning_once
from torch._streambase import _StreamBase
@@ -175,7 +171,7 @@
def as_python_constant(self):
return self.value
- def call_hasattr(self, tx: "InstructionTranslator", name):
+ def call_hasattr(self, tx, name):
result = hasattr(self.value, name)
return variables.ConstantVariable.create(result)
@@ -208,10 +204,7 @@
)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import (
DisabledSavedTensorsHooksVariable,
@@ -365,9 +358,7 @@
from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls
@register(*tracing_state_functions)
- def handle_tracing_state_functions(
- self, tx: "InstructionTranslator", *args, **kwargs
- ):
+ def handle_tracing_state_functions(self, tx, *args, **kwargs):
assert not args and not kwargs
# See: https://github.com/pytorch/pytorch/issues/110765
if self.value in (
@@ -380,9 +371,7 @@
return ConstantVariable.create(tracing_state_functions[self.value])
@register(torch.overrides.get_default_nowrap_functions.__wrapped__)
- def handle_get_default_nowrap_functions(
- self, tx: "InstructionTranslator", *args, **kwargs
- ):
+ def handle_get_default_nowrap_functions(self, tx, *args, **kwargs):
# [Note: __torch_function__] we return empty here because we restrict
# the set of functions that we trace __torch_function__ on to
# functions outside of the actual set. Implementing this properly will require implementing
@@ -392,13 +381,13 @@
)
@register(torch.ops.inductor.accumulate_grad_.default)
- def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_accumulate_grad_(self, tx, *args, **kwargs):
return tx.inline_user_function_return(
SourcelessBuilder.create(tx, polyfill.accumulate_grad), args, kwargs
)
@register(math.radians)
- def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_radians(self, tx, *args, **kwargs):
if not check_unspec_or_constant_args(args, kwargs):
# Use polyfill to convert math.radians(x) into math.pi * x / 180.0
return tx.inline_user_function_return(
@@ -406,7 +395,7 @@
)
@register(torch.is_tensor, torch.overrides.is_tensor_like)
- def handle_is_tensor(self, tx: "InstructionTranslator", arg):
+ def handle_is_tensor(self, tx, arg):
if isinstance(arg, TensorVariable) or (
self.value is torch.overrides.is_tensor_like
and isinstance(arg, UserDefinedObjectVariable)
@@ -420,7 +409,7 @@
torch.is_floating_point,
torch.is_complex,
)
- def handle_is_floating_point(self, tx: "InstructionTranslator", input):
+ def handle_is_floating_point(self, tx, input):
input_arg = input
if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None:
if self.value is torch.is_floating_point:
@@ -431,7 +420,7 @@
raise AssertionError(f"calling {self.value}")
@register(torch.numel)
- def handle_numel(self, tx: "InstructionTranslator", input):
+ def handle_numel(self, tx, input):
if isinstance(input, TensorVariable) and input.size is not None:
return ConstantVariable.create(product(input.size))
elif isinstance(input, TensorVariable):
@@ -439,7 +428,7 @@
return input.call_method(tx, "numel", [], {})
@register(*REWRITE_OPS_TO_TENSOR_SIZE_METHOD)
- def handle_tensor_size_rewrites(self, tx: "InstructionTranslator", input):
+ def handle_tensor_size_rewrites(self, tx, input):
assert isinstance(input, TensorVariable)
return input.call_method(tx, "size", [], {})
@@ -450,7 +439,7 @@
torch.nn.modules.utils._quadruple,
torch.nn.modules.utils._ntuple,
)
- def handle_ntuple(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_ntuple(self, tx, *args, **kwargs):
return self._call_ntuple(tx, args, kwargs)
@register(torch.is_grad_enabled)
@@ -459,9 +448,7 @@
return ConstantVariable.create(torch.is_grad_enabled())
@register(torch.use_deterministic_algorithms)
- def handle_use_deterministic_algorithms(
- self, tx: "InstructionTranslator", mode, warn_only=False
- ):
+ def handle_use_deterministic_algorithms(self, tx, mode, warn_only=False):
if warn_only and warn_only.as_python_constant():
unimplemented("torch.use_deterministic_algorithms(warn_only=True)")
return DeterministicAlgorithmsVariable.create(tx, mode.as_python_constant())
@@ -481,7 +468,7 @@
torch.overrides.has_torch_function_variadic,
torch.overrides.has_torch_function_unary,
)
- def handle_has_torch_function(self, tx: "InstructionTranslator", *args):
+ def handle_has_torch_function(self, tx, *args):
elems = (
args[0].unpack_var_sequence(tx)
if len(args) == 1 and isinstance(args[0], TupleVariable)
@@ -497,11 +484,11 @@
for _, device_interface in get_registered_device_interfaces()
)
)
- def handle_device_interface_stream(self, tx: "InstructionTranslator", stream):
+ def handle_device_interface_stream(self, tx, stream):
return StreamContextVariable.create(tx, stream)
@register(torch.from_numpy)
- def handle_from_numpy(self, tx: "InstructionTranslator", *args):
+ def handle_from_numpy(self, tx, *args):
if not config.trace_numpy:
unimplemented("torch.from_numpy. config.trace_numpy is False")
if not np:
@@ -518,13 +505,11 @@
)
@register(torch.jit.annotate)
- def handle_jit_annotate(self, tx: "InstructionTranslator", the_type, the_value):
+ def handle_jit_annotate(self, tx, the_type, the_value):
return the_value
@register(torch.backends.cudnn.is_acceptable)
- def handle_cudnn_is_acceptable(
- self, tx: "InstructionTranslator", tensor, *extra
- ):
+ def handle_cudnn_is_acceptable(self, tx, tensor, *extra):
# is_acceptable(tensor) returns true if
# (a) tensor dtype/device are supported by cudnn
# (b) cudnn is available
@@ -540,11 +525,11 @@
)
@register(torch.utils.hooks.BackwardHook)
- def handle_backward_hook(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_backward_hook(self, tx, *args, **kwargs):
return variables.BackwardHookVariable.create(tx, *args, **kwargs)
@register(torch.nn.Parameter)
- def handle_parameter(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_parameter(self, tx, *args, **kwargs):
return self.call_nn_parameter(tx, *args, **kwargs)
@register(torch.ops.aten.sym_size, torch.ops.aten.sym_size.int)
@@ -559,7 +544,7 @@
return self.call_method(tx, "stride", [dim], {})
@register(torch.addcdiv)
- def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_addcdiv(self, tx, *args, **kwargs):
if len(args) == 3 and "value" in kwargs and len(kwargs) == 1:
# decompose addcdiv into constituent ops, prevents a graph break due to converting
# value to a scalar
@@ -574,7 +559,7 @@
)
@register(torch._assert)
- def handle_assert(self, tx: "InstructionTranslator", condition, message):
+ def handle_assert(self, tx, condition, message):
if (condition.is_python_constant() and condition.as_python_constant()) or (
isinstance(condition, variables.SymNodeVariable)
and condition.evaluate_expr()
@@ -582,7 +567,7 @@
return ConstantVariable(None)
@register(SDPAParams)
- def handle_sdpa_params(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_sdpa_params(self, tx, *args, **kwargs):
return wrap_fx_proxy(
tx,
proxy=tx.output.create_proxy(
@@ -610,9 +595,7 @@
get_process_group_ranks,
_resolve_group_name_by_ranks_and_tag,
)
- def handle_constant_processgroup_functions(
- self, tx: "InstructionTranslator", *args
- ):
+ def handle_constant_processgroup_functions(self, tx, *args):
# because the input is a "ProcessGroupVariable", we'll be guarding on its
# ID_MATCH based on how it was constructed.
@@ -640,7 +623,7 @@
return SourcelessBuilder.create(tx, invocation_result)
@register(DTensor.from_local)
- def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_from_local(self, tx, *args, **kwargs):
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
# and rewrite args to have only proxyable args, then insert call_function
args_as_value = [x.as_python_constant() for x in args[1:]]
@@ -663,12 +646,7 @@
@register(torch.nested.nested_tensor)
def handle_nested_tensor(
- self,
- tx: "InstructionTranslator",
- tensor_list=None,
- *args,
- layout=None,
- **kwargs,
+ self, tx, tensor_list=None, *args, layout=None, **kwargs
):
from .lists import BaseListVariable
@@ -678,7 +656,7 @@
unimplemented("nested_tensor with non-list input")
@register(torch.nn.functional.one_hot)
- def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_one_hot(self, tx, *args, **kwargs):
if len(args) + len(kwargs) == 1 or (
len(args) == 2
and args[1].is_python_constant()
@@ -689,7 +667,7 @@
)
@register(torch.fx.experimental.symbolic_shapes.guard_size_oblivious)
- def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr):
+ def handle_guard_size_oblivious(self, tx, expr):
if isinstance(expr, SymNodeVariable):
# TODO: this probably should be folded somewhere else but I'm not sure where
# TODO: some of the other symbolic_shapes special tools can also get this treatment too
@@ -702,9 +680,7 @@
return expr
@register(torch._C._autograd._unsafe_set_version_counter)
- def handle_unsafe_set_version_counter(
- self, tx: "InstructionTranslator", *args, **kwargs
- ):
+ def handle_unsafe_set_version_counter(self, tx, *args, **kwargs):
from ..tensor_version_op import _unsafe_set_version_counter
return TorchInGraphFunctionVariable(
@@ -712,7 +688,7 @@
).call_function(tx, [*args], kwargs)
@register(torch.tensor)
- def handle_torch_tensor(self, tx: "InstructionTranslator", *args, **kwargs):
+ def handle_torch_tensor(self, tx, *args, **kwargs):
def check_any_unspec(x):
# NB: This includes UnspecializedPythonVariable
if isinstance(x, (TensorVariable, SymNodeVariable)):
@@ -741,10 +717,7 @@
return handlers
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import ConstantVariable, SymNodeVariable, TensorVariable
from .builder import wrap_fx_proxy
@@ -908,7 +881,7 @@
return tensor_variable
- def _call_ntuple(self, tx: "InstructionTranslator", args, kwargs):
+ def _call_ntuple(self, tx, args, kwargs):
"""inline behavior of torch.nn.modules.utils._ntuple"""
if self.value is torch.nn.modules.utils._ntuple:
count = args[0].as_python_constant()
diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py
index 75eac8a..f87f235 100644
--- a/torch/_dynamo/variables/torch_function.py
+++ b/torch/_dynamo/variables/torch_function.py
@@ -5,9 +5,6 @@
import torch.utils._pytree as pytree
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
-
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
@@ -219,7 +216,7 @@
compile_id = tx.output.compile_id
return f"__subclass_{self.class_type.__name__}_{id(self.class_type)}_c{id}"
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
# [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break.
import torch
@@ -255,7 +252,7 @@
else:
return super().var_getattr(tx, name)
- def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
+ def call_torch_function(self, tx, fn, types, args, kwargs):
return call_torch_function(
tx,
self.class_type_var(tx),
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 55fa527..a9937fe 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -12,10 +12,7 @@
import types
import warnings
-from typing import Dict, Generic, List, TYPE_CHECKING
-
-if TYPE_CHECKING:
- from torch._dynamo.symbolic_convert import InstructionTranslator
+from typing import Dict, Generic, List
from ..bytecode_transformation import create_call_function
@@ -112,14 +109,14 @@
def can_constant_fold_through(self):
return self.value in self._constant_fold_classes()
- def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
+ def has_key_in_generic_dict(self, tx, key):
if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
return not isinstance(mutated_attr, variables.DeletedVariable)
return key in self.value.__dict__
- def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def var_getattr(self, tx, name: str) -> "VariableTracker":
from . import ConstantVariable, EnumVariable
from .builder import SourcelessBuilder, VariableBuilder
@@ -174,7 +171,7 @@
return super().var_getattr(tx, name)
- def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs):
+ def _call_cross_entropy_loss(self, tx, args, kwargs):
"""
functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
label_smoothing=0.0
@@ -277,10 +274,7 @@
return super().call_method(tx, name, args, kwargs)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from ..side_effects import SideEffects
from .builder import SourcelessBuilder, wrap_fx_proxy
@@ -460,14 +454,14 @@
new_fn = new_fn.__func__
return new_fn in (object.__new__, Generic.__new__)
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
if self.source:
source = AttrSource(self.source, name)
install_guard(source.make_guard(GuardBuilder.HASATTR))
return variables.ConstantVariable(hasattr(self.value, name))
return super().call_hasattr(tx, name)
- def const_getattr(self, tx: "InstructionTranslator", name):
+ def const_getattr(self, tx, name):
if name == "__name__":
return self.value.__name__
return super().const_getattr(tx, name)
@@ -521,7 +515,7 @@
return build_torch_function_fn(tx, self.value, self.source)
- def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
+ def call_torch_function(self, tx, fn, types, args, kwargs):
self.torch_function_check()
from .torch_function import _get_subclass_type_var, call_torch_function
@@ -658,7 +652,7 @@
return super().call_method(tx, name, args, kwargs)
- def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
+ def method_setattr_standard(self, tx, name, value):
try:
name = name.as_python_constant()
except NotImplementedError:
@@ -702,10 +696,7 @@
return False
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .. import trace_rules
from .builder import VariableBuilder
@@ -830,7 +821,7 @@
subobj = inspect.getattr_static(self.value, name)
return subobj
- def has_key_in_generic_dict(self, tx: "InstructionTranslator", key):
+ def has_key_in_generic_dict(self, tx, key):
self._check_for_getattribute()
if tx.output.side_effects.has_pending_mutation_of_attr(self, key):
mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True)
@@ -843,7 +834,7 @@
torch.nn.Module.parameters,
)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
from .. import trace_rules
from . import ConstantVariable
@@ -991,7 +982,7 @@
options = {"source": source}
return variables.GetAttrVariable(self, name, **options)
- def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
+ def call_hasattr(self, tx, name: str) -> "VariableTracker":
if tx.output.side_effects.is_attribute_mutation(self):
try:
result = tx.output.side_effects.load_attr(self, name, deleted_ok=True)
@@ -1034,7 +1025,7 @@
else:
unimplemented("UserDefined with non-function __getattr__")
- def odict_getitem(self, tx: "InstructionTranslator", key):
+ def odict_getitem(self, tx, key):
from .builder import VariableBuilder
from .dicts import is_hashable
@@ -1083,10 +1074,7 @@
super().__init__(value, **kwargs)
def call_function(
- self,
- tx: "InstructionTranslator",
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
+ self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
call_source = None
referent = self.value()
@@ -1114,7 +1102,7 @@
assert type(value) is KeyedJaggedTensor
super().__init__(value, **kwargs)
- def var_getattr(self, tx: "InstructionTranslator", name):
+ def var_getattr(self, tx, name):
if (
torch._dynamo.config.force_unspec_int_unbacked_size_like_on_torchrec_kjt
and self.source is not None
@@ -1145,7 +1133,7 @@
self.mutable_local = mutable_local
self.idx = idx
- def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs):
+ def call_method(self, tx, method_name, args, kwargs):
if method_name == "remove":
if self.idx != self.REMOVED:
tx.output.side_effects.remove_hook(self.idx)