Fixed increasing CPU overhead of `RemovableHandle.__init__` (#122847)
For some reason, if we construct `class Handle(RemovableHandle` inside `register_multi_grad_hook`, then over time, the call to `RemovableHandle.__init__` slows down more and more (when we have GC disabled). Perhaps, this is related to the class attribute `next_id: int = 0`. Python experts: please let me know if you have thoughts 😅
I am open to any suggestions on if how we should deal with this `Handle` class. For now, I changed it to a private `_MultiHandle`.
<details>
<summary> Experiment Script </summary>
```
import gc
import time
import torch
NUM_TENSORS = int(5e4)
ts = [torch.empty(1, requires_grad=True) for _ in range(NUM_TENSORS)]
def hook(grad) -> None:
return
gc.disable()
times = []
for i, t in enumerate(ts):
start_time = time.time()
torch.autograd.graph.register_multi_grad_hook([t], hook)
end_time = time.time()
times.append(end_time - start_time)
print([f"{t * 1e6:.3f} us" for t in times[1:6]]) # print first few times
print([f"{t * 1e6:.3f} us" for t in times[-5:]]) # print last few times
times = []
for i, t in enumerate(ts):
start_time = time.time()
t.register_hook(hook)
end_time = time.time()
times.append(end_time - start_time)
print([f"{t * 1e6:.3f} us" for t in times[1:6]]) # print first few times
print([f"{t * 1e6:.3f} us" for t in times[-5:]]) # print last few times
```
</details>
<details>
<summary> Results </summary>
Before fix:
```
['23.603 us', '19.550 us', '15.497 us', '12.875 us', '13.828 us']
['327.110 us', '341.177 us', '329.733 us', '332.832 us', '341.177 us']
['318.050 us', '315.189 us', '319.719 us', '311.613 us', '308.990 us']
['374.317 us', '394.821 us', '350.714 us', '337.362 us', '331.402 us']
```
Calling `register_multi_grad_hook` makes calling itself and `register_hook` slower (actually, any call to `RemovableHandle.__init__`).
After fix:
```
['13.590 us', '9.060 us', '12.875 us', '7.153 us', '8.583 us']
['4.530 us', '5.245 us', '6.437 us', '4.768 us', '5.007 us']
['2.623 us', '1.907 us', '1.431 us', '1.669 us', '1.192 us']
['1.431 us', '1.431 us', '1.192 us', '1.192 us', '1.431 us']
```
</details>
Update: from @soulitzer
> Your suspicion about next_id is right. I think what is happening is that whenever a class attribute is set, it needs to invalidate some cached data for the subclasses one-by-one. https://github.com/python/cpython/blob/eefff682f09394fe4f18b7d7c6ac4c635caadd02/Objects/typeobject.c#L845
And this PR fixes the issue by avoiding creating many subclasses dynamically. Changing next_id to something like List[int] or incrementing a global instead also fixes this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122847
Approved by: https://github.com/soulitzer
ghstack dependencies: #122726
diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py
index 2aca5b3..19938c1 100644
--- a/torch/autograd/graph.py
+++ b/torch/autograd/graph.py
@@ -382,6 +382,23 @@
torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
+class _MultiHandle(RemovableHandle):
+ handles: Tuple[RemovableHandle, ...]
+
+ def __init__(self, handles: Tuple[RemovableHandle, ...]):
+ self.handles = handles
+
+ def remove(self):
+ for handle in self.handles:
+ handle.remove()
+
+ def __getstate__(self):
+ return self.handles
+
+ def __setstate__(self, state):
+ self.handles = state
+
+
def register_multi_grad_hook(
tensors: Sequence[torch.Tensor],
fn: Union[
@@ -442,22 +459,6 @@
if mode not in supported_modes:
raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}")
- class Handle(RemovableHandle):
- handles: Tuple[RemovableHandle, ...]
-
- def __init__(self, handles: Tuple[RemovableHandle, ...]):
- self.handles = handles
-
- def remove(self):
- for handle in self.handles:
- handle.remove()
-
- def __getstate__(self):
- return self.handles
-
- def __setstate__(self, state):
- self.handles = state
-
if mode == "all":
count: Dict[int, int] = dict()
nb_calls = None
@@ -516,7 +517,7 @@
if tensor.requires_grad
)
- return Handle(handles) # type: ignore[possibly-undefined]
+ return _MultiHandle(handles) # type: ignore[possibly-undefined]
# NOTE [Allow mutation on tensors saved for backward]