Revert D23385090: [quant][graphmode][fx] Add support for weight prepack folding
Test Plan: revert-hammer
Differential Revision:
D23385090 (https://github.com/pytorch/pytorch/commit/ef08f92076806b9f17290876db7c5ea87c09873c)
Original commit changeset: 11341f0af525
fbshipit-source-id: fe2bcdc16106923a2cee99eb5cc0a1e9c14ad2c5
diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py
index e039177..b3de593 100644
--- a/test/quantization/test_quantize_fx.py
+++ b/test/quantization/test_quantize_fx.py
@@ -43,33 +43,32 @@
import unittest
class TestQuantizeFx(QuantizationTestCase):
- def _get_conv_linear_test_cases(self):
- ''' Returns a list of test cases, with format:
- is_dynamic, ModuleClass, module_constructor_inputs,
- inputs, quantized_node, weight_prepack_op
- '''
+ """ Unit tests for functionalities
+ """
+ @skipIfNoFBGEMM
+ def test_functional(self):
+ """ Test quantizing functional conv and linear
+ """
class Conv(torch.nn.Module):
- def __init__(self, weight):
+ def __init__(self):
super().__init__()
- self.weight = torch.nn.Parameter(weight)
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
- def forward(self, x):
- return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
+ def forward(self, x, weight):
+ return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)
conv_input = torch.rand(1, 3, 224, 224)
conv_weight = torch.rand(3, 3, 3, 3)
class Linear(torch.nn.Module):
- def __init__(self, weight):
+ def __init__(self):
super().__init__()
- self.weight = torch.nn.Parameter(weight)
- def forward(self, x):
- return F.linear(x, self.weight)
+ def forward(self, x, weight):
+ return F.linear(x, weight)
linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)
@@ -85,62 +84,17 @@
linear_module_input = torch.rand(8, 5)
tests = [
- (False, Conv, (conv_weight,), (conv_input,),
- ns.call_function(torch.ops.quantized.conv2d),
- ns.call_function(torch.ops.quantized.conv2d_prepack)),
- (True, Linear, (linear_weight,), (linear_input,),
- ns.call_function(torch.ops.quantized.linear_dynamic),
- ns.call_function(torch.ops.quantized.linear_prepack)),
- (False, Linear, (linear_weight,), (linear_input,),
- ns.call_function(torch.ops.quantized.linear),
- ns.call_function(torch.ops.quantized.linear_prepack)),
- (True, LinearModule, (), (linear_module_input,),
- ns.call_module(nnqd.Linear),
- None),
- (False, LinearModule, (), (linear_module_input,),
- ns.call_module(nnq.Linear),
- None),
+ (False, Conv, (conv_input, conv_weight), ns.call_function(torch.ops.quantized.conv2d)),
+ (True, Linear, (linear_input, linear_weight), ns.call_function(torch.ops.quantized.linear_dynamic)),
+ (False, Linear, (linear_input, linear_weight), ns.call_function(torch.ops.quantized.linear)),
+ (True, LinearModule, (linear_module_input,), ns.call_module(nnqd.Linear)),
+ (False, LinearModule, (linear_module_input,), ns.call_module(nnq.Linear)),
]
- return tests
- """
- Unit tests for functionalities
- """
- @skipIfNoFBGEMM
- def test_functional_no_debug(self):
- """ Test quantizing functional conv and linear
- """
- tests = self._get_conv_linear_test_cases()
- for (is_dynamic, ModuleClass, module_constructor_inputs,
- inputs, quantized_node, weight_prepack_node) in tests:
+ for is_dynamic, M, inputs, quantized_node in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
- node_occurrence = dict()
- if weight_prepack_node:
- node_occurrence[weight_prepack_node] = 0
self.checkGraphModeFxOp(
- ModuleClass(*module_constructor_inputs),
- inputs, quant_type,
- expected_node=quantized_node,
- expected_node_occurrence=node_occurrence,
- debug=False)
-
- @skipIfNoFBGEMM
- def test_functional_debug(self):
- """ Test quantizing functional conv and linear with debug option
- """
- tests = self._get_conv_linear_test_cases()
- for (is_dynamic, ModuleClass, module_constructor_inputs,
- inputs, quantized_node, weight_prepack_node) in tests:
- quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
- node_occurrence = dict()
- if weight_prepack_node:
- node_occurrence[weight_prepack_node] = 1
- self.checkGraphModeFxOp(
- ModuleClass(*module_constructor_inputs),
- inputs, quant_type,
- expected_node=quantized_node,
- expected_node_occurrence=node_occurrence,
- debug=True)
+ M(), inputs, quant_type, quantized_node)
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index cc042d4..f818d38 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -52,23 +52,6 @@
def quantize(quantizer, node):
quantize_node(node, quantizer.activation_post_process_map[node.name])
-# Returns a function that can get a new attribute name for module with given prefix
-# for example,
-# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
-# >> new_name = get_new_observer_name(module)
-# new_name will be an unused attribute name on module, e.g. `_observer_1`
-def get_new_attr_name_with_prefix(prefix):
- def get_new_attr_name(module):
- def get_attr_name(i):
- return prefix + str(i)
- i = 0
- attr_name = get_attr_name(i)
- while hasattr(module, attr_name):
- i += 1
- attr_name = get_attr_name(i)
- return attr_name
- return get_new_attr_name
-
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv2d : [1],
@@ -265,14 +248,14 @@
# pack weight
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
- prepack_args = tuple([weight] + list(other_args))
+ prepack_args = [weight] + list(other_args)
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
# construct conv input
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
activation_post_process = quantizer.activation_post_process_map[self.conv_node.name]
scale, zero_point, _ = get_qparams(activation_post_process)
- qconv_args = (conv_input, packed_weight, scale, zero_point)
+ qconv_args = [conv_input, packed_weight, scale, zero_point]
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs)
@@ -335,16 +318,12 @@
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
- # TODO: this code can be merged with dynamic linear code
- # linear args
- # (x, weight, bias, ...)
args = load_arg(quantized=[0, 1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
- # all args after bias, including bias
- other_args = load_arg(quantized=False)(self.linear_node.args[2:])
+ other_args = load_arg(quantized=False)(self.linear_node.args[1:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
@@ -353,7 +332,7 @@
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
- prepack_args = (weight, bias)
+ prepack_args = [weight, bias]
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct linear input
@@ -361,7 +340,7 @@
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_qparams(activation_post_process)
- qlinear_args = (linear_input, packed_weight, scale, zero_point)
+ qlinear_args = [linear_input, packed_weight, scale, zero_point]
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
@@ -583,14 +562,13 @@
return quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
else:
- # linear args:
- # (x, weight, bias)
- # quantize weight
- quantized_weight = load_arg(quantized=True)(self.linear_node.args[1])
- bias = None
- # all args after bias, including bias
- other_args = load_arg(quantized=False)(self.linear_node.args[2:])
+ # quantize and dequantize weight
+ args = load_arg(quantized=[1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
+ # pack weight
+ weight = load_arg(quantized=True)(self.linear_node.args[1])
+ bias = None
+ other_args = load_arg(quantized=False)(self.linear_node.args[1:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
@@ -599,23 +577,15 @@
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
- prepack_args = (quantized_weight, bias)
- # pack weight
+ prepack_args = [weight, bias]
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct dynamic linear input
- non_quantized_input = load_arg(quantized=False)(self.linear_node.args[0])
- qdynamic_linear_args = (non_quantized_input, packed_weight)
+ linear_input = load_arg(quantized=False)(self.linear_node.args[0])
+ qdynamic_linear_args = [linear_input, packed_weight]
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)
-
-# weight prepacking ops
-WEIGHT_PREPACK_OPS = {
- torch._ops.ops.quantized.linear_prepack,
- torch._ops.ops.quantized.conv2d_prepack,
-}
-
class Quantizer:
def __init__(self):
# mapping from matched node to activation_post_process
@@ -687,7 +657,16 @@
if node.name in observed:
continue
- get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
+ def get_new_observer_name(parent_module):
+ i = 0
+
+ def get_observer_name(i):
+ return 'activation_post_process_' + str(i)
+ observer_name = get_observer_name(i)
+ while hasattr(parent_module, observer_name):
+ i += 1
+ observer_name = get_observer_name(i)
+ return observer_name
root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
if root_node is None:
env[node.name] = observed_graph.node_copy(node, load_arg)
@@ -766,7 +745,7 @@
def prepare_dynamic(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True)
- def _convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
+ def convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
assert not inplace, 'inplace convert is not supported yet'
self.restore_state(observed)
self.is_dynamic_quant = is_dynamic_quant
@@ -932,7 +911,7 @@
if parent_name:
qparam_full_path = parent_name + '.' + qparam_full_path
inputs.append(self.quantized_graph.create_node('get_param', qparam_full_path))
- quant_env[node.name] = self.quantized_graph.create_node('call_function', quantize_op, tuple(inputs), {})
+ quant_env[node.name] = self.quantized_graph.create_node('call_function', quantize_op, inputs, {})
continue
# dequantize inputs for the node that are not quantized
env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
@@ -947,84 +926,6 @@
delattr(observed_root, n)
return GraphModule(observed_root, self.quantized_graph)
- # Trace back from the weight node util we hit getattr, reconstruct the graph module
- # with the traced nodes and run the graph module to pack the weight. then replace
- # the original chain of ops with the packed weight.
- def _fold_weight(self, quantized):
- def collect_nodes_to_fold(node):
- nodes = [node]
- frontier = [node]
- while frontier:
- node = frontier.pop()
- all_args = list(node.args) + list(node.kwargs.values())
- for arg in all_args:
- if not isinstance(arg, Node):
- continue
- if arg.op == 'placeholder':
- # hit input, can't fold in this case
- return None
- nodes.append(arg)
- if not (arg.op == 'call_function' and arg.target == getattr):
- frontier.append(arg)
- return nodes
-
- packed_weights = dict()
- # map from folded node name to the prepacked weight name
- folded_nodes = dict()
- # get packed weights
- for node in quantized.graph.nodes:
- if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
- nodes_to_fold = collect_nodes_to_fold(node)
- if nodes_to_fold is not None:
- # since we traced back from weight node to getattrr
- nodes_to_fold.reverse()
- prepacking_graph = Graph()
- env = {}
-
- def load_arg(a):
- return map_arg(a, lambda node: env[node.name])
- for node_to_fold in nodes_to_fold:
- env[node_to_fold.name] = prepacking_graph.node_copy(node_to_fold, load_arg)
- folded_nodes[node_to_fold.name] = node
- prepacking_graph.output(load_arg(node.name))
- prepacking_module = GraphModule(quantized.root, prepacking_graph)
- packed_weight = prepacking_module()
- packed_weights[node.name] = packed_weight
-
- # remove folded nodes and replace the prepacking node with getattr
- folded_graph = Graph()
- env = {}
-
- def load_arg(a):
- return map_arg(a, lambda node: env[node.name])
- get_new_packed_weight_name = get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
- quantized_root = quantized.root
- quantized_graph = quantized.graph
- for node in quantized_graph.nodes:
- prepack_node = folded_nodes.get(node.name, None)
- if prepack_node is node:
- packed_weight = packed_weights[node.name]
- # add a prepacked attribute to root
- packed_weight_name = get_new_packed_weight_name(quantized_root)
- setattr(quantized_root, packed_weight_name, packed_weight)
- # replace prepack node with a getattr node
- env[node.name] = folded_graph.create_node(
- 'get_param', packed_weight_name, (), {})
- elif prepack_node is not None:
- # remove the foled node
- continue
- else:
- # copy other nodes
- env[node.name] = folded_graph.node_copy(node, load_arg)
- folded_graph.output(load_arg(quantized_graph.result))
- return GraphModule(quantized_root, folded_graph)
-
- def convert(self, observed, inplace=False, debug=False, is_dynamic=False):
- quantized = self._convert(observed, inplace, debug, is_dynamic)
- if not debug:
- quantized = self._fold_weight(quantized)
- return quantized
-
def _find_matches(self, graph, modules, patterns):
match_map = {} # node name -> (root_node, match_value?)
all_matched = set()
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index 3a20104..137ed38 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -79,9 +79,6 @@
return self.op == other.op and self.target == other.target
- def __repr__(self):
- return repr(self.op) + " " + repr(self.target)
-
def test_only_eval_fn(model, calib_data):
r"""
Default evaluation function takes a torch.utils.data.Dataset or a list of
@@ -548,21 +545,15 @@
if expected_node_occurrence is not None:
for expected_node, occurrence in expected_node_occurrence.items():
- if occurrence != 0:
- self.assertTrue(
- expected_node in nodes_in_graph,
- 'Check failed for node:' + str(expected_node) +
- ' not found')
- self.assertTrue(
- nodes_in_graph[expected_node] == occurrence,
- 'Check failed for node:' + str(expected_node) +
- ' Expected occurrence:' + str(occurrence) +
- ' Found occurrence:' + str(nodes_in_graph[expected_node]))
- else:
- self.assertTrue(
- expected_node not in nodes_in_graph,
- 'Check failed for node:' + str(expected_node) +
- ' expected no occurrence but found')
+ self.assertTrue(
+ expected_node in nodes_in_graph,
+ 'Check failed for node:' + str(expected_node) +
+ ' not found')
+ self.assertTrue(
+ nodes_in_graph[expected_node] == occurrence,
+ 'Check failed for node:' + str(expected_node) +
+ ' Expected occurrence:' + str(occurrence) +
+ ' Found occurrence:' + str(nodes_in_graph[expected_node]))
if expected_node_list is not None:
cur_index = 0
@@ -595,8 +586,7 @@
expected_node=None,
expected_node_occurrence=None,
expected_node_list=None,
- debug=False,
- print_debug_info=False):
+ debug=False):
""" Quantizes model with graph mode quantization on fx and check if the
quantized model contains the quantized_node
@@ -645,7 +635,7 @@
self.assertEqual((result - result_debug).abs().max(), 0), \
'Expecting debug and non-debug option to produce identical result'
- if print_debug_info:
+ if debug:
print()
print('quant type:', quant_type)
print('origianl graph module:', type(model))
@@ -654,9 +644,8 @@
print('quantized graph module:', type(qgraph))
self.printGraphModule(qgraph)
print()
- qgraph_to_check = qgraph_debug if debug else qgraph
self.checkGraphModuleNodes(
- qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
+ qgraph, expected_node, expected_node_occurrence, expected_node_list)
# Below are a series of neural net models to use in testing quantization
# Single layer models