Multihooks should not keep tensor alive in closure (#102859)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102859
Approved by: https://github.com/albanD
diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py
index a0530c4..dc657e1 100644
--- a/torch/autograd/graph.py
+++ b/torch/autograd/graph.py
@@ -366,6 +366,7 @@
return t.grad_fn
grad_fns = list(map(get_grad_fn, tensors))
+ len_tensors = len(tensors)
def get_inner_hook(idx):
def inner_hook(grad: torch.Tensor):
@@ -373,7 +374,7 @@
id = torch._C._current_graph_task_id()
assert id != -1, "expected this hook to be called inside a backward call"
count[id] = count.get(id, 0)
- buffer[id] = buffer.get(id, [None] * len(tensors))
+ buffer[id] = buffer.get(id, [None] * len_tensors)
if count[id] == 0:
# On the first call, compute the actual nb_calls and buffer