dbr quant: refactor `get_func_output_obs_type` to only use `seen_op_info` (#68341)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68341

Before this PR, `get_func_output_obs_type` used information from the
incoming op and its arguments, which makes it hard to cache.

This PR refactors `get_func_output_obs_type` to only use information
collected during tracing. This will make it easier to make performance
improvements in a future PR.

Test Plan:
```
python test/test_quantization.py TestQuantizeDBR
```

Reviewed By: jerryzh168

Differential Revision: D32463755

Pulled By: vkuzo

fbshipit-source-id: 25a220de652f0285685d43aedf7392082104b26c
diff --git a/test/quantization/dbr/test_quantize_dbr.py b/test/quantization/dbr/test_quantize_dbr.py
index bf107fe..6def349 100644
--- a/test/quantization/dbr/test_quantize_dbr.py
+++ b/test/quantization/dbr/test_quantize_dbr.py
@@ -548,6 +548,20 @@
         qconfig = torch.quantization.default_qconfig
         self._test_auto_tracing(model_fp32, qconfig, (torch.randn(1, 1, 2, 2),))
 
+    def test_add_int32(self):
+        class M(torch.nn.Module):
+            def forward(self, x):
+                x = x + x
+                return x
+
+        model_fp32 = M().eval()
+        qconfig = torch.quantization.default_qconfig
+        self._test_auto_tracing(
+            model_fp32, qconfig, (torch.ones(1, 1, 2, 2, dtype=torch.int32),),
+            # FX graph mode quantization does not automatically detect
+            # tensor inputs in non-float dtypes.
+            do_fx_comparison=False)
+
     def test_module_then_add(self):
         class M(torch.nn.Module):
             def __init__(self):
diff --git a/torch/ao/quantization/_dbr/quantization_state.py b/torch/ao/quantization/_dbr/quantization_state.py
index 550f50e..7579a9a 100644
--- a/torch/ao/quantization/_dbr/quantization_state.py
+++ b/torch/ao/quantization/_dbr/quantization_state.py
@@ -277,8 +277,7 @@
         """
         assert self.cur_op_needs_hooks(op)
         seen_op_info = self._get_cur_seen_op_info()
-        func_output_obs_type = get_func_output_obs_type(
-            op, args, seen_op_info.op_packing_only_uses_module_attributes)
+        func_output_obs_type = get_func_output_obs_type(seen_op_info)
         if first_call:
             self._first_call_op_prepare_after_hook_adjust_subgraphs(
                 op, output, args, first_call, qtensor_id, root_module,
diff --git a/torch/ao/quantization/_dbr/utils.py b/torch/ao/quantization/_dbr/utils.py
index ac174ad..27a6f85 100644
--- a/torch/ao/quantization/_dbr/utils.py
+++ b/torch/ao/quantization/_dbr/utils.py
@@ -274,31 +274,37 @@
     REUSES_FIRST_INPUT_OBS = 2
 
 def get_func_output_obs_type(
-    op: Callable,
-    args: Tuple[Any, ...],
-    op_packing_only_uses_module_attributes: bool,
+    seen_op_info: SeenOpInfo,
 ) -> FuncOutputObsType:
-    if isinstance(op, torch.nn.Module):
+    op_type = seen_op_info.type
+    is_module = isinstance(op_type, type(torch.nn.Module))
+    if is_module:
         return FuncOutputObsType.NONE
 
     # check for ops which need packed weights but the weights are
     # coming from another function
-    if not op_packing_only_uses_module_attributes:
+    if not seen_op_info.op_packing_only_uses_module_attributes:
         return FuncOutputObsType.NONE
 
-    if op in add_and_mul_ops:
-        if len(args) > 0 and args[0].dtype in (torch.int32, torch.int64):
+    if op_type in add_and_mul_ops:
+        if (
+            len(seen_op_info.input_tensor_infos) > 0 and
+            seen_op_info.input_tensor_infos[0].inf_dtype in (torch.int32, torch.int64)
+        ):
             # this is handling ops on dtypes such as torch.int
             return FuncOutputObsType.NONE
         elif (
-            len(args) > 1 and
-            (not isinstance(args[1], torch.Tensor))
+            len(seen_op_info.input_tensor_infos) > 1 and
+            seen_op_info.input_tensor_infos[1] is None
         ):
             return FuncOutputObsType.REUSES_FIRST_INPUT_OBS
-    elif op in (torch.relu, F.relu):
+    elif op_type in (torch.relu, F.relu):
         return FuncOutputObsType.NONE
-    elif op == torch.cat:
-        if len(args[0]) > 0 and args[0][0].dtype in (torch.int32, torch.int64):
+    elif op_type == torch.cat:
+        if (
+            len(seen_op_info.input_tensor_infos) > 0 and
+            seen_op_info.input_tensor_infos[0].inf_dtype in (torch.int32, torch.int64)
+        ):
             return FuncOutputObsType.NONE
     return FuncOutputObsType.NEW_OBS