Adds a `maximize` flag to Adam (#68164)
Summary:
Solves the next most important use case in https://github.com/pytorch/pytorch/issues/68052.
I have kept the style as close to that in SGD as seemed reasonable, given the slight differences in their internal implementations.
All feedback welcome!
cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68164
Reviewed By: VitalyFedyunin
Differential Revision: D32994129
Pulled By: albanD
fbshipit-source-id: 65c57c3f3dbbd3e3e5338d51def54482503e8850
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index 73b6646..9a2f3cb 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -689,10 +689,7 @@
sharded_optimizer.load_state_dict(sharded_optim_state_dict)
check_step()
- for opt in [torch.optim.Adam]:
- check_optimizer_equivalence(opt, maximize=False)
-
- for opt in [torch.optim.SGD]:
+ for opt in [torch.optim.Adam, torch.optim.SGD]:
for maximize in (True, False):
check_optimizer_equivalence(opt, maximize=maximize)
diff --git a/test/test_optim.py b/test/test_optim.py
index 1ba7db8..b515aea 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -480,58 +480,68 @@
def test_adam(self):
for optimizer in [optim.Adam, optim_mt.Adam]:
self._test_basic_cases(
- lambda weight, bias: optimizer([weight, bias], lr=1e-3)
+ lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, maximize=maximize),
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer(
+ lambda weight, bias, maximize: optimizer(
+ self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize),
+ constructor_accepts_maximize=True
+ )
+ self._test_basic_cases(
+ lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
+ constructor_accepts_maximize=True
+ )
+ self._test_basic_cases(
+ lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, weight_decay=0.1, maximize=maximize),
+ constructor_accepts_maximize=True
+ )
+ self._test_basic_cases(
+ lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
- lr=1e-3)
+ lr=1e-3, amsgrad=True, maximize=maximize),
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True)
- )
- self._test_basic_cases(
- lambda weight, bias: optimizer([weight, bias], lr=1e-3, weight_decay=0.1)
- )
- self._test_basic_cases(
- lambda weight, bias: optimizer(
+ lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
- lr=1e-3, amsgrad=True)
+ lr=1e-3, maximize=maximize),
+ [lambda opt: ExponentialLR(opt, gamma=0.9)],
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer(
+ lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
- lr=1e-3),
- [lambda opt: ExponentialLR(opt, gamma=0.9)]
+ lr=1e-3, maximize=maximize),
+ [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer(
+ lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
- lr=1e-3),
- [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)]
+ lr=1e-3, maximize=maximize),
+ [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer(
- self._build_params_dict(weight, bias, lr=1e-2),
- lr=1e-3),
- [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)]
- )
- self._test_basic_cases(
- lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
+ lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
- lambda opt: ExponentialLR(opt, gamma=0.9)]
+ lambda opt: ExponentialLR(opt, gamma=0.9)],
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
+ lambda weight, bias, maximize: optimizer([weight, bias], lr=1e-3, amsgrad=True, maximize=maximize),
[lambda opt: ExponentialLR(opt, gamma=0.9),
- lambda opt: ReduceLROnPlateau(opt)]
+ lambda opt: ReduceLROnPlateau(opt)],
+ constructor_accepts_maximize=True
)
self._test_basic_cases(
- lambda weight, bias: optimizer(
+ lambda weight, bias, maximize: optimizer(
self._build_params_dict(weight, bias, lr=1e-2),
- lr=1e-3, amsgrad=True),
+ lr=1e-3, amsgrad=True, maximize=maximize),
[lambda opt: StepLR(opt, gamma=0.9, step_size=10),
- lambda opt: ReduceLROnPlateau(opt)]
+ lambda opt: ReduceLROnPlateau(opt)],
+ constructor_accepts_maximize=True
)
with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
optimizer(None, lr=1e-2, betas=(1.0, 0.0))
diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py
index 32f30a7..2c4b5f7 100644
--- a/torch/distributed/optim/functional_adam.py
+++ b/torch/distributed/optim/functional_adam.py
@@ -23,6 +23,7 @@
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
+ maximize: bool = False,
_allow_empty_param_list: bool = False,
):
if not 0.0 <= lr:
@@ -44,6 +45,7 @@
"weight_decay": weight_decay,
}
self.amsgrad = amsgrad
+ self.maximize = maximize
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
if len(params) == 0 and not _allow_empty_param_list:
@@ -96,6 +98,7 @@
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
+ maximize=self.maximize,
beta1=self.defaults['beta1'],
beta2=self.defaults['beta2'],
lr=self.defaults['lr'],
@@ -156,6 +159,7 @@
max_exp_avg_sqs,
state_steps,
amsgrad=self.amsgrad,
+ maximize=self.maximize,
beta1=self.defaults['beta1'],
beta2=self.defaults['beta2'],
lr=self.defaults['lr'],
diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py
index 10ad321..78daf41 100644
--- a/torch/optim/_functional.py
+++ b/torch/optim/_functional.py
@@ -73,7 +73,8 @@
beta2: float,
lr: float,
weight_decay: float,
- eps: float):
+ eps: float,
+ maximize: bool):
r"""Functional API that performs Adam algorithm computation.
See :class:`~torch.optim.Adam` for details.
@@ -81,7 +82,7 @@
for i, param in enumerate(params):
- grad = grads[i]
+ grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
@@ -103,11 +104,11 @@
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+
+
step_size = lr / bias_correction1
-
param.addcdiv_(exp_avg, denom, value=-step_size)
-
def adamw(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
diff --git a/torch/optim/_multi_tensor/adam.py b/torch/optim/_multi_tensor/adam.py
index c4a111f..feab86c 100644
--- a/torch/optim/_multi_tensor/adam.py
+++ b/torch/optim/_multi_tensor/adam.py
@@ -32,7 +32,7 @@
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False):
+ weight_decay=0, amsgrad=False, *, maximize: bool = False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -44,7 +44,7 @@
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay, amsgrad=amsgrad)
+ weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
@@ -75,6 +75,7 @@
max_exp_avg_sq = []
params_with_grad = []
+
for p in group['params']:
if p.grad is not None:
if p.grad.is_sparse:
@@ -82,6 +83,9 @@
params_with_grad.append(p)
grads.append(p.grad)
+ if group['maximize']:
+ grads = torch._foreach_neg(tuple(grads))
+
for p in params_with_grad:
state = self.state[p]
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index ea2ceaf..f5bdd7a 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -11,12 +11,17 @@
&\rule{110mm}{0.4pt} \\
&\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
\text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
- &\hspace{13mm} \lambda \text{ (weight decay)}, \: amsgrad \\
+ &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},\: \\
+ \textit{maximize} \\
&\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
&\rule{110mm}{0.4pt} \\
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
- &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
+
+ &\hspace{5mm} /textbf{if} \: \textit{maximize}: \\
+ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
+ &\hspace{5mm} /textbf{else} \\
+ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
&\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
@@ -50,6 +55,8 @@
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
+ maximize (bool, optional): maximize the params based on the objective, instead of
+ minimizing (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
@@ -58,7 +65,7 @@
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False):
+ weight_decay=0, amsgrad=False, *, maximize: bool = False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -70,7 +77,7 @@
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay, amsgrad=amsgrad)
+ weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
@@ -141,5 +148,6 @@
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
- eps=group['eps'])
+ eps=group['eps'],
+ maximize=group['maximize'])
return loss