dbr quant: extend qconfig_dict support to functions, part 1 (#69758)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69758

Extends DBR quant `qconfig_dict['object_type']` support to function types,
with the restriction that a parent module must have a qconfig.

A future PR will remove the restriction above (it is due to some technical
debt), to keep PR sizes small.

Test Plan:
```
python test/test_quantization.py TestQuantizeDBR
```

Reviewed By: jerryzh168

Differential Revision: D33020217

Pulled By: vkuzo

fbshipit-source-id: ce8a8185f9c87d437e1319ff6f19e8f6adf41e02
diff --git a/test/quantization/dbr/test_quantize_dbr.py b/test/quantization/dbr/test_quantize_dbr.py
index 9ad3a95..3aea006 100644
--- a/test/quantization/dbr/test_quantize_dbr.py
+++ b/test/quantization/dbr/test_quantize_dbr.py
@@ -8,10 +8,12 @@
 import torch.nn.functional as F
 import torch.nn.intrinsic as nni
 import torch.nn.quantized as nnq
+toq = torch.ops.quantized
 from torch.testing._internal.common_quantization import (
     skipIfNoFBGEMM,
     skip_if_no_torchvision,
     QuantizationTestCase,
+    NodeSpec,
 )
 from torch.quantization import (
     ObserverBase,
@@ -1025,6 +1027,40 @@
         self.assertTrue(isinstance(mq[1], nn.Hardswish))
         self.assertTrue(isinstance(mq[2], nnq.Conv2d))
 
+    def test_qconfig_dict_object_type_function(self):
+        """
+        Verifies that the 'object_type' option of qconfig_dict works
+        on function types.
+        """
+        class M(nn.Module):
+            def forward(self, x):
+                x = x + x
+                x = x * x
+                return x
+
+        m = M()
+        # TODO(future PR): also implement global qconfig being None
+        # and individual functions having qconfigs
+        qconfig_dict = {
+            '': torch.quantization.default_qconfig,
+            'object_type': [
+                (torch.add, None),
+            ],
+        }
+        example_args = (torch.randn(1, 1, 1, 1),)
+        mp = _quantize_dbr.prepare(m, qconfig_dict, example_args)
+        mp(*example_args)
+        mq = _quantize_dbr.convert(mp)
+        mq(*example_args)
+        rewritten = mq.rewrite_for_scripting()
+        expected_occurrence = {
+            NodeSpec.call_function(torch.add): 1,
+            NodeSpec.call_function(toq.add): 0,
+            NodeSpec.call_function(toq.mul): 1,
+        }
+        self.checkGraphModuleNodes(
+            rewritten, expected_node_occurrence=expected_occurrence)
+
     def test_qconfig_dict_module_name(self):
         """
         Verifies that the 'module_name' option of qconfig_dict works
diff --git a/torch/ao/quantization/_dbr/qconfig_dict_utils.py b/torch/ao/quantization/_dbr/qconfig_dict_utils.py
new file mode 100644
index 0000000..68314a8
--- /dev/null
+++ b/torch/ao/quantization/_dbr/qconfig_dict_utils.py
@@ -0,0 +1,27 @@
+from typing import Dict, Any
+
+import torch
+
+TYPE_TO_REPLACEMENT_TYPE = {
+    torch.add: torch.Tensor.add,
+    torch.Tensor.add_: torch.Tensor.add,
+    torch.mul: torch.Tensor.mul,
+    torch.Tensor.mul_: torch.Tensor.mul,
+}
+
+def normalize_object_types(qconfig_dict: Dict[str, Any]) -> None:
+    """
+    This function looks for entries in `qconfig_dict['object_type']`
+    corresponding to PyTorch overrides of Python math functions
+    such as `torch.add` and `torch.mul`. If any of these functions are found,
+    it changes the type to the tensor variant of these functions.
+    This is needed because the tensor variant is what is expected
+    within the framework.
+    """
+    if 'object_type' not in qconfig_dict:
+        return
+
+    for idx, (target_type, qconfig) in enumerate(qconfig_dict['object_type']):
+        replacement_type = TYPE_TO_REPLACEMENT_TYPE.get(target_type, None)
+        if replacement_type is not None:
+            qconfig_dict['object_type'][idx] = (replacement_type, qconfig)
diff --git a/torch/ao/quantization/_dbr/quantization_state.py b/torch/ao/quantization/_dbr/quantization_state.py
index 68917be..a1b14af 100644
--- a/torch/ao/quantization/_dbr/quantization_state.py
+++ b/torch/ao/quantization/_dbr/quantization_state.py
@@ -721,11 +721,12 @@
         if self.idx not in self.idx_to_seen_op_infos:
             op_type_is_module = isinstance(op, torch.nn.Module)
             op_type = type(op) if op_type_is_module else op
+            qconfig = get_cur_qconfig(self.qconfig_dict, fqn, op)
             self.idx_to_seen_op_infos[self.idx] = SeenOpInfo(
                 self.idx, op_type, op_type_is_module, fqn, arg_tensor_infos, [],
                 packable_tensor_idx_to_name, packable_nontensor_idx_to_arg,
                 packable_tensor_kwarg_name_to_name,
-                op_packing_only_uses_module_attributes)
+                op_packing_only_uses_module_attributes, qconfig)
 
         return args, kwargs
 
@@ -807,8 +808,12 @@
                 else:
                     dtype_to_use = torch.float
             else:
-                dtype_to_use = torch.quint8
-                # TODO(future PR): handle functions
+                qconfig = get_cur_qconfig(self.qconfig_dict, seen_op_info.fqn, op)
+                if qconfig is None:
+                    dtype_to_use = torch.float
+                else:
+                    dtype_to_use = qconfig.activation().dtype
+
         elif func_output_dtype_type == FuncOutputDTypeType.DTYPE_DEFAULT_BC_UNSUPPORTED_SYNTAX:
             dtype_to_use = torch.float
         else:
diff --git a/torch/ao/quantization/_dbr/utils.py b/torch/ao/quantization/_dbr/utils.py
index 11fb2dd..39e41e63 100644
--- a/torch/ao/quantization/_dbr/utils.py
+++ b/torch/ao/quantization/_dbr/utils.py
@@ -81,6 +81,8 @@
         # This is False if some packable args are results of other functions.
         # bool
         'op_packing_only_uses_module_attributes',
+        # QConfig for the op, can be None
+        'qconfig',
     ],
 )
 def seen_op_info_repr(self) -> str:
@@ -189,6 +191,9 @@
     if is_module:
         return FuncOutputObsType.NONE
 
+    if seen_op_info.qconfig is None:
+        return FuncOutputObsType.NONE
+
     # check for ops which need packed weights but the weights are
     # coming from another function
     if not seen_op_info.op_packing_only_uses_module_attributes:
@@ -221,6 +226,8 @@
     is_module = isinstance(op_type, type(torch.nn.Module))
     if is_module:
         return False
+    if seen_op_info.qconfig is None:
+        return False
     if op_type in add_and_mul_ops:
         # check if both arguments are tensors
         inputs = seen_op_info.input_tensor_infos
diff --git a/torch/ao/quantization/_quantize_dbr.py b/torch/ao/quantization/_quantize_dbr.py
index 983ced7..6e96427 100644
--- a/torch/ao/quantization/_quantize_dbr.py
+++ b/torch/ao/quantization/_quantize_dbr.py
@@ -2,6 +2,7 @@
 
 from ._dbr.auto_trace import add_auto_observation, add_auto_convert
 from ._dbr.fusion import get_module_fusion_fqns
+from ._dbr.qconfig_dict_utils import normalize_object_types
 
 from .qconfig_dict_utils import (
     get_flattened_qconfig_dict,
@@ -30,8 +31,8 @@
         assert qconfig_dict_option not in qconfig_dict, \
             f'{qconfig_dict_option} option of qconfig_dict is not ' + \
             'implemented yet in define-by-run quantization'
-    # TODO: assert for object_type + function non-existence
 
+    normalize_object_types(qconfig_dict)
     convert_dict_to_ordered_dict(qconfig_dict)
     flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_dict)
     torch.quantization.propagate_qconfig_(model, flattened_qconfig_dict)