Handle complex optimization in Adamax by treating complex numbers as 2D real numbers (#80319)

This commit partially addresses #65711

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80319
Approved by: https://github.com/albanD
diff --git a/test/test_optim.py b/test/test_optim.py
index 2c5c107..a48a087 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -794,6 +794,7 @@
                     [weight, bias], lr=1e-1, weight_decay=1, maximize=maximize),
                 constructor_accepts_maximize=True
             )
+            self._test_complex_2d(optimizer)
             with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"):
                 optimizer(None, lr=1e-2, betas=(0.0, 1.0))
 
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index 8fe74bf..bb45764 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -215,6 +215,12 @@
         if weight_decay != 0:
             grad = grad.add(param, alpha=weight_decay)
 
+        if torch.is_complex(param):
+            param = torch.view_as_real(param)
+            grad = torch.view_as_real(grad)
+            exp_avg = torch.view_as_real(exp_avg)
+            exp_inf = torch.view_as_real(exp_inf)
+
         # Update biased first moment estimate.
         exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
         # Update the exponentially weighted infinity norm.
@@ -249,6 +255,11 @@
     if maximize:
         grads = torch._foreach_neg(grads)
 
+    params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
+    grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
+    exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
+    exp_infs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_infs]
+
     # Update steps
     torch._foreach_add_(state_steps, 1)