[quant][be] Move some helper functions to the top level to reduce function length (#89246)
Summary:
att
Test Plan:
python test/test_quantization.py TestQuantizeFx
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89246
Approved by: https://github.com/vkuzo
diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py
index 0c1249b..ca6ae61 100644
--- a/torch/ao/quantization/fx/convert.py
+++ b/torch/ao/quantization/fx/convert.py
@@ -88,6 +88,83 @@
"run_weight_observers",
]
+def _replace_observer_with_quantize_dequantize_node(
+ model: torch.nn.Module,
+ graph: Graph,
+ node: Node,
+ modules: Dict[str, torch.nn.Module],
+ node_name_to_scope: Dict[str, Tuple[str, type]],
+ node_name_to_qconfig: Dict[str, QConfigAny],
+ is_decomposed: bool) -> None:
+ """ Replace activation_post_process module call node with quantize and
+ dequantize node
+
+ Before:
+ ... -> observer_0(x) -> ...
+ After:
+ ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
+ """
+ assert modules is not None
+ assert isinstance(node.target, str)
+ module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
+ observer_module = modules[node.target]
+ maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
+ # Skip replacing observers to quant/dequant nodes if the qconfigs of all
+ # consumers and producers of this observer are None
+ skip_replacement = all([
+ has_none_qconfig(n, node_name_to_qconfig) for n in
+ list(node.args) + list(node.users.keys())])
+ if skip_replacement or maybe_quantize_node_info is None:
+ # didn't find correponding quantize op and info for the observer_module
+ # so we just remove the observer
+ with graph.inserting_before(node):
+ node.replace_all_uses_with(node.args[0])
+ graph.erase_node(node)
+ else:
+ # otherwise, we can convert the observer moduel call to quantize/dequantize node
+ node_type, quantize_op, qparams = maybe_quantize_node_info
+ # replace observer node with quant - dequant node
+ with graph.inserting_before(node):
+ input_node = node.args[0]
+ quantize_op_inputs = [input_node]
+ for key, value in qparams.items():
+ # TODO: we can add the information of whether a value needs to
+ # be registered as an attribute in qparams dict itself
+ if key in ['_scale_', '_zero_point_']:
+ # For scale and zero_point values we register them as buffers in the root module.
+ # TODO: maybe need more complex attr name here
+ qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
+ quantize_op_inputs.append(qparam_node)
+ else:
+ # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
+ quantize_op_inputs.append(value)
+
+ quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
+ if is_decomposed:
+ # use the same qparams from quantize op
+ dq_inputs = [quantized_node] + quantize_op_inputs[1:]
+ dequantized_node = graph.call_function(
+ torch.ops.quantized_decomposed.dequantize_per_tensor,
+ tuple(dq_inputs),
+ {}
+ )
+ else:
+ dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
+ node.replace_all_uses_with(dequantized_node)
+ graph.erase_node(node)
+
+# this is a temporary hack for custom module, we may want to implement
+# this properly after the custom module class design is finalized
+# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
+# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
+# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
+def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
+ call_custom_module_node = node.args[0]
+ assert isinstance(call_custom_module_node, Node), \
+ f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
+ node.replace_all_uses_with(call_custom_module_node)
+ graph.erase_node(node)
+ insert_dequantize_node(call_custom_module_node, graph)
def restore_state(
observed: torch.nn.Module
@@ -599,85 +676,6 @@
if node.op == 'placeholder':
graph_inputs.append(node.name)
- # TODO: move this outside of this function
- def replace_observer_with_quantize_dequantize_node(
- model: torch.nn.Module,
- graph: Graph,
- node: Node,
- modules: Dict[str, torch.nn.Module],
- node_name_to_scope: Dict[str, Tuple[str, type]],
- node_name_to_qconfig: Dict[str, QConfigAny],
- is_decomposed: bool) -> None:
- """ Replace activation_post_process module call node with quantize and
- dequantize node
-
- Before:
- ... -> observer_0(x) -> ...
- After:
- ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
- """
- assert modules is not None
- assert isinstance(node.target, str)
- module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
- observer_module = modules[node.target]
- maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed)
- # Skip replacing observers to quant/dequant nodes if the qconfigs of all
- # consumers and producers of this observer are None
- skip_replacement = all([
- has_none_qconfig(n, node_name_to_qconfig) for n in
- list(node.args) + list(node.users.keys())])
- if skip_replacement or maybe_quantize_node_info is None:
- # didn't find correponding quantize op and info for the observer_module
- # so we just remove the observer
- with graph.inserting_before(node):
- node.replace_all_uses_with(node.args[0])
- graph.erase_node(node)
- else:
- # otherwise, we can convert the observer moduel call to quantize/dequantize node
- node_type, quantize_op, qparams = maybe_quantize_node_info
- # replace observer node with quant - dequant node
- with graph.inserting_before(node):
- input_node = node.args[0]
- quantize_op_inputs = [input_node]
- for key, value in qparams.items():
- # TODO: we can add the information of whether a value needs to
- # be registered as an attribute in qparams dict itself
- if key in ['_scale_', '_zero_point_']:
- # For scale and zero_point values we register them as buffers in the root module.
- # TODO: maybe need more complex attr name here
- qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value)
- quantize_op_inputs.append(qparam_node)
- else:
- # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
- quantize_op_inputs.append(value)
-
- quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
- if is_decomposed:
- # use the same qparams from quantize op
- dq_inputs = [quantized_node] + quantize_op_inputs[1:]
- dequantized_node = graph.call_function(
- torch.ops.quantized_decomposed.dequantize_per_tensor,
- tuple(dq_inputs),
- {}
- )
- else:
- dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
- node.replace_all_uses_with(dequantized_node)
- graph.erase_node(node)
-
- # this is a temporary hack for custom module, we may want to implement
- # this properly after the custom module class design is finalized
- # TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
- # after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
- # after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
- def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph):
- call_custom_module_node = node.args[0]
- assert isinstance(call_custom_module_node, Node), \
- f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
- node.replace_all_uses_with(call_custom_module_node)
- graph.erase_node(node)
- insert_dequantize_node(call_custom_module_node, graph)
-
# additional state to override inputs to be quantized, if specified
# by the user
placeholder_node_seen_cnt = 0
@@ -728,13 +726,13 @@
if _is_activation_post_process(mod):
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:
- replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
+ _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
else:
- replace_observer_with_quantize_dequantize_node(
+ _replace_observer_with_quantize_dequantize_node(
model, model.graph, node, modules, node_name_to_scope,
node_name_to_qconfig, is_decomposed)
elif isinstance(mod, DeQuantStub):
- replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
+ _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
elif is_observed_standalone_module(mod):
convert_standalone_module(
node, modules, model, is_reference, backend_config)