[optim] rmsprop: handle complex params as independent real params (#83860)

Ref: #65711
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83860
Approved by: https://github.com/albanD
diff --git a/test/test_optim.py b/test/test_optim.py
index 40cc5f9..f58fa51 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -331,7 +331,7 @@
         optim1 = optimizer_constructor([a1])
         optim2 = optimizer_constructor([a1_real, a1_imag])
 
-        for i in range(10):
+        for _ in range(10):
             optim1.zero_grad()
             optim2.zero_grad()
             a2 = torch.complex(a1_real, a1_imag)
@@ -871,6 +871,14 @@
                     lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize),
                 constructor_accepts_maximize=True
             )
+            self._test_complex_2d(optimizer)
+            self._test_complex_2d(lambda param: optimizer(param, centered=True))
+            self._test_complex_2d(lambda param: optimizer(param, momentum=0.1))
+            self._test_complex_2d(lambda param: optimizer(param, maximize=True))
+            self._test_complex_optimizer(lambda param: optimizer([param]))
+            self._test_complex_optimizer(lambda param: optimizer([param], centered=True))
+            self._test_complex_optimizer(lambda param: optimizer([param], momentum=0.1))
+            self._test_complex_optimizer(lambda param: optimizer([param], maximize=True))
             with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
                 optimizer(None, lr=1e-2, momentum=-1.0)
 
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index 20d7323..22a5bd4 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -236,10 +236,18 @@
         if weight_decay != 0:
             grad = grad.add(param, alpha=weight_decay)
 
+        is_complex_param = torch.is_complex(param)
+        if is_complex_param:
+            param = torch.view_as_real(param)
+            grad = torch.view_as_real(grad)
+            square_avg = torch.view_as_real(square_avg)
+
         square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
 
         if centered:
             grad_avg = grad_avgs[i]
+            if is_complex_param:
+                grad_avg = torch.view_as_real(grad_avg)
             grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha)
             avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
         else:
@@ -252,6 +260,8 @@
 
         if momentum > 0:
             buf = momentum_buffer_list[i]
+            if is_complex_param:
+                buf = torch.view_as_real(buf)
             buf.mul_(momentum).addcdiv_(grad, avg)
             param.add_(buf, alpha=-lr)
         else:
@@ -284,10 +294,18 @@
     if weight_decay != 0:
         torch._foreach_add_(grads, params, alpha=weight_decay)
 
+    def _view_complex_as_real(tensor_list):
+        return [torch.view_as_real(t) if torch.is_complex(t) else t for t in tensor_list]
+
+    grads = _view_complex_as_real(grads)
+    params = _view_complex_as_real(params)
+    square_avgs = _view_complex_as_real(square_avgs)
+
     torch._foreach_mul_(square_avgs, alpha)
     torch._foreach_addcmul_(square_avgs, grads, grads, value=1 - alpha)
 
     if centered:
+        grad_avgs = _view_complex_as_real(grad_avgs)
         torch._foreach_mul_(grad_avgs, alpha)
         torch._foreach_add_(grad_avgs, grads, alpha=1 - alpha)
         avg = torch._foreach_addcmul(square_avgs, grad_avgs, grad_avgs, value=-1)
@@ -298,6 +316,7 @@
         torch._foreach_add_(avg, eps)
 
     if momentum > 0:
+        momentum_buffer_list = _view_complex_as_real(momentum_buffer_list)
         torch._foreach_mul_(momentum_buffer_list, momentum)
         torch._foreach_addcdiv_(momentum_buffer_list, grads, avg)
         torch._foreach_add_(params, momentum_buffer_list, alpha=-lr)