Migrate test_complex_optimizer to OptimizerInfo (#118160)

This PR does what it says and more.

1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about.
2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense.
3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160
Approved by: https://github.com/mikaylagawarecki
diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py
index 3a1ac3d..744b3d0 100644
--- a/test/optim/test_optim.py
+++ b/test/optim/test_optim.py
@@ -270,19 +270,6 @@
             constructor_accepts_foreach,
         )
 
-    def _test_complex_optimizer(self, optimizer_constructor):
-        complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
-        real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
-        complex_opt = optimizer_constructor(complex_param)
-        real_opt = optimizer_constructor(real_param)
-
-        for _ in range(3):
-            complex_param.grad = torch.randn_like(complex_param)
-            real_param.grad = torch.view_as_real(complex_param.grad)
-            complex_opt.step()
-            real_opt.step()
-
-            self.assertEqual(torch.view_as_real(complex_param), real_param)
 
     def _test_complex_2d(self, optimizer_constructor):
         a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
@@ -398,40 +385,6 @@
                 multi_tensor=foreach,
             )
 
-    def test_sgd_complex(self):
-        for foreach in (False, True):
-            self._test_complex_optimizer(
-                lambda param: SGD([param], lr=0.001, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: SGD([param], lr=0.001, momentum=1, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: SGD(
-                    [param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
-                )
-            )
-            self._test_complex_optimizer(
-                lambda param: SGD(
-                    [param],
-                    lr=0.001,
-                    nesterov=True,
-                    momentum=1,
-                    weight_decay=1,
-                    foreach=foreach,
-                )
-            )
-            self._test_complex_optimizer(
-                lambda param: SGD(
-                    [param],
-                    lr=0.001,
-                    momentum=1,
-                    dampening=0.5,
-                    weight_decay=1,
-                    foreach=foreach,
-                )
-            )
-
 
     def test_adam(self):
         self._test_basic_cases(
@@ -603,15 +556,6 @@
         )
 
 
-    def test_adadelta_complex(self):
-        # Handles https://github.com/pytorch/pytorch/issues/110606
-        self.rel_tol = 2e-2
-        for foreach in (False, True):
-            self._test_complex_optimizer(lambda weight: Adadelta([weight], foreach=foreach))
-            self._test_complex_optimizer(lambda weight: Adadelta([weight], rho=0.95, foreach=foreach))
-            self._test_complex_optimizer(
-                lambda weight: Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach)
-            )
 
     def test_nadam(self):
         self._test_basic_cases(
@@ -640,28 +584,6 @@
         )
 
 
-    def test_nadam_complex(self):
-        for foreach in (False, True):
-            self._test_complex_optimizer(
-                lambda param: NAdam([param], lr=1e-1, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: NAdam(
-                    [param],
-                    lr=1e-1,
-                    weight_decay=0.01,
-                    foreach=foreach,
-                )
-            )
-            self._test_complex_optimizer(
-                lambda param: NAdam(
-                    [param],
-                    lr=1e-1,
-                    momentum_decay=0.01,
-                    foreach=foreach,
-                )
-            )
-
     def test_adagrad(self):
         self._test_basic_cases(
             lambda weight, bias, maximize, foreach: Adagrad(
@@ -705,19 +627,6 @@
                 multi_tensor=foreach,
             )
 
-    def test_adagrad_complex(self):
-        for foreach in (False, True):
-            self._test_complex_optimizer(
-                lambda param: Adagrad([param], lr=1e-1, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: Adagrad(
-                    [param],
-                    lr=1e-1,
-                    initial_accumulator_value=0.1,
-                    foreach=foreach,
-                )
-            )
 
     def test_adamax(self):
         self._test_complex_2d(Adamax)
@@ -748,29 +657,6 @@
         )
 
 
-    def test_radam_complex(self):
-        for foreach in (False, True):
-            self._test_complex_optimizer(
-                lambda param: RAdam([param], lr=1e-1, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: RAdam(
-                    [param],
-                    lr=1e-1,
-                    weight_decay=0.01,
-                    foreach=foreach,
-                )
-            )
-            self._test_complex_optimizer(
-                lambda param: RAdam(
-                    [param],
-                    lr=1e-1,
-                    weight_decay=0.01,
-                    decoupled_weight_decay=True,
-                    foreach=foreach,
-                )
-            )
-
     def test_rmsprop(self):
         for foreach in (False, True):
             self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
@@ -783,40 +669,6 @@
             self._test_complex_2d(
                 lambda param: RMSprop(param, maximize=True, foreach=foreach)
             )
-            self._test_complex_optimizer(
-                lambda param: RMSprop([param], foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: RMSprop([param], centered=True, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: RMSprop([param], momentum=0.1, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: RMSprop([param], maximize=True, foreach=foreach)
-            )
-
-
-    def test_asgd(self):
-        for foreach in (False, True):
-            # Ref: https://github.com/pytorch/pytorch/issues/84560
-            # self._test_complex_2d(optimizer)
-            self._test_complex_optimizer(
-                lambda params: ASGD([params], foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda params: ASGD([params], maximize=True, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda params: ASGD(
-                    [params], maximize=True, weight_decay=0.1, foreach=foreach
-                )
-            )
-            self._test_complex_optimizer(
-                lambda params: ASGD(
-                    [params], maximize=False, weight_decay=0.1, foreach=foreach
-                )
-            )
 
 
     @skipIfRocm
@@ -824,14 +676,6 @@
     def test_rprop(self):
         for foreach in (False, True):
             self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
-            self._test_complex_optimizer(
-                lambda param: Rprop([param], lr=0.001, foreach=foreach)
-            )
-            self._test_complex_optimizer(
-                lambda param: Rprop(
-                    [param], lr=0.001, maximize=True, foreach=foreach
-                )
-            )
 
 
     def test_lbfgs_returns_consistent_type(self):
diff --git a/test/test_optim.py b/test/test_optim.py
index ea1beb8..73c0898 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -128,6 +128,30 @@
                 self.assertLess(closure().item(), initial_value)
 
 
+    @skipMPS
+    @optims(optim_db, dtypes=[torch.complex64])
+    def test_complex(self, device, dtype, optim_info):
+        optim_cls = optim_info.optim_cls
+        # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
+        all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
+        for optim_input in all_optim_inputs:
+            complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
+            real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params]
+
+            complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
+            real_optimizer = optim_cls(real_params, **optim_input.kwargs)
+
+            for _ in range(3):
+                for (c, r) in zip(complex_params, real_params):
+                    c.grad = torch.randn_like(c)
+                    r.grad = torch.view_as_real(c.grad)
+                complex_optimizer.step()
+                real_optimizer.step()
+
+                for (c, r) in zip(complex_params, real_params):
+                    self.assertEqual(torch.view_as_real(c), r)
+
+
     def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
         """
         Given a flag 'fused' or 'foreach', test for parity of optimizer state
diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py
index 1810082..ac16c13 100644
--- a/torch/optim/adadelta.py
+++ b/torch/optim/adadelta.py
@@ -286,12 +286,12 @@
 
     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
     for ((device_params, device_grads, device_square_avgs, device_acc_deltas), _) in grouped_tensors.values():
-        if maximize:
-            device_grads = torch._foreach_neg(device_grads)
-
         if has_complex:
             _view_as_real(device_params, device_grads, device_square_avgs, device_acc_deltas)
 
+        if maximize:
+            device_grads = torch._foreach_neg(device_grads)
+
         if weight_decay != 0:
             # Re-use the intermediate memory (device_grads) already allocated for maximize
             if maximize:
diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py
index 3d5ef79..ce17e2d 100644
--- a/torch/optim/adagrad.py
+++ b/torch/optim/adagrad.py
@@ -344,13 +344,13 @@
             )
             continue
 
-        if maximize:
-            device_grads = torch._foreach_neg(device_grads)
-
         # Handle complex parameters
         if has_complex:
             _view_as_real(device_params, device_grads, device_state_sums)
 
+        if maximize:
+            device_grads = torch._foreach_neg(device_grads)
+
         # Update steps
         # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
         # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index ee0223f..15abef4 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -489,9 +489,6 @@
         device_state_steps,
     ), _) in grouped_tensors.values():
 
-        if maximize:
-            device_grads = torch._foreach_neg(device_grads)
-
         # Handle complex parameters
         if has_complex:
             if amsgrad:
@@ -499,6 +496,9 @@
             else:
                 _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
 
+        if maximize:
+            device_grads = torch._foreach_neg(device_grads)
+
         # Update steps
         # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
         # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index d914110..8f0ee6c 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -330,12 +330,12 @@
 
     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
     for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
-        if maximize:
-            grouped_grads = torch._foreach_neg(grouped_grads)
-
         if has_complex:
             _view_as_real(grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs)
 
+        if maximize:
+            grouped_grads = torch._foreach_neg(grouped_grads)
+
         # Update steps
         # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
         # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py
index 071e879..bbf5fe7 100644
--- a/torch/optim/adamw.py
+++ b/torch/optim/adamw.py
@@ -521,15 +521,15 @@
         device_max_exp_avg_sqs,
         device_state_steps,
     ), _) in grouped_tensors.values():
-        if maximize:
-            device_grads = torch._foreach_neg(device_grads)
-
         if has_complex:
             if amsgrad:
                 _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs)
             else:
                 _view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
 
+        if maximize:
+            device_grads = torch._foreach_neg(device_grads)
+
         # Update steps
         # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
         # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index 5465d84..c65411a 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -313,13 +313,12 @@
     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
     for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
          grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
-        if maximize:
-            grouped_grads = torch._foreach_neg(grouped_grads)
-
-        grouped_grads = list(grouped_grads)
         if has_complex:
             _view_as_real(grouped_params, grouped_grads, grouped_axs)
 
+        if maximize:
+            grouped_grads = torch._foreach_neg(grouped_grads)
+
         # Update steps
         # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
         # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index bf7e0f7..62d28ae 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -336,6 +336,14 @@
     grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
     for (((grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs,
          grouped_momentum_buffer_list)), _) in grouped_tensors.values():
+        if has_complex:
+            state_and_grads = [grouped_grads, grouped_square_avgs]
+            if momentum > 0:
+                state_and_grads.append(grouped_momentum_buffer_list)
+            if centered:
+                state_and_grads.append(grouped_grad_avgs)
+            _view_as_real(grouped_params, *state_and_grads)
+
         if maximize:
             grouped_grads = torch._foreach_neg(grouped_grads)
 
@@ -346,16 +354,6 @@
             else:
                 grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
 
-        grouped_grads = list(grouped_grads)
-
-        if has_complex:
-            state_and_grads = [grouped_grads, grouped_square_avgs]
-            if momentum > 0:
-                state_and_grads.append(grouped_momentum_buffer_list)
-            if centered:
-                state_and_grads.append(grouped_grad_avgs)
-            _view_as_real(grouped_params, *state_and_grads)
-
         torch._foreach_mul_(grouped_square_avgs, alpha)
         torch._foreach_addcmul_(grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha)
 
diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py
index 3534172..7b53b8c 100644
--- a/torch/testing/_internal/common_optimizers.py
+++ b/torch/testing/_internal/common_optimizers.py
@@ -263,9 +263,7 @@
 def optim_inputs_func_adadelta(device=None):
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
-        OptimizerInput(
-            params=None, kwargs={"lr": 0.01}, desc="non-default lr"
-        ),  # TODO: Move out to testing in param_group?
+        OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
         OptimizerInput(
             params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
         ),
@@ -275,8 +273,8 @@
             desc="maximize",
         ),
         OptimizerInput(
-            params=None, kwargs={"rho": 0.95, "weight_decay": 0.1}, desc="rho"
-        ),  # TODO: Move out to testing in param_group?
+            params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
+        ),
     ]
 
 
@@ -494,6 +492,7 @@
         OptimizerInput(params=None, kwargs={}, desc="default"),
         OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
         OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
+        OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
         OptimizerInput(
             params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
         ),
@@ -545,6 +544,21 @@
 def optim_inputs_func_nadam(device=None):
     cuda_supported_configs = [
         OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+        OptimizerInput(
+            params=None,
+            kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
+            desc="weight_decay, capturable",
+        ),
+        OptimizerInput(
+            params=None,
+            kwargs={
+                "weight_decay": 0.9,
+                "momentum_decay": 6e-3,
+                "decoupled_weight_decay": True,
+                "capturable": True,
+            },
+            desc="decoupled_weight_decay, capturable",
+        ),
     ]
     return [
         OptimizerInput(params=None, kwargs={}, desc="default"),
@@ -1107,6 +1121,11 @@
                 "test_forloop_goes_right_direction_multigpu",
             ),
             DecorateInfo(
+                skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"),
+                "TestOptimRenewed",
+                "test_complex",
+            ),
+            DecorateInfo(
                 skipIfTorchDynamo(
                     "See https://github.com/pytorch/pytorch/issues/115607"
                 ),
@@ -1319,6 +1338,11 @@
                 "TestOptimRenewed",
                 "test_forloop_goes_right_direction_multigpu",
             ),
+            DecorateInfo(
+                unittest.skip("Missing complex support, see #118148"),
+                "TestOptimRenewed",
+                "test_complex",
+            ),
         ),
     ),
     OptimizerInfo(
@@ -1747,6 +1771,11 @@
                 "TestOptimRenewed",
                 "test_deepcopy_copies_all_public_attrs",
             ),
+            DecorateInfo(
+                unittest.skip("Missing complex support, see #118153"),
+                "TestOptimRenewed",
+                "test_complex",
+            ),
         ),
     ),
 ]