|  | import operator | 
|  | from typing import Any, Callable, Dict, Tuple, Optional | 
|  |  | 
|  | import torch | 
|  | import torch.fx | 
|  | import torch.fx as fx | 
|  | from torch.fx import Transformer, Proxy | 
|  | from torch.fx.node import Argument, Target, Node, map_aggregate | 
|  | from torch.fx.operator_schemas import ( | 
|  | normalize_module, | 
|  | normalize_function, | 
|  | create_type_hint, | 
|  | ) | 
|  |  | 
|  | from .schema_type_annotation import AnnotateTypesWithSchema | 
|  |  | 
|  |  | 
|  | class NormalizeArgs(Transformer): | 
|  | """ | 
|  | Normalize arguments to Python targets. This means that | 
|  | `args/kwargs` will be matched up to the module/functional's | 
|  | signature and rewritten to exclusively kwargs in positional order | 
|  | if `normalize_to_only_use_kwargs` is true. Also populates default | 
|  | values. Does not support positional-only parameters or varargs | 
|  | parameters (*args, **kwargs). | 
|  |  | 
|  | If the nodes have 'type' metadata, it will use it to disambiguate | 
|  | overloads. Otherwise, it will throw an error. | 
|  |  | 
|  | Example usage: | 
|  | m = torchvision.models.resnet18() | 
|  | traced = torch.fx.symbolic_trace(m) | 
|  | traced = NormalizeArgs(traced).transform() | 
|  | """ | 
|  |  | 
|  | def __init__( | 
|  | self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True | 
|  | ): | 
|  | super().__init__(module) | 
|  | self.node_map: Dict[Proxy, Node] = {} | 
|  | self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs | 
|  |  | 
|  | def run_node(self, n: Node) -> Any: | 
|  | args, kwargs = self.fetch_args_kwargs_from_env(n) | 
|  |  | 
|  | def get_type(arg): | 
|  | if isinstance(arg, fx.Node): | 
|  | return n.meta["type"] if "type" in n.meta else None | 
|  | return type(arg) | 
|  |  | 
|  | arg_types = map_aggregate(n.args, get_type) | 
|  | assert isinstance(arg_types, tuple) | 
|  | arg_types = tuple([create_type_hint(i) for i in arg_types]) | 
|  | kwarg_types = {k: get_type(v) for k, v in kwargs.items()} | 
|  | if n.op == "call_function": | 
|  | out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) | 
|  | else: | 
|  | out = super().run_node(n) | 
|  | if n.op != "output": | 
|  | self.node_map[out] = n | 
|  | out.node.meta = n.meta | 
|  | out.node.type = n.type | 
|  | return out | 
|  |  | 
|  | def call_function( | 
|  | self, | 
|  | target: Target, | 
|  | args: Tuple[Argument, ...], | 
|  | kwargs: Dict[str, Any], | 
|  | arg_types: Optional[Tuple[Any, ...]] = None, | 
|  | kwarg_types: Optional[Dict[str, Any]] = None, | 
|  | ): | 
|  | assert callable(target) | 
|  | new_args_and_kwargs = normalize_function( | 
|  | target, | 
|  | args,  # type: ignore[arg-type] | 
|  | kwargs, | 
|  | arg_types,  # type: ignore[arg-type] | 
|  | kwarg_types, | 
|  | self.normalize_to_only_use_kwargs, | 
|  | ) | 
|  | if new_args_and_kwargs: | 
|  | new_args, new_kwargs = new_args_and_kwargs | 
|  | return self.tracer.create_proxy( | 
|  | "call_function", target, new_args, new_kwargs | 
|  | ) | 
|  | else: | 
|  | return super().call_function(target, args, kwargs) | 
|  |  | 
|  | def call_module( | 
|  | self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] | 
|  | ): | 
|  | assert isinstance(target, str) | 
|  | new_args_and_kwargs = normalize_module( | 
|  | self.module, | 
|  | target, | 
|  | args,  # type: ignore[arg-type] | 
|  | kwargs, | 
|  | self.normalize_to_only_use_kwargs, | 
|  | ) | 
|  | if new_args_and_kwargs: | 
|  | new_args, new_kwargs = new_args_and_kwargs | 
|  | return super().call_module(target, new_args, new_kwargs) | 
|  | else: | 
|  | return super().call_module(target, args, kwargs) | 
|  |  | 
|  |  | 
|  | class NormalizeOperators(AnnotateTypesWithSchema): | 
|  | """ | 
|  | Normalize callsites that are different ways of "spelling" the same | 
|  | invocation into a single, canonical call. Currently supports: | 
|  |  | 
|  | 1. Normalize operators (e.g. operator.add) to the `torch` ops they | 
|  | ultimately invoke (e.g. torch.add) when it is possible to statically | 
|  | reason that | 
|  |  | 
|  | Example usage: | 
|  |  | 
|  | m = torchvision.models.resnet18() | 
|  |  | 
|  | traced = torch.fx.symbolic_trace(m) | 
|  |  | 
|  | traced = NormalizeOperators(traced).transform() | 
|  | """ | 
|  |  | 
|  | binary_magic_method_remap: Dict[ | 
|  | Callable[[Any, Any], Any], Callable[[Any, Any], Any] | 
|  | ] = { | 
|  | torch.add: operator.add, | 
|  | torch.mul: operator.mul, | 
|  | torch.sub: operator.sub, | 
|  | torch.div: operator.truediv, | 
|  | torch.floor_divide: operator.floordiv, | 
|  | torch.remainder: operator.mod, | 
|  | torch.eq: operator.eq, | 
|  | torch.ne: operator.ne, | 
|  | torch.lt: operator.lt, | 
|  | torch.le: operator.le, | 
|  | torch.gt: operator.gt, | 
|  | torch.ge: operator.ge, | 
|  | } | 
|  |  | 
|  | def call_function( | 
|  | self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] | 
|  | ): | 
|  | # Normalize operators according to the magic methods implemented on tensors here: | 
|  | # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 | 
|  |  | 
|  | assert callable(target) | 
|  |  | 
|  | if target in self.binary_magic_method_remap: | 
|  | if len(args) != 2: | 
|  | return super().call_function(target, args, kwargs) | 
|  | lhs, rhs = args | 
|  |  | 
|  | return super().call_function( | 
|  | target=self.binary_magic_method_remap[target], | 
|  | args=(lhs, rhs), | 
|  | kwargs={}, | 
|  | ) | 
|  |  | 
|  | return super().call_function(target, args, kwargs) |