dbr quant: refactor `get_func_output_obs_type` to only use `seen_op_info` (#68341)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68341
Before this PR, `get_func_output_obs_type` used information from the
incoming op and its arguments, which makes it hard to cache.
This PR refactors `get_func_output_obs_type` to only use information
collected during tracing. This will make it easier to make performance
improvements in a future PR.
Test Plan:
```
python test/test_quantization.py TestQuantizeDBR
```
Reviewed By: jerryzh168
Differential Revision: D32463755
Pulled By: vkuzo
fbshipit-source-id: 25a220de652f0285685d43aedf7392082104b26c
diff --git a/test/quantization/dbr/test_quantize_dbr.py b/test/quantization/dbr/test_quantize_dbr.py
index bf107fe..6def349 100644
--- a/test/quantization/dbr/test_quantize_dbr.py
+++ b/test/quantization/dbr/test_quantize_dbr.py
@@ -548,6 +548,20 @@
qconfig = torch.quantization.default_qconfig
self._test_auto_tracing(model_fp32, qconfig, (torch.randn(1, 1, 2, 2),))
+ def test_add_int32(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ x = x + x
+ return x
+
+ model_fp32 = M().eval()
+ qconfig = torch.quantization.default_qconfig
+ self._test_auto_tracing(
+ model_fp32, qconfig, (torch.ones(1, 1, 2, 2, dtype=torch.int32),),
+ # FX graph mode quantization does not automatically detect
+ # tensor inputs in non-float dtypes.
+ do_fx_comparison=False)
+
def test_module_then_add(self):
class M(torch.nn.Module):
def __init__(self):
diff --git a/torch/ao/quantization/_dbr/quantization_state.py b/torch/ao/quantization/_dbr/quantization_state.py
index 550f50e..7579a9a 100644
--- a/torch/ao/quantization/_dbr/quantization_state.py
+++ b/torch/ao/quantization/_dbr/quantization_state.py
@@ -277,8 +277,7 @@
"""
assert self.cur_op_needs_hooks(op)
seen_op_info = self._get_cur_seen_op_info()
- func_output_obs_type = get_func_output_obs_type(
- op, args, seen_op_info.op_packing_only_uses_module_attributes)
+ func_output_obs_type = get_func_output_obs_type(seen_op_info)
if first_call:
self._first_call_op_prepare_after_hook_adjust_subgraphs(
op, output, args, first_call, qtensor_id, root_module,
diff --git a/torch/ao/quantization/_dbr/utils.py b/torch/ao/quantization/_dbr/utils.py
index ac174ad..27a6f85 100644
--- a/torch/ao/quantization/_dbr/utils.py
+++ b/torch/ao/quantization/_dbr/utils.py
@@ -274,31 +274,37 @@
REUSES_FIRST_INPUT_OBS = 2
def get_func_output_obs_type(
- op: Callable,
- args: Tuple[Any, ...],
- op_packing_only_uses_module_attributes: bool,
+ seen_op_info: SeenOpInfo,
) -> FuncOutputObsType:
- if isinstance(op, torch.nn.Module):
+ op_type = seen_op_info.type
+ is_module = isinstance(op_type, type(torch.nn.Module))
+ if is_module:
return FuncOutputObsType.NONE
# check for ops which need packed weights but the weights are
# coming from another function
- if not op_packing_only_uses_module_attributes:
+ if not seen_op_info.op_packing_only_uses_module_attributes:
return FuncOutputObsType.NONE
- if op in add_and_mul_ops:
- if len(args) > 0 and args[0].dtype in (torch.int32, torch.int64):
+ if op_type in add_and_mul_ops:
+ if (
+ len(seen_op_info.input_tensor_infos) > 0 and
+ seen_op_info.input_tensor_infos[0].inf_dtype in (torch.int32, torch.int64)
+ ):
# this is handling ops on dtypes such as torch.int
return FuncOutputObsType.NONE
elif (
- len(args) > 1 and
- (not isinstance(args[1], torch.Tensor))
+ len(seen_op_info.input_tensor_infos) > 1 and
+ seen_op_info.input_tensor_infos[1] is None
):
return FuncOutputObsType.REUSES_FIRST_INPUT_OBS
- elif op in (torch.relu, F.relu):
+ elif op_type in (torch.relu, F.relu):
return FuncOutputObsType.NONE
- elif op == torch.cat:
- if len(args[0]) > 0 and args[0][0].dtype in (torch.int32, torch.int64):
+ elif op_type == torch.cat:
+ if (
+ len(seen_op_info.input_tensor_infos) > 0 and
+ seen_op_info.input_tensor_infos[0].inf_dtype in (torch.int32, torch.int64)
+ ):
return FuncOutputObsType.NONE
return FuncOutputObsType.NEW_OBS