| import types |
| import math |
| import torch |
| from torch._six import inf |
| from bisect import bisect_right |
| from functools import partial |
| from .optimizer import Optimizer |
| |
| |
| class _LRScheduler(object): |
| def __init__(self, optimizer, last_epoch=-1): |
| if not isinstance(optimizer, Optimizer): |
| raise TypeError('{} is not an Optimizer'.format( |
| type(optimizer).__name__)) |
| self.optimizer = optimizer |
| if last_epoch == -1: |
| for group in optimizer.param_groups: |
| group.setdefault('initial_lr', group['lr']) |
| else: |
| for i, group in enumerate(optimizer.param_groups): |
| if 'initial_lr' not in group: |
| raise KeyError("param 'initial_lr' is not specified " |
| "in param_groups[{}] when resuming an optimizer".format(i)) |
| self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) |
| self.step(last_epoch + 1) |
| self.last_epoch = last_epoch |
| |
| def state_dict(self): |
| """Returns the state of the scheduler as a :class:`dict`. |
| |
| It contains an entry for every variable in self.__dict__ which |
| is not the optimizer. |
| """ |
| return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} |
| |
| def load_state_dict(self, state_dict): |
| """Loads the schedulers state. |
| |
| Arguments: |
| state_dict (dict): scheduler state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| self.__dict__.update(state_dict) |
| |
| def get_lr(self): |
| raise NotImplementedError |
| |
| def step(self, epoch=None): |
| if epoch is None: |
| epoch = self.last_epoch + 1 |
| self.last_epoch = epoch |
| for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): |
| param_group['lr'] = lr |
| |
| |
| class LambdaLR(_LRScheduler): |
| """Sets the learning rate of each parameter group to the initial lr |
| times a given function. When last_epoch=-1, sets initial lr as lr. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| lr_lambda (function or list): A function which computes a multiplicative |
| factor given an integer parameter epoch, or a list of such |
| functions, one for each group in optimizer.param_groups. |
| last_epoch (int): The index of last epoch. Default: -1. |
| |
| Example: |
| >>> # Assuming optimizer has two groups. |
| >>> lambda1 = lambda epoch: epoch // 30 |
| >>> lambda2 = lambda epoch: 0.95 ** epoch |
| >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) |
| >>> for epoch in range(100): |
| >>> scheduler.step() |
| >>> train(...) |
| >>> validate(...) |
| """ |
| |
| def __init__(self, optimizer, lr_lambda, last_epoch=-1): |
| self.optimizer = optimizer |
| if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): |
| self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) |
| else: |
| if len(lr_lambda) != len(optimizer.param_groups): |
| raise ValueError("Expected {} lr_lambdas, but got {}".format( |
| len(optimizer.param_groups), len(lr_lambda))) |
| self.lr_lambdas = list(lr_lambda) |
| self.last_epoch = last_epoch |
| super(LambdaLR, self).__init__(optimizer, last_epoch) |
| |
| def state_dict(self): |
| """Returns the state of the scheduler as a :class:`dict`. |
| |
| It contains an entry for every variable in self.__dict__ which |
| is not the optimizer. |
| The learning rate lambda functions will only be saved if they are callable objects |
| and not if they are functions or lambdas. |
| """ |
| state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} |
| state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) |
| |
| for idx, fn in enumerate(self.lr_lambdas): |
| if not isinstance(fn, types.FunctionType): |
| state_dict['lr_lambdas'][idx] = fn.__dict__.copy() |
| |
| return state_dict |
| |
| def load_state_dict(self, state_dict): |
| """Loads the schedulers state. |
| |
| Arguments: |
| state_dict (dict): scheduler state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| lr_lambdas = state_dict.pop('lr_lambdas') |
| self.__dict__.update(state_dict) |
| |
| for idx, fn in enumerate(lr_lambdas): |
| if fn is not None: |
| self.lr_lambdas[idx].__dict__.update(fn) |
| |
| def get_lr(self): |
| return [base_lr * lmbda(self.last_epoch) |
| for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] |
| |
| |
| class StepLR(_LRScheduler): |
| """Sets the learning rate of each parameter group to the initial lr |
| decayed by gamma every step_size epochs. When last_epoch=-1, sets |
| initial lr as lr. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| step_size (int): Period of learning rate decay. |
| gamma (float): Multiplicative factor of learning rate decay. |
| Default: 0.1. |
| last_epoch (int): The index of last epoch. Default: -1. |
| |
| Example: |
| >>> # Assuming optimizer uses lr = 0.05 for all groups |
| >>> # lr = 0.05 if epoch < 30 |
| >>> # lr = 0.005 if 30 <= epoch < 60 |
| >>> # lr = 0.0005 if 60 <= epoch < 90 |
| >>> # ... |
| >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) |
| >>> for epoch in range(100): |
| >>> scheduler.step() |
| >>> train(...) |
| >>> validate(...) |
| """ |
| |
| def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): |
| self.step_size = step_size |
| self.gamma = gamma |
| super(StepLR, self).__init__(optimizer, last_epoch) |
| |
| def get_lr(self): |
| return [base_lr * self.gamma ** (self.last_epoch // self.step_size) |
| for base_lr in self.base_lrs] |
| |
| |
| class MultiStepLR(_LRScheduler): |
| """Set the learning rate of each parameter group to the initial lr decayed |
| by gamma once the number of epoch reaches one of the milestones. When |
| last_epoch=-1, sets initial lr as lr. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| milestones (list): List of epoch indices. Must be increasing. |
| gamma (float): Multiplicative factor of learning rate decay. |
| Default: 0.1. |
| last_epoch (int): The index of last epoch. Default: -1. |
| |
| Example: |
| >>> # Assuming optimizer uses lr = 0.05 for all groups |
| >>> # lr = 0.05 if epoch < 30 |
| >>> # lr = 0.005 if 30 <= epoch < 80 |
| >>> # lr = 0.0005 if epoch >= 80 |
| >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) |
| >>> for epoch in range(100): |
| >>> scheduler.step() |
| >>> train(...) |
| >>> validate(...) |
| """ |
| |
| def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): |
| if not list(milestones) == sorted(milestones): |
| raise ValueError('Milestones should be a list of' |
| ' increasing integers. Got {}', milestones) |
| self.milestones = milestones |
| self.gamma = gamma |
| super(MultiStepLR, self).__init__(optimizer, last_epoch) |
| |
| def get_lr(self): |
| return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) |
| for base_lr in self.base_lrs] |
| |
| |
| class ExponentialLR(_LRScheduler): |
| """Set the learning rate of each parameter group to the initial lr decayed |
| by gamma every epoch. When last_epoch=-1, sets initial lr as lr. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| gamma (float): Multiplicative factor of learning rate decay. |
| last_epoch (int): The index of last epoch. Default: -1. |
| """ |
| |
| def __init__(self, optimizer, gamma, last_epoch=-1): |
| self.gamma = gamma |
| super(ExponentialLR, self).__init__(optimizer, last_epoch) |
| |
| def get_lr(self): |
| return [base_lr * self.gamma ** self.last_epoch |
| for base_lr in self.base_lrs] |
| |
| |
| class CosineAnnealingLR(_LRScheduler): |
| r"""Set the learning rate of each parameter group using a cosine annealing |
| schedule, where :math:`\eta_{max}` is set to the initial lr and |
| :math:`T_{cur}` is the number of epochs since the last restart in SGDR: |
| |
| .. math:: |
| |
| \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + |
| \cos(\frac{T_{cur}}{T_{max}}\pi)) |
| |
| When last_epoch=-1, sets initial lr as lr. |
| |
| It has been proposed in |
| `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only |
| implements the cosine annealing part of SGDR, and not the restarts. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| T_max (int): Maximum number of iterations. |
| eta_min (float): Minimum learning rate. Default: 0. |
| last_epoch (int): The index of last epoch. Default: -1. |
| |
| .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: |
| https://arxiv.org/abs/1608.03983 |
| """ |
| |
| def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): |
| self.T_max = T_max |
| self.eta_min = eta_min |
| super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) |
| |
| def get_lr(self): |
| return [self.eta_min + (base_lr - self.eta_min) * |
| (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 |
| for base_lr in self.base_lrs] |
| |
| |
| class ReduceLROnPlateau(object): |
| """Reduce learning rate when a metric has stopped improving. |
| Models often benefit from reducing the learning rate by a factor |
| of 2-10 once learning stagnates. This scheduler reads a metrics |
| quantity and if no improvement is seen for a 'patience' number |
| of epochs, the learning rate is reduced. |
| |
| Args: |
| optimizer (Optimizer): Wrapped optimizer. |
| mode (str): One of `min`, `max`. In `min` mode, lr will |
| be reduced when the quantity monitored has stopped |
| decreasing; in `max` mode it will be reduced when the |
| quantity monitored has stopped increasing. Default: 'min'. |
| factor (float): Factor by which the learning rate will be |
| reduced. new_lr = lr * factor. Default: 0.1. |
| patience (int): Number of epochs with no improvement after |
| which learning rate will be reduced. For example, if |
| `patience = 2`, then we will ignore the first 2 epochs |
| with no improvement, and will only decrease the LR after the |
| 3rd epoch if the loss still hasn't improved then. |
| Default: 10. |
| verbose (bool): If ``True``, prints a message to stdout for |
| each update. Default: ``False``. |
| threshold (float): Threshold for measuring the new optimum, |
| to only focus on significant changes. Default: 1e-4. |
| threshold_mode (str): One of `rel`, `abs`. In `rel` mode, |
| dynamic_threshold = best * ( 1 + threshold ) in 'max' |
| mode or best * ( 1 - threshold ) in `min` mode. |
| In `abs` mode, dynamic_threshold = best + threshold in |
| `max` mode or best - threshold in `min` mode. Default: 'rel'. |
| cooldown (int): Number of epochs to wait before resuming |
| normal operation after lr has been reduced. Default: 0. |
| min_lr (float or list): A scalar or a list of scalars. A |
| lower bound on the learning rate of all param groups |
| or each group respectively. Default: 0. |
| eps (float): Minimal decay applied to lr. If the difference |
| between new and old lr is smaller than eps, the update is |
| ignored. Default: 1e-8. |
| |
| Example: |
| >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) |
| >>> scheduler = ReduceLROnPlateau(optimizer, 'min') |
| >>> for epoch in range(10): |
| >>> train(...) |
| >>> val_loss = validate(...) |
| >>> # Note that step should be called after validate() |
| >>> scheduler.step(val_loss) |
| """ |
| |
| def __init__(self, optimizer, mode='min', factor=0.1, patience=10, |
| verbose=False, threshold=1e-4, threshold_mode='rel', |
| cooldown=0, min_lr=0, eps=1e-8): |
| |
| if factor >= 1.0: |
| raise ValueError('Factor should be < 1.0.') |
| self.factor = factor |
| |
| if not isinstance(optimizer, Optimizer): |
| raise TypeError('{} is not an Optimizer'.format( |
| type(optimizer).__name__)) |
| self.optimizer = optimizer |
| |
| if isinstance(min_lr, list) or isinstance(min_lr, tuple): |
| if len(min_lr) != len(optimizer.param_groups): |
| raise ValueError("expected {} min_lrs, got {}".format( |
| len(optimizer.param_groups), len(min_lr))) |
| self.min_lrs = list(min_lr) |
| else: |
| self.min_lrs = [min_lr] * len(optimizer.param_groups) |
| |
| self.patience = patience |
| self.verbose = verbose |
| self.cooldown = cooldown |
| self.cooldown_counter = 0 |
| self.mode = mode |
| self.threshold = threshold |
| self.threshold_mode = threshold_mode |
| self.best = None |
| self.num_bad_epochs = None |
| self.mode_worse = None # the worse value for the chosen mode |
| self.is_better = None |
| self.eps = eps |
| self.last_epoch = -1 |
| self._init_is_better(mode=mode, threshold=threshold, |
| threshold_mode=threshold_mode) |
| self._reset() |
| |
| def _reset(self): |
| """Resets num_bad_epochs counter and cooldown counter.""" |
| self.best = self.mode_worse |
| self.cooldown_counter = 0 |
| self.num_bad_epochs = 0 |
| |
| def step(self, metrics, epoch=None): |
| current = metrics |
| if epoch is None: |
| epoch = self.last_epoch = self.last_epoch + 1 |
| self.last_epoch = epoch |
| |
| if self.is_better(current, self.best): |
| self.best = current |
| self.num_bad_epochs = 0 |
| else: |
| self.num_bad_epochs += 1 |
| |
| if self.in_cooldown: |
| self.cooldown_counter -= 1 |
| self.num_bad_epochs = 0 # ignore any bad epochs in cooldown |
| |
| if self.num_bad_epochs > self.patience: |
| self._reduce_lr(epoch) |
| self.cooldown_counter = self.cooldown |
| self.num_bad_epochs = 0 |
| |
| def _reduce_lr(self, epoch): |
| for i, param_group in enumerate(self.optimizer.param_groups): |
| old_lr = float(param_group['lr']) |
| new_lr = max(old_lr * self.factor, self.min_lrs[i]) |
| if old_lr - new_lr > self.eps: |
| param_group['lr'] = new_lr |
| if self.verbose: |
| print('Epoch {:5d}: reducing learning rate' |
| ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) |
| |
| @property |
| def in_cooldown(self): |
| return self.cooldown_counter > 0 |
| |
| def _cmp(self, mode, threshold_mode, threshold, a, best): |
| if mode == 'min' and threshold_mode == 'rel': |
| rel_epsilon = 1. - threshold |
| return a < best * rel_epsilon |
| |
| elif mode == 'min' and threshold_mode == 'abs': |
| return a < best - threshold |
| |
| elif mode == 'max' and threshold_mode == 'rel': |
| rel_epsilon = threshold + 1. |
| return a > best * rel_epsilon |
| |
| else: # mode == 'max' and epsilon_mode == 'abs': |
| return a > best + threshold |
| |
| def _init_is_better(self, mode, threshold, threshold_mode): |
| if mode not in {'min', 'max'}: |
| raise ValueError('mode ' + mode + ' is unknown!') |
| if threshold_mode not in {'rel', 'abs'}: |
| raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') |
| |
| if mode == 'min': |
| self.mode_worse = inf |
| else: # mode == 'max': |
| self.mode_worse = -inf |
| |
| self.is_better = partial(self._cmp, mode, threshold_mode, threshold) |
| |
| def state_dict(self): |
| return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}} |
| |
| def load_state_dict(self, state_dict): |
| self.__dict__.update(state_dict) |
| self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) |