Add OneCycleLR (#25324)

Summary:
Squash rebase of https://github.com/pytorch/pytorch/issues/21258

ghstack-source-id: 7d3ce522ac4dd3050bc6c6bbda1eaaeb8bc4b2c1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25324
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25325

Differential Revision: D17095722

Pulled By: vincentqb

fbshipit-source-id: 7fe69b210924ee3b39223dd78122aea61267234a
diff --git a/docs/source/optim.rst b/docs/source/optim.rst
index d6d89c9..7ca7725 100644
--- a/docs/source/optim.rst
+++ b/docs/source/optim.rst
@@ -167,3 +167,5 @@
     :members:
 .. autoclass:: torch.optim.lr_scheduler.CyclicLR
     :members:
+.. autoclass:: torch.optim.lr_scheduler.OneCycleLR
+    :members:
diff --git a/test/test_optim.py b/test/test_optim.py
index e334962..f66e53d 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -13,7 +13,7 @@
 from torch import sparse
 from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
     ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
-    CyclicLR, CosineAnnealingWarmRestarts
+    CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR
 from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
     skipIfRocm
 
@@ -1013,6 +1013,58 @@
             adam_opt = optim.Adam(self.net.parameters())
             scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)
 
+    def test_onecycle_lr_invalid_anneal_strategy(self):
+        with self.assertRaises(ValueError):
+            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS")
+
+    def test_onecycle_lr_invalid_pct_start(self):
+        with self.assertRaises(ValueError):
+            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)
+
+    def test_onecycle_lr_cannot_calculate_total_steps(self):
+        with self.assertRaises(ValueError):
+            scheduler = OneCycleLR(self.opt, max_lr=1e-3)
+
+    def test_onecycle_lr_linear_annealing(self):
+        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
+        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
+        lr_targets = [lr_target, lr_target]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
+                               total_steps=10, anneal_strategy='linear')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
+
+    def test_onecycle_lr_cosine_annealing(self):
+        def annealing_cos(start, end, pct):
+            cos_out = math.cos(math.pi * pct) + 1
+            return end + (start - end) / 2.0 * cos_out
+        lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0),
+                     annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0),
+                     annealing_cos(25, 0.5, 6 / 7.0), 0.5]
+        momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0),
+                           annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0),
+                           annealing_cos(1, 22, 6 / 7.0), 22]
+        lr_targets = [lr_target, lr_target]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
+                               total_steps=10)
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)
+
+    def test_cycle_lr_with_adam(self):
+        old_opt = self.opt
+        self.opt = optim.Adam(
+            [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
+            lr=0.05)
+
+        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
+        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
+        lr_targets = [lr_target, lr_target]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
+                               total_steps=10, anneal_strategy='linear')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
+        self.opt = old_opt  # set optimizer back to SGD
+
     def test_lambda_lr(self):
         epochs = 10
         self.opt.param_groups[0]['lr'] = 0.05
@@ -1206,13 +1258,16 @@
                                        msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                            epoch, target[epoch], param_group['lr']), delta=1e-5)
 
-    def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False):
+    def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False):
         for batch_num in range(batch_iterations):
             scheduler.step(batch_num)
             if verbose:
                 if 'momentum' in self.opt.param_groups[0].keys():
                     print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'],
                                                                self.opt.param_groups[0]['momentum']))
+                elif use_beta1 and 'betas' in self.opt.param_groups[0].keys():
+                    print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'],
+                                                            self.opt.param_groups[0]['betas'][0]))
                 else:
                     print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr']))
 
@@ -1222,7 +1277,12 @@
                     msg='LR is wrong in batch_num {}: expected {}, got {}'.format(
                         batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5)
 
-                if 'momentum' in param_group.keys():
+                if use_beta1 and 'betas' in param_group.keys():
+                    self.assertAlmostEqual(
+                        momentum_target[batch_num], param_group['betas'][0],
+                        msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format(
+                            batch_num, momentum_target[batch_num], param_group['betas'][0]), delta=1e-5)
+                elif 'momentum' in param_group.keys():
                     self.assertAlmostEqual(
                         momentum_target[batch_num], param_group['momentum'],
                         msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format(
diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py
index 6101e8c..ea18caa 100644
--- a/torch/optim/lr_scheduler.py
+++ b/torch/optim/lr_scheduler.py
@@ -759,3 +759,223 @@
         self.last_epoch = math.floor(epoch)
         for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
             param_group['lr'] = lr
+
+class OneCycleLR(_LRScheduler):
+    r"""Sets the learning rate of each parameter group according to the
+    1cycle learning rate policy. The 1cycle policy anneals the learning
+    rate from an initial learning rate to some maximum learning rate and then
+    from that maximum learning rate to some minimum learning rate much lower
+    than the initial learning rate.
+    This policy was initially described in the paper `Super-Convergence:
+    Very Fast Training of Neural Networks Using Large Learning Rates`_.
+
+    The 1cycle learning rate policy changes the learning rate after every batch.
+    `step` should be called after a batch has been used for training.
+
+    This scheduler is not chainable.
+
+    This class has two built-in annealing strategies:
+    "cos":
+        Cosine annealing
+    "linear":
+        Linear annealing
+
+    Note also that the total number of steps in the cycle can be determined in one
+    of two ways (listed in order of precedence):
+    1) A value for total_steps is explicitly provided.
+    2) A number of epochs (epochs) and a number of steps per epoch
+       (steps_per_epoch) are provided.
+       In this case, the number of total steps is inferred by
+       total_steps = epochs * steps_per_epoch
+    You must either provide a value for total_steps or provide a value for both
+    epochs and steps_per_epoch.
+
+    Args:
+        optimizer (Optimizer): Wrapped optimizer.
+        max_lr (float or list): Upper learning rate boundaries in the cycle
+            for each parameter group.
+        total_steps (int): The total number of steps in the cycle. Note that
+            if a value is provided here, then it must be inferred by providing
+            a value for epochs and steps_per_epoch.
+            Default: None
+        epochs (int): The number of epochs to train for. This is used along
+            with steps_per_epoch in order to infer the total number of steps in the cycle
+            if a value for total_steps is not provided.
+            Default: None
+        steps_per_epoch (int): The number of steps per epoch to train for. This is
+            used along with epochs in order to infer the total number of steps in the
+            cycle if a value for total_steps is not provided.
+            Default: None
+        pct_start (float): The percentage of the cycle (in number of steps) spent
+            increasing the learning rate.
+            Default: 0.3
+        anneal_strategy (str): {'cos', 'linear'}
+            Specifies the annealing strategy.
+            Default: 'cos'
+        cycle_momentum (bool): If ``True``, momentum is cycled inversely
+            to learning rate between 'base_momentum' and 'max_momentum'.
+            Default: True
+        base_momentum (float or list): Lower momentum boundaries in the cycle
+            for each parameter group. Note that momentum is cycled inversely
+            to learning rate; at the peak of a cycle, momentum is
+            'base_momentum' and learning rate is 'max_lr'.
+            Default: 0.85
+        max_momentum (float or list): Upper momentum boundaries in the cycle
+            for each parameter group. Functionally,
+            it defines the cycle amplitude (max_momentum - base_momentum).
+            Note that momentum is cycled inversely
+            to learning rate; at the start of a cycle, momentum is 'max_momentum'
+            and learning rate is 'base_lr'
+            Default: 0.95
+        div_factor (float): Determines the initial learning rate via
+            initial_lr = max_lr/div_factor
+            Default: 25
+        final_div_factor (float): Determines the minimum learning rate via
+            min_lr = initial_lr/final_div_factor
+            Default: 1e4
+        last_epoch (int): The index of the last batch. This parameter is used when
+            resuming a training job. Since `step()` should be invoked after each
+            batch instead of after each epoch, this number represents the total
+            number of *batches* computed, not the total number of epochs computed.
+            When last_epoch=-1, the schedule is started from the beginning.
+            Default: -1
+
+    Example:
+        >>> data_loader = torch.utils.data.DataLoader(...)
+        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+        >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
+        >>> for epoch in range(10):
+        >>>     for batch in data_loader:
+        >>>         train_batch(...)
+        >>>         scheduler.step()
+
+
+    .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
+        https://arxiv.org/abs/1708.07120
+    """
+    def __init__(self,
+                 optimizer,
+                 max_lr,
+                 total_steps=None,
+                 epochs=None,
+                 steps_per_epoch=None,
+                 pct_start=0.3,
+                 anneal_strategy='cos',
+                 cycle_momentum=True,
+                 base_momentum=0.85,
+                 max_momentum=0.95,
+                 div_factor=25.,
+                 final_div_factor=1e4,
+                 last_epoch=-1):
+
+        # Validate optimizer
+        if not isinstance(optimizer, Optimizer):
+            raise TypeError('{} is not an Optimizer'.format(
+                type(optimizer).__name__))
+        self.optimizer = optimizer
+
+        # Validate total_steps
+        if total_steps is None and epochs is None and steps_per_epoch is None:
+            raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
+        elif total_steps is not None:
+            if total_steps <= 0 or not isinstance(total_steps, int):
+                raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps))
+            self.total_steps = total_steps
+        else:
+            if epochs <= 0 or not isinstance(epochs, int):
+                raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs))
+            if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
+                raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch))
+            self.total_steps = epochs * steps_per_epoch
+        self.step_size_up = float(pct_start * self.total_steps) - 1
+        self.step_size_down = float(self.total_steps - self.step_size_up) - 1
+
+        # Validate pct_start
+        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+            raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))
+
+        # Validate anneal_strategy
+        if anneal_strategy not in ['cos', 'linear']:
+            raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
+        elif anneal_strategy == 'cos':
+            self.anneal_func = self._annealing_cos
+        elif anneal_strategy == 'linear':
+            self.anneal_func = self._annealing_linear
+
+        # Initialize learning rate variables
+        max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
+        if last_epoch == -1:
+            for idx, group in enumerate(self.optimizer.param_groups):
+                group['lr'] = max_lrs[idx] / div_factor
+                group['max_lr'] = max_lrs[idx]
+                group['min_lr'] = group['lr'] / final_div_factor
+
+        # Initialize momentum variables
+        self.cycle_momentum = cycle_momentum
+        if self.cycle_momentum:
+            if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
+                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+            self.use_beta1 = 'betas' in self.optimizer.defaults
+            max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
+            if last_epoch == -1:
+                for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
+                    if self.use_beta1:
+                        _, beta2 = group['betas']
+                        group['betas'] = (m_momentum, beta2)
+                    else:
+                        group['momentum'] = m_momentum
+                    group['max_momentum'] = m_momentum
+                    group['base_momentum'] = b_momentum
+
+        super(OneCycleLR, self).__init__(optimizer, last_epoch)
+
+    def _format_param(self, name, optimizer, param):
+        """Return correctly formatted lr/momentum for each param group."""
+        if isinstance(param, (list, tuple)):
+            if len(param) != len(optimizer.param_groups):
+                raise ValueError("expected {} values for {}, got {}".format(
+                    len(optimizer.param_groups), name, len(param)))
+            return param
+        else:
+            return [param] * len(optimizer.param_groups)
+
+    def _annealing_cos(self, start, end, pct):
+        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+        cos_out = math.cos(math.pi * pct) + 1
+        return end + (start - end) / 2.0 * cos_out
+
+    def _annealing_linear(self, start, end, pct):
+        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+        return (end - start) * pct + start
+
+    def get_lr(self):
+        lrs = []
+        step_num = self.last_epoch
+
+        if step_num > self.total_steps:
+            raise ValueError("Tried to step {} times. The specified number of total steps is {}"
+                             .format(step_num + 1, self.total_steps))
+
+        for group in self.optimizer.param_groups:
+            if step_num <= self.step_size_up:
+                computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up)
+                if self.cycle_momentum:
+                    computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'],
+                                                         step_num / self.step_size_up)
+            else:
+                down_step_num = step_num - self.step_size_up
+                computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down)
+                if self.cycle_momentum:
+                    computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'],
+                                                         down_step_num / self.step_size_down)
+
+            lrs.append(computed_lr)
+            if self.cycle_momentum:
+                if self.use_beta1:
+                    _, beta2 = group['betas']
+                    group['betas'] = (computed_momentum, beta2)
+                else:
+                    group['momentum'] = computed_momentum
+
+        return lrs