Handle complex optimization in Adamax by treating complex numbers as 2D real numbers (#80319)
This commit partially addresses #65711
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80319
Approved by: https://github.com/albanD
diff --git a/test/test_optim.py b/test/test_optim.py
index 2c5c107..a48a087 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -794,6 +794,7 @@
[weight, bias], lr=1e-1, weight_decay=1, maximize=maximize),
constructor_accepts_maximize=True
)
+ self._test_complex_2d(optimizer)
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"):
optimizer(None, lr=1e-2, betas=(0.0, 1.0))
diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py
index 8fe74bf..bb45764 100644
--- a/torch/optim/adamax.py
+++ b/torch/optim/adamax.py
@@ -215,6 +215,12 @@
if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)
+ if torch.is_complex(param):
+ param = torch.view_as_real(param)
+ grad = torch.view_as_real(grad)
+ exp_avg = torch.view_as_real(exp_avg)
+ exp_inf = torch.view_as_real(exp_inf)
+
# Update biased first moment estimate.
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Update the exponentially weighted infinity norm.
@@ -249,6 +255,11 @@
if maximize:
grads = torch._foreach_neg(grads)
+ params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
+ grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
+ exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
+ exp_infs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_infs]
+
# Update steps
torch._foreach_add_(state_steps, 1)