| import torch |
| from torch.fx._symbolic_trace import Tracer |
| from torch.fx.node import Target, Node, Argument |
| from torch.nn.intrinsic import _FusedModule |
| from typing import List, Callable, Tuple, Any, Dict, Optional |
| |
| __all__ = [ |
| "QuantizationTracer", |
| ] |
| |
| class Scope(object): |
| """ Scope object that records the module path and the module type |
| of a module. Scope is used to track the information of the module |
| that contains a Node in a Graph of GraphModule. For example:: |
| |
| class Sub(torch.nn.Module): |
| def forward(self, x): |
| # This will be a call_method Node in GraphModule, |
| # scope for this would be (module_path="sub", module_type=Sub) |
| return x.transpose(1, 2) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| self.sub = Sub() |
| |
| def forward(self, x): |
| # This will be a call_method Node as well, |
| # scope for this would be (module_path="", None) |
| x = x.transpose(1, 2) |
| x = self.sub(x) |
| return x |
| |
| """ |
| |
| def __init__(self, module_path: str, module_type: Any): |
| super().__init__() |
| self.module_path = module_path |
| self.module_type = module_type |
| |
| |
| class ScopeContextManager(object): |
| """ A context manager to track the Scope of Node during symbolic tracing. |
| When entering a forward function of a Module, we'll update the scope information of |
| the current module, and when we exit, we'll restore the previous scope information. |
| """ |
| |
| def __init__( |
| self, scope: Scope, current_module: torch.nn.Module, current_module_path: str |
| ): |
| super().__init__() |
| self.prev_module_type = scope.module_type |
| self.prev_module_path = scope.module_path |
| self.scope = scope |
| self.scope.module_path = current_module_path |
| self.scope.module_type = type(current_module) |
| |
| def __enter__(self): |
| return |
| |
| def __exit__(self, *args): |
| self.scope.module_path = self.prev_module_path |
| self.scope.module_type = self.prev_module_type |
| return |
| |
| class QuantizationTracer(Tracer): |
| def __init__( |
| self, skipped_module_names: List[str], skipped_module_classes: List[Callable] |
| ): |
| super().__init__() |
| self.skipped_module_names = skipped_module_names |
| self.skipped_module_classes = skipped_module_classes |
| # NB: initialized the module_type of top level module to None |
| # we are assuming people won't configure the model with the type of top level |
| # module here, since people can use "" for global config |
| # We can change this if there is a use case that configures |
| # qconfig using top level module type |
| self.scope = Scope("", None) |
| self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} |
| self.record_stack_traces = True |
| |
| def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: |
| return ( |
| ( |
| (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) |
| and not isinstance(m, torch.nn.Sequential) |
| ) |
| or module_qualified_name in self.skipped_module_names |
| or type(m) in self.skipped_module_classes |
| or isinstance(m, _FusedModule) |
| ) |
| |
| def call_module( |
| self, |
| m: torch.nn.Module, |
| forward: Callable[..., Any], |
| args: Tuple[Any, ...], |
| kwargs: Dict[str, Any], |
| ) -> Any: |
| module_qualified_name = self.path_of_module(m) |
| # Creating scope with information of current module |
| # scope will be restored automatically upon exit |
| with ScopeContextManager(self.scope, m, module_qualified_name): |
| return super().call_module(m, forward, args, kwargs) |
| |
| def create_node( |
| self, |
| kind: str, |
| target: Target, |
| args: Tuple[Argument, ...], |
| kwargs: Dict[str, Argument], |
| name: Optional[str] = None, |
| type_expr: Optional[Any] = None, |
| ) -> Node: |
| node = super().create_node(kind, target, args, kwargs, name, type_expr) |
| self.node_name_to_scope[node.name] = ( |
| self.scope.module_path, |
| self.scope.module_type, |
| ) |
| return node |