[optim] rmsprop: handle complex params as independent real params (#83860)
Ref: #65711
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83860
Approved by: https://github.com/albanD
diff --git a/test/test_optim.py b/test/test_optim.py
index 40cc5f9..f58fa51 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -331,7 +331,7 @@
optim1 = optimizer_constructor([a1])
optim2 = optimizer_constructor([a1_real, a1_imag])
- for i in range(10):
+ for _ in range(10):
optim1.zero_grad()
optim2.zero_grad()
a2 = torch.complex(a1_real, a1_imag)
@@ -871,6 +871,14 @@
lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize),
constructor_accepts_maximize=True
)
+ self._test_complex_2d(optimizer)
+ self._test_complex_2d(lambda param: optimizer(param, centered=True))
+ self._test_complex_2d(lambda param: optimizer(param, momentum=0.1))
+ self._test_complex_2d(lambda param: optimizer(param, maximize=True))
+ self._test_complex_optimizer(lambda param: optimizer([param]))
+ self._test_complex_optimizer(lambda param: optimizer([param], centered=True))
+ self._test_complex_optimizer(lambda param: optimizer([param], momentum=0.1))
+ self._test_complex_optimizer(lambda param: optimizer([param], maximize=True))
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
optimizer(None, lr=1e-2, momentum=-1.0)
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index 20d7323..22a5bd4 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -236,10 +236,18 @@
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
+ is_complex_param = torch.is_complex(param)
+ if is_complex_param:
+ param = torch.view_as_real(param)
+ grad = torch.view_as_real(grad)
+ square_avg = torch.view_as_real(square_avg)
+
square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
if centered:
grad_avg = grad_avgs[i]
+ if is_complex_param:
+ grad_avg = torch.view_as_real(grad_avg)
grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
else:
@@ -252,6 +260,8 @@
if momentum > 0:
buf = momentum_buffer_list[i]
+ if is_complex_param:
+ buf = torch.view_as_real(buf)
buf.mul_(momentum).addcdiv_(grad, avg)
param.add_(buf, alpha=-lr)
else:
@@ -284,10 +294,18 @@
if weight_decay != 0:
torch._foreach_add_(grads, params, alpha=weight_decay)
+ def _view_complex_as_real(tensor_list):
+ return [torch.view_as_real(t) if torch.is_complex(t) else t for t in tensor_list]
+
+ grads = _view_complex_as_real(grads)
+ params = _view_complex_as_real(params)
+ square_avgs = _view_complex_as_real(square_avgs)
+
torch._foreach_mul_(square_avgs, alpha)
torch._foreach_addcmul_(square_avgs, grads, grads, value=1 - alpha)
if centered:
+ grad_avgs = _view_complex_as_real(grad_avgs)
torch._foreach_mul_(grad_avgs, alpha)
torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
avg = torch._foreach_addcmul(square_avgs, grad_avgs, grad_avgs, value=-1)
@@ -298,6 +316,7 @@
torch._foreach_add_(avg, eps)
if momentum > 0:
+ momentum_buffer_list = _view_complex_as_real(momentum_buffer_list)
torch._foreach_mul_(momentum_buffer_list, momentum)
torch._foreach_addcdiv_(momentum_buffer_list, grads, avg)
torch._foreach_add_(params, momentum_buffer_list, alpha=-lr)