blob: 7b55033ecba67e7d302da03978bee090f8c3b62f [file] [log] [blame]
import torch
import copy
from torch.fx import GraphModule
class ObservedGraphModule(GraphModule):
def get_preserved_attr_names(self):
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map']
def __init__(self, root, graph):
preserved_attrs = dict()
for attr in self.get_preserved_attr_names():
preserved_attrs[attr] = getattr(root, attr)
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
# GraphModule does not copy attributes which are not in the __dict__
# of vanilla nn.Module. So, we override __deepcopy__ in order
# to copy the quantization specific attributes correctly.
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(fake_mod, self.graph)
def mark_observed_module(module):
return ObservedGraphModule(module, module.graph)
def is_observed_module(module):
return isinstance(module, ObservedGraphModule)
class ObservedStandaloneGraphModule(ObservedGraphModule):
def get_preserved_attr_names(self):
return ['_activation_post_process_map',
'_patterns',
'_qconfig_map',
'_standalone_module_observed_input_idxs',
'_output_is_observed']
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(fake_mod, self.graph)
def mark_observed_standalone_module(module):
return ObservedStandaloneGraphModule(module, module.graph)
def is_observed_standalone_module(module):
return isinstance(module, ObservedStandaloneGraphModule)