| import torch | 
 | import torch.nn as nn | 
 | import torch.overrides | 
 | from torch.nn.modules.module import _addindent | 
 | from torch.package import PackageImporter, PackageExporter | 
 | import linecache | 
 | from typing import Type, Dict, List, Any, Union, Optional, Set | 
 | from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode | 
 | from ._compatibility import compatibility | 
 | from torch.package import Importer, sys_importer | 
 | import copy | 
 | import itertools | 
 | import sys | 
 | import traceback | 
 | from pathlib import Path | 
 | import os | 
 | import warnings | 
 |  | 
 | # Normal exec loses the source code, however we can work with | 
 | # the linecache module to recover it. | 
 | # Using _exec_with_source will add it to our local cache | 
 | # and then tools like TorchScript will be able to get source info. | 
 | class _EvalCacheLoader(object): | 
 |     def __init__(self): | 
 |         self.eval_cache = {} | 
 |         self.next_id = 0 | 
 |  | 
 |     def cache(self, src: str, globals: Dict[str, Any]): | 
 |         """Store the source in a private cache, and add a lazy entry in linecache | 
 |         that allows the source to be retrieved by 'filename'. | 
 |  | 
 |         Args: | 
 |             src (str): The module source to cache | 
 |             globals (dict): The module globals | 
 |  | 
 |         Returns: | 
 |             str: The cache key (and dummy filename) generated for src. | 
 |         """ | 
 |  | 
 |         key = self._get_key() | 
 |         self.eval_cache[key] = src | 
 |  | 
 |         # Don't mutate globals so that this loader is only used | 
 |         # to populate linecache, and doesn't interact with other modules | 
 |         # that might check `__loader__` | 
 |         globals_copy = globals.copy() | 
 |         globals_copy['__file__'] = key | 
 |         globals_copy['__name__'] = key | 
 |         globals_copy['__loader__'] = self | 
 |         linecache.lazycache(key, globals_copy) | 
 |  | 
 |         return key | 
 |  | 
 |     # Part of the loader protocol (PEP 302) | 
 |     # linecache will use this method when trying to find source code | 
 |     def get_source(self, module_name) -> Optional[str]: | 
 |         if module_name in self.eval_cache: | 
 |             return self.eval_cache[module_name] | 
 |         return None | 
 |  | 
 |     def _get_key(self): | 
 |         key = f'<eval_with_key>.{self.next_id}' | 
 |         self.next_id += 1 | 
 |         return key | 
 |  | 
 | _loader = _EvalCacheLoader() | 
 |  | 
 |  | 
 | def _exec_with_source(src: str, globals: Dict[str, Any]): | 
 |     key = _loader.cache(src, globals) | 
 |     exec(compile(src, key, 'exec'), globals) | 
 |  | 
 |  | 
 | def _forward_from_src(src: str, globals: Dict[str, Any]): | 
 |     # avoid mutating the passed in dict | 
 |     globals_copy = globals.copy() | 
 |     _exec_with_source(src, globals_copy) | 
 |     forward_fn = globals_copy['forward'] | 
 |     del globals_copy['forward'] | 
 |     return forward_fn | 
 |  | 
 |  | 
 | def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: | 
 |     if name in _custom_builtins: | 
 |         return _custom_builtins[name].import_str | 
 |     if _is_from_torch(name): | 
 |         return 'import torch' | 
 |     module_name, attr_name = importer.get_name(obj) | 
 |     return f'from {module_name} import {attr_name} as {name}' | 
 |  | 
 |  | 
 | def _format_import_block(globals: Dict[str, Any], importer: Importer): | 
 |     import_strs: Set[str] = set() | 
 |     for name, obj in globals.items(): | 
 |         import_strs.add(_format_import_statement(name, obj, importer)) | 
 |     return '\n'.join(import_strs) | 
 |  | 
 |  | 
 | @compatibility(is_backward_compatible=True) | 
 | def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module: | 
 |     # BC: attribute name was changed from `code` to `_code` to facilitate | 
 |     # making `code` into a property and adding a docstring to it | 
 |     fn_src = body.get('_code') or body['code'] | 
 |     forward = _forward_from_src(import_block + fn_src, {}) | 
 |     return _deserialize_graph_module(forward, body) | 
 |  | 
 |  | 
 | @compatibility(is_backward_compatible=True) | 
 | def reduce_package_graph_module( | 
 |     importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str | 
 | ) -> torch.nn.Module: | 
 |     forward = importer.import_module(generated_module_name).forward | 
 |     return _deserialize_graph_module(forward, body) | 
 |  | 
 | @compatibility(is_backward_compatible=True) | 
 | def reduce_deploy_graph_module( | 
 |     importer: PackageImporter, body: Dict[Any, Any], import_block: str | 
 | ) -> torch.nn.Module: | 
 |     ns = dict() | 
 |     ns["__builtins__"] = importer.patched_builtins | 
 |     fn_src = body.get('_code') | 
 |     assert fn_src is not None | 
 |     forward = _forward_from_src(import_block + fn_src, ns) | 
 |     return _deserialize_graph_module(forward, body) | 
 |  | 
 |  | 
 | def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module: | 
 |     """ | 
 |     Deserialize a GraphModule given the dictionary of the original module, | 
 |     using the code to reconstruct the graph. We delete the actual graph before | 
 |     saving the dictionary so that changes to the in-memory graph format do not | 
 |     get serialized. | 
 |     """ | 
 |     # We create a dummy class here because symbolic_trace pulls the forward() | 
 |     # function off of the class, rather than the instance | 
 |     class CodeOnlyModule(torch.nn.Module): | 
 |         def __init__(self, body): | 
 |             super().__init__() | 
 |             self.__dict__ = body | 
 |  | 
 |     # Try to retrieve the forward source in a backward-compatible way | 
 |     CodeOnlyModule.forward = forward | 
 |  | 
 |     tracer_cls = body.get('_tracer_cls') | 
 |     if tracer_cls is None: | 
 |         from ._symbolic_trace import Tracer | 
 |         tracer_cls = Tracer | 
 |  | 
 |     graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule') | 
 |  | 
 |     # This is a workaround for a mypy linter issue related to | 
 |     # passing base class as an argument - https://github.com/python/mypy/issues/5865. | 
 |     cls_tracer : Any = tracer_cls | 
 |  | 
 |     class KeepModules(cls_tracer): | 
 |         # we shouldn't trace into any of the submodules, | 
 |         # because they were not traced in the original GraphModule | 
 |         def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: | 
 |             return True | 
 |  | 
 |     com = CodeOnlyModule(body) | 
 |  | 
 |     tracer_extras = body.get('_tracer_extras', {}) | 
 |     graph = KeepModules().trace(com, **tracer_extras) | 
 |  | 
 |     # Manually set Tracer class on the reconstructed Graph, to avoid | 
 |     # referencing the private local subclass KeepModules. | 
 |     graph._tracer_cls = tracer_cls | 
 |     gm = GraphModule(com, graph, class_name=graphmodule_cls_name) | 
 |  | 
 |     # The GraphModule constructor only retains attributes referenced by the graph. | 
 |     # In this case, our goal is return a GraphModule as close to identical as the one | 
 |     # put into the package. If any additional attributes were present in body, | 
 |     # we should keep them. | 
 |     for k, v in body.items(): | 
 |         if not hasattr(gm, k): | 
 |             setattr(gm, k, v) | 
 |     return gm | 
 |  | 
 | # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' | 
 | # This installs empty Modules where none exist yet if they are subpaths of target | 
 | def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): | 
 |     *prefix, field = target.split('.') | 
 |     for item in prefix: | 
 |         f = getattr(from_module, item) | 
 |         t = getattr(to_module, item, None) | 
 |         if f is t: | 
 |             # we have already installed one of its parents | 
 |             # (e.g. target = root.linear.weight, but we have already installed root.linear) | 
 |             # once we install a parent, we no longer need to copy the children | 
 |             # since all the needed properties will already be present | 
 |             return | 
 |  | 
 |         if t is None: | 
 |             t = torch.nn.Module() | 
 |             setattr(to_module, item, t) | 
 |         from_module, to_module = f, t | 
 |  | 
 |     orig = getattr(from_module, field) | 
 |     # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. | 
 |     # So, we register it as a named buffer in the target module. | 
 |     if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): | 
 |         to_module.register_buffer(field, orig) | 
 |     else: | 
 |         setattr(to_module, field, orig) | 
 |  | 
 | # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module | 
 | # This installs empty Modules where none exist yet if they are subpaths of target | 
 | def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): | 
 |     *prefix, field = target.split('.') | 
 |     for item in prefix: | 
 |         t = getattr(to_module, item, None) | 
 |  | 
 |         if t is None: | 
 |             t = torch.nn.Module() | 
 |             setattr(to_module, item, t) | 
 |         to_module = t | 
 |  | 
 |     # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. | 
 |     # So, we register it as a named buffer in the target module. | 
 |     if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter): | 
 |         to_module.register_buffer(field, from_obj) | 
 |     else: | 
 |         setattr(to_module, field, from_obj) | 
 |  | 
 | class _WrappedCall: | 
 |     def __init__(self, cls, cls_call): | 
 |         self.cls = cls | 
 |         self.cls_call = cls_call | 
 |  | 
 |     # Previously, if an error occurred when valid | 
 |     # symbolically-traced code was run with an invalid input, the | 
 |     # user would see the source of the error as coming from | 
 |     # `File "<eval_with_key_N">`, where N is some number. We use | 
 |     # this function to generate a more informative error message. We | 
 |     # return the traceback itself, a message explaining that the | 
 |     # error occurred in a traced Module's generated forward | 
 |     # function, and five lines of context surrounding the faulty | 
 |     # line | 
 |     @staticmethod | 
 |     def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: | 
 |         # auxiliary variables (for readability) | 
 |         err_lineno = frame_summary.lineno | 
 |         assert err_lineno is not None | 
 |         line = frame_summary.line | 
 |         assert line is not None | 
 |         err_line_len = len(line) | 
 |         all_src_lines = linecache.getlines(frame_summary.filename) | 
 |  | 
 |         # constituent substrings of the error message | 
 |         tb_repr = traceback.format_exc() | 
 |         custom_msg = ("Call using an FX-traced Module, " | 
 |                       f"line {err_lineno} of the traced Module's " | 
 |                       "generated forward function:") | 
 |         before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) | 
 |         marker = "~" * err_line_len + "~~~ <--- HERE" | 
 |         err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) | 
 |  | 
 |         # joined message | 
 |         return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) | 
 |  | 
 |     def __call__(self, obj, *args, **kwargs): | 
 |         try: | 
 |             if self.cls_call is not None: | 
 |                 return self.cls_call(obj, *args, **kwargs) | 
 |             else: | 
 |                 return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc] | 
 |         except Exception as e: | 
 |             assert e.__traceback__ | 
 |             topmost_framesummary: traceback.FrameSummary = \ | 
 |                 traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1]  # type: ignore[arg-type] | 
 |             if "eval_with_key" in topmost_framesummary.filename: | 
 |                 print(_WrappedCall._generate_error_message(topmost_framesummary), | 
 |                       file=sys.stderr) | 
 |                 raise e.with_traceback(None) | 
 |             else: | 
 |                 raise e | 
 |  | 
 | @compatibility(is_backward_compatible=True) | 
 | class GraphModule(torch.nn.Module): | 
 |     """ | 
 |     GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a | 
 |     ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated | 
 |     from that ``graph``. | 
 |  | 
 |     .. warning:: | 
 |  | 
 |         When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically | 
 |         regenerated. However, if you edit the contents of the ``graph`` without reassigning | 
 |         the ``graph`` attribute itself, you must call ``recompile()`` to update the generated | 
 |         code. | 
 |     """ | 
 |     def __new__(cls: 'Type[GraphModule]', *args, **kwargs): | 
 |         # each instance of a graph module needs its own forward method | 
 |         # so create a new singleton class for each instance. | 
 |         # it is a subclass of the user-defined class, the only difference | 
 |         # is an extra layer to install the forward method | 
 |  | 
 |         # address issue described at https://github.com/pytorch/pytorch/issues/63883 | 
 |         # in other words, traverse class hierarchy to fix the redundant class definition problem | 
 |         for t in cls.__mro__: | 
 |             c = t.__qualname__.split('.')[-1] | 
 |             if c != 'GraphModuleImpl': | 
 |                 cls = t | 
 |                 break | 
 |  | 
 |         class GraphModuleImpl(cls):  # type: ignore[misc, valid-type] | 
 |             pass | 
 |         return super().__new__(GraphModuleImpl) | 
 |  | 
 |     @compatibility(is_backward_compatible=True) | 
 |     def __init__(self, | 
 |                  root: Union[torch.nn.Module, Dict[str, Any]], | 
 |                  graph: Graph, | 
 |                  class_name: str = 'GraphModule'): | 
 |         """ | 
 |         Construct a GraphModule. | 
 |  | 
 |         Args: | 
 |  | 
 |             root (Union[torch.nn.Module, Dict[str, Any]): | 
 |                 ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. | 
 |                 In the case that ``root`` is a Module, any references to Module-based objects (via qualified | 
 |                 name) in the Graph's Nodes' ``target`` field will be copied over from the respective place | 
 |                 within ``root``'s Module hierarchy into the GraphModule's module hierarchy. | 
 |                 In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be | 
 |                 looked up directly in the dict's keys. The object mapped to by the Dict will be copied | 
 |                 over into the appropriate place within the GraphModule's module hierarchy. | 
 |  | 
 |             graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation | 
 |  | 
 |             class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all | 
 |                 error messages will report as originating from ``GraphModule``. It may be helpful to set this | 
 |                 to ``root``'s original name or a name that makes sense within the context of your transform. | 
 |         """ | 
 |         super().__init__() | 
 |         self.__class__.__name__ = class_name | 
 |         if isinstance(root, torch.nn.Module): | 
 |             if hasattr(root, 'training'): | 
 |                 self.training = root.training | 
 |             for node in graph.nodes: | 
 |                 if node.op in ['get_attr', 'call_module']: | 
 |                     assert isinstance(node.target, str) | 
 |                     _copy_attr(root, self, node.target) | 
 |         elif isinstance(root, dict): | 
 |             targets_to_copy = [] | 
 |             for node in graph.nodes: | 
 |                 if node.op in ['get_attr', 'call_module']: | 
 |                     assert isinstance(node.target, str) | 
 |                     if node.target not in root: | 
 |                         raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target + | 
 |                                            ' but that target was not provided in ``root``!') | 
 |                     targets_to_copy.append(node.target) | 
 |             # Sort targets in ascending order of the # of atoms. | 
 |             # This will ensure that less deeply nested attributes are assigned | 
 |             # before more deeply nested attributes. For example, foo.bar | 
 |             # will be assigned before foo.bar.baz. Otherwise, we might assign | 
 |             # the user-provided ``foo.bar`` and wipe out the previously-assigned | 
 |             # ``foo.bar.baz`` | 
 |             targets_to_copy.sort(key=lambda t: t.count('.')) | 
 |             for target_to_copy in targets_to_copy: | 
 |                 _assign_attr(root[target_to_copy], self, target_to_copy) | 
 |         else: | 
 |             raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!') | 
 |  | 
 |         self.graph = graph | 
 |  | 
 |         # Store the Tracer class responsible for creating a Graph separately as part of the | 
 |         # GraphModule state, except when the Tracer is defined in a local namespace. | 
 |         # Locally defined Tracers are not pickleable. This is needed because torch.package will | 
 |         # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer | 
 |         # to re-create the Graph during deserialization. | 
 |         self._tracer_cls = None | 
 |         if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__: | 
 |             self._tracer_cls = self.graph._tracer_cls | 
 |  | 
 |         self._tracer_extras = {} | 
 |         if self.graph._tracer_extras: | 
 |             self._tracer_extras = self.graph._tracer_extras | 
 |  | 
 |     # TorchScript breaks trying to compile the graph setter because of the | 
 |     # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 | 
 |     # | 
 |     # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway | 
 |     __jit_unused_properties__ = ['graph'] | 
 |  | 
 |     @property | 
 |     def graph(self) -> Graph: | 
 |         """ | 
 |         Return the ``Graph`` underlying this ``GraphModule`` | 
 |         """ | 
 |         return self._graph | 
 |  | 
 |     @graph.setter | 
 |     def graph(self, g : Graph) -> None: | 
 |         """ | 
 |         Set the underlying ``Graph`` for this ``GraphModule``. This will internally | 
 |         recompile the ``GraphModule`` so that the generated ``forward()`` function | 
 |         corresponds to ``g`` | 
 |         """ | 
 |         assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}' | 
 |         self._graph = g | 
 |         g.owning_module = self | 
 |         self.recompile() | 
 |  | 
 |     @compatibility(is_backward_compatible=False) | 
 |     def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"): | 
 |         """Dumps out module to ``folder`` with ``module_name`` so that it can be | 
 |         imported with ``from <folder> import <module_name>`` | 
 |  | 
 |         Args: | 
 |  | 
 |             folder (Union[str, os.PathLike]): The folder to write the code out to | 
 |  | 
 |             module_name (str): Top-level name to use for the ``Module`` while | 
 |                 writing out the code | 
 |         """ | 
 |         folder = Path(folder) | 
 |         Path(folder).mkdir(exist_ok=True) | 
 |         torch.save(self.state_dict(), folder / 'state_dict.pt') | 
 |         tab = " " * 4 | 
 |         custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()]) | 
 |         model_str = f""" | 
 | import torch | 
 | {custom_builtins} | 
 |  | 
 | from torch.nn import * | 
 | class {module_name}(torch.nn.Module): | 
 |     def __init__(self): | 
 |         super().__init__() | 
 | """ | 
 |  | 
 |         def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: | 
 |             safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d] | 
 |             if type(module) in safe_reprs: | 
 |                 return f"{module.__repr__()}" | 
 |             else: | 
 |                 return None | 
 |  | 
 |         blobified_modules = [] | 
 |         for module_name, module in self.named_children(): | 
 |             module_str = _gen_model_repr(module_name, module) | 
 |             if module_str is None: | 
 |                 module_file = folder / f'{module_name}.pt' | 
 |                 torch.save(module, module_file) | 
 |                 blobified_modules.append(module_name) | 
 |                 module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') | 
 |                 module_str = f"torch.load(r'{module_file}') # {module_repr}" | 
 |             model_str += f"{tab*2}self.{module_name} = {module_str}\n" | 
 |  | 
 |         for buffer_name, buffer in self._buffers.items(): | 
 |             if buffer is None: | 
 |                 continue | 
 |             model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" | 
 |  | 
 |         for param_name, param in self._parameters.items(): | 
 |             if param is None: | 
 |                 continue | 
 |             model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" | 
 |  | 
 |         model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" | 
 |         model_str += f"{_addindent(self.code, 4)}\n" | 
 |  | 
 |         module_file = folder / 'module.py' | 
 |         module_file.write_text(model_str) | 
 |  | 
 |         init_file = folder / '__init__.py' | 
 |         init_file.write_text('from .module import *') | 
 |  | 
 |         if len(blobified_modules) > 0: | 
 |             warnings.warn("Was not able to save the following children modules as reprs -" | 
 |                           f"saved as pickled files instead: {blobified_modules}") | 
 |  | 
 |     @compatibility(is_backward_compatible=True) | 
 |     def add_submodule(self, target: str, m: torch.nn.Module) -> bool: | 
 |         """ | 
 |         Adds the given submodule to ``self``. | 
 |  | 
 |         This installs empty Modules where none exist yet if they are | 
 |         subpaths of ``target``. | 
 |  | 
 |         Args: | 
 |             target: The fully-qualified string name of the new submodule | 
 |                 (See example in ``nn.Module.get_submodule`` for how to | 
 |                 specify a fully-qualified string.) | 
 |             m: The submodule itself; the actual object we want to | 
 |                 install in the current Module | 
 |  | 
 |         Return: | 
 |             bool: Whether or not the submodule could be inserted. For | 
 |                 this method to return True, each object in the chain | 
 |                 denoted by ``target`` must either a) not exist yet, | 
 |                 or b) reference an ``nn.Module`` (not a parameter or | 
 |                 other attribute) | 
 |         """ | 
 |         *prefix, field = target.split('.') | 
 |         mod: torch.nn.Module = self | 
 |  | 
 |         for item in prefix: | 
 |  | 
 |             submod = getattr(mod, item, None) | 
 |  | 
 |             if submod is None: | 
 |                 submod = torch.nn.Module() | 
 |                 setattr(mod, item, submod) | 
 |  | 
 |             if not isinstance(submod, torch.nn.Module): | 
 |                 return False | 
 |  | 
 |             mod = submod | 
 |  | 
 |         mod.add_module(field, m) | 
 |         return True | 
 |  | 
 |     @compatibility(is_backward_compatible=True) | 
 |     def delete_submodule(self, target: str) -> bool: | 
 |         """ | 
 |         Deletes the given submodule from ``self``. | 
 |  | 
 |         The module will not be deleted if ``target`` is not a valid | 
 |         target. | 
 |  | 
 |         Args: | 
 |             target: The fully-qualified string name of the new submodule | 
 |                 (See example in ``nn.Module.get_submodule`` for how to | 
 |                 specify a fully-qualified string.) | 
 |  | 
 |         Returns: | 
 |             bool: Whether or not the target string referenced a | 
 |                 submodule we want to delete. A return value of ``False`` | 
 |                 means that the ``target`` was not a valid reference to | 
 |                 a submodule. | 
 |         """ | 
 |         atoms = target.split(".") | 
 |         path, target_submod = atoms[:-1], atoms[-1] | 
 |         mod: torch.nn.Module = self | 
 |  | 
 |         # Get the parent module | 
 |         for item in path: | 
 |  | 
 |             if not hasattr(mod, item): | 
 |                 return False | 
 |  | 
 |             mod = getattr(mod, item) | 
 |  | 
 |             if not isinstance(mod, torch.nn.Module): | 
 |                 return False | 
 |  | 
 |         if not hasattr(mod, target_submod): | 
 |             return False | 
 |  | 
 |         if not isinstance(getattr(mod, target_submod), torch.nn.Module): | 
 |             return False | 
 |  | 
 |         delattr(mod, target_submod) | 
 |         return True | 
 |  | 
 |     @compatibility(is_backward_compatible=True) | 
 |     def delete_all_unused_submodules(self) -> None: | 
 |         """ | 
 |         Deletes all unused submodules from ``self``. | 
 |  | 
 |         A Module is considered "used" if any one of the following is | 
 |         true: | 
 |         1. It has children that are used | 
 |         2. Its forward is called directly via a ``call_module`` node | 
 |         3. It has a non-Module attribute that is used from a | 
 |         ``get_attr`` node | 
 |  | 
 |         This method can be called to clean up an ``nn.Module`` without | 
 |         manually calling ``delete_submodule`` on each unused submodule. | 
 |         """ | 
 |         used: List[str] = [] | 
 |  | 
 |         for node in self.graph.nodes: | 
 |  | 
 |             if node.op == "call_module" or node.op == "get_attr": | 
 |  | 
 |                 # A list of strings representing the different parts | 
 |                 # of the path. For exmaple, `foo.bar.baz` gives us | 
 |                 # ["foo", "bar", "baz"] | 
 |                 fullpath = node.target.split(".") | 
 |  | 
 |                 # If we're looking at multiple parts of a path, join | 
 |                 # join them with a dot. Otherwise, return that single | 
 |                 # element without doing anything to it. | 
 |                 def join_fn(x: str, y: str) -> str: | 
 |                     return '.'.join([x, y] if y else [x]) | 
 |  | 
 |                 # Progressively collect all the names of intermediate | 
 |                 # modules. For example, if we have the target | 
 |                 # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and | 
 |                 # `foo.bar.baz` to the list. | 
 |                 for path in itertools.accumulate(fullpath, join_fn): | 
 |                     used.append(path) | 
 |  | 
 |                 # For a `call_module` node, also register all recursive submodules | 
 |                 # as used | 
 |                 if node.op == "call_module": | 
 |                     try: | 
 |                         submod = self.get_submodule(node.target) | 
 |  | 
 |                         for submod_name, _ in submod.named_modules(): | 
 |                             if submod_name != '': | 
 |                                 used.append('.'.join([node.target, submod_name])) | 
 |                     except AttributeError: | 
 |                         # Node referenced nonexistent submodule, don't need to | 
 |                         # worry about GCing anything | 
 |                         pass | 
 |  | 
 |         to_delete = [name for name, _ in self.named_modules() | 
 |                      if name not in used] | 
 |  | 
 |         for name in to_delete: | 
 |             self.delete_submodule(name) | 
 |  | 
 |     @property | 
 |     def code(self) -> str: | 
 |         """ | 
 |         Return the Python code generated from the ``Graph`` underlying this | 
 |         ``GraphModule``. | 
 |         """ | 
 |         if not hasattr(self, '_code'): | 
 |             raise RuntimeError('Code has not been generated! Please report a bug to PyTorch') | 
 |         return self._code | 
 |  | 
 |     @compatibility(is_backward_compatible=True) | 
 |     def recompile(self) -> PythonCode: | 
 |         """ | 
 |         Recompile this GraphModule from its ``graph`` attribute. This should be | 
 |         called after editing the contained ``graph``, otherwise the generated | 
 |         code of this ``GraphModule`` will be out of date. | 
 |         """ | 
 |         if isinstance(self._graph._codegen, _PyTreeCodeGen): | 
 |             self._in_spec = self._graph._codegen.pytree_info.in_spec | 
 |             self._out_spec = self._graph._codegen.pytree_info.out_spec | 
 |         python_code = self._graph.python_code(root_module='self') | 
 |         self._code = python_code.src | 
 |  | 
 |         cls = type(self) | 
 |         cls.forward = _forward_from_src(self._code, python_code.globals) | 
 |  | 
 |         # Determine whether this class explicitly defines a __call__ implementation | 
 |         # to wrap. If it does, save it in order to have wrapped_call invoke it. | 
 |         # If it does not, wrapped_call can use a dynamic call to super() instead. | 
 |         # In most cases, super().__call__ should be torch.nn.Module.__call__. | 
 |         # We do not want to hold a reference to Module.__call__ here; doing so will | 
 |         # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. | 
 |         cls_call = cls.__call__ if "__call__" in vars(cls) else None | 
 |  | 
 |         if '_wrapped_call' not in vars(cls): | 
 |             cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined] | 
 |  | 
 |         def call_wrapped(self, *args, **kwargs): | 
 |             return self._wrapped_call(self, *args, **kwargs) | 
 |  | 
 |         cls.__call__ = call_wrapped | 
 |  | 
 |         return python_code | 
 |  | 
 |     # Passing Tracer as argument allows subclasses extending fx.GraphModule | 
 |     # define their own Tracer (extending fx.Tracer). | 
 |     def __reduce_deploy__(self, importer: Importer): | 
 |         dict_without_graph = self.__dict__.copy() | 
 |         dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__ | 
 |         del dict_without_graph['_graph'] | 
 |  | 
 |         python_code = self.recompile() | 
 |         import_block = _format_import_block(python_code.globals, importer) | 
 |         return (reduce_deploy_graph_module, (dict_without_graph, import_block)) | 
 |  | 
 |     def __reduce_package__(self, exporter: PackageExporter): | 
 |         dict_without_graph = self.__dict__.copy() | 
 |         dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__ | 
 |         del dict_without_graph['_graph'] | 
 |  | 
 |         generated_module_name = f'fx-generated._{exporter.get_unique_id()}' | 
 |         python_code = self.recompile() | 
 |         import_block = _format_import_block(python_code.globals, exporter.importer) | 
 |         module_code = import_block + self.code | 
 |         exporter.save_source_string(generated_module_name, module_code) | 
 |         return (reduce_package_graph_module, (dict_without_graph, generated_module_name)) | 
 |  | 
 |     def __reduce__(self): | 
 |         """ | 
 |         Serialization of GraphModule. We serialize only the generated code, not | 
 |         the underlying ``Graph``. This is because ``Graph`` does not have on-disk | 
 |         backward-compatibility guarantees, whereas Python source code does. | 
 |         On the deserialization side, we symbolically trace through the generated | 
 |         code to regenerate the underlying ``Graph`` | 
 |         """ | 
 |         dict_without_graph = self.__dict__.copy() | 
 |         python_code = self.recompile() | 
 |         import_block = _format_import_block(python_code.globals, sys_importer) | 
 |         del dict_without_graph['_graph'] | 
 |         return (reduce_graph_module, (dict_without_graph, import_block)) | 
 |  | 
 |     # because __reduce__ is defined for serialization, | 
 |     # we need to define deepcopy otherwise it will call __reduce__ | 
 |     # and cause symbolic tracing to occur every time we try to copy the object | 
 |     def __deepcopy__(self, memo): | 
 |         fake_mod = torch.nn.Module() | 
 |         fake_mod.__dict__ = copy.deepcopy(self.__dict__) | 
 |         return GraphModule(fake_mod, fake_mod.__dict__['_graph']) | 
 |  | 
 |     def __copy__(self): | 
 |         return GraphModule(self, self.graph) | 
 |  | 
 |     @compatibility(is_backward_compatible=False) | 
 |     def nested_str(self) -> str: | 
 |         """ | 
 |         Return the Python code generated for current GraphModule and its children GraphModules | 
 |         """ | 
 |         module_code = self.code | 
 |         module_code = module_code.lstrip('\n') | 
 |         module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code | 
 |         module_code = _addindent(module_code, 4) | 
 |  | 
 |         submodule_code_list = [""] | 
 |         for submodule in self.children(): | 
 |             if isinstance(submodule, GraphModule): | 
 |                 submodule_code_list.append(submodule.__nested_code()) | 
 |         submodule_code = "\n".join(submodule_code_list) | 
 |         submodule_code = _addindent(submodule_code, 4) | 
 |  | 
 |         return module_code + submodule_code | 
 |  | 
 |     def __str__(self) -> str: | 
 |         orig_str = super().__str__() | 
 |         return '\n'.join([orig_str, self._code]) | 
 |  | 
 |     def _replicate_for_data_parallel(self): | 
 |         new_gm = self.__copy__() | 
 |         new_gm._is_replica = True | 
 |         return new_gm | 
 |  | 
 | # workarounds for issues in __torch_function__ | 
 |  | 
 | # WAR for __torch_function__ not handling tensor lists, | 
 | # fix is in https://github.com/pytorch/pytorch/pull/34725 | 
 | # orig_cat = torch.cat | 
 | # def patched_cat(*args, **kwargs): | 
 | #     tensors = args[0] | 
 | #     for t in tensors: | 
 | #         if isinstance(t, Proxy): | 
 | #             return t.__torch_function__(patched_cat, (), args, kwargs) | 
 | #     return orig_cat(*args, **kwargs) | 
 | # patched_cat.__module__ = 'torch' | 
 | # patched_cat.__name__ = 'cat' | 
 | # torch.cat = patched_cat |