Revert D23385090: [quant][graphmode][fx] Add support for weight prepack folding

Test Plan: revert-hammer

Differential Revision:
D23385090 (https://github.com/pytorch/pytorch/commit/ef08f92076806b9f17290876db7c5ea87c09873c)

Original commit changeset: 11341f0af525

fbshipit-source-id: fe2bcdc16106923a2cee99eb5cc0a1e9c14ad2c5
diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py
index e039177..b3de593 100644
--- a/test/quantization/test_quantize_fx.py
+++ b/test/quantization/test_quantize_fx.py
@@ -43,33 +43,32 @@
 import unittest
 
 class TestQuantizeFx(QuantizationTestCase):
-    def _get_conv_linear_test_cases(self):
-        ''' Returns a list of test cases, with format:
-        is_dynamic, ModuleClass, module_constructor_inputs,
-        inputs, quantized_node, weight_prepack_op
-        '''
+    """ Unit tests for functionalities
+    """
+    @skipIfNoFBGEMM
+    def test_functional(self):
+        """ Test quantizing functional conv and linear
+        """
         class Conv(torch.nn.Module):
-            def __init__(self, weight):
+            def __init__(self):
                 super().__init__()
-                self.weight = torch.nn.Parameter(weight)
                 self.stride = (1, 1)
                 self.padding = (0, 0)
                 self.dilation = (1, 1)
                 self.groups = 1
 
-            def forward(self, x):
-                return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
+            def forward(self, x, weight):
+                return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)
 
         conv_input = torch.rand(1, 3, 224, 224)
         conv_weight = torch.rand(3, 3, 3, 3)
 
         class Linear(torch.nn.Module):
-            def __init__(self, weight):
+            def __init__(self):
                 super().__init__()
-                self.weight = torch.nn.Parameter(weight)
 
-            def forward(self, x):
-                return F.linear(x, self.weight)
+            def forward(self, x, weight):
+                return F.linear(x, weight)
 
         linear_input = torch.rand(8, 5)
         linear_weight = torch.rand(10, 5)
@@ -85,62 +84,17 @@
         linear_module_input = torch.rand(8, 5)
 
         tests = [
-            (False, Conv, (conv_weight,), (conv_input,),
-             ns.call_function(torch.ops.quantized.conv2d),
-             ns.call_function(torch.ops.quantized.conv2d_prepack)),
-            (True, Linear, (linear_weight,), (linear_input,),
-             ns.call_function(torch.ops.quantized.linear_dynamic),
-             ns.call_function(torch.ops.quantized.linear_prepack)),
-            (False, Linear, (linear_weight,), (linear_input,),
-             ns.call_function(torch.ops.quantized.linear),
-             ns.call_function(torch.ops.quantized.linear_prepack)),
-            (True, LinearModule, (), (linear_module_input,),
-             ns.call_module(nnqd.Linear),
-             None),
-            (False, LinearModule, (), (linear_module_input,),
-             ns.call_module(nnq.Linear),
-             None),
+            (False, Conv, (conv_input, conv_weight), ns.call_function(torch.ops.quantized.conv2d)),
+            (True, Linear, (linear_input, linear_weight), ns.call_function(torch.ops.quantized.linear_dynamic)),
+            (False, Linear, (linear_input, linear_weight), ns.call_function(torch.ops.quantized.linear)),
+            (True, LinearModule, (linear_module_input,), ns.call_module(nnqd.Linear)),
+            (False, LinearModule, (linear_module_input,), ns.call_module(nnq.Linear)),
         ]
-        return tests
 
-    """
-    Unit tests for functionalities
-    """
-    @skipIfNoFBGEMM
-    def test_functional_no_debug(self):
-        """ Test quantizing functional conv and linear
-        """
-        tests = self._get_conv_linear_test_cases()
-        for (is_dynamic, ModuleClass, module_constructor_inputs,
-             inputs, quantized_node, weight_prepack_node) in tests:
+        for is_dynamic, M, inputs, quantized_node in tests:
             quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
-            node_occurrence = dict()
-            if weight_prepack_node:
-                node_occurrence[weight_prepack_node] = 0
             self.checkGraphModeFxOp(
-                ModuleClass(*module_constructor_inputs),
-                inputs, quant_type,
-                expected_node=quantized_node,
-                expected_node_occurrence=node_occurrence,
-                debug=False)
-
-    @skipIfNoFBGEMM
-    def test_functional_debug(self):
-        """ Test quantizing functional conv and linear with debug option
-        """
-        tests = self._get_conv_linear_test_cases()
-        for (is_dynamic, ModuleClass, module_constructor_inputs,
-             inputs, quantized_node, weight_prepack_node) in tests:
-            quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
-            node_occurrence = dict()
-            if weight_prepack_node:
-                node_occurrence[weight_prepack_node] = 1
-            self.checkGraphModeFxOp(
-                ModuleClass(*module_constructor_inputs),
-                inputs, quant_type,
-                expected_node=quantized_node,
-                expected_node_occurrence=node_occurrence,
-                debug=True)
+                M(), inputs, quant_type, quantized_node)
 
 class TestQuantizeFxOps(QuantizationTestCase):
     """Unit tests for individual ops
diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py
index cc042d4..f818d38 100644
--- a/torch/quantization/fx/quantize.py
+++ b/torch/quantization/fx/quantize.py
@@ -52,23 +52,6 @@
 def quantize(quantizer, node):
     quantize_node(node, quantizer.activation_post_process_map[node.name])
 
-# Returns a function that can get a new attribute name for module with given prefix
-# for example,
-# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
-# >> new_name = get_new_observer_name(module)
-# new_name will be an unused attribute name on module, e.g. `_observer_1`
-def get_new_attr_name_with_prefix(prefix):
-    def get_new_attr_name(module):
-        def get_attr_name(i):
-            return prefix + str(i)
-        i = 0
-        attr_name = get_attr_name(i)
-        while hasattr(module, attr_name):
-            i += 1
-            attr_name = get_attr_name(i)
-        return attr_name
-    return get_new_attr_name
-
 # A dictionary for querying the weight index for a given op
 WEIGHT_INDEX_DICT = {
     torch.nn.functional.conv2d : [1],
@@ -265,14 +248,14 @@
                 # pack weight
                 weight = load_arg(quantized=True)(self.conv_node.args[1])
                 other_args = load_arg(quantized=False)(self.conv_node.args[2:])
-                prepack_args = tuple([weight] + list(other_args))
+                prepack_args = [weight] + list(other_args)
                 packed_weight = quantizer.quantized_graph.create_node(
                     'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
                 # construct conv input
                 conv_input = load_arg(quantized=True)(self.conv_node.args[0])
                 activation_post_process = quantizer.activation_post_process_map[self.conv_node.name]
                 scale, zero_point, _ = get_qparams(activation_post_process)
-                qconv_args = (conv_input, packed_weight, scale, zero_point)
+                qconv_args = [conv_input, packed_weight, scale, zero_point]
                 kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
                 return quantizer.quantized_graph.create_node(
                     'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs)
@@ -335,16 +318,12 @@
                     linear_out,
                     quantizer.activation_post_process_map[self.linear_node.name])
             else:
-                # TODO: this code can be merged with dynamic linear code
-                # linear args
-                # (x, weight, bias, ...)
                 args = load_arg(quantized=[0, 1])(self.linear_node.args)
                 kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
                 # pack weight
                 weight = load_arg(quantized=True)(self.linear_node.args[1])
                 bias = None
-                # all args after bias, including bias
-                other_args = load_arg(quantized=False)(self.linear_node.args[2:])
+                other_args = load_arg(quantized=False)(self.linear_node.args[1:])
                 if len(self.linear_node.args) > 2:
                     bias = load_arg(quantized=False)(self.linear_node.args[2])
                     other_args = other_args[1:]  # remove the bias argument
@@ -353,7 +332,7 @@
                         'expect bias provided as a keyword argument when it is not a positional argument'
                     bias = kwargs['bias']
                     kwargs.pop('bias')
-                prepack_args = (weight, bias)
+                prepack_args = [weight, bias]
                 packed_weight = quantizer.quantized_graph.create_node(
                     'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
                 # construct linear input
@@ -361,7 +340,7 @@
                 activation_post_process = \
                     quantizer.activation_post_process_map[self.linear_node.name]
                 scale, zero_point, _ = get_qparams(activation_post_process)
-                qlinear_args = (linear_input, packed_weight, scale, zero_point)
+                qlinear_args = [linear_input, packed_weight, scale, zero_point]
                 return quantizer.quantized_graph.create_node(
                     'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
 
@@ -583,14 +562,13 @@
                 return quantizer.quantized_graph.create_node(
                     'call_function', torch.nn.functional.linear, args, kwargs)
             else:
-                # linear args:
-                # (x, weight, bias)
-                # quantize weight
-                quantized_weight = load_arg(quantized=True)(self.linear_node.args[1])
-                bias = None
-                # all args after bias, including bias
-                other_args = load_arg(quantized=False)(self.linear_node.args[2:])
+                # quantize and dequantize weight
+                args = load_arg(quantized=[1])(self.linear_node.args)
                 kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
+                # pack weight
+                weight = load_arg(quantized=True)(self.linear_node.args[1])
+                bias = None
+                other_args = load_arg(quantized=False)(self.linear_node.args[1:])
                 if len(self.linear_node.args) > 2:
                     bias = load_arg(quantized=False)(self.linear_node.args[2])
                     other_args = other_args[1:]  # remove the bias argument
@@ -599,23 +577,15 @@
                         'expect bias provided as a keyword argument when it is not a positional argument'
                     bias = kwargs['bias']
                     kwargs.pop('bias')
-                prepack_args = (quantized_weight, bias)
-                # pack weight
+                prepack_args = [weight, bias]
                 packed_weight = quantizer.quantized_graph.create_node(
                     'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
                 # construct dynamic linear input
-                non_quantized_input = load_arg(quantized=False)(self.linear_node.args[0])
-                qdynamic_linear_args = (non_quantized_input, packed_weight)
+                linear_input = load_arg(quantized=False)(self.linear_node.args[0])
+                qdynamic_linear_args = [linear_input, packed_weight]
                 return quantizer.quantized_graph.create_node(
                     'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)
 
-
-# weight prepacking ops
-WEIGHT_PREPACK_OPS = {
-    torch._ops.ops.quantized.linear_prepack,
-    torch._ops.ops.quantized.conv2d_prepack,
-}
-
 class Quantizer:
     def __init__(self):
         # mapping from matched node to activation_post_process
@@ -687,7 +657,16 @@
             if node.name in observed:
                 continue
 
-            get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_')
+            def get_new_observer_name(parent_module):
+                i = 0
+
+                def get_observer_name(i):
+                    return 'activation_post_process_' + str(i)
+                observer_name = get_observer_name(i)
+                while hasattr(parent_module, observer_name):
+                    i += 1
+                    observer_name = get_observer_name(i)
+                return observer_name
             root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
             if root_node is None:
                 env[node.name] = observed_graph.node_copy(node, load_arg)
@@ -766,7 +745,7 @@
     def prepare_dynamic(self, model, qconfig_dict, inplace=False):
         return self._prepare(model, qconfig_dict, inplace, is_dynamic_quant=True)
 
-    def _convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
+    def convert(self, observed, inplace=False, debug=False, is_dynamic_quant=False):
         assert not inplace, 'inplace convert is not supported yet'
         self.restore_state(observed)
         self.is_dynamic_quant = is_dynamic_quant
@@ -932,7 +911,7 @@
                         if parent_name:
                             qparam_full_path = parent_name + '.' + qparam_full_path
                         inputs.append(self.quantized_graph.create_node('get_param', qparam_full_path))
-                    quant_env[node.name] = self.quantized_graph.create_node('call_function', quantize_op, tuple(inputs), {})
+                    quant_env[node.name] = self.quantized_graph.create_node('call_function', quantize_op, inputs, {})
                     continue
             # dequantize inputs for the node that are not quantized
             env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
@@ -947,84 +926,6 @@
             delattr(observed_root, n)
         return GraphModule(observed_root, self.quantized_graph)
 
-    # Trace back from the weight node util we hit getattr, reconstruct the graph module
-    # with the traced nodes and run the graph module to pack the weight. then replace
-    # the original chain of ops with the packed weight.
-    def _fold_weight(self, quantized):
-        def collect_nodes_to_fold(node):
-            nodes = [node]
-            frontier = [node]
-            while frontier:
-                node = frontier.pop()
-                all_args = list(node.args) + list(node.kwargs.values())
-                for arg in all_args:
-                    if not isinstance(arg, Node):
-                        continue
-                    if arg.op == 'placeholder':
-                        # hit input, can't fold in this case
-                        return None
-                    nodes.append(arg)
-                    if not (arg.op == 'call_function' and arg.target == getattr):
-                        frontier.append(arg)
-            return nodes
-
-        packed_weights = dict()
-        # map from folded node name to the prepacked weight name
-        folded_nodes = dict()
-        # get packed weights
-        for node in quantized.graph.nodes:
-            if node.op == 'call_function' and node.target in WEIGHT_PREPACK_OPS:
-                nodes_to_fold = collect_nodes_to_fold(node)
-                if nodes_to_fold is not None:
-                    # since we traced back from weight node to getattrr
-                    nodes_to_fold.reverse()
-                    prepacking_graph = Graph()
-                    env = {}
-
-                    def load_arg(a):
-                        return map_arg(a, lambda node: env[node.name])
-                    for node_to_fold in nodes_to_fold:
-                        env[node_to_fold.name] = prepacking_graph.node_copy(node_to_fold, load_arg)
-                        folded_nodes[node_to_fold.name] = node
-                    prepacking_graph.output(load_arg(node.name))
-                    prepacking_module = GraphModule(quantized.root, prepacking_graph)
-                    packed_weight = prepacking_module()
-                    packed_weights[node.name] = packed_weight
-
-        # remove folded nodes and replace the prepacking node with getattr
-        folded_graph = Graph()
-        env = {}
-
-        def load_arg(a):
-            return map_arg(a, lambda node: env[node.name])
-        get_new_packed_weight_name = get_new_attr_name_with_prefix('_fx_pass_packed_weight_')
-        quantized_root = quantized.root
-        quantized_graph = quantized.graph
-        for node in quantized_graph.nodes:
-            prepack_node = folded_nodes.get(node.name, None)
-            if prepack_node is node:
-                packed_weight = packed_weights[node.name]
-                # add a prepacked attribute to root
-                packed_weight_name = get_new_packed_weight_name(quantized_root)
-                setattr(quantized_root, packed_weight_name, packed_weight)
-                # replace prepack node with a getattr node
-                env[node.name] = folded_graph.create_node(
-                    'get_param', packed_weight_name, (), {})
-            elif prepack_node is not None:
-                # remove the foled node
-                continue
-            else:
-                # copy other nodes
-                env[node.name] = folded_graph.node_copy(node, load_arg)
-        folded_graph.output(load_arg(quantized_graph.result))
-        return GraphModule(quantized_root, folded_graph)
-
-    def convert(self, observed, inplace=False, debug=False, is_dynamic=False):
-        quantized = self._convert(observed, inplace, debug, is_dynamic)
-        if not debug:
-            quantized = self._fold_weight(quantized)
-        return quantized
-
     def _find_matches(self, graph, modules, patterns):
         match_map = {}  # node name -> (root_node, match_value?)
         all_matched = set()
diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py
index 3a20104..137ed38 100644
--- a/torch/testing/_internal/common_quantization.py
+++ b/torch/testing/_internal/common_quantization.py
@@ -79,9 +79,6 @@
 
         return self.op == other.op and self.target == other.target
 
-    def __repr__(self):
-        return repr(self.op) + " " + repr(self.target)
-
 def test_only_eval_fn(model, calib_data):
     r"""
     Default evaluation function takes a torch.utils.data.Dataset or a list of
@@ -548,21 +545,15 @@
 
         if expected_node_occurrence is not None:
             for expected_node, occurrence in expected_node_occurrence.items():
-                if occurrence != 0:
-                    self.assertTrue(
-                        expected_node in nodes_in_graph,
-                        'Check failed for node:' + str(expected_node) +
-                        ' not found')
-                    self.assertTrue(
-                        nodes_in_graph[expected_node] == occurrence,
-                        'Check failed for node:' + str(expected_node) +
-                        ' Expected occurrence:' + str(occurrence) +
-                        ' Found occurrence:' + str(nodes_in_graph[expected_node]))
-                else:
-                    self.assertTrue(
-                        expected_node not in nodes_in_graph,
-                        'Check failed for node:' + str(expected_node) +
-                        ' expected no occurrence but found')
+                self.assertTrue(
+                    expected_node in nodes_in_graph,
+                    'Check failed for node:' + str(expected_node) +
+                    ' not found')
+                self.assertTrue(
+                    nodes_in_graph[expected_node] == occurrence,
+                    'Check failed for node:' + str(expected_node) +
+                    ' Expected occurrence:' + str(occurrence) +
+                    ' Found occurrence:' + str(nodes_in_graph[expected_node]))
 
         if expected_node_list is not None:
             cur_index = 0
@@ -595,8 +586,7 @@
                            expected_node=None,
                            expected_node_occurrence=None,
                            expected_node_list=None,
-                           debug=False,
-                           print_debug_info=False):
+                           debug=False):
         """ Quantizes model with graph mode quantization on fx and check if the
         quantized model contains the quantized_node
 
@@ -645,7 +635,7 @@
         self.assertEqual((result - result_debug).abs().max(), 0), \
             'Expecting debug and non-debug option to produce identical result'
 
-        if print_debug_info:
+        if debug:
             print()
             print('quant type:', quant_type)
             print('origianl graph module:', type(model))
@@ -654,9 +644,8 @@
             print('quantized graph module:', type(qgraph))
             self.printGraphModule(qgraph)
             print()
-        qgraph_to_check = qgraph_debug if debug else qgraph
         self.checkGraphModuleNodes(
-            qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list)
+            qgraph, expected_node, expected_node_occurrence, expected_node_list)
 
 # Below are a series of neural net models to use in testing quantization
 # Single layer models