dbr quant: stop overridding tensor getters (#70115)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70115
This PR turns off DBR quant __torch_function__ overrides on
tensor attribute getters such as `x.dtype`. This should help
with making the debug logs more readable, and reduce framework
overhead.
Test Plan:
```
python test/test_quantization.py TestQuantizeDBR
```
Reviewed By: ejguan
Differential Revision: D33189544
Pulled By: vkuzo
fbshipit-source-id: e0d664bb6b76ca9e71c8a439ae985a0849312862
diff --git a/torch/ao/quantization/_dbr/auto_trace.py b/torch/ao/quantization/_dbr/auto_trace.py
index 01cc893..1c1a6a2 100644
--- a/torch/ao/quantization/_dbr/auto_trace.py
+++ b/torch/ao/quantization/_dbr/auto_trace.py
@@ -85,11 +85,14 @@
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
nonlocal global_disable_torch_function_override
- if global_disable_torch_function_override:
- return super().__torch_function__(func, types, args, kwargs)
-
- if func == torch.Tensor.__repr__:
+ if (
+ # global override means disable the override here
+ global_disable_torch_function_override or
# to prevent printing things from going into an infinite loop
+ func == torch.Tensor.__repr__ or
+ # we don't need to override getters in this framework
+ func.__name__ == '__get__'
+ ):
return super().__torch_function__(func, types, args, kwargs)
# if we are in a function, the current module is always a parent
@@ -351,11 +354,14 @@
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
nonlocal global_disable_torch_function_override
- if global_disable_torch_function_override:
- return super().__torch_function__(func, types, args, kwargs)
-
- # to prevent printing things from going into an infinite loop
- if func == torch.Tensor.__repr__:
+ if (
+ # global override means disable the override here
+ global_disable_torch_function_override or
+ # to prevent printing things from going into an infinite loop
+ func == torch.Tensor.__repr__ or
+ # we don't need to override getters in this framework
+ func.__name__ == '__get__'
+ ):
return super().__torch_function__(func, types, args, kwargs)
kwargs = kwargs if kwargs else {}
@@ -364,14 +370,13 @@
hook_type = get_torch_function_hook_type(parent_module, func)
if enable_logging:
- with torch._C.DisableTorchFunction():
- fqn_for_logging = module_id_to_fqn.get(
- id(parent_module), 'unknown') if parent_module else None
- logger.debug(
- f" fqn:{fqn_for_logging} _tf_ {func} " +
- f"hook_type {hook_type} " +
- # f"arg_types {[type(arg) for arg in args]}) " +
- f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]}")
+ fqn_for_logging = module_id_to_fqn.get(
+ id(parent_module), 'unknown') if parent_module else None
+ logger.debug(
+ f" fqn:{fqn_for_logging} _tf_ {func} " +
+ f"hook_type {hook_type} " +
+ # f"arg_types {[type(arg) for arg in args]}) " +
+ f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]}")
if hook_type is HookType.OP_HOOKS:
qstate: AutoQuantizationState = parent_module._auto_quant_state # type: ignore[union-attr]
@@ -477,13 +482,12 @@
module_stack.append(self)
hook_type = get_module_hook_type(parent_module, cur_module)
if enable_logging:
- with torch._C.DisableTorchFunction():
- fqn_for_logging = module_id_to_fqn.get(id(self), None)
- logger.debug(
- f" fqn: {fqn_for_logging} " +
- f"_cl_ {type(self)} " +
- f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " +
- f"hook_type {hook_type}")
+ fqn_for_logging = module_id_to_fqn.get(id(self), None)
+ logger.debug(
+ f" fqn: {fqn_for_logging} " +
+ f"_cl_ {type(self)} " +
+ f"arg_dtypes {[arg.dtype if isinstance(arg, torch.Tensor) else None for arg in args]} " +
+ f"hook_type {hook_type}")
if hook_type is HookType.OP_HOOKS:
# before hooks
@@ -555,13 +559,12 @@
output = orig_module_call(self, *args, **kwargs)
if enable_logging:
- with torch._C.DisableTorchFunction():
- fqn_for_logging = module_id_to_fqn.get(id(self), None)
- logger.debug(
- f" fqn: {fqn_for_logging} " +
- f"_cl_ {type(self)} " +
- f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " +
- "end")
+ fqn_for_logging = module_id_to_fqn.get(id(self), None)
+ logger.debug(
+ f" fqn: {fqn_for_logging} " +
+ f"_cl_ {type(self)} " +
+ f"dtype {output.dtype if isinstance(output, torch.Tensor) else None} " +
+ "end")
return output
finally:
module_stack.pop()