[quant][fx] Remove Standalone and CustomModule QuantizeHandler type checks in prepare (#75202)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75202

Instead of checking the type we use a method in the QuantizeHandler to check if a module
is a standalone or custom module, not user facing

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps

Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D35379641

fbshipit-source-id: c2f970c7e27f74793fa67f8fd5a16a43525e35aa
(cherry picked from commit 251500f06359c9046dd9067543cc80be24ddee33)
diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py
index 38fb6e8..a1217ec 100644
--- a/torch/ao/quantization/fx/match_utils.py
+++ b/torch/ao/quantization/fx/match_utils.py
@@ -7,8 +7,6 @@
 from .quantization_types import Pattern
 from .quantization_patterns import (
     QuantizeHandler,
-    CustomModuleQuantizeHandler,
-    StandaloneModuleQuantizeHandler,
 )
 from ..qconfig import (
     QConfigAny,
@@ -198,7 +196,7 @@
            type(modules[node.target]) in custom_module_classes:
             custom_module_qconfig = qconfig_map[node.name]
             match_map[node.name] = (
-                node, node, None, CustomModuleQuantizeHandler(node, modules),
+                node, node, None, QuantizeHandler(node, modules, is_custom_module=True),
                 custom_module_qconfig)
 
     def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
@@ -214,10 +212,10 @@
            (is_standalone_module(node.target, modules) or
                 is_observed_standalone_module(modules[node.target])):
             # add node to matched nodes
-            custom_module_qconfig = qconfig_map[node.name]
+            standalone_module_qconfig = qconfig_map[node.name]
             match_map[node.name] = (
                 node, node, None,
-                StandaloneModuleQuantizeHandler(node, modules),
-                custom_module_qconfig)
+                QuantizeHandler(node, modules, is_standalone_module=True),
+                standalone_module_qconfig)
 
     return match_map
diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py
index 72fcdcd..d5fee7a 100644
--- a/torch/ao/quantization/fx/prepare.py
+++ b/torch/ao/quantization/fx/prepare.py
@@ -30,8 +30,6 @@
 
 from .quantization_patterns import (
     QuantizeHandler,
-    CustomModuleQuantizeHandler,
-    StandaloneModuleQuantizeHandler,
 )
 
 from .quantization_types import (
@@ -488,8 +486,7 @@
     # default (no observer)
     new_arg = arg
 
-    is_standalone_module = qhandler is not None and \
-        isinstance(qhandler, StandaloneModuleQuantizeHandler)
+    is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
     assert qconfig is not None
     if not is_standalone_module:
         # regular flow for most nodes, except standalone modules
@@ -713,8 +710,7 @@
     assert qconfig is not None
     assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
 
-    is_standalone_module = qhandler is not None and \
-        isinstance(qhandler, StandaloneModuleQuantizeHandler)
+    is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
 
     dtype = node_name_to_target_dtype[node.name]["output_activation_dtype"]
     should_insert_observer = dtype not in DO_NOT_OBS_DTYPE_LIST + [torch.float]
@@ -1251,7 +1247,7 @@
                                 if not maybe_make_input_output_share_observers(node, model, modules):
                                     remove_output_observer(node, model, modules)
 
-                            if isinstance(qhandler, CustomModuleQuantizeHandler):
+                            if qhandler is not None and qhandler.is_custom_module():
                                 swap_custom_module_to_observed(node, qconfig, modules, prepare_custom_config_dict)
 
                 else:  # output
@@ -1294,7 +1290,7 @@
     ) in matches.items():
         if qhandler is None:
             continue
-        elif not isinstance(qhandler, StandaloneModuleQuantizeHandler):
+        elif not qhandler.is_standalone_module():
             continue
 
         sm_qconfig_dict, sm_prepare_config_dict, sm_backend_config_dict = \
diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py
index 5d81dcf..9114664 100644
--- a/torch/ao/quantization/fx/quantization_patterns.py
+++ b/torch/ao/quantization/fx/quantization_patterns.py
@@ -25,7 +25,6 @@
 import operator
 from typing import Any, Callable, Dict, Optional
 
-# this is temporary, will be removed soon
 def _default_root_node_getter(node_pattern):
     if node_pattern is None:
         return node_pattern
@@ -47,7 +46,9 @@
             self,
             node_pattern: NodePattern,
             modules: Dict[str, torch.nn.Module],
-            root_node_getter: Callable = None):
+            root_node_getter: Callable = None,
+            is_custom_module=False,
+            is_standalone_module=False):
         """ Records pattern information in __init__, which will be used
         in convert
         """
@@ -56,6 +57,8 @@
         if root_node_getter is None:
             root_node_getter = _default_root_node_getter
         self.root_node = root_node_getter(node_pattern)
+        self.is_custom_module_ = is_custom_module
+        self.is_standalone_module_ = is_standalone_module
         self.num_tensor_args = 0
         # determine how many of the first two args are Tensors (versus scalars)
         # this distinguishes things like "x + y" from "x + 2" or "2 + x"
@@ -107,6 +110,12 @@
         """
         return qconfig.activation
 
+    def is_custom_module(self):
+        return self.is_custom_module_
+
+    def is_standalone_module(self):
+        return self.is_standalone_module_
+
 @register_quant_pattern(operator.sub)
 @register_quant_pattern(operator.mul)
 @register_quant_pattern(operator.truediv)
@@ -261,9 +270,6 @@
     def is_general_tensor_value_op(self) -> bool:
         return True
 
-class CustomModuleQuantizeHandler(QuantizeHandler):
-    pass
-
 @register_quant_pattern(torch.nn.Identity)
 @register_quant_pattern(torch.transpose)
 @register_quant_pattern(torch.repeat_interleave)
@@ -298,8 +304,10 @@
     def is_general_tensor_value_op(self) -> bool:
         return True
 
+# TODO: not used, can be removed after torch.quantization namespace is deprecated
+class CustomModuleQuantizeHandler(QuantizeHandler):
+    pass
+
+# TODO: not used, can be removed after torch.quantization namespace is deprecated
 class StandaloneModuleQuantizeHandler(QuantizeHandler):
-    """ Converts an observed standalone module to quantized standalone module
-    by calling convert_fx on the observed standalone module.
-    """
     pass