| import functools | 
 | import inspect | 
 | import itertools | 
 | import logging | 
 | import math | 
 | import operator | 
 | import types | 
 | from typing import Dict, List | 
 |  | 
 | import torch | 
 | from torch import sym_float, sym_int | 
 |  | 
 | from .. import config, variables | 
 | from ..allowed_functions import is_allowed | 
 | from ..exc import unimplemented, Unsupported, UserError, UserErrorType | 
 | from ..guards import GuardBuilder | 
 | from ..replay_record import DummyModule | 
 | from ..source import AttrSource, is_constant_source, SuperSource, TypeSource | 
 | from ..utils import ( | 
 |     check_constant_args, | 
 |     check_unspec_python_args, | 
 |     istype, | 
 |     proxy_args_kwargs, | 
 |     specialize_args_kwargs, | 
 | ) | 
 | from .base import MutableLocal, typestr, VariableTracker | 
 | from .constant import ConstantVariable, EnumVariable | 
 | from .dicts import ConstDictVariable | 
 | from .lists import ( | 
 |     BaseListVariable, | 
 |     ListIteratorVariable, | 
 |     ListVariable, | 
 |     SizeVariable, | 
 |     TupleIteratorVariable, | 
 |     TupleVariable, | 
 | ) | 
 | from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable | 
 | from .user_defined import UserDefinedVariable | 
 |  | 
 | log = logging.getLogger(__name__) | 
 |  | 
 |  | 
 | class BuiltinVariable(VariableTracker): | 
 |     @staticmethod | 
 |     @functools.lru_cache(None) | 
 |     def _constant_fold_functions(): | 
 |         fns = { | 
 |             abs, | 
 |             all, | 
 |             any, | 
 |             bool, | 
 |             callable, | 
 |             chr, | 
 |             divmod, | 
 |             float, | 
 |             int, | 
 |             len, | 
 |             max, | 
 |             min, | 
 |             ord, | 
 |             pow, | 
 |             repr, | 
 |             round, | 
 |             set, | 
 |             str, | 
 |             str.format, | 
 |             sum, | 
 |             type, | 
 |             operator.pos, | 
 |             operator.neg, | 
 |             operator.not_, | 
 |             operator.invert, | 
 |             operator.pow, | 
 |             operator.mul, | 
 |             operator.matmul, | 
 |             operator.floordiv, | 
 |             operator.truediv, | 
 |             operator.mod, | 
 |             operator.add, | 
 |             operator.sub, | 
 |             operator.getitem, | 
 |             operator.lshift, | 
 |             operator.rshift, | 
 |             operator.and_, | 
 |             operator.or_, | 
 |             operator.xor, | 
 |             operator.ipow, | 
 |             operator.imul, | 
 |             operator.imatmul, | 
 |             operator.ifloordiv, | 
 |             operator.itruediv, | 
 |             operator.imod, | 
 |             operator.iadd, | 
 |             operator.isub, | 
 |             operator.ilshift, | 
 |             operator.irshift, | 
 |             operator.iand, | 
 |             operator.ixor, | 
 |             operator.ior, | 
 |             operator.index, | 
 |         } | 
 |         fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) | 
 |         return fns | 
 |  | 
 |     def can_constant_fold_through(self): | 
 |         return self.fn in self._constant_fold_functions() | 
 |  | 
 |     @staticmethod | 
 |     @functools.lru_cache(None) | 
 |     def _fx_graph_functions(): | 
 |         fns = { | 
 |             operator.pos, | 
 |             operator.neg, | 
 |             operator.not_, | 
 |             operator.invert, | 
 |             operator.pow, | 
 |             operator.mul, | 
 |             operator.matmul, | 
 |             operator.floordiv, | 
 |             operator.truediv, | 
 |             operator.mod, | 
 |             operator.add, | 
 |             operator.sub, | 
 |             operator.getitem, | 
 |             operator.lshift, | 
 |             operator.rshift, | 
 |             operator.and_, | 
 |             operator.or_, | 
 |             operator.xor, | 
 |             operator.ipow, | 
 |             operator.imul, | 
 |             operator.imatmul, | 
 |             operator.ifloordiv, | 
 |             operator.itruediv, | 
 |             operator.imod, | 
 |             operator.iadd, | 
 |             operator.isub, | 
 |             operator.ilshift, | 
 |             operator.irshift, | 
 |             operator.iand, | 
 |             operator.ixor, | 
 |             operator.ior, | 
 |         } | 
 |         return fns | 
 |  | 
 |     @staticmethod | 
 |     @functools.lru_cache(None) | 
 |     def _binops(): | 
 |         # function -> ([forward name, reverse name, in-place name], in-place op) | 
 |         fns = { | 
 |             operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), | 
 |             operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), | 
 |             operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), | 
 |             operator.truediv: ( | 
 |                 ["__truediv__", "__rtruediv__", "__itruediv__"], | 
 |                 operator.itruediv, | 
 |             ), | 
 |             operator.floordiv: ( | 
 |                 ["__floordiv__", "__rfloordiv__", "__ifloordiv__"], | 
 |                 operator.ifloordiv, | 
 |             ), | 
 |             operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), | 
 |             pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), | 
 |             operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), | 
 |             # NB: The follow binary operators are not supported for now, since the | 
 |             # corresponding magic methods aren't defined on SymInt / SymFloat: | 
 |             # operator.matmul | 
 |             # divmod | 
 |             # operator.lshift | 
 |             # operator.rshift | 
 |             # operator.and_ | 
 |             # operator.or_ | 
 |             # operator.xor | 
 |         } | 
 |         return fns | 
 |  | 
 |     @staticmethod | 
 |     @functools.lru_cache(None) | 
 |     def _binop_handlers(): | 
 |         # Multiple dispatch mechanism defining custom binop behavior for certain type | 
 |         # combinations. Handlers are attempted in order, and will be used if the type checks | 
 |         # match. They are expected to have the signature: | 
 |         # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker | 
 |  | 
 |         # Override table contains: op_fn -> [list of handlers] | 
 |         op_handlers = {} | 
 |         for ( | 
 |             op, | 
 |             (magic_method_names, in_place_op), | 
 |         ) in BuiltinVariable._binops().items(): | 
 |             op_handlers[op] = [] | 
 |             op_handlers[in_place_op] = [] | 
 |  | 
 |             forward_name, reverse_name, inplace_name = magic_method_names | 
 |  | 
 |             # User-defined args (highest precedence) | 
 |             def user_defined_handler( | 
 |                 tx, | 
 |                 a, | 
 |                 b, | 
 |                 options, | 
 |                 forward_name=forward_name, | 
 |                 reverse_name=reverse_name, | 
 |             ): | 
 |                 # Manually handle reversing logic if needed (e.g. call __radd__) | 
 |  | 
 |                 # TODO: If we expand this to handle tensor args, we need to manually | 
 |                 # handle cases like this: | 
 |                 # | 
 |                 # class A(int): | 
 |                 #     def __radd__(self, other): | 
 |                 #         print("woof") | 
 |                 # torch.randn(3) + A(3) | 
 |                 # | 
 |                 # In this example, A.__radd__() is not called -> nothing is printed, because | 
 |                 # Tensor.__add__ only does a subtype test against int, ignoring the subclass. | 
 |                 # To be fully correct, we should not call A.__radd__() here, and there may be | 
 |                 # other cases to reason about and add exceptions for. | 
 |                 if isinstance(a, UserDefinedVariable): | 
 |                     return a.call_method(tx, forward_name, [b], {}) | 
 |                 else: | 
 |                     return b.call_method(tx, reverse_name, [a], {}) | 
 |  | 
 |             op_handlers[op].append( | 
 |                 ((UserDefinedVariable, VariableTracker), user_defined_handler) | 
 |             ) | 
 |             op_handlers[op].append( | 
 |                 ((VariableTracker, UserDefinedVariable), user_defined_handler) | 
 |             ) | 
 |  | 
 |             def user_defined_inplace_handler( | 
 |                 tx, a, b, options, forward_name=inplace_name | 
 |             ): | 
 |                 return a.call_method(tx, forward_name, [b], {}) | 
 |  | 
 |             op_handlers[in_place_op].append( | 
 |                 ((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) | 
 |             ) | 
 |             op_handlers[in_place_op].append( | 
 |                 ((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) | 
 |             ) | 
 |  | 
 |             # Dynamic shape args | 
 |             def dynamic_handler(tx, a, b, options, fn=op): | 
 |                 from .builder import wrap_fx_proxy | 
 |  | 
 |                 return wrap_fx_proxy( | 
 |                     tx, | 
 |                     tx.output.create_proxy( | 
 |                         "call_function", fn, *proxy_args_kwargs([a, b], {}) | 
 |                     ), | 
 |                     **options, | 
 |                 ) | 
 |  | 
 |             op_handlers[op].append( | 
 |                 ((SymNodeVariable, VariableTracker), dynamic_handler) | 
 |             ) | 
 |             op_handlers[op].append( | 
 |                 ((VariableTracker, SymNodeVariable), dynamic_handler) | 
 |             ) | 
 |  | 
 |             # NB: Prefer out-of-place op when calling in-place op to generate valid graph | 
 |             op_handlers[in_place_op].append( | 
 |                 ((SymNodeVariable, VariableTracker), dynamic_handler) | 
 |             ) | 
 |             op_handlers[in_place_op].append( | 
 |                 ((VariableTracker, SymNodeVariable), dynamic_handler) | 
 |             ) | 
 |  | 
 |         # Special cases - lower precedence but still prefer these over constant folding | 
 |  | 
 |         # List-like addition (e.g. [1, 2] + [3, 4]) | 
 |         def tuple_add_handler(tx, a, b, options): | 
 |             return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options) | 
 |  | 
 |         def size_add_handler(tx, a, b, options): | 
 |             return SizeVariable(a.items + list(b.unpack_var_sequence(tx)), **options) | 
 |  | 
 |         list_like_addition_handlers = [ | 
 |             # NB: Prefer the tuple-specific logic over base logic because of | 
 |             # some SizeVariable weirdness. Specifically, the tuple-specific logic | 
 |             # drops the subclass type (e.g. SizeVariable) and returns TupleVariables. | 
 |             ( | 
 |                 (SizeVariable, SizeVariable), | 
 |                 size_add_handler, | 
 |             ), | 
 |             ( | 
 |                 (TupleVariable, TupleVariable), | 
 |                 tuple_add_handler, | 
 |             ), | 
 |             ( | 
 |                 (TupleVariable, ConstantVariable), | 
 |                 tuple_add_handler, | 
 |             ), | 
 |             ( | 
 |                 (ConstantVariable, TupleVariable), | 
 |                 lambda tx, a, b, options: TupleVariable( | 
 |                     list(a.unpack_var_sequence(tx)) + b.items, **options | 
 |                 ), | 
 |             ), | 
 |             ( | 
 |                 (BaseListVariable, BaseListVariable), | 
 |                 lambda tx, a, b, options: type(a)(a.items + b.items, **options), | 
 |             ), | 
 |         ] | 
 |         op_handlers[operator.add].extend(list_like_addition_handlers) | 
 |  | 
 |         def list_iadd_handler(tx, a, b, options): | 
 |             if not a.mutable_local or not b.has_unpack_var_sequence(tx): | 
 |                 # Handler doesn't apply | 
 |                 return None | 
 |  | 
 |             return tx.replace_all( | 
 |                 a, | 
 |                 ListVariable( | 
 |                     list(a.items) + list(b.unpack_var_sequence(tx)), | 
 |                     regen_guards=False, | 
 |                     **options, | 
 |                 ), | 
 |             ) | 
 |  | 
 |         list_like_iadd_handlers = [ | 
 |             ( | 
 |                 (ListVariable, VariableTracker), | 
 |                 list_iadd_handler, | 
 |             ), | 
 |             ( | 
 |                 (TupleVariable, TupleVariable), | 
 |                 tuple_add_handler, | 
 |             ), | 
 |             ( | 
 |                 (TupleVariable, ConstantVariable), | 
 |                 tuple_add_handler, | 
 |             ), | 
 |         ] | 
 |         op_handlers[operator.iadd].extend(list_like_iadd_handlers) | 
 |  | 
 |         # List-like expansion (e.g. [1, 2, 3] * 3) | 
 |         def expand_list_like(tx, lst, const, options): | 
 |             return lst.__class__( | 
 |                 items=lst.items * const.as_python_constant(), | 
 |                 mutable_local=MutableLocal(), | 
 |                 **options, | 
 |             ) | 
 |  | 
 |         list_like_expansion_handlers = [ | 
 |             ((ListVariable, ConstantVariable), expand_list_like), | 
 |             ((TupleVariable, ConstantVariable), expand_list_like), | 
 |             ( | 
 |                 (ConstantVariable, ListVariable), | 
 |                 lambda tx, a, b, options: expand_list_like(tx, b, a, options), | 
 |             ), | 
 |             ( | 
 |                 (ConstantVariable, TupleVariable), | 
 |                 lambda tx, a, b, options: expand_list_like(tx, b, a, options), | 
 |             ), | 
 |         ] | 
 |         op_handlers[operator.mul].extend(list_like_expansion_handlers) | 
 |  | 
 |         return op_handlers | 
 |  | 
 |     @staticmethod | 
 |     def _find_binop_handler(op, a, b): | 
 |         handlers = BuiltinVariable._binop_handlers() | 
 |         if op not in handlers: | 
 |             return None | 
 |  | 
 |         # Return first handler that matches the type checks | 
 |         for (type1, type2), handler in handlers[op]: | 
 |             if isinstance(a, type1) and isinstance(b, type2): | 
 |                 return handler | 
 |  | 
 |         return None | 
 |  | 
 |     def can_insert_in_graph(self): | 
 |         return self.fn in self._fx_graph_functions() | 
 |  | 
 |     def __init__(self, fn, **kwargs): | 
 |         super().__init__(**kwargs) | 
 |         self.fn = fn | 
 |  | 
 |     def __str__(self): | 
 |         if self.fn is None: | 
 |             name = "None" | 
 |         else: | 
 |             name = self.fn.__name__ | 
 |  | 
 |         return f"{self.__class__.__name__}({name})" | 
 |  | 
 |     def python_type(self): | 
 |         return type(self.fn) | 
 |  | 
 |     def as_python_constant(self): | 
 |         return self.fn | 
 |  | 
 |     def reconstruct(self, codegen): | 
 |         name = self.fn.__name__ | 
 |         assert self.fn.__module__ == "builtins" | 
 |         assert name not in codegen.tx.f_globals, "shadowed global" | 
 |         return [codegen.create_load_global(name, False, add=True)] | 
 |  | 
 |     def constant_args(self, *args, **kwargs): | 
 |         return check_constant_args(args, kwargs) | 
 |  | 
 |     def tensor_args(self, *args, **kwargs): | 
 |         return any( | 
 |             isinstance(i, variables.TensorVariable) | 
 |             for i in itertools.chain(args, kwargs.values()) | 
 |         ) and not any( | 
 |             isinstance(i, variables.GetAttrVariable) | 
 |             for i in itertools.chain(args, kwargs.values()) | 
 |         ) | 
 |  | 
 |     def unspec_python_args(self, *args, **kwargs): | 
 |         return check_unspec_python_args(args, kwargs) | 
 |  | 
 |     @staticmethod | 
 |     def unwrap_unspec_args_kwargs(args, kwargs): | 
 |         unwrapped_args = [] | 
 |         unwrapped_kwargs = {} | 
 |         for x in args: | 
 |             if isinstance( | 
 |                 x, | 
 |                 (variables.UnspecializedPythonVariable,), | 
 |             ): | 
 |                 unwrapped_args.append(x.raw_value) | 
 |             else: | 
 |                 unwrapped_args.append(x.as_python_constant()) | 
 |         for k, v in kwargs: | 
 |             if isinstance( | 
 |                 x, | 
 |                 (variables.UnspecializedPythonVariable,), | 
 |             ): | 
 |                 unwrapped_kwargs.update({k: v.raw_value}) | 
 |             else: | 
 |                 unwrapped_kwargs.update({k: v.as_python_constant()}) | 
 |         return unwrapped_args, unwrapped_kwargs | 
 |  | 
 |     def call_function( | 
 |         self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" | 
 |     ) -> "VariableTracker": | 
 |         from .builder import wrap_fx_proxy, wrap_fx_proxy_cls | 
 |  | 
 |         constant_args = check_constant_args(args, kwargs) | 
 |         tensor_args = self.tensor_args(*args, **kwargs) | 
 |         unspec_python_args = self.unspec_python_args(*args, **kwargs) | 
 |         options = VariableTracker.propagate(self, args, kwargs.values()) | 
 |         has_constant_handler = self.can_constant_fold_through() and ( | 
 |             constant_args or unspec_python_args | 
 |         ) | 
 |         assert isinstance(args, (list, tuple)) | 
 |         assert isinstance(kwargs, dict) | 
 |  | 
 |         if ( | 
 |             self.fn is operator.getitem | 
 |             and len(args) == 2 | 
 |             and isinstance(args[1], variables.TensorVariable) | 
 |             and args[1].dtype == torch.bool | 
 |             and not config.dynamic_shapes | 
 |         ): | 
 |             unimplemented("dynamic Tensor.__getitem__(bool[])") | 
 |  | 
 |         # args[0] is list and args[1] is unspec | 
 |         if self.fn is operator.getitem and not isinstance( | 
 |             args[0], variables.TensorVariable | 
 |         ): | 
 |             tensor_args = False | 
 |             args, kwargs = specialize_args_kwargs(tx, args, kwargs) | 
 |  | 
 |         if ( | 
 |             self.can_insert_in_graph() | 
 |             and tensor_args | 
 |             and not ( | 
 |                 self.fn is operator.getitem | 
 |                 and isinstance(args[0], ConstDictVariable) | 
 |                 and isinstance(args[1], variables.TensorVariable) | 
 |             ) | 
 |         ): | 
 |             try: | 
 |                 fn = self.fn | 
 |                 if self.fn is operator.iadd and isinstance( | 
 |                     args[0], variables.ConstantVariable | 
 |                 ): | 
 |                     # Work around weird bug in hf_T5 | 
 |                     fn, args = operator.add, [args[1], args[0]] | 
 |  | 
 |                 if self.fn is operator.getitem and isinstance(args[1], SymNodeVariable): | 
 |                     # Standard indexing will force specialization due to | 
 |                     # __index__.  Rewrite as a regular torch op which will | 
 |                     # trace fine | 
 |                     fn, args = torch.select, [ | 
 |                         args[0], | 
 |                         variables.ConstantVariable(0), | 
 |                         args[1], | 
 |                     ] | 
 |  | 
 |                 proxy = tx.output.create_proxy( | 
 |                     "call_function", | 
 |                     fn, | 
 |                     *proxy_args_kwargs(args, kwargs), | 
 |                 ) | 
 |                 if any(isinstance(arg, FakeItemVariable) for arg in args): | 
 |                     return wrap_fx_proxy_cls( | 
 |                         FakeItemVariable, | 
 |                         tx, | 
 |                         proxy, | 
 |                         **options, | 
 |                     ) | 
 |                 elif self.unspec_python_args(*args, **kwargs): | 
 |                     _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) | 
 |                     raw_value = self.fn(*_args, **_kwargs) | 
 |  | 
 |                     need_unwrap = any( | 
 |                         x.need_unwrap | 
 |                         for x in itertools.chain(args, kwargs.values()) | 
 |                         if isinstance(x, variables.UnspecializedPythonVariable) | 
 |                     ) | 
 |  | 
 |                     return wrap_fx_proxy_cls( | 
 |                         UnspecializedPythonVariable, | 
 |                         tx, | 
 |                         proxy, | 
 |                         raw_value=raw_value, | 
 |                         need_unwrap=need_unwrap, | 
 |                         **options, | 
 |                     ) | 
 |                 elif all(isinstance(x, SymNodeVariable) for x in args): | 
 |                     return SymNodeVariable.create(tx, proxy, None, **options) | 
 |                 else: | 
 |                     # Work around for vision_maskrcnn due to precision difference | 
 |                     # specialize the dividend when float divide by tensor | 
 |                     if self.fn is operator.truediv and isinstance( | 
 |                         args[0], variables.UnspecializedPythonVariable | 
 |                     ): | 
 |                         args[0] = args[0].convert_to_constant(tx) | 
 |                     return wrap_fx_proxy(tx, proxy, **options) | 
 |  | 
 |             except NotImplementedError: | 
 |                 unimplemented(f"partial tensor op: {self} {args} {kwargs}") | 
 |  | 
 |         # Handle cases like int(torch.seed()) | 
 |         # Also handle sym_float to sym_int cases | 
 |         if self.fn in (int, float) and isinstance(args[0], SymNodeVariable): | 
 |             fn_ = sym_int if self.fn is int else sym_float | 
 |             out = wrap_fx_proxy( | 
 |                 tx=tx, | 
 |                 proxy=tx.output.create_proxy( | 
 |                     "call_function", | 
 |                     fn_, | 
 |                     (args[0].as_proxy(),), | 
 |                     {}, | 
 |                 ), | 
 |                 **options, | 
 |             ) | 
 |             return out | 
 |  | 
 |         # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.) | 
 |         # NB: Tensor args are handled above and not here | 
 |         if len(kwargs) == 0 and len(args) == 2: | 
 |             # Try to find a handler for the arg types; otherwise, fall through to constant handler | 
 |             binop_handler = BuiltinVariable._find_binop_handler( | 
 |                 self.fn, args[0], args[1] | 
 |             ) | 
 |             if binop_handler: | 
 |                 res = binop_handler(tx, args[0], args[1], options) | 
 |                 if res is not None: | 
 |                     return res | 
 |  | 
 |         handler = getattr(self, f"call_{self.fn.__name__}", None) | 
 |         if handler: | 
 |             try: | 
 |                 inspect.signature(handler).bind(tx, *args, **kwargs) | 
 |             except TypeError as exc: | 
 |                 if not has_constant_handler: | 
 |                     log.warning( | 
 |                         "incorrect arg count %s %s and no constant handler", | 
 |                         handler, | 
 |                         exc, | 
 |                     ) | 
 |                 handler = None | 
 |  | 
 |         if handler: | 
 |             try: | 
 |                 result = handler(tx, *args, **kwargs) | 
 |                 if result is not None: | 
 |                     return result.add_options(options) | 
 |             except Unsupported as exc: | 
 |                 if not has_constant_handler: | 
 |                     raise | 
 |                 # Actually, we will handle this just fine | 
 |                 exc.remove_from_stats() | 
 |  | 
 |         if has_constant_handler: | 
 |             args, kwargs = specialize_args_kwargs(tx, args, kwargs) | 
 |             # constant fold | 
 |             return variables.ConstantVariable( | 
 |                 self.as_python_constant()( | 
 |                     *[x.as_python_constant() for x in args], | 
 |                     **{k: v.as_python_constant() for k, v in kwargs.items()}, | 
 |                 ), | 
 |                 **options, | 
 |             ) | 
 |  | 
 |         if self.fn is round: | 
 |             if len(args) > 0 and isinstance(args[0], SymNodeVariable): | 
 |                 raise UserError( | 
 |                     UserErrorType.STANDARD_LIBRARY, | 
 |                     "Calling round() on symbolic value is not supported. " | 
 |                     "You can use floor() to implement this functionality", | 
 |                 ) | 
 |         return super().call_function(tx, args, kwargs) | 
 |  | 
 |     def _call_min_max(self, tx, *args): | 
 |         if len(args) == 1 and args[0].has_unpack_var_sequence(tx): | 
 |             # expand iterable | 
 |             items = args[0].unpack_var_sequence(tx) | 
 |             return self._call_min_max_seq(tx, items) | 
 |         elif len(args) == 2: | 
 |             return self._call_min_max_binary(tx, args[0], args[1]) | 
 |         elif len(args) > 2: | 
 |             return self._call_min_max_seq(tx, args) | 
 |  | 
 |     def _call_min_max_seq(self, tx, items): | 
 |         assert len(items) > 0 | 
 |         if len(items) == 1: | 
 |             return items[0] | 
 |  | 
 |         return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) | 
 |  | 
 |     def _call_min_max_binary(self, tx, a, b): | 
 |         if self.tensor_args(a, b): | 
 |             if not isinstance(a, variables.TensorVariable): | 
 |                 a, b = b, a | 
 |             assert isinstance(a, variables.TensorVariable) | 
 |  | 
 |             # result of an item call is a scalar convert to a tensor | 
 |             if isinstance(a, FakeItemVariable): | 
 |                 a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {}) | 
 |  | 
 |             # Dynamic input does not get resolved, rather, gets stored as call_function | 
 |             if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): | 
 |                 from .builder import wrap_fx_proxy | 
 |  | 
 |                 return wrap_fx_proxy( | 
 |                     tx=tx, | 
 |                     proxy=tx.output.create_proxy( | 
 |                         "call_function", | 
 |                         self.fn, | 
 |                         *proxy_args_kwargs([a, b], {}), | 
 |                     ), | 
 |                     **VariableTracker.propagate(self, [a, b]), | 
 |                 ) | 
 |  | 
 |             # convert min/max to torch ops | 
 |             if b.is_python_constant(): | 
 |                 kwargs = {"min": b} if (self.fn is max) else {"max": b} | 
 |                 result = variables.TorchVariable(torch.clamp).call_function( | 
 |                     tx, [a], kwargs | 
 |                 ) | 
 |             else: | 
 |                 fn = {max: torch.maximum, min: torch.minimum}[self.fn] | 
 |                 result = variables.TorchVariable(fn).call_function(tx, [a, b], {}) | 
 |  | 
 |             # return unspec if both a, b are unspec or const | 
 |             if all( | 
 |                 isinstance( | 
 |                     i, | 
 |                     ( | 
 |                         variables.UnspecializedPythonVariable, | 
 |                         variables.ConstantVariable, | 
 |                     ), | 
 |                 ) | 
 |                 for i in [a, b] | 
 |             ): | 
 |                 if any(isinstance(val, FakeItemVariable) for val in [a, b]): | 
 |                     return variables.FakeItemVariable.from_tensor_variable(result) | 
 |  | 
 |                 if b.is_python_constant(): | 
 |                     raw_b = b.as_python_constant() | 
 |                 else: | 
 |                     raw_b = b.raw_value | 
 |                 if self.fn is max: | 
 |                     raw_res = max(a.raw_value, raw_b) | 
 |                 else: | 
 |                     raw_res = min(a.raw_value, raw_b) | 
 |  | 
 |                 need_unwrap = any( | 
 |                     x.need_unwrap | 
 |                     for x in [a, b] | 
 |                     if isinstance(x, variables.UnspecializedPythonVariable) | 
 |                 ) | 
 |                 return variables.UnspecializedPythonVariable.from_tensor_variable( | 
 |                     result, raw_res, need_unwrap | 
 |                 ) | 
 |             # otherwise return tensor | 
 |             else: | 
 |                 return result | 
 |         elif isinstance(a, variables.ConstantVariable) and isinstance( | 
 |             b, variables.ConstantVariable | 
 |         ): | 
 |             if self.fn is max: | 
 |                 return variables.ConstantVariable(max(a.value, b.value)) | 
 |             else: | 
 |                 return variables.ConstantVariable(min(a.value, b.value)) | 
 |         elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): | 
 |             proxy = tx.output.create_proxy( | 
 |                 "call_function", self.fn, *proxy_args_kwargs([a, b], {}) | 
 |             ) | 
 |             return SymNodeVariable.create(tx, proxy, None) | 
 |         else: | 
 |             unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") | 
 |  | 
 |     call_min = _call_min_max | 
 |     call_max = _call_min_max | 
 |  | 
 |     def call_range(self, tx, *args): | 
 |         if self.unspec_python_args(*args) or self.constant_args(*args): | 
 |             args, _ = specialize_args_kwargs(tx, args, {}) | 
 |             return variables.RangeVariable(args) | 
 |         elif self._dynamic_args(*args): | 
 |  | 
 |             def guard_if_dyn(arg): | 
 |                 if isinstance(arg, SymNodeVariable): | 
 |                     return arg.evaluate_expr(tx.output) | 
 |                 elif isinstance(arg, ConstantVariable): | 
 |                     return arg.as_python_constant() | 
 |                 return arg | 
 |  | 
 |             args = [variables.ConstantVariable(guard_if_dyn(arg)) for arg in args] | 
 |             return variables.RangeVariable(args) | 
 |         # None no-ops this handler and lets the driving function proceed | 
 |         return None | 
 |  | 
 |     def _dynamic_args(self, *args, **kwargs): | 
 |         return any(isinstance(x, SymNodeVariable) for x in args) or any( | 
 |             isinstance(x, SymNodeVariable) for x in kwargs.values() | 
 |         ) | 
 |  | 
 |     def call_slice(self, tx, *args): | 
 |         return variables.SliceVariable(args) | 
 |  | 
 |     def _dyn_proxy(self, tx, *args, **kwargs): | 
 |         from .builder import wrap_fx_proxy | 
 |  | 
 |         options = VariableTracker.propagate(self, args, kwargs.values()) | 
 |         return wrap_fx_proxy( | 
 |             tx, | 
 |             tx.output.create_proxy( | 
 |                 "call_function", self.fn, *proxy_args_kwargs(args, kwargs) | 
 |             ), | 
 |             **options, | 
 |         ) | 
 |  | 
 |     def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): | 
 |         if self._dynamic_args(*args, **kwargs): | 
 |             return self._dyn_proxy(tx, *args, **kwargs) | 
 |         cls = variables.BaseListVariable.cls_for(self.fn) | 
 |         if obj is None: | 
 |             return cls( | 
 |                 [], | 
 |                 mutable_local=MutableLocal(), | 
 |             ) | 
 |         elif obj.has_unpack_var_sequence(tx): | 
 |             guards = set() | 
 |             if obj.source and not is_constant_source(obj.source): | 
 |                 if isinstance(obj, TupleIteratorVariable): | 
 |                     guards.add(obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN)) | 
 |                 else: | 
 |                     guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH)) | 
 |             return cls( | 
 |                 list(obj.unpack_var_sequence(tx)), | 
 |                 mutable_local=MutableLocal(), | 
 |                 guards=guards, | 
 |             ).add_options(self, obj) | 
 |  | 
 |     call_iter = _call_iter_tuple_list | 
 |     call_tuple = _call_iter_tuple_list | 
 |     call_list = _call_iter_tuple_list | 
 |  | 
 |     @staticmethod | 
 |     def is_supported_call_dict_arg(tx, arg): | 
 |         return ( | 
 |             arg is None | 
 |             or isinstance(arg, ConstDictVariable) | 
 |             or ( | 
 |                 isinstance( | 
 |                     arg, | 
 |                     ( | 
 |                         ListVariable, | 
 |                         TupleVariable, | 
 |                         ListIteratorVariable, | 
 |                     ), | 
 |                 ) | 
 |                 and all( | 
 |                     isinstance(x, (ListVariable, TupleVariable)) | 
 |                     and isinstance( | 
 |                         x.unpack_var_sequence(tx)[0], (ConstantVariable, EnumVariable) | 
 |                     ) | 
 |                     for x in arg.unpack_var_sequence(tx) | 
 |                 ) | 
 |             ) | 
 |         ) | 
 |  | 
 |     def call_callable(self, tx, arg): | 
 |         from .functions import BaseUserFunctionVariable | 
 |  | 
 |         if isinstance( | 
 |             arg, (variables.UserDefinedClassVariable, BaseUserFunctionVariable) | 
 |         ): | 
 |             return variables.ConstantVariable(True).add_options(arg) | 
 |  | 
 |     @staticmethod | 
 |     def call_dict_helper(tx, user_cls, arg, **options): | 
 |         if arg is None: | 
 |             return ConstDictVariable( | 
 |                 {}, user_cls, mutable_local=MutableLocal() | 
 |             ).add_options(options) | 
 |         elif isinstance(arg, variables.ConstDictVariable): | 
 |             return arg.clone( | 
 |                 user_cls=user_cls, mutable_local=MutableLocal() | 
 |             ).add_options(options) | 
 |         elif isinstance( | 
 |             arg, | 
 |             ( | 
 |                 ListVariable, | 
 |                 TupleVariable, | 
 |                 ListIteratorVariable, | 
 |             ), | 
 |         ): | 
 |             items = user_cls() | 
 |             for x in arg.unpack_var_sequence(tx): | 
 |                 k = x.unpack_var_sequence(tx)[0].as_python_constant() | 
 |                 v = x.unpack_var_sequence(tx)[1] | 
 |                 items.update({k: v}) | 
 |             return ConstDictVariable( | 
 |                 items, user_cls, mutable_local=MutableLocal() | 
 |             ).add_options(options) | 
 |         else: | 
 |             raise AssertionError("call_dict_helper with illegal arg") | 
 |  | 
 |     def call_dict(self, tx, *args, **kwargs): | 
 |         if not (args or kwargs): | 
 |             return self.call_dict_helper(tx, dict, None) | 
 |         elif ( | 
 |             not kwargs | 
 |             and len(args) == 1 | 
 |             and self.is_supported_call_dict_arg(tx, args[0]) | 
 |         ): | 
 |             return self.call_dict_helper(tx, dict, args[0]) | 
 |         elif not args and kwargs: | 
 |             return variables.ConstDictVariable( | 
 |                 dict(kwargs), user_cls=dict, mutable_local=MutableLocal() | 
 |             ) | 
 |         else: | 
 |             unimplemented(f"dict(): {args} {kwargs}") | 
 |  | 
 |     def call_zip(self, tx, *args): | 
 |         options = VariableTracker.propagate(self, args) | 
 |         if all(x.has_unpack_var_sequence(tx) for x in args): | 
 |             items = [ | 
 |                 variables.TupleVariable(list(item), **options) | 
 |                 for item in zip(*[arg.unpack_var_sequence(tx) for arg in args]) | 
 |             ] | 
 |             return variables.TupleVariable(items, **options) | 
 |  | 
 |     def call_enumerate(self, tx, *args): | 
 |         options = VariableTracker.propagate(self, args) | 
 |         if len(args) == 1: | 
 |             start = 0 | 
 |         else: | 
 |             assert len(args) == 2 | 
 |             assert isinstance(args[1], variables.ConstantVariable) | 
 |             start = args[1].as_python_constant() | 
 |         if args[0].has_unpack_var_sequence(tx): | 
 |             items = [ | 
 |                 variables.TupleVariable( | 
 |                     [variables.ConstantVariable(idx, **options), var], | 
 |                     **options, | 
 |                 ) | 
 |                 for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) | 
 |             ] | 
 |             return variables.TupleVariable(items, **options) | 
 |  | 
 |     def call_len(self, tx, *args, **kwargs): | 
 |         return args[0].call_method(tx, "__len__", args[1:], kwargs) | 
 |  | 
 |     def call_getitem(self, tx, *args, **kwargs): | 
 |         if self.unspec_python_args(*args, **kwargs): | 
 |             args, kwargs = specialize_args_kwargs(tx, args, kwargs) | 
 |         return args[0].call_method(tx, "__getitem__", args[1:], kwargs) | 
 |  | 
 |     def call_isinstance(self, tx, arg, isinstance_type): | 
 |         arg_type = arg.python_type() | 
 |  | 
 |         isinstance_type = isinstance_type.as_python_constant() | 
 |  | 
 |         if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: | 
 |             return variables.ConstantVariable(arg.call_isinstance(isinstance_type)) | 
 |         # UserDefinedObject with C extensions can have torch.Tensor attributes, | 
 |         # so break graph. | 
 |         if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( | 
 |             arg.value, types.MemberDescriptorType | 
 |         ): | 
 |             unimplemented( | 
 |                 f"isinstance called on UserDefinedClass {arg} {isinstance_type}" | 
 |             ) | 
 |         # handle __instancecheck__ defined in user class | 
 |         if ( | 
 |             isinstance(arg, variables.UserDefinedObjectVariable) | 
 |             and "__instancecheck__" in isinstance_type.__class__.__dict__ | 
 |         ): | 
 |             return variables.ConstantVariable( | 
 |                 isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) | 
 |             ) | 
 |  | 
 |         try: | 
 |             val = issubclass(arg_type, isinstance_type) | 
 |         except TypeError: | 
 |             val = arg_type is isinstance_type | 
 |         return variables.ConstantVariable(val) | 
 |  | 
 |     def call_super(self, tx, a, b): | 
 |         source = ( | 
 |             None | 
 |             if a.source is None or b.source is None | 
 |             else SuperSource(a.source, b.source) | 
 |         ) | 
 |         return variables.SuperVariable(a, b, source=source) | 
 |  | 
 |     def call_next(self, tx, arg): | 
 |         if isinstance(arg, variables.ListIteratorVariable): | 
 |             val, next_iter = arg.next_variables() | 
 |             tx.replace_all(arg, next_iter) | 
 |             return val | 
 |         elif isinstance(arg, variables.BaseListVariable): | 
 |             return arg.items[0].add_options(self, arg) | 
 |  | 
 |     def call_hasattr(self, tx, obj, attr): | 
 |         if attr.is_python_constant(): | 
 |             name = attr.as_python_constant() | 
 |             return obj.call_hasattr(tx, name).add_options(self, obj, attr) | 
 |  | 
 |     def call_map(self, tx, fn, seq): | 
 |         if seq.has_unpack_var_sequence(tx): | 
 |             items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)] | 
 |             return variables.TupleVariable(items).add_options(self, fn, seq) | 
 |  | 
 |     def call_sum(self, tx, seq, **kwargs): | 
 |         # Special case for sum on tuple of floats and ints | 
 |         if ( | 
 |             isinstance(seq, (variables.ListVariable, variables.TupleVariable)) | 
 |             and all( | 
 |                 isinstance(x, variables.ConstantVariable) | 
 |                 and isinstance(x.value, (int, float)) | 
 |                 for x in seq.items | 
 |             ) | 
 |             and not kwargs | 
 |         ): | 
 |             new_list = [x.value for x in seq.items] | 
 |             return variables.ConstantVariable(sum(new_list)) | 
 |         if seq.has_unpack_var_sequence(tx): | 
 |             start = kwargs.pop( | 
 |                 "start", variables.ConstantVariable(0) | 
 |             ).as_python_constant() | 
 |             assert not kwargs | 
 |             items = seq.unpack_var_sequence(tx)[start:] | 
 |             return BuiltinVariable(functools.reduce).call_function( | 
 |                 tx, | 
 |                 [ | 
 |                     BuiltinVariable(operator.add), | 
 |                     variables.TupleVariable(items), | 
 |                     variables.ConstantVariable(0).add_options(self, seq), | 
 |                 ], | 
 |                 {}, | 
 |             ) | 
 |  | 
 |     def call_reduce(self, tx, function, iterable, initializer=None): | 
 |         if iterable.has_unpack_var_sequence(tx): | 
 |             items = iterable.unpack_var_sequence(tx) | 
 |             if initializer is None: | 
 |                 value, items = items[0], items[1:] | 
 |             else: | 
 |                 value = initializer | 
 |             for element in items: | 
 |                 value = function.call_function(tx, [value, element], {}) | 
 |             return value | 
 |  | 
 |     def call_getattr( | 
 |         self, tx, obj: VariableTracker, name_var: VariableTracker, default=None | 
 |     ): | 
 |         from . import ( | 
 |             ConstantVariable, | 
 |             GetAttrVariable, | 
 |             PythonModuleVariable, | 
 |             TorchVariable, | 
 |             UserFunctionVariable, | 
 |         ) | 
 |         from .builder import VariableBuilder | 
 |  | 
 |         options = VariableTracker.propagate(self, obj, name_var) | 
 |         guards = options["guards"] | 
 |         name = name_var.as_python_constant() | 
 |  | 
 |         if not name_var.is_python_constant(): | 
 |             unimplemented("non-const getattr() name") | 
 |  | 
 |         if tx.output.side_effects.is_attribute_mutation(obj): | 
 |             try: | 
 |                 # re-read a pending side effect? | 
 |                 return tx.output.side_effects.load_attr(obj, name).add_options(options) | 
 |             except KeyError: | 
 |                 pass | 
 |  | 
 |         if default is not None: | 
 |             hasattr_var = self.call_hasattr(tx, obj, name_var) | 
 |             guards.update(hasattr_var.guards) | 
 |             assert hasattr_var.as_python_constant() in (True, False) | 
 |             if not hasattr_var.as_python_constant(): | 
 |                 return default.add_guards(guards) | 
 |  | 
 |         if obj.source: | 
 |             source = AttrSource(obj.source, name) | 
 |             options["source"] = source | 
 |         else: | 
 |             source = None | 
 |  | 
 |         if isinstance(obj, variables.NNModuleVariable): | 
 |             return obj.var_getattr(tx, name).add_options(options) | 
 |         elif isinstance(obj, variables.TensorVariable) and name == "grad": | 
 |             if source: | 
 |                 # We are going to be raising this tensor as grapharg. So, ensure | 
 |                 # that we have real grad value instead of fake tensor value. | 
 |                 # Walk through the inputs of the subgraph and find if we already | 
 |                 # have the original tensor stored in the graphargs. | 
 |                 for grapharg in tx.output.graphargs: | 
 |                     if grapharg.source == source.base: | 
 |                         example_value = grapharg.example.grad | 
 |                         return VariableBuilder(tx, source)(example_value).add_options( | 
 |                             options | 
 |                         ) | 
 |                 unimplemented("tensor grad") | 
 |             else: | 
 |                 unimplemented("tensor grad") | 
 |         elif isinstance( | 
 |             obj, | 
 |             ( | 
 |                 variables.TensorVariable, | 
 |                 variables.NamedTupleVariable, | 
 |                 variables.ConstantVariable, | 
 |                 variables.UserDefinedClassVariable, | 
 |                 variables.UserDefinedObjectVariable, | 
 |             ), | 
 |         ): | 
 |             try: | 
 |                 return ( | 
 |                     obj.var_getattr(tx, name).clone(source=source).add_options(options) | 
 |                 ) | 
 |             except NotImplementedError: | 
 |                 return GetAttrVariable(obj, name, **options) | 
 |         elif isinstance(obj, TorchVariable): | 
 |             member = getattr(obj.value, name) | 
 |             if is_allowed(member): | 
 |                 return TorchVariable(member, **options) | 
 |             elif ConstantVariable.is_literal(member): | 
 |                 return ConstantVariable(member, **options) | 
 |             else: | 
 |                 return VariableBuilder(tx, source)(member).add_guards(guards) | 
 |         elif isinstance(obj, (PythonModuleVariable, DummyModule)): | 
 |             member = obj.value.__dict__[name] | 
 |  | 
 |             if config.replay_record_enabled: | 
 |                 tx.exec_recorder.record_module_access(obj.value, name, member) | 
 |  | 
 |             return VariableBuilder(tx, source)(member).add_guards(guards) | 
 |         elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): | 
 |             return ConstantVariable( | 
 |                 getattr(obj.fn, name), **VariableTracker.propagate(obj) | 
 |             ) | 
 |         else: | 
 |             try: | 
 |                 return ( | 
 |                     obj.var_getattr(tx, name).clone(source=source).add_options(options) | 
 |                 ) | 
 |             except NotImplementedError: | 
 |                 return GetAttrVariable(obj, name, **options) | 
 |  | 
 |     def call_setattr( | 
 |         self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker | 
 |     ): | 
 |         if isinstance(obj, variables.DataClassVariable): | 
 |             return obj.call_method(tx, "__setattr__", [name_var, val], {}) | 
 |         elif ( | 
 |             tx.output.side_effects.is_attribute_mutation(obj) | 
 |             and name_var.is_python_constant() | 
 |         ): | 
 |             tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val) | 
 |             return val.add_options(self, obj, name_var) | 
 |         elif isinstance(obj, variables.UserDefinedObjectVariable): | 
 |             unimplemented( | 
 |                 f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}" | 
 |             ) | 
 |         elif isinstance(obj, variables.NNModuleVariable): | 
 |             obj.convert_to_unspecialized(tx) | 
 |  | 
 |     def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker): | 
 |         return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) | 
 |  | 
 |     def call_type(self, tx, obj: VariableTracker): | 
 |         from .builder import VariableBuilder | 
 |  | 
 |         try: | 
 |             py_type = obj.python_type() | 
 |         except NotImplementedError: | 
 |             py_type = None | 
 |  | 
 |         if istype(obj, variables.TupleVariable): | 
 |             return BuiltinVariable(py_type).add_options(self, obj) | 
 |  | 
 |         if py_type is not None and obj.source: | 
 |             return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options( | 
 |                 self, obj | 
 |             ) | 
 |  | 
 |         raise UserError( | 
 |             UserErrorType.ANTI_PATTERN, | 
 |             "Can't call type() on generated custom object. " | 
 |             "Please use __class__ instead", | 
 |         ) | 
 |  | 
 |     def call_reversed(self, tx, obj: VariableTracker): | 
 |         if obj.has_unpack_var_sequence(tx): | 
 |             items = list(reversed(obj.unpack_var_sequence(tx))) | 
 |             return variables.TupleVariable( | 
 |                 items, **VariableTracker.propagate(self, obj) | 
 |             ) | 
 |  | 
 |     def call_sorted(self, tx, obj: VariableTracker, **kwargs): | 
 |         if ( | 
 |             obj.has_unpack_var_sequence(tx) | 
 |             and not isinstance(obj, variables.TensorVariable) | 
 |             and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) | 
 |         ): | 
 |             function = kwargs.pop("key", None) | 
 |             reverse = kwargs.pop( | 
 |                 "reverse", ConstantVariable(False) | 
 |             ).as_python_constant() | 
 |             assert len(kwargs) == 0 | 
 |             if function: | 
 |                 items = sorted( | 
 |                     obj.unpack_var_sequence(tx), | 
 |                     key=lambda x: function.call_function( | 
 |                         tx, [x], {} | 
 |                     ).as_python_constant(), | 
 |                     reverse=reverse, | 
 |                 ) | 
 |             else: | 
 |                 items = sorted( | 
 |                     obj.unpack_var_sequence(tx), | 
 |                     key=lambda x: x.as_python_constant(), | 
 |                     reverse=reverse, | 
 |                 ) | 
 |             return variables.ListVariable(items, **VariableTracker.propagate(self, obj)) | 
 |  | 
 |     def call_chain(self, tx, *args): | 
 |         if all(obj.has_unpack_var_sequence(tx) for obj in args): | 
 |             items = [] | 
 |             for obj in args: | 
 |                 items.extend(obj.unpack_var_sequence(tx)) | 
 |             return variables.TupleVariable( | 
 |                 items, **VariableTracker.propagate(self, *args) | 
 |             ) | 
 |  | 
 |     def call_islice(self, tx, iterable, *args): | 
 |         if iterable.has_unpack_var_sequence(tx) and all( | 
 |             x.is_python_constant() for x in args | 
 |         ): | 
 |             const_args = [x.as_python_constant() for x in args] | 
 |             items = iterable.unpack_var_sequence(tx) | 
 |             items = list(itertools.islice(items, *const_args)) | 
 |             return variables.TupleVariable( | 
 |                 items, **VariableTracker.propagate(self, iterable, *args) | 
 |             ) | 
 |  | 
 |     # neg is a constant fold function, so we only get here if constant fold is not valid | 
 |     def call_neg(self, tx, a): | 
 |         if isinstance(a, SymNodeVariable): | 
 |             return SymNodeVariable.create( | 
 |                 tx, | 
 |                 (operator.neg)(a.as_proxy()), | 
 |                 sym_num=None, | 
 |             ) | 
 |         # None no-ops this handler and lets the driving function proceed | 
 |         return None | 
 |  | 
 |     def call_id(self, tx, *args): | 
 |         if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): | 
 |             nn_mod_variable = args[0] | 
 |             mod = tx.output.get_submodule(nn_mod_variable.module_key) | 
 |             return variables.ConstantVariable(id(mod)) | 
 |         else: | 
 |             unimplemented(f"call_id with args {args}") | 
 |  | 
 |     def _comparison(self, tx, left, right): | 
 |         """ | 
 |         Used to implement comparison operators for different types. | 
 |         For example, list1 < list2 is implemented differently from tensor1 < tensor2 | 
 |         """ | 
 |         from . import ( | 
 |             BaseListVariable, | 
 |             ConstantVariable, | 
 |             TensorVariable, | 
 |             UserFunctionVariable, | 
 |         ) | 
 |         from .lists import SizeVariable | 
 |         from .tensor import ( | 
 |             supported_const_comparison_ops, | 
 |             supported_tensor_comparison_ops, | 
 |         ) | 
 |  | 
 |         op = self.fn | 
 |  | 
 |         def _unimplemented(): | 
 |             unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}") | 
 |  | 
 |         if isinstance(left, UserFunctionVariable): | 
 |             if op not in supported_const_comparison_ops.values(): | 
 |                 _unimplemented() | 
 |             if not isinstance(right, UserFunctionVariable): | 
 |                 _unimplemented() | 
 |             return ConstantVariable(op(left.fn, right.fn)) | 
 |  | 
 |         # Note, we have a rare BaseListVariable subtype mismatch with valid comparison | 
 |         # x = torch.randn([3, 3]) | 
 |         # x.size() == (3, 3) # True | 
 |         # (3, 3) == x.size() # True | 
 |         if isinstance(left, (SizeVariable, TupleVariable)) and isinstance( | 
 |             right, (TupleVariable, SizeVariable) | 
 |         ): | 
 |             return BaseListVariable.list_compare(tx, op, left, right) | 
 |  | 
 |         if isinstance(left, BaseListVariable): | 
 |             if not type(left) == type(right):  # Mismatch in BaseListVariable subclasses | 
 |                 _unimplemented() | 
 |             return BaseListVariable.list_compare(tx, op, left, right) | 
 |  | 
 |         if isinstance(left, TensorVariable): | 
 |             from .builder import wrap_fx_proxy | 
 |  | 
 |             if op not in supported_tensor_comparison_ops.values(): | 
 |                 _unimplemented() | 
 |             return wrap_fx_proxy( | 
 |                 tx, | 
 |                 op(left.as_proxy(), right.as_proxy()), | 
 |             ) | 
 |  | 
 |         if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable): | 
 |             if op not in supported_tensor_comparison_ops.values(): | 
 |                 _unimplemented() | 
 |  | 
 |             return SymNodeVariable.create( | 
 |                 tx, | 
 |                 op(left.as_proxy(), right.as_proxy()), | 
 |                 sym_num=None, | 
 |             ) | 
 |  | 
 |         if isinstance(left, ConstantVariable) and isinstance(right, ConstantVariable): | 
 |             return ConstantVariable(op(left.value, right.value)) | 
 |  | 
 |         _unimplemented() | 
 |  | 
 |     # and_ is a constant fold function, so we only get here if constant fold is not valid | 
 |     def call_and_(self, tx, a, b): | 
 |         if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( | 
 |             b, (SymNodeVariable, ConstantVariable) | 
 |         ): | 
 |             return SymNodeVariable.create( | 
 |                 tx, | 
 |                 tx.output.create_proxy( | 
 |                     "call_function", operator.and_, *proxy_args_kwargs([a, b], {}) | 
 |                 ), | 
 |                 sym_num=None, | 
 |             ) | 
 |         # None no-ops this handler and lets the driving function proceed | 
 |         return None | 
 |  | 
 |     # or_ is a constant fold function, so we only get here if constant fold is not valid | 
 |     def call_or_(self, tx, a, b): | 
 |         if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( | 
 |             b, (SymNodeVariable, ConstantVariable) | 
 |         ): | 
 |             return SymNodeVariable.create( | 
 |                 tx, | 
 |                 tx.output.create_proxy( | 
 |                     "call_function", operator.or_, *proxy_args_kwargs([a, b], {}) | 
 |                 ), | 
 |                 sym_num=None, | 
 |             ) | 
 |         # None no-ops this handler and lets the driving function proceed | 
 |         return None | 
 |  | 
 |     def call_not_(self, tx, a): | 
 |         if isinstance(a, SymNodeVariable): | 
 |             return SymNodeVariable.create( | 
 |                 tx, | 
 |                 tx.output.create_proxy( | 
 |                     "call_function", operator.not_, *proxy_args_kwargs([a], {}) | 
 |                 ), | 
 |                 sym_num=None, | 
 |             ) | 
 |         return None | 
 |  | 
 |     call_eq = _comparison | 
 |     call_gt = _comparison | 
 |     call_lt = _comparison | 
 |     call_ge = _comparison | 
 |     call_le = _comparison | 
 |     call_ne = _comparison | 
 |     call_is_ = _comparison | 
 |     call_is_not = _comparison |