reduce memory usage for centered rmsprop (#24170)
Summary:
Reduce gpu memory usage by using in-place operation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24170
Differential Revision: D16784495
Pulled By: vincentqb
fbshipit-source-id: 03820cdc9a3952b95b9af0f87d3a9bb0f21e9b4d
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index cbb28f7..caa23af 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -88,7 +88,7 @@
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.mul_(alpha).add_(1 - alpha, grad)
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt().add_(group['eps'])
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt_().add_(group['eps'])
else:
avg = square_avg.sqrt().add_(group['eps'])