blob: 3a959447cfd6bc90b14208e7e319794558b20694 [file] [log] [blame]
import torch
from torch.fx._symbolic_trace import Tracer
from torch.fx.node import Target, Node, Argument
from torch.nn.intrinsic import _FusedModule
from typing import List, Callable, Tuple, Any, Dict, Optional
__all__ = [
"QuantizationTracer",
]
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)
def __enter__(self):
return
def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return
class QuantizationTracer(Tracer):
def __init__(
self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
):
super().__init__()
self.skipped_module_names = skipped_module_names
self.skipped_module_classes = skipped_module_classes
# NB: initialized the module_type of top level module to None
# we are assuming people won't configure the model with the type of top level
# module here, since people can use "" for global config
# We can change this if there is a use case that configures
# qconfig using top level module type
self.scope = Scope("", None)
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
self.record_stack_traces = True
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return (
(
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
and not isinstance(m, torch.nn.Sequential)
)
or module_qualified_name in self.skipped_module_names
or type(m) in self.skipped_module_classes
or isinstance(m, _FusedModule)
)
def call_module(
self,
m: torch.nn.Module,
forward: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Any:
module_qualified_name = self.path_of_module(m)
# Creating scope with information of current module
# scope will be restored automatically upon exit
with ScopeContextManager(self.scope, m, module_qualified_name):
return super().call_module(m, forward, args, kwargs)
def create_node(
self,
kind: str,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
) -> Node:
node = super().create_node(kind, target, args, kwargs, name, type_expr)
self.node_name_to_scope[node.name] = (
self.scope.module_path,
self.scope.module_type,
)
return node