dbr quant: make fqn during prepare op hook required (#69726)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69726
This is a cleanup, this variable was previously optional
but it always exists, because the only way an op hook
can run if there is a parent module with an `AutoQuantizationState`
object.
Test Plan:
```
python test/test_quantization.py TestQuantizeDBR
```
Reviewed By: albanD
Differential Revision: D33003472
Pulled By: vkuzo
fbshipit-source-id: de5769194808d42b025b848667815b4e3d73b6c6
diff --git a/torch/ao/quantization/_dbr/auto_trace.py b/torch/ao/quantization/_dbr/auto_trace.py
index 5ec2dba..a7ff043 100644
--- a/torch/ao/quantization/_dbr/auto_trace.py
+++ b/torch/ao/quantization/_dbr/auto_trace.py
@@ -205,9 +205,11 @@
global_disable_torch_function_override
global_disable_torch_function_override = True
+ # mypy ignore is used instead of assert because this
+ # runs on every forward and assert has a performance cost
args, kwargs = parent_qstate.op_prepare_before_hook(
cur_module, args, kwargs, first_call, qtensor_id,
- fqn, cur_module)
+ fqn, cur_module) # type: ignore[arg-type]
# original forward
output = orig_module_call(self, *args, **kwargs)
diff --git a/torch/ao/quantization/_dbr/quantization_state.py b/torch/ao/quantization/_dbr/quantization_state.py
index fa07d68..68917be 100644
--- a/torch/ao/quantization/_dbr/quantization_state.py
+++ b/torch/ao/quantization/_dbr/quantization_state.py
@@ -247,7 +247,7 @@
kwargs: Dict[str, Any],
first_call: bool,
qtensor_id: List[int],
- fqn: Optional[str],
+ fqn: str,
root_module: torch.nn.Module,
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""
@@ -613,8 +613,7 @@
arg_tensor_infos: List[Optional[QTensorInfo]],
func_output_dtype_type: FuncOutputDTypeType,
qtensor_id: List[int],
- # TODO(future PR): figure out if we can make fqn required
- fqn: Optional[str],
+ fqn: str,
) -> None:
"""
Runs the prepare hook during first_call for individual
@@ -644,8 +643,7 @@
# which will be converted to a quant later
# TODO(future PR): share these observers if multiple ops need
# this quant.
- # TODO(future PR): make fqn required and remove the mypy ignore
- qconfig = get_cur_qconfig(self.qconfig_dict, fqn, op) # type: ignore[arg-type]
+ qconfig = get_cur_qconfig(self.qconfig_dict, fqn, op)
if qconfig is None:
# If qconfig is None, we do not need any input observers
return
@@ -665,7 +663,7 @@
kwargs: Dict[str, Any],
first_call: bool,
qtensor_id: List[int],
- fqn: Optional[str],
+ fqn: str,
root_module: torch.nn.Module,
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""