[quant][graphmode][fx] Custom module support (#44766)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44766
There might be modules that are not symbolically traceable, e.g. LSTM (since it has
input dependent control flows), to support quantization in these cases, user will provide
the corresponding observed and quantized version of the custom module, the observed
custom module with observers already inserted in the module and the quantized version will
have the corresponding ops quantized. And use
```
from torch.quantization import register_observed_custom_module_mapping
from torch.quantization import register_quantized_custom_module_mapping
register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)
```
to register the custom module mappings, we'll also need to define a custom delegate class
for symbolic trace in order to prevent the custom module from being traced:
```python
class CustomDelegate(DefaultDelegate):
def is_leaf_module(self, m):
return (m.__module__.startswith('torch.nn') and
not isinstance(m, torch.nn.Sequential)) or \
isinstance(m, CustomModule)
m = symbolic_trace(original_m, delegate_class=CustomDelegate)
```
Test Plan: Imported from OSS
Reviewed By: z-a-f
Differential Revision: D23723455
fbshipit-source-id: 50d666e29b94cbcbea5fb6bcc73b00cff87eb77a
diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py
index 3170bfb..fc4a735 100644
--- a/test/quantization/test_quantize_fx.py
+++ b/test/quantization/test_quantize_fx.py
@@ -20,6 +20,8 @@
quantize_static_fx,
quantize_dynamic_fx,
prepare_qat_fx,
+ register_observed_custom_module_mapping,
+ register_quantized_custom_module_mapping,
)
from torch.quantization import (
@@ -482,6 +484,140 @@
# Verify that loaded state dict produces same results.
self.assertEqual(quant(x), quant_2(x))
+ @skipIfNoFBGEMM
+ def test_custom_module_class(self):
+ class CustomModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(1, 1, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+ class ObservedCustomModule(torch.nn.Module):
+ def __init__(self, conv):
+ super().__init__()
+ self.conv = conv
+
+ def forward(self, x):
+ return self.conv(x)
+
+ @classmethod
+ def from_float(cls, float_module):
+ assert hasattr(float_module, 'qconfig')
+ observed = cls(float_module.conv)
+ observed.qconfig = float_module.qconfig
+ return observed
+
+ class QuantizedCustomModule(torch.nn.Module):
+ def __init__(self, conv):
+ super().__init__()
+ self.conv = conv
+
+ def forward(self, x):
+ return self.conv(x)
+
+ @classmethod
+ def from_observed(cls, observed_module):
+ assert hasattr(observed_module, 'qconfig')
+ assert hasattr(observed_module, 'activation_post_process')
+ observed_module.conv.activation_post_process = \
+ observed_module.activation_post_process
+ quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
+ return quantized
+
+ class DynamicallyQuantizedCustomModule(torch.nn.Module):
+ def __init__(self, conv):
+ super().__init__()
+ self.conv = conv
+
+ def forward(self, x):
+ return self.conv(x)
+
+ @classmethod
+ def from_observed(cls, observed_module):
+ assert hasattr(observed_module, 'qconfig')
+ assert hasattr(observed_module, 'activation_post_process')
+ quantized = cls(nnqd.Conv2d.from_float(observed_module.conv))
+ return quantized
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(1, 1, 1)
+ self.custom = CustomModule()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.custom(x)
+ return x
+
+ class RefM(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = torch.nn.Conv2d(1, 1, 1)
+ self.conv2 = torch.nn.Conv2d(1, 1, 1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+ data = torch.randn(1, 1, 1, 1)
+ # instantiate M and RefM and align the parameters
+ original_m = M()
+ original_ref_m = RefM()
+ original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
+ original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
+ original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach())
+ original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())
+
+ from torch.fx.symbolic_trace import Tracer
+
+ # define a custom tracer to not trace through the custom module
+
+ class CustomTracer(Tracer):
+ def is_leaf_module(self, m, module_qualified_name):
+ return (m.__module__.startswith('torch.nn') and
+ not isinstance(m, torch.nn.Sequential)) or \
+ isinstance(m, CustomModule)
+
+ # TODO: add other quant types after mixed mode support
+ for quant_type in [QuantType.STATIC]:
+ # register observed and quantized custom module classes
+ register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
+ register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)
+
+ m = CustomTracer().trace(original_m).eval()
+ qconfig_dict = {'': default_qconfig}
+ # check prepared model
+ m = prepare_static_fx(m, qconfig_dict)
+ # calibration
+ m(data)
+ # all activation observers are inserted in the top level module
+ count_check = {
+ ns.call_module(torch.quantization.MinMaxObserver): 3
+ }
+ self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
+
+ # check converted/quantized model
+ m = convert_static_fx(m)
+ count_check = {
+ ns.call_function(torch.quantize_per_tensor) : 1,
+ ns.call_module(nnq.Conv2d) : 1,
+ ns.call_method('dequantize') : 1,
+ }
+ self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
+ res = m(data)
+
+ # quantize the reference model
+ ref_m = symbolic_trace(original_ref_m).eval()
+ ref_m = prepare_fx(ref_m, qconfig_dict)
+ ref_m(data)
+ ref_m = convert_fx(ref_m)
+ ref_res = ref_m(data)
+ self.assertEqual(res, ref_res)
+
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py
index fe1ced9..773a9a3 100644
--- a/torch/nn/quantized/modules/conv.py
+++ b/torch/nn/quantized/modules/conv.py
@@ -146,7 +146,7 @@
@classmethod
def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
- r"""Creates a qconv object and returns it.
+ r"""Creates a qconv object and returns it.
"""
if weight_post_process is None:
weight_post_process = mod.qconfig.weight()
diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py
index ed908dd..3193c33 100644
--- a/torch/quantization/__init__.py
+++ b/torch/quantization/__init__.py
@@ -9,6 +9,7 @@
from .quantize_fx import *
from .quantization_mappings import *
from .fuser_method_mappings import *
+from .custom_module_class_mappings import *
def default_eval_fn(model, calib_data):
r"""
@@ -40,6 +41,11 @@
'get_compare_output_module_list',
'register_quantized_operator_mapping', 'get_quantized_operator',
'register_fuser_method', 'get_fuser_method',
+ 'register_observed_custom_module_mapping',
+ 'get_observed_custom_module_class',
+ 'register_quantized_custom_mdoule_mapping',
+ 'get_quantized_custom_module_class',
+ 'is_custom_module_class',
# Sub functions for `prepare` and `swap_module`
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
'default_eval_fn', 'get_observer_dict',
diff --git a/torch/quantization/custom_module_class_mappings.py b/torch/quantization/custom_module_class_mappings.py
new file mode 100644
index 0000000..c622902
--- /dev/null
+++ b/torch/quantization/custom_module_class_mappings.py
@@ -0,0 +1,75 @@
+OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
+
+def register_observed_custom_module_mapping(float_custom_module_class, observed_custom_module_class):
+ """ Register a mapping from `float_custom_module_class` to
+ `observed_custom_module_class`
+ `observed_custom_module_class` will have a `from_float` classmethod,
+ which will return an observed custom module instance given
+ a float custom module instance.
+ This will be used in prepare step of post training static quantization or
+ quantization aware training
+ """
+ assert hasattr(observed_custom_module_class, 'from_float'), 'from_float must be' + \
+ ' defined in observed custom module class'
+ OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
+ observed_custom_module_class
+
+def get_observed_custom_module_class(float_custom_module_class):
+ """ Get the corresponding observed module class for a given
+ float custom module.
+ """
+ observed_custom_module_class = \
+ OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
+ assert observed_custom_module_class is not None, \
+ 'Float Custom module class {}'.format(float_custom_module_class) + \
+ ' does not have a corresponding observed module class'
+ return observed_custom_module_class
+
+QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
+
+def register_quantized_custom_module_mapping(float_custom_module_class, quantized_custom_module_class):
+ """ Register a mapping from `float_custom_module_class` to `quantized_custom_module_class`
+ A quantized custom module class should accept quantized input and
+ return quantized output. (we can relax this condition in the
+ future if there is a need)
+ `quantized_custom_module_class` will have a `from_observed` classmethod,
+ which will return an quantized custom module instance given
+ a observed custom module instance.
+ This will be used in prepare step of post training static quantization or
+ quantization aware training
+ """
+ assert hasattr(quantized_custom_module_class, 'from_observed'), 'from_observed' + \
+ ' must be defined in quantized custom module class'
+ QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
+ quantized_custom_module_class
+
+def get_quantized_custom_module_class(float_custom_module_class):
+ """ Get the corresponding quantized module class for a given
+ float custom module.
+ """
+ quantized_custom_module_class = \
+ QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
+ assert quantized_custom_module_class is not None, \
+ 'Float Custom module class {}'.format(float_custom_module_class) + \
+ ' does not have a corresponding quantized module class'
+ return quantized_custom_module_class
+
+def is_custom_module_class(module_class):
+ """ Check if a given module class is a custom module class
+ """
+ return module_class in OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS and \
+ module_class in QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS
+
+def mark_observed_custom_module(module, custom_module_class):
+ """ Mark a module as observed custom module, so that
+ it can be identified during convert step
+ """
+ module._is_observed_custom_module = True
+ module._FLOAT_MODULE = custom_module_class
+
+def is_observed_custom_module(module):
+ """ Check if a module is marked as observed custom module
+ or not
+ """
+ return hasattr(module, '_is_observed_custom_module') and \
+ module._is_observed_custom_module
diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py
index fa5a873..ab85c9a 100644
--- a/torch/quantization/fx/quantization_patterns.py
+++ b/torch/quantization/fx/quantization_patterns.py
@@ -6,6 +6,9 @@
get_static_quant_module_class,
get_quantized_operator,
)
+from ..custom_module_class_mappings import (
+ get_quantized_custom_module_class,
+)
from .pattern_utils import (
register_quant_pattern,
register_dynamic_quant_pattern,
@@ -507,6 +510,28 @@
quantizer.quantized_graph,
node, quantizer.activation_post_process_map[node.name])
+class CustomModuleQuantizeHandler(QuantizeHandler):
+ def convert(self, quantizer, node, load_arg, debug=False):
+ """ Convert a float custom module to quantized custom module
+ """
+ assert node.op == 'call_module'
+ observed_custom_module = quantizer.modules[node.target]
+ if node.name in quantizer.activation_post_process_map:
+ observed_custom_module.activation_post_process = \
+ quantizer.activation_post_process_map[node.name]
+ quantized_custom_module_class = \
+ get_quantized_custom_module_class(observed_custom_module._FLOAT_MODULE)
+ quantized_custom_module = \
+ quantized_custom_module_class.from_observed(observed_custom_module)
+ parent_name, name = _parent_name(node.target)
+ setattr(quantizer.modules[parent_name], name, quantized_custom_module)
+ # hardcoded the qunatized input to be None (take whatever is in the environemnt),
+ # we can extend this
+ # if there is a need, e.g. get the indexes of quantized inputs from some
+ # module attribute like module._QUANTIZED_INPUT_INDEXES
+ return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
+
+
# 2. Post Training Dynamic Quantizatoin Patterns
@register_dynamic_quant_pattern(torch.nn.Linear)
@register_dynamic_quant_pattern(torch.nn.functional.linear)
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index 7967b4e..8d74225 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -18,6 +18,12 @@
from ..quantization_mappings import (
get_qat_module_mappings,
)
+from ..custom_module_class_mappings import (
+ is_custom_module_class,
+ get_observed_custom_module_class,
+ mark_observed_custom_module,
+ is_observed_custom_module,
+)
from ..quantize import _remove_qconfig
@@ -193,7 +199,6 @@
if not inplace:
model = copy.deepcopy(model)
self.is_dynamic_quant = is_dynamic_quant
- # TODO: allow user specified patterns
if self.is_dynamic_quant:
self.patterns = get_dynamic_quant_patterns()
else:
@@ -235,6 +240,8 @@
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
env[node.name] = observed_graph.node_copy(node, load_arg)
+ if qconfig is None:
+ continue
def insert_observer(node, observer, device):
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
@@ -246,10 +253,22 @@
if device:
getattr(model, observer_name).to(device)
+ if isinstance(obj, CustomModuleQuantizeHandler):
+ custom_module = self.modules[node.target]
+ observed_custom_module_class = \
+ get_observed_custom_module_class(type(custom_module))
+ observed_custom_module = \
+ observed_custom_module_class.from_float(custom_module)
+ mark_observed_custom_module(observed_custom_module, type(custom_module))
+ parent_name, name = _parent_name(node.target)
+ setattr(self.modules[parent_name], name, observed_custom_module)
+
# don't need to insert observer for output in dynamic quantization
if self.is_dynamic_quant:
continue
+ # inserting observers for output of observed module, or mark the output
+ # as observed
if isinstance(obj, CopyNode):
assert node.op in [
'call_module',
@@ -355,6 +374,7 @@
self.modules = dict(model.named_modules())
matches = self._find_matches(model.graph, self.modules, self.patterns)
+
quants = self._find_quants(model.graph, matches)
self.quantized_graph = Graph()
env = {}
@@ -619,6 +639,16 @@
all_matched.add(n.name)
# break after finding the first match
break
+
+ # add custom module instances to the match result
+ for node in graph.nodes:
+ if node.op == 'call_module' and \
+ (is_custom_module_class(type(self.modules[node.target])) or
+ is_observed_custom_module(self.modules[node.target])):
+ custom_module_qconfig = self.qconfig_map[node.name]
+ match_map[node.name] = (
+ node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig)
+
return match_map
def _find_quants(self, graph, matches):