|  | from .graph_module import GraphModule | 
|  | from .graph import Graph | 
|  | from .node import Argument, Node, Target, map_arg, map_aggregate | 
|  | from .proxy import Proxy | 
|  | from ._symbolic_trace import Tracer | 
|  | from ._compatibility import compatibility | 
|  | from . import config | 
|  | import torch.fx.traceback as fx_traceback | 
|  | import torch | 
|  | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union | 
|  | import inspect | 
|  | from contextlib import contextmanager | 
|  | from torch.hub import tqdm | 
|  |  | 
|  | __all__ = ['Interpreter', 'Transformer'] | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | class Interpreter: | 
|  | """ | 
|  | An Interpreter executes an FX graph Node-by-Node. This pattern | 
|  | can be useful for many things, including writing code | 
|  | transformations as well as analysis passes. | 
|  |  | 
|  | Methods in the Interpreter class can be overridden to customize | 
|  | the behavior of execution. The map of overrideable methods | 
|  | in terms of call hierarchy:: | 
|  |  | 
|  | run() | 
|  | +-- run_node | 
|  | +-- placeholder() | 
|  | +-- get_attr() | 
|  | +-- call_function() | 
|  | +-- call_method() | 
|  | +-- call_module() | 
|  | +-- output() | 
|  |  | 
|  | Example: | 
|  |  | 
|  | Suppose we want to swap all instances of ``torch.neg`` with | 
|  | ``torch.sigmoid`` and vice versa (including their ``Tensor`` | 
|  | method equivalents). We could subclass Interpreter like so:: | 
|  |  | 
|  | class NegSigmSwapInterpreter(Interpreter): | 
|  | def call_function(self, target : Target, | 
|  | args : Tuple, kwargs : Dict) -> Any: | 
|  | if target == torch.sigmoid: | 
|  | return torch.neg(*args, **kwargs) | 
|  | return super().call_function(n) | 
|  |  | 
|  | def call_method(self, target : Target, | 
|  | args : Tuple, kwargs : Dict) -> Any: | 
|  | if target == 'neg': | 
|  | call_self, *args_tail = args | 
|  | return call_self.sigmoid(*args_tail, **kwargs) | 
|  | return super().call_method(n) | 
|  |  | 
|  | def fn(x): | 
|  | return torch.sigmoid(x).neg() | 
|  |  | 
|  | gm = torch.fx.symbolic_trace(fn) | 
|  | input = torch.randn(3, 4) | 
|  | result = NegSigmSwapInterpreter(gm).run(input) | 
|  | torch.testing.assert_close(result, torch.neg(input).sigmoid()) | 
|  |  | 
|  | Args: | 
|  | module (GraphModule): The module to be executed | 
|  | garbage_collect_values (bool): Whether to delete values after their last | 
|  | use within the Module's execution. This ensures optimal memory usage during | 
|  | execution. This can be disabled to, for example, examine all of the intermediate | 
|  | values in the execution by looking at the ``Interpreter.env`` attribute. | 
|  | """ | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def __init__(self, module : GraphModule, garbage_collect_values : bool = True): | 
|  | assert isinstance(module, GraphModule) | 
|  | self.module = module | 
|  | self.submodules = dict(self.module.named_modules()) | 
|  | self.env : Dict[Node, Any] = {} | 
|  | self.name = "Interpreter" | 
|  | self.garbage_collect_values = garbage_collect_values | 
|  | self.extra_traceback = True | 
|  |  | 
|  | if self.garbage_collect_values: | 
|  | # Run through reverse nodes and record the first instance of a use | 
|  | # of a given node. This represents the *last* use of the node in the | 
|  | # execution order of the program, which we will use to free unused | 
|  | # values | 
|  | node_to_last_use : Dict[Node, Node] = {} | 
|  | self.user_to_last_uses : Dict[Node, List[Node]] = {} | 
|  |  | 
|  | def register_last_uses(n : Node, user : Node): | 
|  | if n not in node_to_last_use: | 
|  | node_to_last_use[n] = user | 
|  | self.user_to_last_uses.setdefault(user, []).append(n) | 
|  |  | 
|  | for node in reversed(self.module.graph.nodes): | 
|  | map_arg(node.args, lambda n: register_last_uses(n, node)) | 
|  | map_arg(node.kwargs, lambda n: register_last_uses(n, node)) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: | 
|  | """ | 
|  | Run `module` via interpretation and return the result. | 
|  |  | 
|  | Args: | 
|  | *args: The arguments to the Module to run, in positional order | 
|  | initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. | 
|  | This is a dict mapping `Node` to any value. This can be used, for example, to | 
|  | pre-populate results for certain `Nodes` so as to do only partial evaluation within | 
|  | the interpreter. | 
|  | enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and | 
|  | process_outputs function first before using them. | 
|  |  | 
|  | Returns: | 
|  | Any: The value returned from executing the Module | 
|  | """ | 
|  | self.env = initial_env if initial_env is not None else {} | 
|  |  | 
|  | # Positional function args are consumed left-to-right by | 
|  | # `placeholder` nodes. Use an iterator to keep track of | 
|  | # position and extract those values. | 
|  | if enable_io_processing: | 
|  | args = self.module.graph.process_inputs(*args) | 
|  | self.args_iter : Iterator[Any] = iter(args) | 
|  | pbar = tqdm(total=len(self.module.graph.nodes), | 
|  | desc=f"{self.name}: {str(list(self.module.graph.nodes)) if config.verbose_progress else ''}", | 
|  | initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) | 
|  |  | 
|  | for node in self.module.graph.nodes: | 
|  | pbar.update(1) | 
|  | if node in self.env: | 
|  | # Short circuit if we have this value. This could | 
|  | # be used, for example, for partial evaluation | 
|  | # where the caller has pre-populated `env` with | 
|  | # values for a subset of the program. | 
|  | continue | 
|  |  | 
|  | try: | 
|  | self.env[node] = self.run_node(node) | 
|  | except Exception as e: | 
|  | if self.extra_traceback: | 
|  | msg = f"While executing {node.format_node()}" | 
|  | msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) | 
|  | msg += f"\nOriginal traceback:\n{node.stack_trace}" | 
|  | e.args = (msg,) + e.args[1:] | 
|  | if isinstance(e, KeyError): | 
|  | raise RuntimeError(*e.args) from e | 
|  | raise | 
|  |  | 
|  | if self.garbage_collect_values: | 
|  | for to_delete in self.user_to_last_uses.get(node, []): | 
|  | del self.env[to_delete] | 
|  |  | 
|  | if node.op == 'output': | 
|  | output_val = self.env[node] | 
|  | return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def boxed_run(self, args_list): | 
|  | """ | 
|  | Run `module` via interpretation and return the result.  This uses the "boxed" | 
|  | calling convention, where you pass a list of arguments, which will be cleared | 
|  | by the interpreter.  This ensures that input tensors are promptly deallocated. | 
|  | """ | 
|  | args_iter = iter(args_list) | 
|  | env = {} | 
|  | for n in self.module.graph.nodes: | 
|  | if n.op == "placeholder": | 
|  | env[n] = next(args_iter) | 
|  | args_list.clear() | 
|  | return self.run(initial_env=env) | 
|  |  | 
|  | @contextmanager | 
|  | def _set_current_node(self, node): | 
|  | with fx_traceback.set_current_meta(node): | 
|  | yield | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def run_node(self, n : Node) -> Any: | 
|  | """ | 
|  | Run a specific node ``n`` and return the result. | 
|  | Calls into placeholder, get_attr, call_function, | 
|  | call_method, call_module, or output depending | 
|  | on ``node.op`` | 
|  |  | 
|  | Args: | 
|  | n (Node): The Node to execute | 
|  |  | 
|  | Returns: | 
|  | Any: The result of executing ``n`` | 
|  | """ | 
|  | with self._set_current_node(n): | 
|  | args, kwargs = self.fetch_args_kwargs_from_env(n) | 
|  | assert isinstance(args, tuple) | 
|  | assert isinstance(kwargs, dict) | 
|  | return getattr(self, n.op)(n.target, args, kwargs) | 
|  |  | 
|  | # Main Node running APIs | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | """ | 
|  | Execute a ``placeholder`` node. Note that this is stateful: | 
|  | ``Interpreter`` maintains an internal iterator over | 
|  | arguments passed to ``run`` and this method returns | 
|  | next() on that iterator. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  |  | 
|  | Returns: | 
|  | Any: The argument value that was retrieved. | 
|  | """ | 
|  | assert isinstance(target, str) | 
|  | if target.startswith('*'): | 
|  | # For a starred parameter e.g. `*args`, retrieve all | 
|  | # remaining values from the args list. | 
|  | return list(self.args_iter) | 
|  | else: | 
|  | try: | 
|  | return next(self.args_iter) | 
|  | except StopIteration as si: | 
|  | if len(args) > 0: | 
|  | return args[0] | 
|  | else: | 
|  | raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | """ | 
|  | Execute a ``get_attr`` node. Will retrieve an attribute | 
|  | value from the ``Module`` hierarchy of ``self.module``. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  |  | 
|  | Return: | 
|  | Any: The value of the attribute that was retrieved | 
|  | """ | 
|  | assert isinstance(target, str) | 
|  | return self.fetch_attr(target) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | """ | 
|  | Execute a ``call_function`` node and return the result. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  |  | 
|  | Return | 
|  | Any: The value returned by the function invocation | 
|  | """ | 
|  | assert not isinstance(target, str) | 
|  |  | 
|  | # Execute the function and return the result | 
|  | return target(*args, **kwargs) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | """ | 
|  | Execute a ``call_method`` node and return the result. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  |  | 
|  | Return | 
|  | Any: The value returned by the method invocation | 
|  | """ | 
|  | # args[0] is the `self` object for this method call | 
|  | self_obj, *args_tail = args | 
|  |  | 
|  | # Execute the method and return the result | 
|  | assert isinstance(target, str) | 
|  | return getattr(self_obj, target)(*args_tail, **kwargs) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | """ | 
|  | Execute a ``call_module`` node and return the result. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  |  | 
|  | Return | 
|  | Any: The value returned by the module invocation | 
|  | """ | 
|  | # Retrieve executed args and kwargs values from the environment | 
|  |  | 
|  | # Execute the method and return the result | 
|  | assert isinstance(target, str) | 
|  | submod = self.fetch_attr(target) | 
|  |  | 
|  | return submod(*args, **kwargs) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | """ | 
|  | Execute an ``output`` node. This really just retrieves | 
|  | the value referenced by the ``output`` node and returns it. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  |  | 
|  | Return: | 
|  | Any: The return value referenced by the output node | 
|  | """ | 
|  | return args[0] | 
|  |  | 
|  | # Helper methods | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def fetch_attr(self, target : str): | 
|  | """ | 
|  | Fetch an attribute from the ``Module`` hierarchy of ``self.module``. | 
|  |  | 
|  | Args: | 
|  | target (str): The fully-qualified name of the attribute to fetch | 
|  |  | 
|  | Return: | 
|  | Any: The value of the attribute. | 
|  | """ | 
|  | target_atoms = target.split('.') | 
|  | attr_itr = self.module | 
|  | for i, atom in enumerate(target_atoms): | 
|  | if not hasattr(attr_itr, atom): | 
|  | raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") | 
|  | attr_itr = getattr(attr_itr, atom) | 
|  | return attr_itr | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: | 
|  | """ | 
|  | Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` | 
|  | from the current execution environment. | 
|  |  | 
|  | Args: | 
|  | n (Node): The node for which ``args`` and ``kwargs`` should be fetched. | 
|  |  | 
|  | Return: | 
|  | Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. | 
|  | """ | 
|  | args = self.map_nodes_to_values(n.args, n) | 
|  | assert isinstance(args, tuple) | 
|  | kwargs = self.map_nodes_to_values(n.kwargs, n) | 
|  | assert isinstance(kwargs, dict) | 
|  | return args, kwargs | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: | 
|  | """ | 
|  | Recursively descend through ``args`` and look up the concrete value | 
|  | for each ``Node`` in the current execution environment. | 
|  |  | 
|  | Args: | 
|  | args (Argument): Data structure within which to look up concrete values | 
|  |  | 
|  | n (Node): Node to which ``args`` belongs. This is only used for error reporting. | 
|  | """ | 
|  | def load_arg(n_arg : Node) -> Any: | 
|  | if n_arg not in self.env: | 
|  | raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' | 
|  | f'to diagnose such issues') | 
|  | return self.env[n_arg] | 
|  | return map_arg(args, load_arg) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | class Transformer(Interpreter): | 
|  | """ | 
|  | ``Transformer`` is a special type of interpreter that produces a | 
|  | new ``Module``. It exposes a ``transform()`` method that returns | 
|  | the transformed ``Module``. ``Transformer`` does not require | 
|  | arguments to run, as ``Interpreter`` does. ``Transformer`` works | 
|  | entirely symbolically. | 
|  |  | 
|  | Example: | 
|  |  | 
|  | Suppose we want to swap all instances of ``torch.neg`` with | 
|  | ``torch.sigmoid`` and vice versa (including their ``Tensor`` | 
|  | method equivalents). We could subclass ``Transformer`` like so:: | 
|  |  | 
|  | class NegSigmSwapXformer(Transformer): | 
|  | def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | if target == torch.sigmoid: | 
|  | return torch.neg(*args, **kwargs) | 
|  | return super().call_function(n) | 
|  |  | 
|  | def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | if target == 'neg': | 
|  | call_self, *args_tail = args | 
|  | return call_self.sigmoid(*args_tail, **kwargs) | 
|  | return super().call_method(n) | 
|  |  | 
|  | def fn(x): | 
|  | return torch.sigmoid(x).neg() | 
|  |  | 
|  | gm = torch.fx.symbolic_trace(fn) | 
|  |  | 
|  | transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() | 
|  | input = torch.randn(3, 4) | 
|  | torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) | 
|  |  | 
|  | Args: | 
|  | module (GraphModule): The ``Module`` to be transformed. | 
|  | """ | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def __init__(self, module): | 
|  | super().__init__(module) | 
|  | self.new_graph = Graph() | 
|  | self.new_graph.set_codegen(module.graph._codegen) | 
|  |  | 
|  | class TransformerTracer(Tracer): | 
|  | def __init__(self, graph: Graph): | 
|  | super().__init__() | 
|  | self.graph = graph | 
|  | self.tensor_attrs: Dict[torch.Tensor, str] = {}  # type: ignore[assignment] | 
|  |  | 
|  | def is_leaf_module(self, _, __) -> bool: | 
|  | return True | 
|  |  | 
|  | self.tracer = TransformerTracer(self.new_graph) | 
|  | self.tracer.root = module | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: | 
|  | """ | 
|  | Execute a ``placeholder`` node. In ``Transformer``, this is | 
|  | overridden to insert a new ``placeholder`` into the output | 
|  | graph. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  | """ | 
|  | assert isinstance(target, str) | 
|  | default_value = next(iter(args)) if args else inspect.Signature.empty | 
|  | return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: | 
|  | """ | 
|  | Execute a ``get_attr`` node. In ``Transformer``, this is | 
|  | overridden to insert a new ``get_attr`` node into the output | 
|  | graph. | 
|  |  | 
|  | Args: | 
|  | target (Target): The call target for this node. See | 
|  | `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for | 
|  | details on semantics | 
|  | args (Tuple): Tuple of positional args for this invocation | 
|  | kwargs (Dict): Dict of keyword arguments for this invocation | 
|  | """ | 
|  | assert isinstance(target, str) | 
|  | return self.tracer.create_proxy("get_attr", target, args, kwargs) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | # Override so that the leaf module policy from `self.tracer` is respected. | 
|  | assert isinstance(target, str) | 
|  | submod = self.fetch_attr(target) | 
|  | return self.tracer.call_module(submod, submod.forward, args, kwargs) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | 
|  | # Override so that functions that were wrapped are still wrapped. | 
|  | return self.tracer.create_proxy('call_function', target, args, kwargs) | 
|  |  | 
|  | @compatibility(is_backward_compatible=True) | 
|  | def transform(self) -> GraphModule: | 
|  | """ | 
|  | Transform ``self.module`` and return the transformed | 
|  | ``GraphModule``. | 
|  | """ | 
|  | with fx_traceback.preserve_node_meta(): | 
|  | result = super().run(enable_io_processing=False) | 
|  | if result is not None: | 
|  | def strip_proxy(a : Union[Argument, Proxy]) -> Any: | 
|  | return a.node if isinstance(a, Proxy) else a | 
|  | self.new_graph.output(map_aggregate(result, strip_proxy)) | 
|  | return GraphModule(self.module, self.new_graph) |