| import types | 
 | import math | 
 | from torch._six import inf | 
 | from functools import wraps | 
 | import warnings | 
 | import weakref | 
 | from collections import Counter | 
 | from bisect import bisect_right | 
 |  | 
 | from .optimizer import Optimizer | 
 |  | 
 |  | 
 | EPOCH_DEPRECATION_WARNING = ( | 
 |     "The epoch parameter in `scheduler.step()` was not necessary and is being " | 
 |     "deprecated where possible. Please use `scheduler.step()` to step the " | 
 |     "scheduler. During the deprecation, if epoch is different from None, the " | 
 |     "closed form is used instead of the new chainable form, where available. " | 
 |     "Please open an issue if you are unable to replicate your use case: " | 
 |     "https://github.com/pytorch/pytorch/issues/new/choose." | 
 | ) | 
 |  | 
 | SAVE_STATE_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." | 
 |  | 
 | class _LRScheduler(object): | 
 |  | 
 |     def __init__(self, optimizer, last_epoch=-1): | 
 |  | 
 |         # Attach optimizer | 
 |         if not isinstance(optimizer, Optimizer): | 
 |             raise TypeError('{} is not an Optimizer'.format( | 
 |                 type(optimizer).__name__)) | 
 |         self.optimizer = optimizer | 
 |  | 
 |         # Initialize epoch and base learning rates | 
 |         if last_epoch == -1: | 
 |             for group in optimizer.param_groups: | 
 |                 group.setdefault('initial_lr', group['lr']) | 
 |         else: | 
 |             for i, group in enumerate(optimizer.param_groups): | 
 |                 if 'initial_lr' not in group: | 
 |                     raise KeyError("param 'initial_lr' is not specified " | 
 |                                    "in param_groups[{}] when resuming an optimizer".format(i)) | 
 |         self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) | 
 |         self.last_epoch = last_epoch | 
 |  | 
 |         # Following https://github.com/pytorch/pytorch/issues/20124 | 
 |         # We would like to ensure that `lr_scheduler.step()` is called after | 
 |         # `optimizer.step()` | 
 |         def with_counter(method): | 
 |             if getattr(method, '_with_counter', False): | 
 |                 # `optimizer.step()` has already been replaced, return. | 
 |                 return method | 
 |  | 
 |             # Keep a weak reference to the optimizer instance to prevent | 
 |             # cyclic references. | 
 |             instance_ref = weakref.ref(method.__self__) | 
 |             # Get the unbound method for the same purpose. | 
 |             func = method.__func__ | 
 |             cls = instance_ref().__class__ | 
 |             del method | 
 |  | 
 |             @wraps(func) | 
 |             def wrapper(*args, **kwargs): | 
 |                 instance = instance_ref() | 
 |                 instance._step_count += 1 | 
 |                 wrapped = func.__get__(instance, cls) | 
 |                 return wrapped(*args, **kwargs) | 
 |  | 
 |             # Note that the returned function here is no longer a bound method, | 
 |             # so attributes like `__func__` and `__self__` no longer exist. | 
 |             wrapper._with_counter = True | 
 |             return wrapper | 
 |  | 
 |         self.optimizer.step = with_counter(self.optimizer.step) | 
 |         self.optimizer._step_count = 0 | 
 |         self._step_count = 0 | 
 |  | 
 |         self.step() | 
 |  | 
 |     def state_dict(self): | 
 |         """Returns the state of the scheduler as a :class:`dict`. | 
 |  | 
 |         It contains an entry for every variable in self.__dict__ which | 
 |         is not the optimizer. | 
 |         """ | 
 |         return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} | 
 |  | 
 |     def load_state_dict(self, state_dict): | 
 |         """Loads the schedulers state. | 
 |  | 
 |         Arguments: | 
 |             state_dict (dict): scheduler state. Should be an object returned | 
 |                 from a call to :meth:`state_dict`. | 
 |         """ | 
 |         self.__dict__.update(state_dict) | 
 |  | 
 |     def get_last_lr(self): | 
 |         """ Return last computed learning rate by current scheduler. | 
 |         """ | 
 |         return self._last_lr | 
 |  | 
 |     def get_lr(self): | 
 |         # Compute learning rate using chainable form of the scheduler | 
 |         raise NotImplementedError | 
 |  | 
 |     def step(self, epoch=None): | 
 |         # Raise a warning if old pattern is detected | 
 |         # https://github.com/pytorch/pytorch/issues/20124 | 
 |         if self._step_count == 1: | 
 |             if not hasattr(self.optimizer.step, "_with_counter"): | 
 |                 warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " | 
 |                               "initialization. Please, make sure to call `optimizer.step()` before " | 
 |                               "`lr_scheduler.step()`. See more details at " | 
 |                               "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) | 
 |  | 
 |             # Just check if there were two first lr_scheduler.step() calls before optimizer.step() | 
 |             elif self.optimizer._step_count < 1: | 
 |                 warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " | 
 |                               "In PyTorch 1.1.0 and later, you should call them in the opposite order: " | 
 |                               "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this " | 
 |                               "will result in PyTorch skipping the first value of the learning rate schedule. " | 
 |                               "See more details at " | 
 |                               "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) | 
 |         self._step_count += 1 | 
 |  | 
 |         class _enable_get_lr_call: | 
 |  | 
 |             def __init__(self, o): | 
 |                 self.o = o | 
 |  | 
 |             def __enter__(self): | 
 |                 self.o._get_lr_called_within_step = True | 
 |                 return self | 
 |  | 
 |             def __exit__(self, type, value, traceback): | 
 |                 self.o._get_lr_called_within_step = False | 
 |  | 
 |         with _enable_get_lr_call(self): | 
 |             if epoch is None: | 
 |                 self.last_epoch += 1 | 
 |                 values = self.get_lr() | 
 |             else: | 
 |                 warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) | 
 |                 self.last_epoch = epoch | 
 |                 if hasattr(self, "_get_closed_form_lr"): | 
 |                     values = self._get_closed_form_lr() | 
 |                 else: | 
 |                     values = self.get_lr() | 
 |  | 
 |         for param_group, lr in zip(self.optimizer.param_groups, values): | 
 |             param_group['lr'] = lr | 
 |  | 
 |         self._last_lr = [group['lr'] for group in self.optimizer.param_groups] | 
 |  | 
 |  | 
 | class LambdaLR(_LRScheduler): | 
 |     """Sets the learning rate of each parameter group to the initial lr | 
 |     times a given function. When last_epoch=-1, sets initial lr as lr. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         lr_lambda (function or list): A function which computes a multiplicative | 
 |             factor given an integer parameter epoch, or a list of such | 
 |             functions, one for each group in optimizer.param_groups. | 
 |         last_epoch (int): The index of last epoch. Default: -1. | 
 |  | 
 |     Example: | 
 |         >>> # Assuming optimizer has two groups. | 
 |         >>> lambda1 = lambda epoch: epoch // 30 | 
 |         >>> lambda2 = lambda epoch: 0.95 ** epoch | 
 |         >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) | 
 |         >>> for epoch in range(100): | 
 |         >>>     train(...) | 
 |         >>>     validate(...) | 
 |         >>>     scheduler.step() | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, lr_lambda, last_epoch=-1): | 
 |         self.optimizer = optimizer | 
 |  | 
 |         if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): | 
 |             self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) | 
 |         else: | 
 |             if len(lr_lambda) != len(optimizer.param_groups): | 
 |                 raise ValueError("Expected {} lr_lambdas, but got {}".format( | 
 |                     len(optimizer.param_groups), len(lr_lambda))) | 
 |             self.lr_lambdas = list(lr_lambda) | 
 |         self.last_epoch = last_epoch | 
 |         super(LambdaLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def state_dict(self): | 
 |         """Returns the state of the scheduler as a :class:`dict`. | 
 |  | 
 |         It contains an entry for every variable in self.__dict__ which | 
 |         is not the optimizer. | 
 |         The learning rate lambda functions will only be saved if they are callable objects | 
 |         and not if they are functions or lambdas. | 
 |         """ | 
 |  | 
 |         warnings.warn(SAVE_STATE_WARNING, UserWarning) | 
 |         state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} | 
 |         state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) | 
 |  | 
 |         for idx, fn in enumerate(self.lr_lambdas): | 
 |             if not isinstance(fn, types.FunctionType): | 
 |                 state_dict['lr_lambdas'][idx] = fn.__dict__.copy() | 
 |  | 
 |         return state_dict | 
 |  | 
 |     def load_state_dict(self, state_dict): | 
 |         """Loads the schedulers state. | 
 |  | 
 |         Arguments: | 
 |             state_dict (dict): scheduler state. Should be an object returned | 
 |                 from a call to :meth:`state_dict`. | 
 |         """ | 
 |  | 
 |         warnings.warn(SAVE_STATE_WARNING, UserWarning) | 
 |         lr_lambdas = state_dict.pop('lr_lambdas') | 
 |         self.__dict__.update(state_dict) | 
 |  | 
 |         for idx, fn in enumerate(lr_lambdas): | 
 |             if fn is not None: | 
 |                 self.lr_lambdas[idx].__dict__.update(fn) | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.") | 
 |  | 
 |         return [base_lr * lmbda(self.last_epoch) | 
 |                 for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] | 
 |  | 
 |  | 
 | class MultiplicativeLR(_LRScheduler): | 
 |     """Multiply the learning rate of each parameter group by the factor given | 
 |     in the specified function. When last_epoch=-1, sets initial lr as lr. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         lr_lambda (function or list): A function which computes a multiplicative | 
 |             factor given an integer parameter epoch, or a list of such | 
 |             functions, one for each group in optimizer.param_groups. | 
 |         last_epoch (int): The index of last epoch. Default: -1. | 
 |  | 
 |     Example: | 
 |         >>> # Assuming optimizer has two groups. | 
 |         >>> lmbda = lambda epoch: 0.95 | 
 |         >>> scheduler = LambdaLR(optimizer, lr_lambda=lmbda) | 
 |         >>> for epoch in range(100): | 
 |         >>>     train(...) | 
 |         >>>     validate(...) | 
 |         >>>     scheduler.step() | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, lr_lambda, last_epoch=-1): | 
 |         self.optimizer = optimizer | 
 |  | 
 |         if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): | 
 |             self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) | 
 |         else: | 
 |             if len(lr_lambda) != len(optimizer.param_groups): | 
 |                 raise ValueError("Expected {} lr_lambdas, but got {}".format( | 
 |                     len(optimizer.param_groups), len(lr_lambda))) | 
 |             self.lr_lambdas = list(lr_lambda) | 
 |         self.last_epoch = last_epoch | 
 |         super(MultiplicativeLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def state_dict(self): | 
 |         """Returns the state of the scheduler as a :class:`dict`. | 
 |  | 
 |         It contains an entry for every variable in self.__dict__ which | 
 |         is not the optimizer. | 
 |         The learning rate lambda functions will only be saved if they are callable objects | 
 |         and not if they are functions or lambdas. | 
 |         """ | 
 |         state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} | 
 |         state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) | 
 |  | 
 |         for idx, fn in enumerate(self.lr_lambdas): | 
 |             if not isinstance(fn, types.FunctionType): | 
 |                 state_dict['lr_lambdas'][idx] = fn.__dict__.copy() | 
 |  | 
 |         return state_dict | 
 |  | 
 |     def load_state_dict(self, state_dict): | 
 |         """Loads the schedulers state. | 
 |  | 
 |         Arguments: | 
 |             state_dict (dict): scheduler state. Should be an object returned | 
 |                 from a call to :meth:`state_dict`. | 
 |         """ | 
 |         lr_lambdas = state_dict.pop('lr_lambdas') | 
 |         self.__dict__.update(state_dict) | 
 |  | 
 |         for idx, fn in enumerate(lr_lambdas): | 
 |             if fn is not None: | 
 |                 self.lr_lambdas[idx].__dict__.update(fn) | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         if self.last_epoch > 0: | 
 |             return [group['lr'] * lmbda(self.last_epoch) | 
 |                     for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)] | 
 |         else: | 
 |             return [base_lr for base_lr in self.base_lrs] | 
 |  | 
 |  | 
 | class StepLR(_LRScheduler): | 
 |     """Decays the learning rate of each parameter group by gamma every | 
 |     step_size epochs. Notice that such decay can happen simultaneously with | 
 |     other changes to the learning rate from outside this scheduler. When | 
 |     last_epoch=-1, sets initial lr as lr. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         step_size (int): Period of learning rate decay. | 
 |         gamma (float): Multiplicative factor of learning rate decay. | 
 |             Default: 0.1. | 
 |         last_epoch (int): The index of last epoch. Default: -1. | 
 |  | 
 |     Example: | 
 |         >>> # Assuming optimizer uses lr = 0.05 for all groups | 
 |         >>> # lr = 0.05     if epoch < 30 | 
 |         >>> # lr = 0.005    if 30 <= epoch < 60 | 
 |         >>> # lr = 0.0005   if 60 <= epoch < 90 | 
 |         >>> # ... | 
 |         >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) | 
 |         >>> for epoch in range(100): | 
 |         >>>     train(...) | 
 |         >>>     validate(...) | 
 |         >>>     scheduler.step() | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): | 
 |         self.step_size = step_size | 
 |         self.gamma = gamma | 
 |         super(StepLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): | 
 |             return [group['lr'] for group in self.optimizer.param_groups] | 
 |         return [group['lr'] * self.gamma | 
 |                 for group in self.optimizer.param_groups] | 
 |  | 
 |     def _get_closed_form_lr(self): | 
 |         return [base_lr * self.gamma ** (self.last_epoch // self.step_size) | 
 |                 for base_lr in self.base_lrs] | 
 |  | 
 |  | 
 | class MultiStepLR(_LRScheduler): | 
 |     """Decays the learning rate of each parameter group by gamma once the | 
 |     number of epoch reaches one of the milestones. Notice that such decay can | 
 |     happen simultaneously with other changes to the learning rate from outside | 
 |     this scheduler. When last_epoch=-1, sets initial lr as lr. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         milestones (list): List of epoch indices. Must be increasing. | 
 |         gamma (float): Multiplicative factor of learning rate decay. | 
 |             Default: 0.1. | 
 |         last_epoch (int): The index of last epoch. Default: -1. | 
 |  | 
 |     Example: | 
 |         >>> # Assuming optimizer uses lr = 0.05 for all groups | 
 |         >>> # lr = 0.05     if epoch < 30 | 
 |         >>> # lr = 0.005    if 30 <= epoch < 80 | 
 |         >>> # lr = 0.0005   if epoch >= 80 | 
 |         >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) | 
 |         >>> for epoch in range(100): | 
 |         >>>     train(...) | 
 |         >>>     validate(...) | 
 |         >>>     scheduler.step() | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): | 
 |         self.milestones = Counter(milestones) | 
 |         self.gamma = gamma | 
 |         super(MultiStepLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         if self.last_epoch not in self.milestones: | 
 |             return [group['lr'] for group in self.optimizer.param_groups] | 
 |         return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] | 
 |                 for group in self.optimizer.param_groups] | 
 |  | 
 |     def _get_closed_form_lr(self): | 
 |         return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) | 
 |                 for base_lr in self.base_lrs] | 
 |  | 
 |  | 
 | class ExponentialLR(_LRScheduler): | 
 |     """Decays the learning rate of each parameter group by gamma every epoch. | 
 |     When last_epoch=-1, sets initial lr as lr. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         gamma (float): Multiplicative factor of learning rate decay. | 
 |         last_epoch (int): The index of last epoch. Default: -1. | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, gamma, last_epoch=-1): | 
 |         self.gamma = gamma | 
 |         super(ExponentialLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         if self.last_epoch == 0: | 
 |             return self.base_lrs | 
 |         return [group['lr'] * self.gamma | 
 |                 for group in self.optimizer.param_groups] | 
 |  | 
 |     def _get_closed_form_lr(self): | 
 |         return [base_lr * self.gamma ** self.last_epoch | 
 |                 for base_lr in self.base_lrs] | 
 |  | 
 |  | 
 | class CosineAnnealingLR(_LRScheduler): | 
 |     r"""Set the learning rate of each parameter group using a cosine annealing | 
 |     schedule, where :math:`\eta_{max}` is set to the initial lr and | 
 |     :math:`T_{cur}` is the number of epochs since the last restart in SGDR: | 
 |  | 
 |     .. math:: | 
 |         \begin{aligned} | 
 |             \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 | 
 |             + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), | 
 |             & T_{cur} \neq (2k+1)T_{max}; \\ | 
 |             \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) | 
 |             \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), | 
 |             & T_{cur} = (2k+1)T_{max}. | 
 |         \end{aligned} | 
 |  | 
 |     When last_epoch=-1, sets initial lr as lr. Notice that because the schedule | 
 |     is defined recursively, the learning rate can be simultaneously modified | 
 |     outside this scheduler by other operators. If the learning rate is set | 
 |     solely by this scheduler, the learning rate at each step becomes: | 
 |  | 
 |     .. math:: | 
 |         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | 
 |         \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) | 
 |  | 
 |     It has been proposed in | 
 |     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only | 
 |     implements the cosine annealing part of SGDR, and not the restarts. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         T_max (int): Maximum number of iterations. | 
 |         eta_min (float): Minimum learning rate. Default: 0. | 
 |         last_epoch (int): The index of last epoch. Default: -1. | 
 |  | 
 |     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | 
 |         https://arxiv.org/abs/1608.03983 | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): | 
 |         self.T_max = T_max | 
 |         self.eta_min = eta_min | 
 |         super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         if self.last_epoch == 0: | 
 |             return self.base_lrs | 
 |         elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: | 
 |             return [group['lr'] + (base_lr - self.eta_min) * | 
 |                     (1 - math.cos(math.pi / self.T_max)) / 2 | 
 |                     for base_lr, group in | 
 |                     zip(self.base_lrs, self.optimizer.param_groups)] | 
 |         return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / | 
 |                 (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * | 
 |                 (group['lr'] - self.eta_min) + self.eta_min | 
 |                 for group in self.optimizer.param_groups] | 
 |  | 
 |     def _get_closed_form_lr(self): | 
 |         return [self.eta_min + (base_lr - self.eta_min) * | 
 |                 (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 | 
 |                 for base_lr in self.base_lrs] | 
 |  | 
 |  | 
 | class ReduceLROnPlateau(object): | 
 |     """Reduce learning rate when a metric has stopped improving. | 
 |     Models often benefit from reducing the learning rate by a factor | 
 |     of 2-10 once learning stagnates. This scheduler reads a metrics | 
 |     quantity and if no improvement is seen for a 'patience' number | 
 |     of epochs, the learning rate is reduced. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         mode (str): One of `min`, `max`. In `min` mode, lr will | 
 |             be reduced when the quantity monitored has stopped | 
 |             decreasing; in `max` mode it will be reduced when the | 
 |             quantity monitored has stopped increasing. Default: 'min'. | 
 |         factor (float): Factor by which the learning rate will be | 
 |             reduced. new_lr = lr * factor. Default: 0.1. | 
 |         patience (int): Number of epochs with no improvement after | 
 |             which learning rate will be reduced. For example, if | 
 |             `patience = 2`, then we will ignore the first 2 epochs | 
 |             with no improvement, and will only decrease the LR after the | 
 |             3rd epoch if the loss still hasn't improved then. | 
 |             Default: 10. | 
 |         verbose (bool): If ``True``, prints a message to stdout for | 
 |             each update. Default: ``False``. | 
 |         threshold (float): Threshold for measuring the new optimum, | 
 |             to only focus on significant changes. Default: 1e-4. | 
 |         threshold_mode (str): One of `rel`, `abs`. In `rel` mode, | 
 |             dynamic_threshold = best * ( 1 + threshold ) in 'max' | 
 |             mode or best * ( 1 - threshold ) in `min` mode. | 
 |             In `abs` mode, dynamic_threshold = best + threshold in | 
 |             `max` mode or best - threshold in `min` mode. Default: 'rel'. | 
 |         cooldown (int): Number of epochs to wait before resuming | 
 |             normal operation after lr has been reduced. Default: 0. | 
 |         min_lr (float or list): A scalar or a list of scalars. A | 
 |             lower bound on the learning rate of all param groups | 
 |             or each group respectively. Default: 0. | 
 |         eps (float): Minimal decay applied to lr. If the difference | 
 |             between new and old lr is smaller than eps, the update is | 
 |             ignored. Default: 1e-8. | 
 |  | 
 |     Example: | 
 |         >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) | 
 |         >>> scheduler = ReduceLROnPlateau(optimizer, 'min') | 
 |         >>> for epoch in range(10): | 
 |         >>>     train(...) | 
 |         >>>     val_loss = validate(...) | 
 |         >>>     # Note that step should be called after validate() | 
 |         >>>     scheduler.step(val_loss) | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, mode='min', factor=0.1, patience=10, | 
 |                  verbose=False, threshold=1e-4, threshold_mode='rel', | 
 |                  cooldown=0, min_lr=0, eps=1e-8): | 
 |  | 
 |         if factor >= 1.0: | 
 |             raise ValueError('Factor should be < 1.0.') | 
 |         self.factor = factor | 
 |  | 
 |         # Attach optimizer | 
 |         if not isinstance(optimizer, Optimizer): | 
 |             raise TypeError('{} is not an Optimizer'.format( | 
 |                 type(optimizer).__name__)) | 
 |         self.optimizer = optimizer | 
 |  | 
 |         if isinstance(min_lr, list) or isinstance(min_lr, tuple): | 
 |             if len(min_lr) != len(optimizer.param_groups): | 
 |                 raise ValueError("expected {} min_lrs, got {}".format( | 
 |                     len(optimizer.param_groups), len(min_lr))) | 
 |             self.min_lrs = list(min_lr) | 
 |         else: | 
 |             self.min_lrs = [min_lr] * len(optimizer.param_groups) | 
 |  | 
 |         self.patience = patience | 
 |         self.verbose = verbose | 
 |         self.cooldown = cooldown | 
 |         self.cooldown_counter = 0 | 
 |         self.mode = mode | 
 |         self.threshold = threshold | 
 |         self.threshold_mode = threshold_mode | 
 |         self.best = None | 
 |         self.num_bad_epochs = None | 
 |         self.mode_worse = None  # the worse value for the chosen mode | 
 |         self.eps = eps | 
 |         self.last_epoch = 0 | 
 |         self._init_is_better(mode=mode, threshold=threshold, | 
 |                              threshold_mode=threshold_mode) | 
 |         self._reset() | 
 |  | 
 |     def _reset(self): | 
 |         """Resets num_bad_epochs counter and cooldown counter.""" | 
 |         self.best = self.mode_worse | 
 |         self.cooldown_counter = 0 | 
 |         self.num_bad_epochs = 0 | 
 |  | 
 |     def step(self, metrics, epoch=None): | 
 |         # convert `metrics` to float, in case it's a zero-dim Tensor | 
 |         current = float(metrics) | 
 |         if epoch is None: | 
 |             epoch = self.last_epoch + 1 | 
 |         else: | 
 |             warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) | 
 |         self.last_epoch = epoch | 
 |  | 
 |         if self.is_better(current, self.best): | 
 |             self.best = current | 
 |             self.num_bad_epochs = 0 | 
 |         else: | 
 |             self.num_bad_epochs += 1 | 
 |  | 
 |         if self.in_cooldown: | 
 |             self.cooldown_counter -= 1 | 
 |             self.num_bad_epochs = 0  # ignore any bad epochs in cooldown | 
 |  | 
 |         if self.num_bad_epochs > self.patience: | 
 |             self._reduce_lr(epoch) | 
 |             self.cooldown_counter = self.cooldown | 
 |             self.num_bad_epochs = 0 | 
 |  | 
 |         self._last_lr = [group['lr'] for group in self.optimizer.param_groups] | 
 |  | 
 |     def _reduce_lr(self, epoch): | 
 |         for i, param_group in enumerate(self.optimizer.param_groups): | 
 |             old_lr = float(param_group['lr']) | 
 |             new_lr = max(old_lr * self.factor, self.min_lrs[i]) | 
 |             if old_lr - new_lr > self.eps: | 
 |                 param_group['lr'] = new_lr | 
 |                 if self.verbose: | 
 |                     print('Epoch {:5d}: reducing learning rate' | 
 |                           ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) | 
 |  | 
 |     @property | 
 |     def in_cooldown(self): | 
 |         return self.cooldown_counter > 0 | 
 |  | 
 |     def is_better(self, a, best): | 
 |         if self.mode == 'min' and self.threshold_mode == 'rel': | 
 |             rel_epsilon = 1. - self.threshold | 
 |             return a < best * rel_epsilon | 
 |  | 
 |         elif self.mode == 'min' and self.threshold_mode == 'abs': | 
 |             return a < best - self.threshold | 
 |  | 
 |         elif self.mode == 'max' and self.threshold_mode == 'rel': | 
 |             rel_epsilon = self.threshold + 1. | 
 |             return a > best * rel_epsilon | 
 |  | 
 |         else:  # mode == 'max' and epsilon_mode == 'abs': | 
 |             return a > best + self.threshold | 
 |  | 
 |     def _init_is_better(self, mode, threshold, threshold_mode): | 
 |         if mode not in {'min', 'max'}: | 
 |             raise ValueError('mode ' + mode + ' is unknown!') | 
 |         if threshold_mode not in {'rel', 'abs'}: | 
 |             raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') | 
 |  | 
 |         if mode == 'min': | 
 |             self.mode_worse = inf | 
 |         else:  # mode == 'max': | 
 |             self.mode_worse = -inf | 
 |  | 
 |         self.mode = mode | 
 |         self.threshold = threshold | 
 |         self.threshold_mode = threshold_mode | 
 |  | 
 |     def state_dict(self): | 
 |         return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} | 
 |  | 
 |     def load_state_dict(self, state_dict): | 
 |         self.__dict__.update(state_dict) | 
 |         self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) | 
 |  | 
 |  | 
 | class CyclicLR(_LRScheduler): | 
 |     r"""Sets the learning rate of each parameter group according to | 
 |     cyclical learning rate policy (CLR). The policy cycles the learning | 
 |     rate between two boundaries with a constant frequency, as detailed in | 
 |     the paper `Cyclical Learning Rates for Training Neural Networks`_. | 
 |     The distance between the two boundaries can be scaled on a per-iteration | 
 |     or per-cycle basis. | 
 |  | 
 |     Cyclical learning rate policy changes the learning rate after every batch. | 
 |     `step` should be called after a batch has been used for training. | 
 |  | 
 |     This class has three built-in policies, as put forth in the paper: | 
 |  | 
 |     * "triangular": A basic triangular cycle without amplitude scaling. | 
 |     * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. | 
 |     * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` | 
 |       at each cycle iteration. | 
 |  | 
 |     This implementation was adapted from the github repo: `bckenstler/CLR`_ | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         base_lr (float or list): Initial learning rate which is the | 
 |             lower boundary in the cycle for each parameter group. | 
 |         max_lr (float or list): Upper learning rate boundaries in the cycle | 
 |             for each parameter group. Functionally, | 
 |             it defines the cycle amplitude (max_lr - base_lr). | 
 |             The lr at any cycle is the sum of base_lr | 
 |             and some scaling of the amplitude; therefore | 
 |             max_lr may not actually be reached depending on | 
 |             scaling function. | 
 |         step_size_up (int): Number of training iterations in the | 
 |             increasing half of a cycle. Default: 2000 | 
 |         step_size_down (int): Number of training iterations in the | 
 |             decreasing half of a cycle. If step_size_down is None, | 
 |             it is set to step_size_up. Default: None | 
 |         mode (str): One of {triangular, triangular2, exp_range}. | 
 |             Values correspond to policies detailed above. | 
 |             If scale_fn is not None, this argument is ignored. | 
 |             Default: 'triangular' | 
 |         gamma (float): Constant in 'exp_range' scaling function: | 
 |             gamma**(cycle iterations) | 
 |             Default: 1.0 | 
 |         scale_fn (function): Custom scaling policy defined by a single | 
 |             argument lambda function, where | 
 |             0 <= scale_fn(x) <= 1 for all x >= 0. | 
 |             If specified, then 'mode' is ignored. | 
 |             Default: None | 
 |         scale_mode (str): {'cycle', 'iterations'}. | 
 |             Defines whether scale_fn is evaluated on | 
 |             cycle number or cycle iterations (training | 
 |             iterations since start of cycle). | 
 |             Default: 'cycle' | 
 |         cycle_momentum (bool): If ``True``, momentum is cycled inversely | 
 |             to learning rate between 'base_momentum' and 'max_momentum'. | 
 |             Default: True | 
 |         base_momentum (float or list): Lower momentum boundaries in the cycle | 
 |             for each parameter group. Note that momentum is cycled inversely | 
 |             to learning rate; at the peak of a cycle, momentum is | 
 |             'base_momentum' and learning rate is 'max_lr'. | 
 |             Default: 0.8 | 
 |         max_momentum (float or list): Upper momentum boundaries in the cycle | 
 |             for each parameter group. Functionally, | 
 |             it defines the cycle amplitude (max_momentum - base_momentum). | 
 |             The momentum at any cycle is the difference of max_momentum | 
 |             and some scaling of the amplitude; therefore | 
 |             base_momentum may not actually be reached depending on | 
 |             scaling function. Note that momentum is cycled inversely | 
 |             to learning rate; at the start of a cycle, momentum is 'max_momentum' | 
 |             and learning rate is 'base_lr' | 
 |             Default: 0.9 | 
 |         last_epoch (int): The index of the last batch. This parameter is used when | 
 |             resuming a training job. Since `step()` should be invoked after each | 
 |             batch instead of after each epoch, this number represents the total | 
 |             number of *batches* computed, not the total number of epochs computed. | 
 |             When last_epoch=-1, the schedule is started from the beginning. | 
 |             Default: -1 | 
 |  | 
 |     Example: | 
 |         >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) | 
 |         >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) | 
 |         >>> data_loader = torch.utils.data.DataLoader(...) | 
 |         >>> for epoch in range(10): | 
 |         >>>     for batch in data_loader: | 
 |         >>>         train_batch(...) | 
 |         >>>         scheduler.step() | 
 |  | 
 |  | 
 |     .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 | 
 |     .. _bckenstler/CLR: https://github.com/bckenstler/CLR | 
 |     """ | 
 |  | 
 |     def __init__(self, | 
 |                  optimizer, | 
 |                  base_lr, | 
 |                  max_lr, | 
 |                  step_size_up=2000, | 
 |                  step_size_down=None, | 
 |                  mode='triangular', | 
 |                  gamma=1., | 
 |                  scale_fn=None, | 
 |                  scale_mode='cycle', | 
 |                  cycle_momentum=True, | 
 |                  base_momentum=0.8, | 
 |                  max_momentum=0.9, | 
 |                  last_epoch=-1): | 
 |  | 
 |         # Attach optimizer | 
 |         if not isinstance(optimizer, Optimizer): | 
 |             raise TypeError('{} is not an Optimizer'.format( | 
 |                 type(optimizer).__name__)) | 
 |         self.optimizer = optimizer | 
 |  | 
 |         base_lrs = self._format_param('base_lr', optimizer, base_lr) | 
 |         if last_epoch == -1: | 
 |             for lr, group in zip(base_lrs, optimizer.param_groups): | 
 |                 group['lr'] = lr | 
 |  | 
 |         self.max_lrs = self._format_param('max_lr', optimizer, max_lr) | 
 |  | 
 |         step_size_up = float(step_size_up) | 
 |         step_size_down = float(step_size_down) if step_size_down is not None else step_size_up | 
 |         self.total_size = step_size_up + step_size_down | 
 |         self.step_ratio = step_size_up / self.total_size | 
 |  | 
 |         if mode not in ['triangular', 'triangular2', 'exp_range'] \ | 
 |                 and scale_fn is None: | 
 |             raise ValueError('mode is invalid and scale_fn is None') | 
 |  | 
 |         self.mode = mode | 
 |         self.gamma = gamma | 
 |  | 
 |         if scale_fn is None: | 
 |             if self.mode == 'triangular': | 
 |                 self.scale_fn = self._triangular_scale_fn | 
 |                 self.scale_mode = 'cycle' | 
 |             elif self.mode == 'triangular2': | 
 |                 self.scale_fn = self._triangular2_scale_fn | 
 |                 self.scale_mode = 'cycle' | 
 |             elif self.mode == 'exp_range': | 
 |                 self.scale_fn = self._exp_range_scale_fn | 
 |                 self.scale_mode = 'iterations' | 
 |         else: | 
 |             self.scale_fn = scale_fn | 
 |             self.scale_mode = scale_mode | 
 |  | 
 |         self.cycle_momentum = cycle_momentum | 
 |         if cycle_momentum: | 
 |             if 'momentum' not in optimizer.defaults: | 
 |                 raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') | 
 |  | 
 |             base_momentums = self._format_param('base_momentum', optimizer, base_momentum) | 
 |             if last_epoch == -1: | 
 |                 for momentum, group in zip(base_momentums, optimizer.param_groups): | 
 |                     group['momentum'] = momentum | 
 |             self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups)) | 
 |             self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) | 
 |  | 
 |         super(CyclicLR, self).__init__(optimizer, last_epoch) | 
 |         self.base_lrs = base_lrs | 
 |  | 
 |     def _format_param(self, name, optimizer, param): | 
 |         """Return correctly formatted lr/momentum for each param group.""" | 
 |         if isinstance(param, (list, tuple)): | 
 |             if len(param) != len(optimizer.param_groups): | 
 |                 raise ValueError("expected {} values for {}, got {}".format( | 
 |                     len(optimizer.param_groups), name, len(param))) | 
 |             return param | 
 |         else: | 
 |             return [param] * len(optimizer.param_groups) | 
 |  | 
 |     def _triangular_scale_fn(self, x): | 
 |         return 1. | 
 |  | 
 |     def _triangular2_scale_fn(self, x): | 
 |         return 1 / (2. ** (x - 1)) | 
 |  | 
 |     def _exp_range_scale_fn(self, x): | 
 |         return self.gamma**(x) | 
 |  | 
 |     def get_lr(self): | 
 |         """Calculates the learning rate at batch index. This function treats | 
 |         `self.last_epoch` as the last batch index. | 
 |  | 
 |         If `self.cycle_momentum` is ``True``, this function has a side effect of | 
 |         updating the optimizer's momentum. | 
 |         """ | 
 |  | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         cycle = math.floor(1 + self.last_epoch / self.total_size) | 
 |         x = 1. + self.last_epoch / self.total_size - cycle | 
 |         if x <= self.step_ratio: | 
 |             scale_factor = x / self.step_ratio | 
 |         else: | 
 |             scale_factor = (x - 1) / (self.step_ratio - 1) | 
 |  | 
 |         lrs = [] | 
 |         for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): | 
 |             base_height = (max_lr - base_lr) * scale_factor | 
 |             if self.scale_mode == 'cycle': | 
 |                 lr = base_lr + base_height * self.scale_fn(cycle) | 
 |             else: | 
 |                 lr = base_lr + base_height * self.scale_fn(self.last_epoch) | 
 |             lrs.append(lr) | 
 |  | 
 |         if self.cycle_momentum: | 
 |             momentums = [] | 
 |             for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): | 
 |                 base_height = (max_momentum - base_momentum) * scale_factor | 
 |                 if self.scale_mode == 'cycle': | 
 |                     momentum = max_momentum - base_height * self.scale_fn(cycle) | 
 |                 else: | 
 |                     momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) | 
 |                 momentums.append(momentum) | 
 |             for param_group, momentum in zip(self.optimizer.param_groups, momentums): | 
 |                 param_group['momentum'] = momentum | 
 |  | 
 |         return lrs | 
 |  | 
 |  | 
 | class CosineAnnealingWarmRestarts(_LRScheduler): | 
 |     r"""Set the learning rate of each parameter group using a cosine annealing | 
 |     schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` | 
 |     is the number of epochs since the last restart and :math:`T_{i}` is the number | 
 |     of epochs between two warm restarts in SGDR: | 
 |  | 
 |     .. math:: | 
 |         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | 
 |         \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) | 
 |  | 
 |     When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. | 
 |     When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. | 
 |  | 
 |     It has been proposed in | 
 |     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         T_0 (int): Number of iterations for the first restart. | 
 |         T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. | 
 |         eta_min (float, optional): Minimum learning rate. Default: 0. | 
 |         last_epoch (int, optional): The index of last epoch. Default: -1. | 
 |  | 
 |     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | 
 |         https://arxiv.org/abs/1608.03983 | 
 |     """ | 
 |  | 
 |     def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): | 
 |         if T_0 <= 0 or not isinstance(T_0, int): | 
 |             raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) | 
 |         if T_mult < 1 or not isinstance(T_mult, int): | 
 |             raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) | 
 |         self.T_0 = T_0 | 
 |         self.T_i = T_0 | 
 |         self.T_mult = T_mult | 
 |         self.eta_min = eta_min | 
 |  | 
 |         super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) | 
 |  | 
 |         self.T_cur = self.last_epoch | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 | 
 |                 for base_lr in self.base_lrs] | 
 |  | 
 |     def step(self, epoch=None): | 
 |         """Step could be called after every batch update | 
 |  | 
 |         Example: | 
 |             >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) | 
 |             >>> iters = len(dataloader) | 
 |             >>> for epoch in range(20): | 
 |             >>>     for i, sample in enumerate(dataloader): | 
 |             >>>         inputs, labels = sample['inputs'], sample['labels'] | 
 |             >>>         optimizer.zero_grad() | 
 |             >>>         outputs = net(inputs) | 
 |             >>>         loss = criterion(outputs, labels) | 
 |             >>>         loss.backward() | 
 |             >>>         optimizer.step() | 
 |             >>>         scheduler.step(epoch + i / iters) | 
 |  | 
 |         This function can be called in an interleaved way. | 
 |  | 
 |         Example: | 
 |             >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) | 
 |             >>> for epoch in range(20): | 
 |             >>>     scheduler.step() | 
 |             >>> scheduler.step(26) | 
 |             >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) | 
 |         """ | 
 |  | 
 |         if epoch is None and self.last_epoch < 0: | 
 |             epoch = 0 | 
 |  | 
 |         if epoch is None: | 
 |             epoch = self.last_epoch + 1 | 
 |             self.T_cur = self.T_cur + 1 | 
 |             if self.T_cur >= self.T_i: | 
 |                 self.T_cur = self.T_cur - self.T_i | 
 |                 self.T_i = self.T_i * self.T_mult | 
 |         else: | 
 |             if epoch < 0: | 
 |                 raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) | 
 |             if epoch >= self.T_0: | 
 |                 if self.T_mult == 1: | 
 |                     self.T_cur = epoch % self.T_0 | 
 |                 else: | 
 |                     n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) | 
 |                     self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) | 
 |                     self.T_i = self.T_0 * self.T_mult ** (n) | 
 |             else: | 
 |                 self.T_i = self.T_0 | 
 |                 self.T_cur = epoch | 
 |         self.last_epoch = math.floor(epoch) | 
 |  | 
 |         class _enable_get_lr_call: | 
 |  | 
 |             def __init__(self, o): | 
 |                 self.o = o | 
 |  | 
 |             def __enter__(self): | 
 |                 self.o._get_lr_called_within_step = True | 
 |                 return self | 
 |  | 
 |             def __exit__(self, type, value, traceback): | 
 |                 self.o._get_lr_called_within_step = False | 
 |                 return self | 
 |  | 
 |         with _enable_get_lr_call(self): | 
 |             for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): | 
 |                 param_group['lr'] = lr | 
 |  | 
 |         self._last_lr = [group['lr'] for group in self.optimizer.param_groups] | 
 |  | 
 |  | 
 | class OneCycleLR(_LRScheduler): | 
 |     r"""Sets the learning rate of each parameter group according to the | 
 |     1cycle learning rate policy. The 1cycle policy anneals the learning | 
 |     rate from an initial learning rate to some maximum learning rate and then | 
 |     from that maximum learning rate to some minimum learning rate much lower | 
 |     than the initial learning rate. | 
 |     This policy was initially described in the paper `Super-Convergence: | 
 |     Very Fast Training of Neural Networks Using Large Learning Rates`_. | 
 |  | 
 |     The 1cycle learning rate policy changes the learning rate after every batch. | 
 |     `step` should be called after a batch has been used for training. | 
 |  | 
 |     This scheduler is not chainable. | 
 |  | 
 |     Note also that the total number of steps in the cycle can be determined in one | 
 |     of two ways (listed in order of precedence): | 
 |  | 
 |     #. A value for total_steps is explicitly provided. | 
 |     #. A number of epochs (epochs) and a number of steps per epoch | 
 |        (steps_per_epoch) are provided. | 
 |        In this case, the number of total steps is inferred by | 
 |        total_steps = epochs * steps_per_epoch | 
 |  | 
 |     You must either provide a value for total_steps or provide a value for both | 
 |     epochs and steps_per_epoch. | 
 |  | 
 |     Args: | 
 |         optimizer (Optimizer): Wrapped optimizer. | 
 |         max_lr (float or list): Upper learning rate boundaries in the cycle | 
 |             for each parameter group. | 
 |         total_steps (int): The total number of steps in the cycle. Note that | 
 |             if a value is provided here, then it must be inferred by providing | 
 |             a value for epochs and steps_per_epoch. | 
 |             Default: None | 
 |         epochs (int): The number of epochs to train for. This is used along | 
 |             with steps_per_epoch in order to infer the total number of steps in the cycle | 
 |             if a value for total_steps is not provided. | 
 |             Default: None | 
 |         steps_per_epoch (int): The number of steps per epoch to train for. This is | 
 |             used along with epochs in order to infer the total number of steps in the | 
 |             cycle if a value for total_steps is not provided. | 
 |             Default: None | 
 |         pct_start (float): The percentage of the cycle (in number of steps) spent | 
 |             increasing the learning rate. | 
 |             Default: 0.3 | 
 |         anneal_strategy (str): {'cos', 'linear'} | 
 |             Specifies the annealing strategy: "cos" for cosine annealing, "linear" for | 
 |             linear annealing. | 
 |             Default: 'cos' | 
 |         cycle_momentum (bool): If ``True``, momentum is cycled inversely | 
 |             to learning rate between 'base_momentum' and 'max_momentum'. | 
 |             Default: True | 
 |         base_momentum (float or list): Lower momentum boundaries in the cycle | 
 |             for each parameter group. Note that momentum is cycled inversely | 
 |             to learning rate; at the peak of a cycle, momentum is | 
 |             'base_momentum' and learning rate is 'max_lr'. | 
 |             Default: 0.85 | 
 |         max_momentum (float or list): Upper momentum boundaries in the cycle | 
 |             for each parameter group. Functionally, | 
 |             it defines the cycle amplitude (max_momentum - base_momentum). | 
 |             Note that momentum is cycled inversely | 
 |             to learning rate; at the start of a cycle, momentum is 'max_momentum' | 
 |             and learning rate is 'base_lr' | 
 |             Default: 0.95 | 
 |         div_factor (float): Determines the initial learning rate via | 
 |             initial_lr = max_lr/div_factor | 
 |             Default: 25 | 
 |         final_div_factor (float): Determines the minimum learning rate via | 
 |             min_lr = initial_lr/final_div_factor | 
 |             Default: 1e4 | 
 |         last_epoch (int): The index of the last batch. This parameter is used when | 
 |             resuming a training job. Since `step()` should be invoked after each | 
 |             batch instead of after each epoch, this number represents the total | 
 |             number of *batches* computed, not the total number of epochs computed. | 
 |             When last_epoch=-1, the schedule is started from the beginning. | 
 |             Default: -1 | 
 |  | 
 |     Example: | 
 |         >>> data_loader = torch.utils.data.DataLoader(...) | 
 |         >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) | 
 |         >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) | 
 |         >>> for epoch in range(10): | 
 |         >>>     for batch in data_loader: | 
 |         >>>         train_batch(...) | 
 |         >>>         scheduler.step() | 
 |  | 
 |  | 
 |     .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: | 
 |         https://arxiv.org/abs/1708.07120 | 
 |     """ | 
 |     def __init__(self, | 
 |                  optimizer, | 
 |                  max_lr, | 
 |                  total_steps=None, | 
 |                  epochs=None, | 
 |                  steps_per_epoch=None, | 
 |                  pct_start=0.3, | 
 |                  anneal_strategy='cos', | 
 |                  cycle_momentum=True, | 
 |                  base_momentum=0.85, | 
 |                  max_momentum=0.95, | 
 |                  div_factor=25., | 
 |                  final_div_factor=1e4, | 
 |                  last_epoch=-1): | 
 |  | 
 |         # Validate optimizer | 
 |         if not isinstance(optimizer, Optimizer): | 
 |             raise TypeError('{} is not an Optimizer'.format( | 
 |                 type(optimizer).__name__)) | 
 |         self.optimizer = optimizer | 
 |  | 
 |         # Validate total_steps | 
 |         if total_steps is None and epochs is None and steps_per_epoch is None: | 
 |             raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") | 
 |         elif total_steps is not None: | 
 |             if total_steps <= 0 or not isinstance(total_steps, int): | 
 |                 raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps)) | 
 |             self.total_steps = total_steps | 
 |         else: | 
 |             if epochs <= 0 or not isinstance(epochs, int): | 
 |                 raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs)) | 
 |             if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): | 
 |                 raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch)) | 
 |             self.total_steps = epochs * steps_per_epoch | 
 |         self.step_size_up = float(pct_start * self.total_steps) - 1 | 
 |         self.step_size_down = float(self.total_steps - self.step_size_up) - 1 | 
 |  | 
 |         # Validate pct_start | 
 |         if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): | 
 |             raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) | 
 |  | 
 |         # Validate anneal_strategy | 
 |         if anneal_strategy not in ['cos', 'linear']: | 
 |             raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) | 
 |         elif anneal_strategy == 'cos': | 
 |             self.anneal_func = self._annealing_cos | 
 |         elif anneal_strategy == 'linear': | 
 |             self.anneal_func = self._annealing_linear | 
 |  | 
 |         # Initialize learning rate variables | 
 |         max_lrs = self._format_param('max_lr', self.optimizer, max_lr) | 
 |         if last_epoch == -1: | 
 |             for idx, group in enumerate(self.optimizer.param_groups): | 
 |                 group['initial_lr'] = max_lrs[idx] / div_factor | 
 |                 group['max_lr'] = max_lrs[idx] | 
 |                 group['min_lr'] = group['initial_lr'] / final_div_factor | 
 |  | 
 |         # Initialize momentum variables | 
 |         self.cycle_momentum = cycle_momentum | 
 |         if self.cycle_momentum: | 
 |             if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: | 
 |                 raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') | 
 |             self.use_beta1 = 'betas' in self.optimizer.defaults | 
 |             max_momentums = self._format_param('max_momentum', optimizer, max_momentum) | 
 |             base_momentums = self._format_param('base_momentum', optimizer, base_momentum) | 
 |             if last_epoch == -1: | 
 |                 for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): | 
 |                     if self.use_beta1: | 
 |                         _, beta2 = group['betas'] | 
 |                         group['betas'] = (m_momentum, beta2) | 
 |                     else: | 
 |                         group['momentum'] = m_momentum | 
 |                     group['max_momentum'] = m_momentum | 
 |                     group['base_momentum'] = b_momentum | 
 |  | 
 |         super(OneCycleLR, self).__init__(optimizer, last_epoch) | 
 |  | 
 |     def _format_param(self, name, optimizer, param): | 
 |         """Return correctly formatted lr/momentum for each param group.""" | 
 |         if isinstance(param, (list, tuple)): | 
 |             if len(param) != len(optimizer.param_groups): | 
 |                 raise ValueError("expected {} values for {}, got {}".format( | 
 |                     len(optimizer.param_groups), name, len(param))) | 
 |             return param | 
 |         else: | 
 |             return [param] * len(optimizer.param_groups) | 
 |  | 
 |     def _annealing_cos(self, start, end, pct): | 
 |         "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." | 
 |         cos_out = math.cos(math.pi * pct) + 1 | 
 |         return end + (start - end) / 2.0 * cos_out | 
 |  | 
 |     def _annealing_linear(self, start, end, pct): | 
 |         "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." | 
 |         return (end - start) * pct + start | 
 |  | 
 |     def get_lr(self): | 
 |         if not self._get_lr_called_within_step: | 
 |             warnings.warn("To get the last learning rate computed by the scheduler, " | 
 |                           "please use `get_last_lr()`.", UserWarning) | 
 |  | 
 |         lrs = [] | 
 |         step_num = self.last_epoch | 
 |  | 
 |         if step_num > self.total_steps: | 
 |             raise ValueError("Tried to step {} times. The specified number of total steps is {}" | 
 |                              .format(step_num + 1, self.total_steps)) | 
 |  | 
 |         for group in self.optimizer.param_groups: | 
 |             if step_num <= self.step_size_up: | 
 |                 computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up) | 
 |                 if self.cycle_momentum: | 
 |                     computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'], | 
 |                                                          step_num / self.step_size_up) | 
 |             else: | 
 |                 down_step_num = step_num - self.step_size_up | 
 |                 computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down) | 
 |                 if self.cycle_momentum: | 
 |                     computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'], | 
 |                                                          down_step_num / self.step_size_down) | 
 |  | 
 |             lrs.append(computed_lr) | 
 |             if self.cycle_momentum: | 
 |                 if self.use_beta1: | 
 |                     _, beta2 = group['betas'] | 
 |                     group['betas'] = (computed_momentum, beta2) | 
 |                 else: | 
 |                     group['momentum'] = computed_momentum | 
 |  | 
 |         return lrs |