commit | dd9ade6377ef480ad0c60d4dce2f9cfbfeee98e9 | [log] [tgz] |
---|---|---|
author | Aaron Gokaslan <aaronGokaslan@gmail.com> | Fri Mar 17 21:34:10 2023 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Fri Mar 17 21:34:14 2023 +0000 |
tree | cbeee588090ae39246a294e4e58100cddadec5ab | |
parent | 98a5cf090d6b460afd1b43e47cd52404e8e07c91 [diff] |
Remove unnecessary items() call in zero_grad (#97040) Micro-optimization to zero_grad() which is performance critical Pull Request resolved: https://github.com/pytorch/pytorch/pull/97040 Approved by: https://github.com/ezyang, https://github.com/albanD
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 27f9492..99954c2 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py
@@ -469,7 +469,7 @@ else: per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) if foreach: - for _, per_dtype_grads in per_device_and_dtype_grads.items(): + for per_dtype_grads in per_device_and_dtype_grads.values(): for grads in per_dtype_grads.values(): torch._foreach_zero_(grads)