Add OneCycleLR (#25324)
Summary:
Squash rebase of https://github.com/pytorch/pytorch/issues/21258
ghstack-source-id: 7d3ce522ac4dd3050bc6c6bbda1eaaeb8bc4b2c1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25324
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25325
Differential Revision: D17095722
Pulled By: vincentqb
fbshipit-source-id: 7fe69b210924ee3b39223dd78122aea61267234a
diff --git a/docs/source/optim.rst b/docs/source/optim.rst
index d6d89c9..7ca7725 100644
--- a/docs/source/optim.rst
+++ b/docs/source/optim.rst
@@ -167,3 +167,5 @@
:members:
.. autoclass:: torch.optim.lr_scheduler.CyclicLR
:members:
+.. autoclass:: torch.optim.lr_scheduler.OneCycleLR
+ :members:
diff --git a/test/test_optim.py b/test/test_optim.py
index e334962..f66e53d 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -13,7 +13,7 @@
from torch import sparse
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
- CyclicLR, CosineAnnealingWarmRestarts
+ CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
skipIfRocm
@@ -1013,6 +1013,58 @@
adam_opt = optim.Adam(self.net.parameters())
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
+ def test_onecycle_lr_invalid_anneal_strategy(self):
+ with self.assertRaises(ValueError):
+ scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS")
+
+ def test_onecycle_lr_invalid_pct_start(self):
+ with self.assertRaises(ValueError):
+ scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)
+
+ def test_onecycle_lr_cannot_calculate_total_steps(self):
+ with self.assertRaises(ValueError):
+ scheduler = OneCycleLR(self.opt, max_lr=1e-3)
+
+ def test_onecycle_lr_linear_annealing(self):
+ lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
+ momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
+ lr_targets = [lr_target, lr_target]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
+ total_steps=10, anneal_strategy='linear')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
+
+ def test_onecycle_lr_cosine_annealing(self):
+ def annealing_cos(start, end, pct):
+ cos_out = math.cos(math.pi * pct) + 1
+ return end + (start - end) / 2.0 * cos_out
+ lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0),
+ annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0),
+ annealing_cos(25, 0.5, 6 / 7.0), 0.5]
+ momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0),
+ annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0),
+ annealing_cos(1, 22, 6 / 7.0), 22]
+ lr_targets = [lr_target, lr_target]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
+ total_steps=10)
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
+
+ def test_cycle_lr_with_adam(self):
+ old_opt = self.opt
+ self.opt = optim.Adam(
+ [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
+ lr=0.05)
+
+ lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
+ momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
+ lr_targets = [lr_target, lr_target]
+ momentum_targets = [momentum_target, momentum_target]
+ scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
+ total_steps=10, anneal_strategy='linear')
+ self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
+ self.opt = old_opt # set optimizer back to SGD
+
def test_lambda_lr(self):
epochs = 10
self.opt.param_groups[0]['lr'] = 0.05
@@ -1206,13 +1258,16 @@
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)
- def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False):
+ def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False):
for batch_num in range(batch_iterations):
scheduler.step(batch_num)
if verbose:
if 'momentum' in self.opt.param_groups[0].keys():
print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'],
self.opt.param_groups[0]['momentum']))
+ elif use_beta1 and 'betas' in self.opt.param_groups[0].keys():
+ print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'],
+ self.opt.param_groups[0]['betas'][0]))
else:
print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr']))
@@ -1222,7 +1277,12 @@
msg='LR is wrong in batch_num {}: expected {}, got {}'.format(
batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5)
- if 'momentum' in param_group.keys():
+ if use_beta1 and 'betas' in param_group.keys():
+ self.assertAlmostEqual(
+ momentum_target[batch_num], param_group['betas'][0],
+ msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format(
+ batch_num, momentum_target[batch_num], param_group['betas'][0]), delta=1e-5)
+ elif 'momentum' in param_group.keys():
self.assertAlmostEqual(
momentum_target[batch_num], param_group['momentum'],
msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format(
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index 6101e8c..ea18caa 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -759,3 +759,223 @@
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
+
+class OneCycleLR(_LRScheduler):
+ r"""Sets the learning rate of each parameter group according to the
+ 1cycle learning rate policy. The 1cycle policy anneals the learning
+ rate from an initial learning rate to some maximum learning rate and then
+ from that maximum learning rate to some minimum learning rate much lower
+ than the initial learning rate.
+ This policy was initially described in the paper `Super-Convergence:
+ Very Fast Training of Neural Networks Using Large Learning Rates`_.
+
+ The 1cycle learning rate policy changes the learning rate after every batch.
+ `step` should be called after a batch has been used for training.
+
+ This scheduler is not chainable.
+
+ This class has two built-in annealing strategies:
+ "cos":
+ Cosine annealing
+ "linear":
+ Linear annealing
+
+ Note also that the total number of steps in the cycle can be determined in one
+ of two ways (listed in order of precedence):
+ 1) A value for total_steps is explicitly provided.
+ 2) A number of epochs (epochs) and a number of steps per epoch
+ (steps_per_epoch) are provided.
+ In this case, the number of total steps is inferred by
+ total_steps = epochs * steps_per_epoch
+ You must either provide a value for total_steps or provide a value for both
+ epochs and steps_per_epoch.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group.
+ total_steps (int): The total number of steps in the cycle. Note that
+ if a value is provided here, then it must be inferred by providing
+ a value for epochs and steps_per_epoch.
+ Default: None
+ epochs (int): The number of epochs to train for. This is used along
+ with steps_per_epoch in order to infer the total number of steps in the cycle
+ if a value for total_steps is not provided.
+ Default: None
+ steps_per_epoch (int): The number of steps per epoch to train for. This is
+ used along with epochs in order to infer the total number of steps in the
+ cycle if a value for total_steps is not provided.
+ Default: None
+ pct_start (float): The percentage of the cycle (in number of steps) spent
+ increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy.
+ Default: 'cos'
+ cycle_momentum (bool): If ``True``, momentum is cycled inversely
+ to learning rate between 'base_momentum' and 'max_momentum'.
+ Default: True
+ base_momentum (float or list): Lower momentum boundaries in the cycle
+ for each parameter group. Note that momentum is cycled inversely
+ to learning rate; at the peak of a cycle, momentum is
+ 'base_momentum' and learning rate is 'max_lr'.
+ Default: 0.85
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ Note that momentum is cycled inversely
+ to learning rate; at the start of a cycle, momentum is 'max_momentum'
+ and learning rate is 'base_lr'
+ Default: 0.95
+ div_factor (float): Determines the initial learning rate via
+ initial_lr = max_lr/div_factor
+ Default: 25
+ final_div_factor (float): Determines the minimum learning rate via
+ min_lr = initial_lr/final_div_factor
+ Default: 1e4
+ last_epoch (int): The index of the last batch. This parameter is used when
+ resuming a training job. Since `step()` should be invoked after each
+ batch instead of after each epoch, this number represents the total
+ number of *batches* computed, not the total number of epochs computed.
+ When last_epoch=-1, the schedule is started from the beginning.
+ Default: -1
+
+ Example:
+ >>> data_loader = torch.utils.data.DataLoader(...)
+ >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+ >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
+ >>> for epoch in range(10):
+ >>> for batch in data_loader:
+ >>> train_batch(...)
+ >>> scheduler.step()
+
+
+ .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
+ https://arxiv.org/abs/1708.07120
+ """
+ def __init__(self,
+ optimizer,
+ max_lr,
+ total_steps=None,
+ epochs=None,
+ steps_per_epoch=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ cycle_momentum=True,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ div_factor=25.,
+ final_div_factor=1e4,
+ last_epoch=-1):
+
+ # Validate optimizer
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+
+ # Validate total_steps
+ if total_steps is None and epochs is None and steps_per_epoch is None:
+ raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
+ elif total_steps is not None:
+ if total_steps <= 0 or not isinstance(total_steps, int):
+ raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps))
+ self.total_steps = total_steps
+ else:
+ if epochs <= 0 or not isinstance(epochs, int):
+ raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs))
+ if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
+ raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch))
+ self.total_steps = epochs * steps_per_epoch
+ self.step_size_up = float(pct_start * self.total_steps) - 1
+ self.step_size_down = float(self.total_steps - self.step_size_up) - 1
+
+ # Validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
+
+ # Validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
+ elif anneal_strategy == 'cos':
+ self.anneal_func = self._annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = self._annealing_linear
+
+ # Initialize learning rate variables
+ max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
+ if last_epoch == -1:
+ for idx, group in enumerate(self.optimizer.param_groups):
+ group['lr'] = max_lrs[idx] / div_factor
+ group['max_lr'] = max_lrs[idx]
+ group['min_lr'] = group['lr'] / final_div_factor
+
+ # Initialize momentum variables
+ self.cycle_momentum = cycle_momentum
+ if self.cycle_momentum:
+ if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
+ raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+ self.use_beta1 = 'betas' in self.optimizer.defaults
+ max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+ base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
+ if last_epoch == -1:
+ for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['max_momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+
+ super(OneCycleLR, self).__init__(optimizer, last_epoch)
+
+ def _format_param(self, name, optimizer, param):
+ """Return correctly formatted lr/momentum for each param group."""
+ if isinstance(param, (list, tuple)):
+ if len(param) != len(optimizer.param_groups):
+ raise ValueError("expected {} values for {}, got {}".format(
+ len(optimizer.param_groups), name, len(param)))
+ return param
+ else:
+ return [param] * len(optimizer.param_groups)
+
+ def _annealing_cos(self, start, end, pct):
+ "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ cos_out = math.cos(math.pi * pct) + 1
+ return end + (start - end) / 2.0 * cos_out
+
+ def _annealing_linear(self, start, end, pct):
+ "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ return (end - start) * pct + start
+
+ def get_lr(self):
+ lrs = []
+ step_num = self.last_epoch
+
+ if step_num > self.total_steps:
+ raise ValueError("Tried to step {} times. The specified number of total steps is {}"
+ .format(step_num + 1, self.total_steps))
+
+ for group in self.optimizer.param_groups:
+ if step_num <= self.step_size_up:
+ computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up)
+ if self.cycle_momentum:
+ computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'],
+ step_num / self.step_size_up)
+ else:
+ down_step_num = step_num - self.step_size_up
+ computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down)
+ if self.cycle_momentum:
+ computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'],
+ down_step_num / self.step_size_down)
+
+ lrs.append(computed_lr)
+ if self.cycle_momentum:
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (computed_momentum, beta2)
+ else:
+ group['momentum'] = computed_momentum
+
+ return lrs