Migrate test_complex_optimizer to OptimizerInfo (#118160)
This PR does what it says and more.
1. We increase coverage by a LOT! Previously, complex was not tested for many many configs, including foreach + maximize at the same time. Or the fused impls. Or just random configs people forgot about.
2. I rearranged the maximize conditional and the _view_as_real to preserve list-ness. This is needed for _view_as_real to function properly, I did add a comment in the Files Changed. This new order also just...makes more aesthetic sense.
3. Note that LBFGS and SparseAdam are skipped--they don't support complex and now we know.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118160
Approved by: https://github.com/mikaylagawarecki
diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py
index 3a1ac3d..744b3d0 100644
--- a/test/optim/test_optim.py
+++ b/test/optim/test_optim.py
@@ -270,19 +270,6 @@
constructor_accepts_foreach,
)
- def _test_complex_optimizer(self, optimizer_constructor):
- complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
- real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
- complex_opt = optimizer_constructor(complex_param)
- real_opt = optimizer_constructor(real_param)
-
- for _ in range(3):
- complex_param.grad = torch.randn_like(complex_param)
- real_param.grad = torch.view_as_real(complex_param.grad)
- complex_opt.step()
- real_opt.step()
-
- self.assertEqual(torch.view_as_real(complex_param), real_param)
def _test_complex_2d(self, optimizer_constructor):
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
@@ -398,40 +385,6 @@
multi_tensor=foreach,
)
- def test_sgd_complex(self):
- for foreach in (False, True):
- self._test_complex_optimizer(
- lambda param: SGD([param], lr=0.001, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: SGD([param], lr=0.001, momentum=1, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: SGD(
- [param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
- )
- )
- self._test_complex_optimizer(
- lambda param: SGD(
- [param],
- lr=0.001,
- nesterov=True,
- momentum=1,
- weight_decay=1,
- foreach=foreach,
- )
- )
- self._test_complex_optimizer(
- lambda param: SGD(
- [param],
- lr=0.001,
- momentum=1,
- dampening=0.5,
- weight_decay=1,
- foreach=foreach,
- )
- )
-
def test_adam(self):
self._test_basic_cases(
@@ -603,15 +556,6 @@
)
- def test_adadelta_complex(self):
- # Handles https://github.com/pytorch/pytorch/issues/110606
- self.rel_tol = 2e-2
- for foreach in (False, True):
- self._test_complex_optimizer(lambda weight: Adadelta([weight], foreach=foreach))
- self._test_complex_optimizer(lambda weight: Adadelta([weight], rho=0.95, foreach=foreach))
- self._test_complex_optimizer(
- lambda weight: Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach)
- )
def test_nadam(self):
self._test_basic_cases(
@@ -640,28 +584,6 @@
)
- def test_nadam_complex(self):
- for foreach in (False, True):
- self._test_complex_optimizer(
- lambda param: NAdam([param], lr=1e-1, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: NAdam(
- [param],
- lr=1e-1,
- weight_decay=0.01,
- foreach=foreach,
- )
- )
- self._test_complex_optimizer(
- lambda param: NAdam(
- [param],
- lr=1e-1,
- momentum_decay=0.01,
- foreach=foreach,
- )
- )
-
def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
@@ -705,19 +627,6 @@
multi_tensor=foreach,
)
- def test_adagrad_complex(self):
- for foreach in (False, True):
- self._test_complex_optimizer(
- lambda param: Adagrad([param], lr=1e-1, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: Adagrad(
- [param],
- lr=1e-1,
- initial_accumulator_value=0.1,
- foreach=foreach,
- )
- )
def test_adamax(self):
self._test_complex_2d(Adamax)
@@ -748,29 +657,6 @@
)
- def test_radam_complex(self):
- for foreach in (False, True):
- self._test_complex_optimizer(
- lambda param: RAdam([param], lr=1e-1, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: RAdam(
- [param],
- lr=1e-1,
- weight_decay=0.01,
- foreach=foreach,
- )
- )
- self._test_complex_optimizer(
- lambda param: RAdam(
- [param],
- lr=1e-1,
- weight_decay=0.01,
- decoupled_weight_decay=True,
- foreach=foreach,
- )
- )
-
def test_rmsprop(self):
for foreach in (False, True):
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
@@ -783,40 +669,6 @@
self._test_complex_2d(
lambda param: RMSprop(param, maximize=True, foreach=foreach)
)
- self._test_complex_optimizer(
- lambda param: RMSprop([param], foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: RMSprop([param], centered=True, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: RMSprop([param], momentum=0.1, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: RMSprop([param], maximize=True, foreach=foreach)
- )
-
-
- def test_asgd(self):
- for foreach in (False, True):
- # Ref: https://github.com/pytorch/pytorch/issues/84560
- # self._test_complex_2d(optimizer)
- self._test_complex_optimizer(
- lambda params: ASGD([params], foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda params: ASGD([params], maximize=True, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda params: ASGD(
- [params], maximize=True, weight_decay=0.1, foreach=foreach
- )
- )
- self._test_complex_optimizer(
- lambda params: ASGD(
- [params], maximize=False, weight_decay=0.1, foreach=foreach
- )
- )
@skipIfRocm
@@ -824,14 +676,6 @@
def test_rprop(self):
for foreach in (False, True):
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
- self._test_complex_optimizer(
- lambda param: Rprop([param], lr=0.001, foreach=foreach)
- )
- self._test_complex_optimizer(
- lambda param: Rprop(
- [param], lr=0.001, maximize=True, foreach=foreach
- )
- )
def test_lbfgs_returns_consistent_type(self):
diff --git a/test/test_optim.py b/test/test_optim.py
index ea1beb8..73c0898 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -128,6 +128,30 @@
self.assertLess(closure().item(), initial_value)
+ @skipMPS
+ @optims(optim_db, dtypes=[torch.complex64])
+ def test_complex(self, device, dtype, optim_info):
+ optim_cls = optim_info.optim_cls
+ # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
+ all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
+ for optim_input in all_optim_inputs:
+ complex_params = [torch.randn(2, 3, device=device, dtype=dtype, requires_grad=True) for _ in range(3)]
+ real_params = [torch.view_as_real(p).detach().clone().requires_grad_(True) for p in complex_params]
+
+ complex_optimizer = optim_cls(complex_params, **optim_input.kwargs)
+ real_optimizer = optim_cls(real_params, **optim_input.kwargs)
+
+ for _ in range(3):
+ for (c, r) in zip(complex_params, real_params):
+ c.grad = torch.randn_like(c)
+ r.grad = torch.view_as_real(c.grad)
+ complex_optimizer.step()
+ real_optimizer.step()
+
+ for (c, r) in zip(complex_params, real_params):
+ self.assertEqual(torch.view_as_real(c), r)
+
+
def _test_derived_optimizers(self, device, dtype, optim_info, flag, reduced_precision=False, assert_step_dtype=None):
"""
Given a flag 'fused' or 'foreach', test for parity of optimizer state
diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py
index 1810082..ac16c13 100644
--- a/torch/optim/adadelta.py
+++ b/torch/optim/adadelta.py
@@ -286,12 +286,12 @@
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, acc_deltas])
for ((device_params, device_grads, device_square_avgs, device_acc_deltas), _) in grouped_tensors.values():
- if maximize:
- device_grads = torch._foreach_neg(device_grads)
-
if has_complex:
_view_as_real(device_params, device_grads, device_square_avgs, device_acc_deltas)
+ if maximize:
+ device_grads = torch._foreach_neg(device_grads)
+
if weight_decay != 0:
# Re-use the intermediate memory (device_grads) already allocated for maximize
if maximize:
diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py
index 3d5ef79..ce17e2d 100644
--- a/torch/optim/adagrad.py
+++ b/torch/optim/adagrad.py
@@ -344,13 +344,13 @@
)
continue
- if maximize:
- device_grads = torch._foreach_neg(device_grads)
-
# Handle complex parameters
if has_complex:
_view_as_real(device_params, device_grads, device_state_sums)
+ if maximize:
+ device_grads = torch._foreach_neg(device_grads)
+
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index ee0223f..15abef4 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -489,9 +489,6 @@
device_state_steps,
), _) in grouped_tensors.values():
- if maximize:
- device_grads = torch._foreach_neg(device_grads)
-
# Handle complex parameters
if has_complex:
if amsgrad:
@@ -499,6 +496,9 @@
else:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
+ if maximize:
+ device_grads = torch._foreach_neg(device_grads)
+
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index d914110..8f0ee6c 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -330,12 +330,12 @@
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_infs, state_steps])
for ((grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs, grouped_state_steps), _) in grouped_tensors.values():
- if maximize:
- grouped_grads = torch._foreach_neg(grouped_grads)
-
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs)
+ if maximize:
+ grouped_grads = torch._foreach_neg(grouped_grads)
+
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py
index 071e879..bbf5fe7 100644
--- a/torch/optim/adamw.py
+++ b/torch/optim/adamw.py
@@ -521,15 +521,15 @@
device_max_exp_avg_sqs,
device_state_steps,
), _) in grouped_tensors.values():
- if maximize:
- device_grads = torch._foreach_neg(device_grads)
-
if has_complex:
if amsgrad:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs)
else:
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
+ if maximize:
+ device_grads = torch._foreach_neg(device_grads)
+
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py
index 5465d84..c65411a 100644
--- a/torch/optim/asgd.py
+++ b/torch/optim/asgd.py
@@ -313,13 +313,12 @@
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
for ((device, _), ((grouped_params, grouped_grads, grouped_axs, grouped_mus,
grouped_etas, grouped_state_steps), _)) in grouped_tensors.items():
- if maximize:
- grouped_grads = torch._foreach_neg(grouped_grads)
-
- grouped_grads = list(grouped_grads)
if has_complex:
_view_as_real(grouped_params, grouped_grads, grouped_axs)
+ if maximize:
+ grouped_grads = torch._foreach_neg(grouped_grads)
+
# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py
index bf7e0f7..62d28ae 100644
--- a/torch/optim/rmsprop.py
+++ b/torch/optim/rmsprop.py
@@ -336,6 +336,14 @@
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, square_avgs, grad_avgs, momentum_buffer_list])
for (((grouped_params, grouped_grads, grouped_square_avgs, grouped_grad_avgs,
grouped_momentum_buffer_list)), _) in grouped_tensors.values():
+ if has_complex:
+ state_and_grads = [grouped_grads, grouped_square_avgs]
+ if momentum > 0:
+ state_and_grads.append(grouped_momentum_buffer_list)
+ if centered:
+ state_and_grads.append(grouped_grad_avgs)
+ _view_as_real(grouped_params, *state_and_grads)
+
if maximize:
grouped_grads = torch._foreach_neg(grouped_grads)
@@ -346,16 +354,6 @@
else:
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
- grouped_grads = list(grouped_grads)
-
- if has_complex:
- state_and_grads = [grouped_grads, grouped_square_avgs]
- if momentum > 0:
- state_and_grads.append(grouped_momentum_buffer_list)
- if centered:
- state_and_grads.append(grouped_grad_avgs)
- _view_as_real(grouped_params, *state_and_grads)
-
torch._foreach_mul_(grouped_square_avgs, alpha)
torch._foreach_addcmul_(grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha)
diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py
index 3534172..7b53b8c 100644
--- a/torch/testing/_internal/common_optimizers.py
+++ b/torch/testing/_internal/common_optimizers.py
@@ -263,9 +263,7 @@
def optim_inputs_func_adadelta(device=None):
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
- OptimizerInput(
- params=None, kwargs={"lr": 0.01}, desc="non-default lr"
- ), # TODO: Move out to testing in param_group?
+ OptimizerInput(params=None, kwargs={"lr": 0.01}, desc="non-default lr"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
@@ -275,8 +273,8 @@
desc="maximize",
),
OptimizerInput(
- params=None, kwargs={"rho": 0.95, "weight_decay": 0.1}, desc="rho"
- ), # TODO: Move out to testing in param_group?
+ params=None, kwargs={"rho": 0.95, "weight_decay": 0.9}, desc="rho"
+ ),
]
@@ -494,6 +492,7 @@
OptimizerInput(params=None, kwargs={}, desc="default"),
OptimizerInput(params=None, kwargs={"lr": 0.02}, desc="non-default lr"),
OptimizerInput(params=None, kwargs={"t0": 100}, desc="t0"),
+ OptimizerInput(params=None, kwargs={"maximize": True}, desc="maximize"),
OptimizerInput(
params=None, kwargs={"weight_decay": 0.1}, desc="nonzero weight_decay"
),
@@ -545,6 +544,21 @@
def optim_inputs_func_nadam(device=None):
cuda_supported_configs = [
OptimizerInput(params=None, kwargs={"capturable": True}, desc="capturable"),
+ OptimizerInput(
+ params=None,
+ kwargs={"weight_decay": 0.9, "momentum_decay": 6e-3, "capturable": True},
+ desc="weight_decay, capturable",
+ ),
+ OptimizerInput(
+ params=None,
+ kwargs={
+ "weight_decay": 0.9,
+ "momentum_decay": 6e-3,
+ "decoupled_weight_decay": True,
+ "capturable": True,
+ },
+ desc="decoupled_weight_decay, capturable",
+ ),
]
return [
OptimizerInput(params=None, kwargs={}, desc="default"),
@@ -1107,6 +1121,11 @@
"test_forloop_goes_right_direction_multigpu",
),
DecorateInfo(
+ skipIfTorchDynamo("Mismatched _foreach_addcdiv_ types, see #118159"),
+ "TestOptimRenewed",
+ "test_complex",
+ ),
+ DecorateInfo(
skipIfTorchDynamo(
"See https://github.com/pytorch/pytorch/issues/115607"
),
@@ -1319,6 +1338,11 @@
"TestOptimRenewed",
"test_forloop_goes_right_direction_multigpu",
),
+ DecorateInfo(
+ unittest.skip("Missing complex support, see #118148"),
+ "TestOptimRenewed",
+ "test_complex",
+ ),
),
),
OptimizerInfo(
@@ -1747,6 +1771,11 @@
"TestOptimRenewed",
"test_deepcopy_copies_all_public_attrs",
),
+ DecorateInfo(
+ unittest.skip("Missing complex support, see #118153"),
+ "TestOptimRenewed",
+ "test_complex",
+ ),
),
),
]