Add beta1 support to CyclicLR momentum (#113548)
Fixes #73910
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113548
Approved by: https://github.com/janeyx99
diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py
index 20c5028..02199e4 100644
--- a/test/optim/test_lrscheduler.py
+++ b/test/optim/test_lrscheduler.py
@@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F
from torch.nn import Parameter
-from torch.optim import Adam, SGD
+from torch.optim import Adam, SGD, Rprop
from torch.optim.lr_scheduler import (
LambdaLR,
MultiplicativeLR,
@@ -1510,8 +1510,12 @@
def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
with self.assertRaises(ValueError):
- adam_opt = Adam(self.net.parameters())
- scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
+ rprop_opt = Rprop(self.net.parameters())
+ scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)
+
+ def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
+ adam_opt = Adam(self.net.parameters())
+ scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
def test_cycle_lr_removed_after_out_of_scope(self):
import gc
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index df659b6..a681142 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -1268,15 +1268,20 @@
self.cycle_momentum = cycle_momentum
if cycle_momentum:
- if 'momentum' not in optimizer.defaults:
- raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+ if 'momentum' not in optimizer.defaults and 'betas' not in optimizer.defaults:
+ raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
- base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
- if last_epoch == -1:
- for momentum, group in zip(base_momentums, optimizer.param_groups):
- group['momentum'] = momentum
- self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
+ self.use_beta1 = 'betas' in self.optimizer.defaults
+ self.base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+ if last_epoch == -1:
+ for m_momentum, b_momentum, group in zip(self.max_momentums, self.base_momentums, optimizer.param_groups):
+ if self.use_beta1:
+ group['betas'] = (m_momentum, *group['betas'][1:])
+ else:
+ group['momentum'] = m_momentum
+ group['max_momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
super().__init__(optimizer, last_epoch, verbose)
self.base_lrs = base_lrs
@@ -1359,7 +1364,10 @@
momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
momentums.append(momentum)
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
- param_group['momentum'] = momentum
+ if self.use_beta1:
+ param_group['betas'] = (momentum, *param_group['betas'][1:])
+ else:
+ param_group['momentum'] = momentum
return lrs
@@ -1721,7 +1729,7 @@
self.cycle_momentum = cycle_momentum
if self.cycle_momentum:
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
- raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+ raise ValueError('optimizer must support momentum or beta1 with `cycle_momentum` option enabled')
self.use_beta1 = 'betas' in self.optimizer.defaults
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)