|  | # Owner(s): ["module: optimizer"] | 
|  |  | 
|  | import warnings | 
|  | import math | 
|  | import unittest | 
|  | import functools | 
|  | import itertools | 
|  | from copy import deepcopy | 
|  |  | 
|  | import torch | 
|  | import torch.optim as optim | 
|  | import torch.nn.functional as F | 
|  | from torch.nn import Parameter | 
|  | from torch.optim import SGD | 
|  | from torch import sparse | 
|  | from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, StepLR, \ | 
|  | MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ | 
|  | LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler, PolynomialLR, \ | 
|  | EPOCH_DEPRECATION_WARNING | 
|  | from torch.optim.swa_utils import AveragedModel, SWALR, update_bn | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ | 
|  | parametrize, instantiate_parametrized_tests, gradcheck, skipIfRocm | 
|  | # load_tests from common_utils is used to automatically filter tests for | 
|  | # sharding on sandcastle. This line silences flake warnings | 
|  | load_tests = load_tests | 
|  |  | 
|  |  | 
|  | def rosenbrock(tensor): | 
|  | x, y = tensor | 
|  | return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2 | 
|  |  | 
|  |  | 
|  | def drosenbrock(tensor): | 
|  | x, y = tensor | 
|  | return torch.tensor((-400 * x * (y - x ** 2) - 2 * (1 - x), 200 * (y - x ** 2))) | 
|  |  | 
|  |  | 
|  | class TestOptim(TestCase): | 
|  | exact_dtype = True | 
|  |  | 
|  | def _test_rosenbrock_sparse(self, constructor, scheduler_constructors=None, | 
|  | sparse_only=False, maximize=False): | 
|  | if scheduler_constructors is None: | 
|  | scheduler_constructors = [] | 
|  | params_t = torch.tensor([1.5, 1.5]) | 
|  |  | 
|  | params = Parameter(params_t) | 
|  | optimizer = constructor([params]) | 
|  | schedulers = [] | 
|  | for scheduler_constructor in scheduler_constructors: | 
|  | schedulers.append(scheduler_constructor(optimizer)) | 
|  |  | 
|  | if not sparse_only: | 
|  | params_c = Parameter(params_t.clone()) | 
|  | optimizer_c = constructor([params_c]) | 
|  |  | 
|  | solution = torch.tensor([1, 1]) | 
|  | with torch.no_grad(): | 
|  | initial_dist = params.dist(solution) | 
|  |  | 
|  | def eval(params, sparse_grad, w): | 
|  | # Depending on w, provide only the x or y gradient | 
|  | optimizer.zero_grad() | 
|  | loss = rosenbrock(params) | 
|  | loss.backward() | 
|  | grad = drosenbrock(params.data) | 
|  | # NB: We torture test the optimizer by returning an | 
|  | # uncoalesced sparse tensor | 
|  | if w: | 
|  | i = torch.LongTensor([[0, 0]]) | 
|  | x = grad[0] | 
|  | v = torch.tensor([x / 4., x - x / 4.]) | 
|  | else: | 
|  | i = torch.LongTensor([[1, 1]]) | 
|  | y = grad[1] | 
|  | v = torch.tensor([y - y / 4., y / 4.]) | 
|  | x = sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype) | 
|  | with torch.no_grad(): | 
|  | if sparse_grad: | 
|  | params.grad = x | 
|  | else: | 
|  | params.grad = x.to_dense() | 
|  | return loss | 
|  |  | 
|  | for i in range(2000): | 
|  | # Do cyclic coordinate descent | 
|  | w = i % 2 | 
|  | optimizer.step(functools.partial(eval, params, True, w)) | 
|  | for scheduler in schedulers: | 
|  | if isinstance(scheduler, ReduceLROnPlateau): | 
|  | scheduler.step(rosenbrock(params)) | 
|  | else: | 
|  | scheduler.step() | 
|  | if not sparse_only: | 
|  | optimizer_c.step(functools.partial(eval, params_c, False, w)) | 
|  | self.assertEqual(params, params_c) | 
|  |  | 
|  | if not maximize: | 
|  | self.assertLessEqual(params.data.dist(solution), initial_dist) | 
|  | else: | 
|  | self.assertGreaterEqual(rosenbrock(params), rosenbrock(params_t)) | 
|  |  | 
|  | def _test_basic_cases_template(self, weight_tensor, bias_tensor, input_tensor, constructor, | 
|  | scheduler_constructors, constructor_accepts_maximize=True, constructor_accepts_foreach=False): | 
|  | maximize_options = set([False, constructor_accepts_maximize]) | 
|  | foreach_options = set([False, constructor_accepts_foreach]) | 
|  |  | 
|  | four_arg_constructor = constructor | 
|  | if constructor_accepts_maximize and constructor_accepts_foreach: | 
|  | pass | 
|  | elif constructor_accepts_maximize: | 
|  | def four_arg_constructor(weight, bias, maximize, foreach): | 
|  | self.assertFalse(foreach) | 
|  | return constructor(weight, bias, maximize) | 
|  | elif constructor_accepts_foreach: | 
|  | def four_arg_constructor(weight, bias, maximize, foreach): | 
|  | self.assertFalse(maximize) | 
|  | return constructor(weight, bias, foreach) | 
|  | else: | 
|  | def four_arg_constructor(weight, bias, maximize, foreach): | 
|  | self.assertFalse(maximize or foreach) | 
|  | return constructor(weight, bias) | 
|  |  | 
|  | for maximize, foreach in itertools.product(maximize_options, foreach_options): | 
|  | with torch.no_grad(): | 
|  | weight = Parameter(weight_tensor.clone().detach()) | 
|  | bias = Parameter(bias_tensor.clone().detach()) | 
|  | input = input_tensor.clone().detach().requires_grad_() | 
|  | optimizer = four_arg_constructor(weight, bias, maximize, foreach) | 
|  | schedulers = [] | 
|  | for scheduler_constructor in scheduler_constructors: | 
|  | schedulers.append(scheduler_constructor(optimizer)) | 
|  |  | 
|  | # to check if the optimizer can be printed as a string | 
|  | optimizer.__repr__() | 
|  |  | 
|  | def fn(): | 
|  | optimizer.zero_grad() | 
|  | y = weight.mv(input) | 
|  | if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device(): | 
|  | y = y.cuda(bias.get_device()) | 
|  | loss = (y + bias).pow(2).sum() | 
|  | loss.backward() | 
|  | return loss | 
|  |  | 
|  | initial_value = fn().item() | 
|  | for _ in range(200): | 
|  | for scheduler in schedulers: | 
|  | if isinstance(scheduler, ReduceLROnPlateau): | 
|  | val_loss = fn() | 
|  | scheduler.step(val_loss) | 
|  | else: | 
|  | scheduler.step() | 
|  | optimizer.step(fn) | 
|  | if maximize: | 
|  | self.assertGreater(fn().item(), initial_value) | 
|  | else: | 
|  | self.assertLess(fn().item(), initial_value) | 
|  |  | 
|  | def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None): | 
|  | weight = Parameter(weight) | 
|  | bias = Parameter(bias) | 
|  | with torch.no_grad(): | 
|  | input = input.clone().detach().requires_grad_() | 
|  |  | 
|  | def fn_base(optimizer, weight, bias): | 
|  | optimizer.zero_grad() | 
|  | i = input_cuda if weight.is_cuda else input | 
|  | loss = (weight.mv(i) + bias).pow(2).sum() | 
|  | loss.backward() | 
|  | return loss | 
|  |  | 
|  | optimizer = constructor(weight, bias) | 
|  | fn = functools.partial(fn_base, optimizer, weight, bias) | 
|  |  | 
|  | # Prime the optimizer | 
|  | for _i in range(20): | 
|  | optimizer.step(fn) | 
|  | # Clone the weights and construct new optimizer for them | 
|  | with torch.no_grad(): | 
|  | weight_c = Parameter(weight.clone().detach()) | 
|  | bias_c = Parameter(bias.clone().detach()) | 
|  | optimizer_c = constructor(weight_c, bias_c) | 
|  | fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c) | 
|  | # Load state dict | 
|  | state_dict = deepcopy(optimizer.state_dict()) | 
|  | state_dict_c = deepcopy(optimizer.state_dict()) | 
|  | optimizer_c.load_state_dict(state_dict_c) | 
|  | # Run both optimizations in parallel | 
|  | for _ in range(20): | 
|  | optimizer.step(fn) | 
|  | optimizer_c.step(fn_c) | 
|  | self.assertEqual(weight, weight_c) | 
|  | self.assertEqual(bias, bias_c) | 
|  | # Make sure state dict wasn't modified | 
|  | self.assertEqual(state_dict, state_dict_c) | 
|  | # Make sure state dict is deterministic with equal but not identical parameters | 
|  | self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) | 
|  | # Make sure repeated parameters have identical representation in state dict | 
|  | optimizer_c.param_groups.extend(optimizer_c.param_groups) | 
|  | self.assertEqual(optimizer.state_dict()['param_groups'][-1], | 
|  | optimizer_c.state_dict()['param_groups'][-1]) | 
|  |  | 
|  | # Make sure that optimizers that support maximize can load older models | 
|  | state_dict = optimizer.state_dict() | 
|  | if 'maximize' in state_dict['param_groups'][0]: | 
|  | for group in state_dict['param_groups']: | 
|  | del group['maximize'] | 
|  | optimizer.load_state_dict(state_dict) | 
|  | # Make sure we can still step | 
|  | optimizer.step() | 
|  | # Make sure that optimizers that support foreach can load older models | 
|  | state_dict = optimizer.state_dict() | 
|  | if 'foreach' in state_dict['param_groups'][0]: | 
|  | for group in state_dict['param_groups']: | 
|  | del group['foreach'] | 
|  | optimizer.load_state_dict(state_dict) | 
|  | # Make sure we can still step | 
|  | optimizer.step() | 
|  |  | 
|  | # Make sure that loading optimizers with step not wrapped in tensor can work | 
|  | state_dict = optimizer.state_dict() | 
|  | if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']): | 
|  | for state in state_dict['state'].values(): | 
|  | state['step'] = state['step'].item() | 
|  | optimizer.load_state_dict(state_dict) | 
|  | optimizer.step() | 
|  |  | 
|  | # Check that state dict can be loaded even when we cast parameters | 
|  | # to a different type and move to a different device. | 
|  | if not torch.cuda.is_available(): | 
|  | return | 
|  |  | 
|  | with torch.no_grad(): | 
|  | input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda") | 
|  | weight_cuda = Parameter(weight.clone().detach().to(dtype=torch.float32, device="cuda")) | 
|  | bias_cuda = Parameter(bias.clone().detach().to(dtype=torch.float32, device="cuda")) | 
|  | optimizer_cuda = constructor(weight_cuda, bias_cuda) | 
|  | fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda) | 
|  |  | 
|  | state_dict = deepcopy(optimizer.state_dict()) | 
|  | state_dict_c = deepcopy(optimizer.state_dict()) | 
|  | optimizer_cuda.load_state_dict(state_dict_c) | 
|  |  | 
|  | # Make sure state dict wasn't modified | 
|  | self.assertEqual(state_dict, state_dict_c) | 
|  |  | 
|  | # Make sure that device of state['step'] is still CPU | 
|  | new_state_dict = optimizer_cuda.state_dict() | 
|  | if 'step' in state_dict['state'][0] and torch.is_tensor(state_dict['state'][0]['step']): | 
|  | for state in new_state_dict['state'].values(): | 
|  | self.assertEqual(state['step'].device.type, 'cpu') | 
|  |  | 
|  | for _i in range(20): | 
|  | optimizer.step(fn) | 
|  | optimizer_cuda.step(fn_cuda) | 
|  | self.assertEqual(weight, weight_cuda) | 
|  | self.assertEqual(bias, bias_cuda, atol=atol, rtol=rtol) | 
|  |  | 
|  | # validate deepcopy() copies all public attributes | 
|  | def getPublicAttr(obj): | 
|  | return set(k for k in obj.__dict__ if not k.startswith('_')) | 
|  | self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer))) | 
|  |  | 
|  | def _test_basic_cases(self, constructor, scheduler_constructors=None, | 
|  | ignore_multidevice=False, constructor_accepts_maximize=False, constructor_accepts_foreach=False, | 
|  | atol=None, rtol=None): | 
|  | if scheduler_constructors is None: | 
|  | scheduler_constructors = [] | 
|  |  | 
|  | def make_two_arg_constructor(constructor, maximize: bool = False, foreach: bool = False): | 
|  | if constructor_accepts_maximize and constructor_accepts_foreach: | 
|  | return lambda weight, bias: constructor(weight, bias, maximize, foreach) | 
|  | if constructor_accepts_maximize: | 
|  | return lambda weight, bias: constructor(weight, bias, maximize) | 
|  | if constructor_accepts_foreach: | 
|  | return lambda weight, bias: constructor(weight, bias, foreach) | 
|  | return constructor | 
|  |  | 
|  | for maximize, foreach in itertools.product( | 
|  | set([False, constructor_accepts_maximize]), | 
|  | set([False, constructor_accepts_foreach]), | 
|  | ): | 
|  | self._test_state_dict( | 
|  | torch.randn(10, 5), | 
|  | torch.randn(10), | 
|  | torch.randn(5), | 
|  | make_two_arg_constructor(constructor, maximize, foreach), | 
|  | atol=atol, rtol=rtol | 
|  | ) | 
|  | self._test_basic_cases_template( | 
|  | torch.randn(10, 5), | 
|  | torch.randn(10), | 
|  | torch.randn(5), | 
|  | constructor, | 
|  | scheduler_constructors, | 
|  | constructor_accepts_maximize, | 
|  | constructor_accepts_foreach, | 
|  | ) | 
|  | # non-contiguous parameters | 
|  | self._test_basic_cases_template( | 
|  | torch.randn(10, 5, 2)[..., 0], | 
|  | torch.randn(10, 2)[..., 0], | 
|  | torch.randn(5), | 
|  | constructor, | 
|  | scheduler_constructors, | 
|  | constructor_accepts_maximize, | 
|  | constructor_accepts_foreach, | 
|  | ) | 
|  | # CUDA | 
|  | if not torch.cuda.is_available(): | 
|  | return | 
|  | self._test_basic_cases_template( | 
|  | torch.randn(10, 5).cuda(), | 
|  | torch.randn(10).cuda(), | 
|  | torch.randn(5).cuda(), | 
|  | constructor, | 
|  | scheduler_constructors, | 
|  | constructor_accepts_maximize, | 
|  | constructor_accepts_foreach, | 
|  | ) | 
|  | # Multi-GPU | 
|  | if not torch.cuda.device_count() > 1 or ignore_multidevice: | 
|  | return | 
|  | self._test_basic_cases_template( | 
|  | torch.randn(10, 5).cuda(0), | 
|  | torch.randn(10).cuda(1), | 
|  | torch.randn(5).cuda(0), | 
|  | constructor, | 
|  | scheduler_constructors, | 
|  | constructor_accepts_maximize, | 
|  | constructor_accepts_foreach, | 
|  | ) | 
|  |  | 
|  | def _test_complex_optimizer(self, optimizer_constructor): | 
|  | complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True) | 
|  | real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_() | 
|  | complex_opt = optimizer_constructor(complex_param) | 
|  | real_opt = optimizer_constructor(real_param) | 
|  |  | 
|  | for _ in range(3): | 
|  | complex_param.grad = torch.randn_like(complex_param) | 
|  | real_param.grad = torch.view_as_real(complex_param.grad) | 
|  | complex_opt.step() | 
|  | real_opt.step() | 
|  |  | 
|  | self.assertEqual(torch.view_as_real(complex_param), real_param) | 
|  |  | 
|  | def _test_complex_2d(self, optimizer_constructor, f=None): | 
|  | if f is None: | 
|  | f = rosenbrock | 
|  | a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True) | 
|  | a1_real = a1.real.clone().detach() | 
|  | a1_imag = a1.imag.clone().detach() | 
|  | a1_real.requires_grad_() | 
|  | a1_imag.requires_grad_() | 
|  | optim1 = optimizer_constructor([a1]) | 
|  | optim2 = optimizer_constructor([a1_real, a1_imag]) | 
|  |  | 
|  | for _ in range(10): | 
|  | optim1.zero_grad() | 
|  | optim2.zero_grad() | 
|  | a2 = torch.complex(a1_real, a1_imag) | 
|  | f(a1).backward() | 
|  | f(a2).backward() | 
|  |  | 
|  | self.assertEqual(a1.grad.real, a1_real.grad) | 
|  | self.assertEqual(a1.grad.imag, a1_imag.grad) | 
|  |  | 
|  | optim1.step() | 
|  | optim2.step() | 
|  | self.assertEqual(a1.real, a1_real) | 
|  | self.assertEqual(a1.imag, a1_imag) | 
|  |  | 
|  | def _build_params_dict(self, weight, bias, **kwargs): | 
|  | return [{'params': [weight]}, dict(params=[bias], **kwargs)] | 
|  |  | 
|  | def _build_params_dict_single(self, weight, bias, **kwargs): | 
|  | return [dict(params=bias, **kwargs)] | 
|  |  | 
|  | def test_sgd(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD( | 
|  | self._build_params_dict_single(weight, bias, lr=1e-2), | 
|  | lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD( | 
|  | self._build_params_dict_single(weight, bias, lr=1e-2), maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.9, step_size=10), | 
|  | lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.9, step_size=10), | 
|  | lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.99, step_size=10), | 
|  | lambda opt: ExponentialLR(opt, gamma=0.99), | 
|  | lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: | 
|  | optim.SGD([weight, bias], lr=1e-3, momentum=0.5, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: | 
|  | optim.SGD([weight, bias], lr=1e-3, momentum=0.5, weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: | 
|  | optim.SGD([weight, bias], nesterov=True, lr=1e-3, momentum=0.5, weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.SGD([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], | 
|  | constructor_accepts_maximize=True, constructor_accepts_foreach=True, | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"): | 
|  | optim.SGD(None, lr=1e-2, momentum=-0.5) | 
|  |  | 
|  | def test_sgd_sparse(self): | 
|  | for foreach in (False, True): | 
|  | self._test_rosenbrock_sparse( | 
|  | lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach) | 
|  | ) | 
|  | self._test_rosenbrock_sparse( | 
|  | lambda params: optim.SGD(params, lr=0.0048, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)] | 
|  | ) | 
|  |  | 
|  | def test_sgd_complex(self): | 
|  | for foreach in (False, True): | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.SGD([param], lr=0.001, foreach=foreach) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.SGD([param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.SGD([param], lr=0.001, nesterov=True, momentum=1, weight_decay=1, foreach=foreach) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.SGD([param], lr=0.001, momentum=1, dampening=0.5, weight_decay=1, foreach=foreach) | 
|  | ) | 
|  |  | 
|  | def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag): | 
|  | if not torch.cuda.is_available(): | 
|  | return | 
|  | assert flag in ("foreach", "fused") | 
|  |  | 
|  | kIterations = 4 | 
|  | device = 'cuda' | 
|  | for optimizer_constructor, params in optimizer_pairs_with_flags: | 
|  | res, state = [], [] | 
|  | for foreach in (False, True): | 
|  | input = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device).reshape(3, 2) | 
|  |  | 
|  | torch.manual_seed(1) | 
|  | model = torch.nn.Sequential(torch.nn.Linear(2, 3), | 
|  | torch.nn.Sigmoid(), | 
|  | torch.nn.Linear(3, 1), | 
|  | torch.nn.Sigmoid()) | 
|  | model.to(dtype=torch.float64, device=device) | 
|  | params_with_foreach = deepcopy(params) | 
|  | params_with_foreach["foreach"] = foreach | 
|  | optimizer = optimizer_constructor(model.parameters(), **params_with_foreach) | 
|  |  | 
|  | for _ in range(kIterations): | 
|  | optimizer.zero_grad() | 
|  | output = model(input) | 
|  | loss = output.sum() | 
|  | loss.backward() | 
|  |  | 
|  | # Test that step behaves as expected (a no-op) when grads are set to None | 
|  | if iter == 0: | 
|  | optimizer.zero_grad(set_to_none=True) | 
|  |  | 
|  | optimizer.step() | 
|  |  | 
|  | state.append(optimizer.state) | 
|  | res.append(model.parameters()) | 
|  |  | 
|  | st_state = state[0] | 
|  | mt_state = state[1] | 
|  | for st_p, mt_p in zip(res[0], res[1]): | 
|  | self.assertEqual(st_p, mt_p, atol=5e-5, rtol=0) | 
|  |  | 
|  | # check that optimizer states are the same | 
|  | st_p_state = st_state[st_p] | 
|  | mt_p_state = mt_state[mt_p] | 
|  |  | 
|  | for k in st_p_state: | 
|  | actual = mt_p_state[k] | 
|  | # If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`, | 
|  | # `step` Tensor is 1D while usually it's 0D. | 
|  | if k == "step" and isinstance(actual, torch.Tensor) and actual.ndim == 1: | 
|  | actual = actual[0] | 
|  | self.assertEqual(st_p_state[k], actual, atol=5e-5, rtol=0) | 
|  |  | 
|  | def test_multi_tensor_optimizers(self): | 
|  | optimizer_pairs_with_flags = [ | 
|  | (optim.Adam, dict(weight_decay=1., amsgrad=True)), | 
|  | (optim.Adam, dict(weight_decay=1., amsgrad=False)), | 
|  | (optim.Adam, dict(weight_decay=0., amsgrad=True)), | 
|  | (optim.Adam, dict(weight_decay=0., amsgrad=False)), | 
|  | (optim.AdamW, dict(weight_decay=1., amsgrad=True)), | 
|  | (optim.AdamW, dict(weight_decay=1., amsgrad=False)), | 
|  | (optim.AdamW, dict(weight_decay=0., amsgrad=True)), | 
|  | (optim.AdamW, dict(weight_decay=0., amsgrad=False)), | 
|  | (optim.NAdam, dict(weight_decay=0., momentum_decay=6e-3)), | 
|  | (optim.NAdam, dict(weight_decay=1., momentum_decay=6e-3)), | 
|  | (optim.NAdam, dict(weight_decay=0., momentum_decay=4e-3)), | 
|  | (optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)), | 
|  | (optim.SGD, dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True)), | 
|  | (optim.SGD, dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False)), | 
|  | (optim.RAdam, dict(weight_decay=0)), | 
|  | (optim.RAdam, dict(weight_decay=1)), | 
|  | (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)), | 
|  | (optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)), | 
|  | (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)), | 
|  | (optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)), | 
|  | (optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))), | 
|  | (optim.ASGD, dict(weight_decay=0)), | 
|  | (optim.ASGD, dict(weight_decay=1)), | 
|  | (optim.Adamax, dict(weight_decay=0)), | 
|  | (optim.Adamax, dict(weight_decay=1)), | 
|  | (optim.Adadelta, dict(weight_decay=0)), | 
|  | (optim.Adadelta, dict(weight_decay=1)), | 
|  | (optim.Adagrad, dict(weight_decay=0)), | 
|  | (optim.Adagrad, dict(weight_decay=1)), | 
|  | ] | 
|  | self._test_derived_optimizers(optimizer_pairs_with_flags, "foreach") | 
|  |  | 
|  | def test_fused_optimizers(self): | 
|  | optimizer_pairs_with_flags = [ | 
|  | (optim.Adam, dict(weight_decay=1., amsgrad=False)), | 
|  | (optim.Adam, dict(weight_decay=1., amsgrad=True)), | 
|  | (optim.Adam, dict(weight_decay=0., amsgrad=False)), | 
|  | (optim.Adam, dict(weight_decay=0., amsgrad=True)), | 
|  | ] | 
|  | self._test_derived_optimizers(optimizer_pairs_with_flags, "fused") | 
|  |  | 
|  | def test_adam(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | [weight, bias], lr=1e-3, weight_decay=0.1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ExponentialLR(opt, gamma=0.9)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), | 
|  | lambda opt: ExponentialLR(opt, gamma=0.9)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | [weight, bias], lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ExponentialLR(opt, gamma=0.9), | 
|  | lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, amsgrad=True, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.9, step_size=10), | 
|  | lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  |  | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | [lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_complex_2d(optim.Adam) | 
|  | self._test_complex_2d(functools.partial(optim.Adam, foreach=True)) | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): | 
|  | optim.Adam(None, lr=1e-2, betas=(1.0, 0.0)) | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): | 
|  | optim.Adam(None, lr=1e-2, weight_decay=-1) | 
|  |  | 
|  | def test_adamw(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.AdamW([weight, bias], lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.AdamW( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.AdamW( | 
|  | [weight, bias], lr=1e-3, weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.AdamW( | 
|  | [weight, bias], lr=1e-3, weight_decay=1, amsgrad=True, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_complex_2d(optim.AdamW) | 
|  | self._test_complex_2d(functools.partial(optim.AdamW, foreach=True)) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): | 
|  | optim.AdamW(None, lr=1e-2, weight_decay=-1) | 
|  |  | 
|  | def test_sparse_adam(self): | 
|  | self._test_rosenbrock_sparse( | 
|  | lambda params: optim.SparseAdam(params, lr=4e-2), | 
|  | [], | 
|  | True | 
|  | ) | 
|  | self._test_rosenbrock_sparse( | 
|  | lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True), | 
|  | [], | 
|  | True, | 
|  | True | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): | 
|  | optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0)) | 
|  | with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"): | 
|  | optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)]) | 
|  | with self.assertRaisesRegex(ValueError, "SparseAdam requires dense parameter tensors"): | 
|  | optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}]) | 
|  |  | 
|  | # ROCm precision is too low to pass this test | 
|  | def test_adadelta(self): | 
|  | # Handles https://github.com/pytorch/pytorch/issues/69698 | 
|  | self.rel_tol = 4e-3 | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adadelta([weight, bias], maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adadelta( | 
|  | self._build_params_dict(weight, bias, rho=0.95), maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adadelta( | 
|  | self._build_params_dict(weight, bias, rho=0.95), maximize=maximize, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=0.9, step_size=10), | 
|  | lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adadelta( | 
|  | [weight, bias], weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"): | 
|  | optim.Adadelta(None, lr=1e-2, rho=1.1) | 
|  |  | 
|  | def test_adadelta_complex(self): | 
|  | # Handles https://github.com/pytorch/pytorch/issues/69698 | 
|  | self.rel_tol = 2e-2 | 
|  | for optimizer in [optim.Adadelta]: | 
|  | self._test_complex_optimizer( | 
|  | lambda weight: optimizer([weight]) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda weight: optimizer([weight], rho=0.95) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda weight: optimizer([weight], rho=0.95, weight_decay=1) | 
|  | ) | 
|  |  | 
|  | def test_nadam(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.NAdam([weight, bias], lr=1e-3, foreach=foreach), | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.NAdam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, foreach=foreach), | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.NAdam( | 
|  | [weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3, foreach=foreach), | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.NAdam( | 
|  | [weight, bias], lr=1e-3, weight_decay=0.1, momentum_decay=6e-3, foreach=foreach), | 
|  | [lambda opt: ExponentialLR(opt, gamma=0.9)], | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): | 
|  | optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0)) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): | 
|  | optim.NAdam(None, lr=1e-2, momentum_decay=-0.2) | 
|  |  | 
|  | def test_adagrad(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adagrad([weight, bias], lr=1e-1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adagrad( | 
|  | [weight, bias], lr=1e-1, initial_accumulator_value=0.1, maximize=maximize, foreach=foreach, | 
|  | ), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adagrad( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-1, | 
|  | maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adagrad( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-1, | 
|  | maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adagrad( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-1, | 
|  | maximize=maximize, foreach=foreach), | 
|  | [lambda opt: ReduceLROnPlateau(opt), | 
|  | lambda opt: ExponentialLR(opt, gamma=0.99)], | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"): | 
|  | optim.Adagrad(None, lr=1e-2, lr_decay=-0.5) | 
|  |  | 
|  | def test_adagrad_sparse(self): | 
|  | for foreach in (False, True): | 
|  | self._test_rosenbrock_sparse( | 
|  | lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach) | 
|  | ) | 
|  | self._test_rosenbrock_sparse( | 
|  | lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach), | 
|  | [lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500), | 
|  | lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)] | 
|  | ) | 
|  |  | 
|  | def test_adagrad_complex(self): | 
|  | for foreach in (False, True): | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.Adagrad([param], lr=1e-1, foreach=foreach) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.Adagrad( | 
|  | [param], lr=1e-1, initial_accumulator_value=0.1, foreach=foreach, | 
|  | ) | 
|  | ) | 
|  |  | 
|  | def test_adamax(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adamax( | 
|  | [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adamax( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Adamax( | 
|  | [weight, bias], lr=1e-1, weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_complex_2d(optim.Adamax) | 
|  | self._test_complex_2d(functools.partial(optim.Adamax, foreach=True)) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 1: 1.0"): | 
|  | optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0)) | 
|  |  | 
|  | def test_radam(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, foreach=foreach), | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.RAdam( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach), | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach), | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, foreach: optim.RAdam([weight, bias], lr=1e-3, foreach=foreach), | 
|  | [lambda opt: ExponentialLR(opt, gamma=0.9), lambda opt: ReduceLROnPlateau(opt)], | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"): | 
|  | optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0)) | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): | 
|  | optim.RAdam(None, lr=1e-2, weight_decay=-1) | 
|  |  | 
|  | def test_rmsprop(self): | 
|  | for foreach in (False, True): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.RMSprop( | 
|  | [weight, bias], lr=1e-2, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.RMSprop( | 
|  | self._build_params_dict(weight, bias, lr=1e-3), | 
|  | lr=1e-2, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.RMSprop( | 
|  | self._build_params_dict(weight, bias, lr=1e-3), | 
|  | lr=1e-2, centered=True, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.RMSprop( | 
|  | self._build_params_dict(weight, bias, lr=1e-3), | 
|  | lr=1e-2, centered=True, momentum=0.1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.RMSprop( | 
|  | self._build_params_dict(weight, bias, lr=1e-3), | 
|  | lr=1e-2, momentum=0.1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.RMSprop( | 
|  | self._build_params_dict(weight, bias, lr=1e-3), | 
|  | lr=1e-2, momentum=0.1, weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach)) | 
|  | self._test_complex_2d(lambda param: optim.RMSprop(param, centered=True, foreach=foreach)) | 
|  | self._test_complex_2d(lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach)) | 
|  | self._test_complex_2d(lambda param: optim.RMSprop(param, maximize=True, foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda param: optim.RMSprop([param], foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda param: optim.RMSprop([param], centered=True, foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda param: optim.RMSprop([param], maximize=True, foreach=foreach)) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): | 
|  | optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach) | 
|  |  | 
|  | def test_asgd(self): | 
|  | for foreach in (False, True): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.ASGD( | 
|  | [weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.ASGD( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, t0=100, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.ASGD( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), | 
|  | lr=1e-3, weight_decay=1, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | # Ref: https://github.com/pytorch/pytorch/issues/84560 | 
|  | # self._test_complex_2d(optimizer) | 
|  | self._test_complex_optimizer(lambda params: optim.ASGD([params], foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=True, foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=True, weight_decay=0.9, foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda params: optim.ASGD([params], maximize=False, weight_decay=0.9, foreach=foreach)) | 
|  | self._test_complex_optimizer(lambda params: optim.ASGD([params], weight_decay=0.9, foreach=foreach)) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): | 
|  | optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach) | 
|  |  | 
|  | @skipIfRocm | 
|  | def test_rprop(self): | 
|  | is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6) | 
|  | for foreach in (False, True): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Rprop( | 
|  | [weight, bias], lr=2e-4, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias, maximize, foreach: optim.Rprop( | 
|  | self._build_params_dict(weight, bias, lr=1e-2), lr=2e-4, maximize=maximize, foreach=foreach), | 
|  | constructor_accepts_maximize=True, | 
|  | constructor_accepts_foreach=True, | 
|  | atol=4e-5 if is_cuda_sm86 else None, rtol=3e-5 if is_cuda_sm86 else None | 
|  | ) | 
|  | self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach)) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.Rprop([param], lr=0.001, foreach=foreach) | 
|  | ) | 
|  | self._test_complex_optimizer( | 
|  | lambda param: optim.Rprop([param], lr=0.001, maximize=True, foreach=foreach) | 
|  | ) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"): | 
|  | optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach) | 
|  |  | 
|  | def test_lbfgs(self): | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias: optim.LBFGS([weight, bias]), | 
|  | ignore_multidevice=True | 
|  | ) | 
|  | self._test_basic_cases( | 
|  | lambda weight, bias: optim.LBFGS([weight, bias], line_search_fn="strong_wolfe"), | 
|  | ignore_multidevice=True | 
|  | ) | 
|  |  | 
|  | @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") | 
|  | def test_lbfgs_return_type(self): | 
|  | params = [torch.randn(10, 5), torch.randn(10)] | 
|  | opt1 = optim.LBFGS(params, 0.01, tolerance_grad=math.inf) | 
|  | opt2 = optim.LBFGS(params, 0.01, tolerance_grad=-math.inf) | 
|  |  | 
|  | def closure(): | 
|  | return torch.tensor([10]) | 
|  |  | 
|  | res1 = opt1.step(closure) | 
|  | res2 = opt2.step(closure) | 
|  | self.assertEqual(type(res1), type(res2)) | 
|  |  | 
|  | def test_invalid_param_type(self): | 
|  | with self.assertRaises(TypeError): | 
|  | optim.SGD(Parameter(torch.randn(5, 5)), lr=3) | 
|  |  | 
|  | def test_duplicate_params_in_param_group(self): | 
|  | param = Parameter(torch.randn(5, 5)) | 
|  | with warnings.catch_warnings(record=True) as w: | 
|  | warnings.simplefilter("always") | 
|  | optim.SGD([param, param], lr=0.1) | 
|  | self.assertEqual(len(w), 1) | 
|  | self.assertIn('a parameter group with duplicate parameters', str(w[0].message)) | 
|  |  | 
|  | def test_no_grad_for_all_params(self): | 
|  | params = [torch.randn(5, 5, requires_grad=False) for _ in range(2)] | 
|  |  | 
|  | optimizer_list = [ | 
|  | optim.Adadelta, | 
|  | optim.AdamW, | 
|  | optim.Adam, | 
|  | optim.Adagrad, | 
|  | optim.Adamax, | 
|  | optim.RMSprop, | 
|  | optim.SGD, | 
|  | optim.SparseAdam, | 
|  | optim.ASGD, | 
|  | ] | 
|  | for optim_ctr in optimizer_list: | 
|  | opt = optim_ctr(params, lr=0.1) | 
|  | # make sure step can still run even if | 
|  | # all params have no grad | 
|  | opt.step() | 
|  |  | 
|  | # make sure that `state_steps` is correctly either updated or not updated when `found_inf`. | 
|  | def test_functional_fused_adam_with_foundinf(self): | 
|  | if not torch.cuda.is_available(): | 
|  | self.skipTest("CUDA is required.") | 
|  |  | 
|  | from torch.optim import adam | 
|  |  | 
|  | num_tensors = 5 | 
|  | for amsgrad in (False, True): | 
|  | params, grads, exp_avgs, exp_avg_sqs = [[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4)] | 
|  | max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else [] | 
|  | state_steps = [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)] | 
|  | grad_scale = torch.cuda.amp.grad_scaler._MultiDeviceReplicator( | 
|  | torch.ones((1,), dtype=torch.float32, device="cuda")) | 
|  | found_inf = torch.cuda.amp.grad_scaler._MultiDeviceReplicator( | 
|  | torch.ones((1,), dtype=torch.float32, device="cuda")) | 
|  |  | 
|  | adam.adam( | 
|  | params, | 
|  | grads, | 
|  | exp_avgs, | 
|  | exp_avg_sqs, | 
|  | max_exp_avg_sqs, | 
|  | state_steps, | 
|  | foreach=False, | 
|  | capturable=False, | 
|  | fused=True, | 
|  | amsgrad=amsgrad, | 
|  | beta1=0.9, | 
|  | beta2=0.99, | 
|  | lr=1e-2, | 
|  | weight_decay=.0, | 
|  | eps=1e-8, | 
|  | maximize=False, | 
|  | grad_scale=grad_scale, | 
|  | found_inf=found_inf, | 
|  | ) | 
|  |  | 
|  | self.assertEqual( | 
|  | state_steps, | 
|  | [torch.ones((1,), dtype=torch.float32, device="cuda") for _ in range(num_tensors)], | 
|  | ) | 
|  |  | 
|  | def test_empty_grad(self): | 
|  | optimizers = [torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam, torch.optim.AdamW, | 
|  | torch.optim.Adamax, torch.optim.ASGD, torch.optim.NAdam, torch.optim.RAdam, | 
|  | torch.optim.RMSprop, torch.optim.Rprop, torch.optim.SGD, torch.optim.SparseAdam] | 
|  |  | 
|  | for optimizer in optimizers: | 
|  | net = torch.nn.Embedding(5, 1, padding_idx=0, sparse=optimizer is torch.optim.SparseAdam) | 
|  | original_params = (param.detach().clone() for param in net.parameters()) | 
|  | # Simulate a batch that only indexes the embedding at padding_idx | 
|  | x = torch.tensor([[0, 0]]).int() | 
|  | y = torch.tensor([[[3.0], [4.0]]]) | 
|  | opt = optimizer(net.parameters(), lr=1e-5) | 
|  | torch.nn.MSELoss()(net.forward(x), y).backward() | 
|  |  | 
|  | opt.step() | 
|  |  | 
|  | for original_param, param in zip(original_params, net.parameters()): | 
|  | # assert that the parameters have not changed | 
|  | self.assertEqual(original_param, param) | 
|  |  | 
|  |  | 
|  |  | 
|  | class SchedulerTestNet(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super(SchedulerTestNet, self).__init__() | 
|  | self.conv1 = torch.nn.Conv2d(1, 1, 1) | 
|  | self.conv2 = torch.nn.Conv2d(1, 1, 1) | 
|  |  | 
|  | def forward(self, x): | 
|  | return self.conv2(F.relu(self.conv1(x))) | 
|  |  | 
|  |  | 
|  | class LambdaLRTestObject: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  |  | 
|  | def __call__(self, epoch): | 
|  | return self.value * epoch | 
|  |  | 
|  | def __eq__(self, other): | 
|  | if isinstance(other, self.__class__): | 
|  | return self.__dict__ == other.__dict__ | 
|  | else: | 
|  | return False | 
|  |  | 
|  |  | 
|  | class TestLRScheduler(TestCase): | 
|  | exact_dtype = True | 
|  |  | 
|  | def setUp(self): | 
|  | super(TestLRScheduler, self).setUp() | 
|  | self.net = SchedulerTestNet() | 
|  | self.opt = SGD( | 
|  | [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], | 
|  | lr=0.05) | 
|  |  | 
|  | def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1): | 
|  | """This function swallows the epoch deprecation warning which is produced when we | 
|  | call `scheduler.step(epoch)` with some not `None` value of `epoch`. | 
|  | this is deprecated, and this function will need to be removed/updated when | 
|  | the schedulers no longer accept the parameter at all. | 
|  | """ | 
|  | self.assertEqual(len(w), num_warnings) | 
|  | for warning in w: | 
|  | self.assertEqual(len(warning.message.args), 1) | 
|  | self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING) | 
|  |  | 
|  | def test_error_when_getlr_has_epoch(self): | 
|  | class MultiStepLR(torch.optim.lr_scheduler.LRScheduler): | 
|  | def __init__(self, optimizer, gamma, milestones, last_epoch=-1): | 
|  | self.init_lr = [group['lr'] for group in optimizer.param_groups] | 
|  | self.gamma = gamma | 
|  | self.milestones = milestones | 
|  | super().__init__(optimizer, last_epoch) | 
|  |  | 
|  | def get_lr(self, step): | 
|  | global_step = self.last_epoch | 
|  | gamma_power = ([0] + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m])[-1] | 
|  | return [init_lr * (self.gamma ** gamma_power) for init_lr in self.init_lr] | 
|  |  | 
|  | optimizer = torch.optim.SGD([torch.rand(1)], lr=1) | 
|  |  | 
|  | with self.assertRaises(TypeError): | 
|  | scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20]) | 
|  |  | 
|  | def test_no_cyclic_references(self): | 
|  | import gc | 
|  | param = Parameter(torch.empty(10)) | 
|  | optim = SGD([param], lr=0.5) | 
|  | scheduler = LambdaLR(optim, lambda epoch: 1.0) | 
|  | del scheduler | 
|  |  | 
|  | # Prior to Python 3.7, local variables in a function will be referred by the current frame. | 
|  | import sys | 
|  | if sys.version_info < (3, 7): | 
|  | import inspect | 
|  | referrers = gc.get_referrers(optim) | 
|  | self.assertTrue( | 
|  | len(referrers) == 1 and referrers[0] is inspect.currentframe(), | 
|  | "Optimizer should contain no cyclic references (except current frame)") | 
|  | del referrers | 
|  | else: | 
|  | self.assertTrue( | 
|  | len(gc.get_referrers(optim)) == 0, | 
|  | "Optimizer should contain no cyclic references") | 
|  |  | 
|  | gc.collect() | 
|  | del optim | 
|  | self.assertEqual( | 
|  | gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__") | 
|  |  | 
|  | def test_no_cyclic_references_in_step(self): | 
|  | import gc | 
|  | import weakref | 
|  |  | 
|  | def run(): | 
|  | param = torch.empty(10, requires_grad=True) | 
|  | optim = SGD(params=[param], lr=0.5) | 
|  | scheduler = LambdaLR(optim, lambda epoch: 1.0) | 
|  | param.sum().backward() | 
|  | optim.step() | 
|  | scheduler.step() | 
|  |  | 
|  | return weakref.ref(scheduler) | 
|  |  | 
|  | # To ensure that there are no reference cycles in scheduler, | 
|  | # we need to turn off the garbage collector. Since gc will | 
|  | # automatically collect unreachable objects. | 
|  | gc.disable() | 
|  | ref = run() | 
|  | assert ref() is None | 
|  | gc.enable()  # restore | 
|  |  | 
|  | def test_old_pattern_warning(self): | 
|  | epochs = 35 | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | def old_pattern(): | 
|  | for _ in range(epochs): | 
|  | scheduler.step() | 
|  | self.opt.step() | 
|  |  | 
|  | self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern) | 
|  |  | 
|  | def test_old_pattern_warning_with_arg(self): | 
|  | epochs = 35 | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | def old_pattern2(): | 
|  | for _ in range(epochs): | 
|  | scheduler.step() | 
|  | self.opt.step() | 
|  |  | 
|  | self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2) | 
|  |  | 
|  | def test_old_pattern_warning_resuming(self): | 
|  | epochs = 35 | 
|  | for i, group in enumerate(self.opt.param_groups): | 
|  | group['initial_lr'] = 0.01 | 
|  |  | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | def old_pattern(): | 
|  | for _ in range(epochs): | 
|  | scheduler.step() | 
|  | self.opt.step() | 
|  |  | 
|  | self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern) | 
|  |  | 
|  | def test_old_pattern_warning_resuming_with_arg(self): | 
|  | epochs = 35 | 
|  | for i, group in enumerate(self.opt.param_groups): | 
|  | group['initial_lr'] = 0.01 | 
|  |  | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | def old_pattern2(): | 
|  | for _ in range(epochs): | 
|  | scheduler.step() | 
|  | self.opt.step() | 
|  |  | 
|  | self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2) | 
|  |  | 
|  | def test_old_pattern_warning_with_overridden_optim_step(self): | 
|  | epochs = 35 | 
|  | for i, group in enumerate(self.opt.param_groups): | 
|  | group['initial_lr'] = 0.01 | 
|  |  | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | # emulate use-case with optimizer.step overridden | 
|  | import types | 
|  |  | 
|  | old_step = self.opt.step | 
|  |  | 
|  | def new_step(o, *args, **kwargs): | 
|  | retval = old_step(*args, **kwargs) | 
|  | return retval | 
|  |  | 
|  | self.opt.step = types.MethodType(new_step, self.opt) | 
|  |  | 
|  | def old_pattern2(): | 
|  | for _ in range(epochs): | 
|  | scheduler.step() | 
|  | self.opt.step() | 
|  |  | 
|  | self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2) | 
|  |  | 
|  | def test_new_pattern_no_warning(self): | 
|  | epochs = 35 | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | for _ in range(epochs): | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | def test_new_pattern_no_warning_with_arg(self): | 
|  | epochs = 35 | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | for _ in range(epochs): | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | def test_new_pattern_no_warning_with_overridden_optim_step(self): | 
|  | epochs = 35 | 
|  | with warnings.catch_warnings(record=True) as ws: | 
|  | warnings.simplefilter("always")  # allow any warning to be raised | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self.assertTrue(len(ws) == 0, "No warning should be raised") | 
|  |  | 
|  | # emulate use-case with optimizer.step overridden | 
|  | import types | 
|  |  | 
|  | old_step = self.opt.step | 
|  |  | 
|  | def new_step(o, *args, **kwargs): | 
|  | retval = old_step(*args, **kwargs) | 
|  | return retval | 
|  |  | 
|  | self.opt.step = types.MethodType(new_step, self.opt) | 
|  |  | 
|  | def new_pattern(): | 
|  | for e in range(epochs): | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  |  | 
|  | self.assertWarnsRegex(UserWarning, r'`optimizer.step\(\)` has been overridden', new_pattern) | 
|  |  | 
|  | def _test_lr_is_constant_for_constant_epoch(self, scheduler): | 
|  | l = [] | 
|  |  | 
|  | for _ in range(10): | 
|  | scheduler.optimizer.step() | 
|  | with warnings.catch_warnings(record=True) as w: | 
|  | scheduler.step(2) | 
|  | self._check_warning_is_epoch_deprecation_warning(w) | 
|  |  | 
|  | l.append(self.opt.param_groups[0]['lr']) | 
|  | self.assertEqual(min(l), max(l)) | 
|  |  | 
|  | def test_step_lr_is_constant_for_constant_epoch(self): | 
|  | scheduler = StepLR(self.opt, 2) | 
|  | self._test_lr_is_constant_for_constant_epoch(scheduler) | 
|  |  | 
|  | def test_exponential_lr_is_constant_for_constant_epoch(self): | 
|  | scheduler = ExponentialLR(self.opt, gamma=0.9) | 
|  | self._test_lr_is_constant_for_constant_epoch(scheduler) | 
|  |  | 
|  | def test_constantlr_is_constant_for_constant_epoch(self): | 
|  | scheduler = ConstantLR(self.opt) | 
|  | self._test_lr_is_constant_for_constant_epoch(scheduler) | 
|  |  | 
|  | def test_linear_linearlr_is_constant_for_constant_epoch(self): | 
|  | scheduler = LinearLR(self.opt) | 
|  | self._test_lr_is_constant_for_constant_epoch(scheduler) | 
|  |  | 
|  | def test_polynomial_lr_is_constant_for_constant_epoch(self): | 
|  | scheduler = PolynomialLR(self.opt, power=0.9) | 
|  | self._test_lr_is_constant_for_constant_epoch(scheduler) | 
|  |  | 
|  | def test_step_lr(self): | 
|  | # lr = 0.05     if epoch < 3 | 
|  | # lr = 0.005    if 30 <= epoch < 6 | 
|  | # lr = 0.0005   if epoch >= 9 | 
|  | epochs = 10 | 
|  | single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_get_last_lr_step_lr(self): | 
|  | from torch.nn import Parameter | 
|  | epochs = 10 | 
|  | optimizer = torch.optim.SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1) | 
|  | targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]] | 
|  | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) | 
|  | self._test_get_last_lr(scheduler, targets, epochs) | 
|  |  | 
|  | def test_get_last_lr_multi_step_lr(self): | 
|  | # lr = 0.05     if epoch < 2 | 
|  | # lr = 0.005    if 2 <= epoch < 5 | 
|  | # lr = 0.0005   if 5 <= epoch < 9 | 
|  | # lr = 0.00005   if 9 <= epoch | 
|  | epochs = 10 | 
|  | single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | self._test_get_last_lr(scheduler, targets, epochs) | 
|  |  | 
|  | def test_multi_step_lr(self): | 
|  | # lr = 0.05     if epoch < 2 | 
|  | # lr = 0.005    if 2 <= epoch < 5 | 
|  | # lr = 0.0005   if epoch < 9 | 
|  | # lr = 0.00005   if epoch >= 9 | 
|  | epochs = 10 | 
|  | single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_multi_step_lr_with_epoch(self): | 
|  | # lr = 0.05     if epoch < 2 | 
|  | # lr = 0.005    if 2 <= epoch < 5 | 
|  | # lr = 0.0005   if epoch < 9 | 
|  | # lr = 0.00005   if epoch >= 9 | 
|  | epochs = 10 | 
|  | single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | self._test_with_epoch(scheduler, targets, epochs) | 
|  |  | 
|  | def test_get_last_lr_constantlr(self): | 
|  | # lr = 0.025     if epoch < 5 | 
|  | # lr = 0.005    if 5 <= epoch | 
|  | epochs = 10 | 
|  | single_targets = [0.025] * 5 + [0.05] * 5 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) | 
|  | self._test_get_last_lr(scheduler, targets, epochs) | 
|  |  | 
|  | def test_get_last_lr_linearlr(self): | 
|  | # lr = 0.025     if epoch == 0 | 
|  | # lr = 0.03125   if epoch == 1 | 
|  | # lr = 0.0375    if epoch == 2 | 
|  | # lr = 0.04375   if epoch == 3 | 
|  | # lr = 0.005     if 4 <= epoch | 
|  | epochs = 10 | 
|  | start_factor = 1.0 / 4 | 
|  | end_factor = 3. / 5 | 
|  | iters = 4 | 
|  | interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)] | 
|  | single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (epochs - iters) | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters) | 
|  | self._test_get_last_lr(scheduler, targets, epochs) | 
|  |  | 
|  | def test_constantlr(self): | 
|  | # lr = 0.025     if epoch < 5 | 
|  | # lr = 0.005    if 5 <= epoch | 
|  | epochs = 10 | 
|  | single_targets = [0.025] * 5 + [0.05] * 5 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_linearlr(self): | 
|  | # lr = 0.025     if epoch == 0 | 
|  | # lr = 0.03125   if epoch == 1 | 
|  | # lr = 0.0375    if epoch == 2 | 
|  | # lr = 0.04375   if epoch == 3 | 
|  | # lr = 0.005     if 4 <= epoch | 
|  | epochs = 10 | 
|  | start_factor = 1.0 / 2 | 
|  | iters = 4 | 
|  | interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)] | 
|  | single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_linearlr_start_factor_limits1(self): | 
|  | start_factor = 0. | 
|  | iters = 4 | 
|  | with self.assertRaises(ValueError): | 
|  | LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  |  | 
|  | def test_linearlr_start_factor_limits2(self): | 
|  | start_factor = 1.1 | 
|  | iters = 4 | 
|  | with self.assertRaises(ValueError): | 
|  | LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  |  | 
|  | def test_constantlr_with_epoch(self): | 
|  | # lr = 0.025     if epoch < 5 | 
|  | # lr = 0.005    if 5 <= epoch | 
|  | epochs = 10 | 
|  | single_targets = [0.025] * 5 + [0.05] * 5 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) | 
|  | self._test_with_epoch(scheduler, targets, epochs) | 
|  |  | 
|  | def test_linearlr_with_epoch(self): | 
|  | # lr = 0.025     if epoch == 0 | 
|  | # lr = 0.03125   if epoch == 1 | 
|  | # lr = 0.0375    if epoch == 2 | 
|  | # lr = 0.04375   if epoch == 3 | 
|  | # lr = 0.005     if 4 <= epoch | 
|  | epochs = 10 | 
|  | start_factor = 1.0 / 2 | 
|  | end_factor = 1. | 
|  | iters = 4 | 
|  | interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)] | 
|  | single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  | self._test_with_epoch(scheduler, targets, epochs) | 
|  |  | 
|  | def test_exp_lr(self): | 
|  | epochs = 10 | 
|  | single_targets = [0.05 * (0.9 ** x) for x in range(epochs)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = ExponentialLR(self.opt, gamma=0.9) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_poly_lr(self): | 
|  | epochs = 10 | 
|  | power = 0.9 | 
|  | total_iters = 5 | 
|  | single_targets = [(1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)] + [0.0] * (epochs - total_iters) | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_cos_anneal_lr(self): | 
|  | epochs = 10 | 
|  | eta_min = 1e-10 | 
|  | single_targets = [eta_min + (0.05 - eta_min) * | 
|  | (1 + math.cos(math.pi * x / epochs)) / 2 | 
|  | for x in range(epochs)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_closed_form_step_lr(self): | 
|  | scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, 20) | 
|  |  | 
|  | def test_closed_form_linearlr(self): | 
|  | scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4) | 
|  | closed_form_scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, 20) | 
|  |  | 
|  | def test_closed_form_constantlr(self): | 
|  | scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) | 
|  | closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, 20) | 
|  |  | 
|  | def test_closed_form_multi_step_lr(self): | 
|  | scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, 20) | 
|  |  | 
|  | def test_closed_form_exp_lr(self): | 
|  | scheduler = ExponentialLR(self.opt, gamma=0.9) | 
|  | closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, 20) | 
|  |  | 
|  | def test_closed_form_poly_lr(self): | 
|  | scheduler = PolynomialLR(self.opt, power=0.9) | 
|  | closed_form_scheduler = PolynomialLR(self.opt, power=0.9) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, 20) | 
|  |  | 
|  | def test_closed_form_cos_anneal_lr(self): | 
|  | eta_min = 1e-10 | 
|  | epochs = 20 | 
|  | T_max = 5 | 
|  | scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) | 
|  | closed_form_scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) | 
|  | self._test_against_closed_form(scheduler, closed_form_scheduler, epochs) | 
|  |  | 
|  | def test_cos_anneal_lr_continue(self): | 
|  | eta_min = 0.1 | 
|  | T_max = 5 | 
|  | scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  | original_lrs = scheduler._last_lr | 
|  | new_scheduler = CosineAnnealingLR( | 
|  | self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0) | 
|  | new_lrs = new_scheduler._last_lr | 
|  | torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5) | 
|  |  | 
|  | def test_reduce_lr_on_plateau1(self): | 
|  | epochs = 10 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 20] | 
|  | metrics = [10 - i * 0.0167 for i in range(20)] | 
|  | scheduler = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', | 
|  | threshold=0.01, patience=5, cooldown=5) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau2(self): | 
|  | epochs = 22 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] | 
|  | metrics = [10 - i * 0.0165 for i in range(22)] | 
|  | scheduler = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', | 
|  | mode='min', threshold=0.1) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau3(self): | 
|  | epochs = 22 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] | 
|  | metrics = [-0.8] * 2 + [-0.234] * 20 | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, | 
|  | threshold_mode='abs') | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau4(self): | 
|  | epochs = 20 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 20] | 
|  | metrics = [1.5 * (1.025 ** i) for i in range(20)]  # 1.025 > 1.1**0.25 | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='max', patience=3, | 
|  | threshold_mode='rel', threshold=0.1) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau5(self): | 
|  | epochs = 20 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] | 
|  | metrics = [1.5 * (1.005 ** i) for i in range(20)] | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', | 
|  | threshold=0.1, patience=5, cooldown=5) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau6(self): | 
|  | epochs = 20 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 20] | 
|  | metrics = [1.5 * (0.85 ** i) for i in range(20)] | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', | 
|  | threshold=0.1) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau7(self): | 
|  | epochs = 20 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] | 
|  | metrics = [1] * 7 + [0.6] + [0.5] * 12 | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='min', threshold_mode='rel', | 
|  | threshold=0.1, patience=5, cooldown=5) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau8(self): | 
|  | epochs = 20 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] | 
|  | metrics = [1.5 * (1.005 ** i) for i in range(20)] | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='max', threshold_mode='rel', min_lr=[0.4, 0.3], | 
|  | threshold=0.1, patience=5, cooldown=5) | 
|  | self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) | 
|  |  | 
|  | def test_sequentiallr1(self): | 
|  | epochs = 19 | 
|  | schedulers = [None] * 2 | 
|  | targets = [[0.05, 0.04, 0.032] + [0.05 for x in range(4)] | 
|  | + [0.05 * 0.1 for x in range(4)] | 
|  | + [0.05 * 0.01 for x in range(4)] | 
|  | + [0.05 * 0.001 for x in range(4)]] | 
|  | milestones = [3] | 
|  | schedulers[0] = ExponentialLR(self.opt, gamma=0.8) | 
|  | schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4) | 
|  | scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_sequentiallr2(self): | 
|  | epochs = 13 | 
|  | schedulers = [None] * 2 | 
|  | targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9 ** x for x in range(10)]] | 
|  | milestones = [3] | 
|  | schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.9) | 
|  | scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_sequentiallr3(self): | 
|  | epochs = 12 | 
|  | schedulers = [None] * 3 | 
|  | targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032] | 
|  | + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]] | 
|  | milestones = [3, 6] | 
|  | schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.8) | 
|  | schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) | 
|  | scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_sequentiallr4(self): | 
|  | optimizer = torch.optim.SGD([torch.tensor(0.5)], lr=0.1) | 
|  | prev_lr = optimizer.param_groups[0]["lr"] | 
|  |  | 
|  | schedulers = [ | 
|  | torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1), | 
|  | torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1) | 
|  | ] | 
|  | scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones=[10]) | 
|  |  | 
|  | new_lr = optimizer.param_groups[0]["lr"] | 
|  |  | 
|  | # Ensure that multiple schedulers does not affect the initial learning rate | 
|  | self.assertEqual(prev_lr, new_lr) | 
|  |  | 
|  | def test_get_last_lr_sequentiallr(self): | 
|  | epochs = 12 | 
|  | milestones = [3, 6] | 
|  | schedulers = [None] * 3 | 
|  | schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.8) | 
|  | schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) | 
|  | scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) | 
|  | constant_lr_target = [0.005] * 3 | 
|  | exponential_lr_target = [0.05, 0.04, 0.032] | 
|  | step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] | 
|  | single_targets = constant_lr_target + exponential_lr_target + step_lr_target | 
|  | targets = [single_targets, [x * 10 for x in single_targets]] | 
|  | self._test_get_last_lr(scheduler, targets, epochs) | 
|  |  | 
|  | def test_chained_lr2_get_last_lr_before_step(self): | 
|  | schedulers = [ | 
|  | LinearLR(self.opt, start_factor=0.4, total_iters=3), | 
|  | MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) | 
|  | ] | 
|  | scheduler = ChainedScheduler(schedulers) | 
|  | self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) | 
|  |  | 
|  | def test_chained_lr1(self): | 
|  | epochs = 10 | 
|  | schedulers = [None] * 1 | 
|  | targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3] | 
|  | schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | scheduler = ChainedScheduler(schedulers) | 
|  | self._test([scheduler], targets, epochs) | 
|  | self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) | 
|  |  | 
|  | def test_chained_lr2(self): | 
|  | epochs = 10 | 
|  | schedulers = [None] * 1 | 
|  | targets = [[0.02, 0.03, 0.04] + [0.05] * 9] | 
|  | schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) | 
|  | scheduler = ChainedScheduler(schedulers) | 
|  | self._test([scheduler], targets, epochs) | 
|  | self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) | 
|  |  | 
|  | def test_chained_lr3(self): | 
|  | epochs = 10 | 
|  | schedulers = [None] * 2 | 
|  | targets = [[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3] | 
|  | schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) | 
|  | schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) | 
|  | scheduler = ChainedScheduler(schedulers) | 
|  | self._test([scheduler], targets, epochs) | 
|  | self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) | 
|  |  | 
|  | def test_chained_lr4(self): | 
|  | epochs = 9 | 
|  | schedulers = [None] * 3 | 
|  | targets = [[0.05 * 0.2 * 0.9 ** x for x in range(3)] | 
|  | + [0.05 * 0.2 * 0.9 ** 3 * 0.1] | 
|  | + [0.05 * 0.9 ** x * 0.1 for x in range(4, 6)] | 
|  | + [0.05 * 0.9 ** x * 0.01 for x in range(6, 9)]] | 
|  | schedulers[0] = ExponentialLR(self.opt, gamma=0.9) | 
|  | schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4) | 
|  | schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | scheduler = ChainedScheduler(schedulers) | 
|  | self._test([scheduler], targets, epochs) | 
|  | self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) | 
|  |  | 
|  | def test_chained_lr5(self): | 
|  | def poly_lr(lr: float): | 
|  | return [ | 
|  | (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters) | 
|  | ] + [0.0] * (epochs - total_iters) | 
|  |  | 
|  | schedulers = [None] * 2 | 
|  | epochs = 10 | 
|  | power = 0.9 | 
|  | total_iters = 5 | 
|  | const_factor = 0.1 | 
|  | single_targets = [x * const_factor for x in poly_lr(lr=0.05)] | 
|  | targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]] | 
|  | schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters) | 
|  | schedulers[1] = ConstantLR(self.opt, factor=const_factor) | 
|  | scheduler = ChainedScheduler(schedulers) | 
|  | self._test(scheduler, targets, epochs) | 
|  | self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) | 
|  |  | 
|  | def test_compound_step_and_multistep_lr(self): | 
|  | epochs = 10 | 
|  | schedulers = [None] * 2 | 
|  | schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]] | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_step_and_exp_lr(self): | 
|  | epochs = 10 | 
|  | schedulers = [None] * 2 | 
|  | single_targets = [0.05 * (0.9 ** x) for x in range(3)] | 
|  | single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)] | 
|  | single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)] | 
|  | single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.9) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_exp_and_multistep_lr(self): | 
|  | epochs = 10 | 
|  | schedulers = [None] * 2 | 
|  | single_targets = [0.05 * (0.9 ** x) for x in range(2)] | 
|  | single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)] | 
|  | single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)] | 
|  | single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.9) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_exp_and_linearlr(self): | 
|  | epochs = 10 | 
|  | iters = 4 | 
|  | start_factor = 0.4 | 
|  | end_factor = 0.9 | 
|  | schedulers = [None] * 2 | 
|  | single_targets = [0.05 * (0.9 ** x) for x in range(11)] | 
|  | for i in range(iters): | 
|  | single_targets[i] *= start_factor + i / iters * (end_factor - start_factor) | 
|  | for i in range(iters, 11): | 
|  | single_targets[i] *= end_factor | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers[0] = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.9) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_step_and_constantlr(self): | 
|  | epochs = 10 | 
|  | iters = 4 | 
|  | factor = 0.4 | 
|  | schedulers = [None] * 2 | 
|  | single_targets = [0.05 * 0.4] * 3 + [0.005 * 0.4] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 3 | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_linearlr_and_multistep_lr(self): | 
|  | epochs = 10 | 
|  | iters = 4 | 
|  | start_factor = 0.4 | 
|  | schedulers = [None] * 2 | 
|  | single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2 | 
|  | for i in range(iters): | 
|  | single_targets[i] *= start_factor + i / iters * (1 - start_factor) | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_cosanneal_and_step_lr(self): | 
|  | epochs = 10 | 
|  | eta_min = 1e-10 | 
|  | single_targets = [eta_min + (0.05 - eta_min) * | 
|  | (1 + math.cos(math.pi * x / epochs)) / 2 | 
|  | for x in range(epochs)] | 
|  | single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers = [None] * 2 | 
|  | schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) | 
|  | schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_cosanneal_and_multistep_lr(self): | 
|  | epochs = 10 | 
|  | eta_min = 1e-10 | 
|  | single_targets = [eta_min + (0.05 - eta_min) * | 
|  | (1 + math.cos(math.pi * x / epochs)) / 2 | 
|  | for x in range(epochs)] | 
|  | multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] | 
|  | single_targets = [x * y for x, y in zip(single_targets, multipliers)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers = [None] * 2 | 
|  | schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) | 
|  | schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_cosanneal_and_linearlr(self): | 
|  | epochs = 10 | 
|  | iters = 4 | 
|  | start_factor = 0.4 | 
|  | eta_min = 1e-10 | 
|  | schedulers = [None] * 2 | 
|  | single_targets = [eta_min + (0.05 - eta_min) * | 
|  | (1 + math.cos(math.pi * x / epochs)) / 2 | 
|  | for x in range(epochs)] | 
|  | for i in range(iters): | 
|  | single_targets[i] *= start_factor + i / iters * (1 - start_factor) | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  | schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_cosanneal_and_exp_lr(self): | 
|  | epochs = 10 | 
|  | eta_min = 1e-10 | 
|  | single_targets = [eta_min + (0.05 - eta_min) * | 
|  | (1 + math.cos(math.pi * x / epochs)) / 2 | 
|  | for x in range(epochs)] | 
|  | multipliers = [0.1 ** i for i in range(epochs)] | 
|  | single_targets = [x * y for x, y in zip(single_targets, multipliers)] | 
|  | targets = [single_targets, [x * epochs for x in single_targets]] | 
|  | schedulers = [None] * 2 | 
|  | schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.1) | 
|  | self._test(schedulers, targets, epochs) | 
|  |  | 
|  | def test_compound_reduce_lr_on_plateau1(self): | 
|  | epochs = 10 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | single_targets = [0.5] * 20 | 
|  | multipliers = [0.1 ** (i // 3) for i in range(20)] | 
|  | single_targets = [x * y for x, y in zip(multipliers, single_targets)] | 
|  | targets = [single_targets] | 
|  | targets = targets[1:]  # test runs step before checking lr | 
|  | metrics = [10 - i * 0.0167 for i in range(20)] | 
|  | schedulers = [None, None] | 
|  | schedulers[0] = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min', | 
|  | threshold=0.01, patience=5, cooldown=5) | 
|  | schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) | 
|  | self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) | 
|  |  | 
|  | def test_compound_reduce_lr_on_plateau2(self): | 
|  | epochs = 22 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 | 
|  | multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10 | 
|  | single_targets = [x * y for x, y in zip(single_targets, multipliers)] | 
|  | targets = [single_targets] | 
|  | targets = targets[1:]  # test runs step before checking lr | 
|  | metrics = [10 - i * 0.0165 for i in range(22)] | 
|  | schedulers = [None] * 2 | 
|  | schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', | 
|  | mode='min', threshold=0.1) | 
|  | schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12]) | 
|  | self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) | 
|  |  | 
|  | def test_compound_reduce_lr_on_plateau3(self): | 
|  | epochs = 22 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4 | 
|  | multipliers = [0.1 ** i for i in range(epochs)] | 
|  | single_targets = [x * y for x, y in zip(multipliers, single_targets)] | 
|  | targets = [single_targets] | 
|  | targets = targets[1:]  # test runs step before checking lr | 
|  | metrics = [-0.8] * 2 + [-0.234] * 20 | 
|  | schedulers = [None, None] | 
|  | schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5, | 
|  | threshold_mode='abs') | 
|  | schedulers[1] = ExponentialLR(self.opt, gamma=0.1) | 
|  | self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) | 
|  |  | 
|  | def test_compound_reduce_lr_on_plateau4(self): | 
|  | epochs = 20 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.05 | 
|  | epochs = 10 | 
|  | eta_min = 1e-10 | 
|  | single_targets = [eta_min + (0.05 - eta_min) * | 
|  | (1 + math.cos(math.pi * x / epochs)) / 2 | 
|  | for x in range(epochs)] | 
|  | targets = [single_targets] | 
|  | targets = targets[1:]  # test runs step before checking lr | 
|  | metrics = [1.5 * (1.025 ** i) for i in range(20)]  # 1.025 > 1.1**0.25 | 
|  | schedulers = [None, None] | 
|  | schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=3, | 
|  | threshold_mode='rel', threshold=0.1) | 
|  | schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) | 
|  | self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) | 
|  |  | 
|  | def test_compound_reduce_lr_on_plateau5(self): | 
|  | iters = 4 | 
|  | start_factor = 0.4 | 
|  | epochs = 22 | 
|  | for param_group in self.opt.param_groups: | 
|  | param_group['lr'] = 0.5 | 
|  | single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 | 
|  | multipliers = [1] * 22 | 
|  | for i in range(iters): | 
|  | multipliers[i] *= start_factor + i / iters * (1 - start_factor) | 
|  | single_targets = [x * y for x, y in zip(single_targets, multipliers)] | 
|  | targets = [single_targets] | 
|  | targets = targets[1:]  # test runs step before checking lr | 
|  | metrics = [10 - i * 0.0165 for i in range(22)] | 
|  | schedulers = [None] * 2 | 
|  | schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', | 
|  | mode='min', threshold=0.1) | 
|  | schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) | 
|  | self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) | 
|  |  | 
|  | def test_cycle_lr_invalid_mode(self): | 
|  | with self.assertRaises(ValueError): | 
|  | scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") | 
|  |  | 
|  | def test_cycle_lr_triangular_mode_one_lr(self): | 
|  | lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] | 
|  | momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, | 
|  | cycle_momentum=True, base_momentum=1, max_momentum=5, | 
|  | mode='triangular') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | def test_cycle_lr_triangular_mode_one_lr_no_momentum(self): | 
|  | lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_target = [self.opt.defaults['momentum']] * len(lr_target) | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, | 
|  | cycle_momentum=False, mode='triangular') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | def test_cycle_lr_triangular2_mode_one_lr(self): | 
|  | lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, | 
|  | 1, 1.25, 1.50, 1.75, 2.00, 1.75] | 
|  | momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, | 
|  | 3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, | 
|  | cycle_momentum=True, base_momentum=1, max_momentum=5, | 
|  | mode='triangular2') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | def test_cycle_lr_exp_range_mode_one_lr(self): | 
|  | base_lr, max_lr = 1, 5 | 
|  | diff_lr = max_lr - base_lr | 
|  | gamma = 0.9 | 
|  | xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] | 
|  | lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] | 
|  | momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = CyclicLR(self.opt, base_lr=base_lr, | 
|  | max_lr=max_lr, step_size_up=4, | 
|  | cycle_momentum=True, base_momentum=base_lr, max_momentum=max_lr, | 
|  | mode='exp_range', gamma=gamma) | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | def test_cycle_lr_triangular_mode(self): | 
|  | lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] | 
|  | lr_target_2 = [x + 1 for x in lr_target_1] | 
|  | lr_targets = [lr_target_1, lr_target_2] | 
|  | momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] | 
|  | momentum_target_2 = [x + 1 for x in momentum_target_1] | 
|  | momentum_targets = [momentum_target_1, momentum_target_2] | 
|  | scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4, | 
|  | cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6], | 
|  | mode='triangular') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) | 
|  |  | 
|  | def test_cycle_lr_triangular2_mode(self): | 
|  | lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1, | 
|  | 1.25, 1.50, 1.75, 2.00, 1.75] | 
|  | lr_target_2 = [x + 2 for x in lr_target_1] | 
|  | lr_targets = [lr_target_1, lr_target_2] | 
|  | momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5, | 
|  | 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25] | 
|  | momentum_target_2 = [x + 2 for x in momentum_target_1] | 
|  | momentum_targets = [momentum_target_1, momentum_target_2] | 
|  | scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4, | 
|  | cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7], | 
|  | mode='triangular2') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) | 
|  |  | 
|  | def test_cycle_lr_exp_range_mode(self): | 
|  | base_lr_1, max_lr_1 = 1, 5 | 
|  | base_lr_2, max_lr_2 = 5, 12 | 
|  |  | 
|  | diff_lr_1 = max_lr_1 - base_lr_1 | 
|  | diff_lr_2 = max_lr_2 - base_lr_2 | 
|  |  | 
|  | gamma = 0.9 | 
|  | xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] | 
|  | lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] | 
|  | lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] | 
|  | lr_targets = [lr_target_1, lr_target_2] | 
|  | momentum_target_1 = [max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] | 
|  | momentum_target_2 = [max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] | 
|  | momentum_targets = [momentum_target_1, momentum_target_2] | 
|  | scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2], | 
|  | max_lr=[max_lr_1, max_lr_2], step_size_up=4, | 
|  | cycle_momentum=True, base_momentum=[base_lr_1, base_lr_2], | 
|  | max_momentum=[max_lr_1, max_lr_2], | 
|  | mode='exp_range', gamma=gamma) | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) | 
|  |  | 
|  | def test_cycle_lr_triangular_mode_step_size_up_down(self): | 
|  | lr_target = [1.0, 2.0, 3.0, 4.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  |  | 
|  | scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, | 
|  | step_size_up=4, | 
|  | step_size_down=6, | 
|  | cycle_momentum=True, | 
|  | base_momentum=1, max_momentum=5, | 
|  | mode='triangular') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | def test_cycle_lr_triangular2_mode_step_size_up_down(self): | 
|  | lr_base_target = ([ | 
|  | 1.0, 3.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0, 2.0, 3.0, 8.0 / 3, | 
|  | 7.0 / 3, 6.0 / 3, 5.0 / 3, 4.0 / 3, 1.0, 3.0 / 2, 2.0, 11.0 / 6, 10.0 / 6, 9.0 / 6, | 
|  | 8.0 / 6, 7.0 / 6 | 
|  | ]) | 
|  | momentum_base_target = ([ | 
|  | 5.0, 3.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0, 4.0, 3.0, 10.0 / 3, | 
|  | 11.0 / 3, 4.0, 13.0 / 3, 14.0 / 3, 5.0, 4.5, 4.0, 25.0 / 6, 13.0 / 3, 4.5, 14.0 / 3, | 
|  | 29.0 / 6 | 
|  | ]) | 
|  | deltas = [2 * i for i in range(0, 2)] | 
|  | base_lrs = [1 + delta for delta in deltas] | 
|  | max_lrs = [5 + delta for delta in deltas] | 
|  | lr_targets = [[x + delta for x in lr_base_target] for delta in deltas] | 
|  | momentum_targets = [[x + delta for x in momentum_base_target] for delta in deltas] | 
|  | scheduler = CyclicLR( | 
|  | self.opt, | 
|  | base_lr=base_lrs, | 
|  | max_lr=max_lrs, | 
|  | step_size_up=2, | 
|  | step_size_down=6, | 
|  | cycle_momentum=True, | 
|  | base_momentum=base_lrs, | 
|  | max_momentum=max_lrs, | 
|  | mode='triangular2') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_base_target)) | 
|  |  | 
|  | def test_cycle_lr_exp_range_mode_step_size_up_down(self): | 
|  | base_lr, max_lr = 1, 5 | 
|  | diff_lr = max_lr - base_lr | 
|  | gamma = 0.9 | 
|  | xs = ([ | 
|  | 0.0, 0.5, 1.0, 5.0 / 6, 4.0 / 6, 3.0 / 6, 2.0 / 6, 1.0 / 6, 0.0, 0.5, 1.0, 5.0 / 6, | 
|  | 4.0 / 6 | 
|  | ]) | 
|  | lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = CyclicLR(self.opt, base_lr=base_lr, max_lr=max_lr, | 
|  | step_size_up=2, step_size_down=6, | 
|  | cycle_momentum=True, base_momentum=base_lr, | 
|  | max_momentum=max_lr, | 
|  | mode='exp_range', gamma=gamma) | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | def test_cycle_lr_with_momentumless_optimizer(self): | 
|  | # Note [Temporarily set optimizer to Adam] | 
|  | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | 
|  | # The TestLRScheduler object carries around an SGD optimizer to avoid having to | 
|  | # instantiate one for every test. This gets in the way for our very specific case | 
|  | # in which we need to use Adam (or really any optimizer that doesn't use momentum) | 
|  | # in order to test that the momentum bug in CyclicLR is fixed (the bug is described | 
|  | # in more detail in https://github.com/pytorch/pytorch/issues/19003 ). | 
|  | old_opt = self.opt | 
|  | self.opt = optim.Adam( | 
|  | [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], | 
|  | lr=0.05) | 
|  |  | 
|  | lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_target = [None] * len(lr_target) | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4, | 
|  | cycle_momentum=False, mode='triangular') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) | 
|  |  | 
|  | self.opt = old_opt  # set optimizer back to SGD | 
|  |  | 
|  | def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): | 
|  | with self.assertRaises(ValueError): | 
|  | adam_opt = optim.Adam(self.net.parameters()) | 
|  | scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True) | 
|  |  | 
|  | def test_cycle_lr_removed_after_out_of_scope(self): | 
|  | import gc | 
|  | import weakref | 
|  | gc.disable() | 
|  |  | 
|  | def test(): | 
|  | adam_opt = optim.Adam(self.net.parameters()) | 
|  | scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False) | 
|  | return weakref.ref(scheduler) | 
|  |  | 
|  | ref = test() | 
|  | assert ref() is None | 
|  | gc.enable() | 
|  |  | 
|  | def test_onecycle_lr_invalid_anneal_strategy(self): | 
|  | with self.assertRaises(ValueError): | 
|  | scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS") | 
|  |  | 
|  | def test_onecycle_lr_invalid_pct_start(self): | 
|  | with self.assertRaises(ValueError): | 
|  | scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1) | 
|  |  | 
|  | def test_onecycle_lr_cannot_calculate_total_steps(self): | 
|  | with self.assertRaises(ValueError): | 
|  | scheduler = OneCycleLR(self.opt, max_lr=1e-3) | 
|  |  | 
|  | def test_onecycle_lr_linear_annealing(self): | 
|  | lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] | 
|  | momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, | 
|  | total_steps=10, anneal_strategy='linear') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) | 
|  |  | 
|  | def test_onecycle_lr_linear_annealing_three_phases(self): | 
|  | lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25] | 
|  | momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = OneCycleLR(self.opt, max_lr=25, div_factor=25, | 
|  | base_momentum=1, max_momentum=22, | 
|  | total_steps=10, anneal_strategy='linear', | 
|  | pct_start=0.4, final_div_factor=4, | 
|  | three_phase=True) | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) | 
|  |  | 
|  | def test_onecycle_lr_cosine_annealing(self): | 
|  | def annealing_cos(start, end, pct): | 
|  | cos_out = math.cos(math.pi * pct) + 1 | 
|  | return end + (start - end) / 2.0 * cos_out | 
|  | lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0), | 
|  | annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0), | 
|  | annealing_cos(25, 0.5, 6 / 7.0), 0.5] | 
|  | momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0), | 
|  | annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0), | 
|  | annealing_cos(1, 22, 6 / 7.0), 22] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, | 
|  | total_steps=10) | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) | 
|  |  | 
|  | def test_cycle_lr_with_adam(self): | 
|  | old_opt = self.opt | 
|  | self.opt = optim.Adam( | 
|  | [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], | 
|  | lr=0.05) | 
|  |  | 
|  | lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] | 
|  | momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] | 
|  | lr_targets = [lr_target, lr_target] | 
|  | momentum_targets = [momentum_target, momentum_target] | 
|  | scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, | 
|  | total_steps=10, anneal_strategy='linear') | 
|  | self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True) | 
|  | self.opt = old_opt  # set optimizer back to SGD | 
|  |  | 
|  | def test_lambda_lr(self): | 
|  | epochs = 10 | 
|  | self.opt.param_groups[0]['lr'] = 0.05 | 
|  | self.opt.param_groups[1]['lr'] = 0.4 | 
|  | targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]] | 
|  | scheduler = LambdaLR(self.opt, | 
|  | lr_lambda=[lambda x1: 0.9 ** x1, lambda x2: 0.8 ** x2]) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | def test_multiplicative_lr(self): | 
|  | epochs = 10 | 
|  | self.opt.param_groups[0]['lr'] = 0.05 | 
|  | self.opt.param_groups[1]['lr'] = 0.4 | 
|  | targets = [[0.05 * (0.9 ** x) for x in range(epochs)], [0.4 * (0.8 ** x) for x in range(epochs)]] | 
|  | scheduler = MultiplicativeLR(self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]) | 
|  | self._test(scheduler, targets, epochs) | 
|  |  | 
|  | @parametrize("T_mult", [1, 2, 4]) | 
|  | def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): | 
|  | iters = 100 | 
|  | eta_min = 1e-10 | 
|  | T_i = 10 | 
|  | T_cur = 0 | 
|  | targets = [[0.05], [0.5]] | 
|  | scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min) | 
|  | for _ in range(1, iters, 1): | 
|  | T_cur += 1 | 
|  | if T_cur >= T_i: | 
|  | T_cur = T_cur - T_i | 
|  | T_i = int(T_mult) * T_i | 
|  | targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] | 
|  | targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] | 
|  | self._test(scheduler, targets, iters) | 
|  |  | 
|  | def test_CosineAnnealingWarmRestarts_lr2(self): | 
|  | iters = 30 | 
|  | eta_min = 1e-10 | 
|  | T_mults = [1, 2, 4] | 
|  | for T_mult in T_mults: | 
|  | T_i = 10 | 
|  | T_cur = 0 | 
|  | targets = [[0.05], [0.5]] | 
|  | scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min) | 
|  | for _ in torch.arange(0.1, iters, 0.1): | 
|  | T_cur = round(T_cur + 0.1, 1) | 
|  | if T_cur >= T_i: | 
|  | T_cur = T_cur - T_i | 
|  | T_i = int(T_mult) * T_i | 
|  | targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] | 
|  | targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] | 
|  | self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters) | 
|  |  | 
|  | def test_CosineAnnealingWarmRestarts_lr3(self): | 
|  | epochs_for_T_mults = [[0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13], | 
|  | [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3], | 
|  | [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50]] | 
|  | T_curs_for_T_mults = [[1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3], | 
|  | [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3], | 
|  | [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10]] | 
|  | T_is_for_T_mults = [[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], | 
|  | [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10], | 
|  | [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90]] | 
|  | eta_min = 1e-10 | 
|  | T_mults = [1, 2, 3] | 
|  | for epochs, T_mult, T_curs, T_is in zip(epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults): | 
|  | targets = [[0.05], [0.5]] | 
|  | scheduler = CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min) | 
|  | for T_cur, T_i in zip(T_curs, T_is): | 
|  | targets[0] += [eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] | 
|  | targets[1] += [eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2] | 
|  | self._test_interleaved_CosineAnnealingWarmRestarts(scheduler, targets, epochs) | 
|  |  | 
|  | def test_swalr_no_anneal(self): | 
|  | epochs, swa_start, swa_lr = 10, 5, 0.01 | 
|  | initial_lrs = [group['lr'] for group in self.opt.param_groups] | 
|  | targets = [[lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) | 
|  | for lr in initial_lrs] | 
|  | swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr) | 
|  | self._test_swalr(swa_scheduler, None, targets, swa_start, epochs) | 
|  |  | 
|  | def test_swalr_cosine_anneal_after_multiplicative(self): | 
|  | # same swa_lr for different param_groups | 
|  | epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5 | 
|  | mult_factor = 0.9 | 
|  | scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) | 
|  | swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr) | 
|  |  | 
|  | def anneal_coef(t): | 
|  | if t + 1 >= anneal_epochs: | 
|  | return 0. | 
|  | return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2 | 
|  |  | 
|  | initial_lrs = [group['lr'] for group in self.opt.param_groups] | 
|  | targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)] | 
|  | for lr in initial_lrs] | 
|  | swa_epochs = epochs - swa_start - 1 | 
|  | targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)] | 
|  | for lrs in targets_before_swa] | 
|  |  | 
|  | self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) | 
|  |  | 
|  | def test_swalr_linear_anneal_after_multiplicative(self): | 
|  | # separate swa_lr for different param_groups | 
|  | epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4 | 
|  | mult_factor = 0.9 | 
|  | scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) | 
|  | swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, | 
|  | anneal_strategy="linear", swa_lr=swa_lrs) | 
|  |  | 
|  | def anneal_coef(t): | 
|  | if t + 1 >= anneal_epochs: | 
|  | return 0. | 
|  | return 1 - (t + 1) / anneal_epochs | 
|  |  | 
|  | initial_lrs = [group['lr'] for group in self.opt.param_groups] | 
|  | targets_before_swa = [[lr * mult_factor**i for i in range(swa_start + 1)] | 
|  | for lr in initial_lrs] | 
|  | swa_epochs = epochs - swa_start - 1 | 
|  | targets = [lrs + [lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) for t in range(swa_epochs)] | 
|  | for lrs, swa_lr in zip(targets_before_swa, swa_lrs)] | 
|  |  | 
|  | self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) | 
|  |  | 
|  | def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): | 
|  | for epoch in range(epochs): | 
|  | for param_group, target in zip(self.opt.param_groups, targets): | 
|  | self.assertEqual(target[epoch], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) | 
|  | if epoch >= swa_start: | 
|  | self.opt.step() | 
|  | swa_scheduler.step() | 
|  | elif scheduler is not None: | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  |  | 
|  | def test_swalr_hypers(self): | 
|  | # Test that SWALR raises errors for incorrect hyper-parameters | 
|  | with self.assertRaisesRegex(ValueError, "anneal_strategy must"): | 
|  | swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.) | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, "anneal_epochs must"): | 
|  | swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.) | 
|  | with self.assertRaisesRegex(ValueError, "anneal_epochs must"): | 
|  | swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.) | 
|  | with self.assertRaisesRegex(ValueError, "swa_lr must"): | 
|  | swa_scheduler = SWALR(self.opt, swa_lr=[1., 0.1, 0.01]) | 
|  |  | 
|  | def test_step_lr_state_dict(self): | 
|  | self._check_scheduler_state_dict( | 
|  | lambda: StepLR(self.opt, gamma=0.1, step_size=3), | 
|  | lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1)) | 
|  |  | 
|  | def test_multi_step_lr_state_dict(self): | 
|  | self._check_scheduler_state_dict( | 
|  | lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]), | 
|  | lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6])) | 
|  |  | 
|  | def test_exp_step_lr_state_dict(self): | 
|  | self._check_scheduler_state_dict( | 
|  | lambda: ExponentialLR(self.opt, gamma=0.1), | 
|  | lambda: ExponentialLR(self.opt, gamma=0.01)) | 
|  |  | 
|  | def test_cosine_lr_state_dict(self): | 
|  | epochs = 10 | 
|  | eta_min = 1e-10 | 
|  | self._check_scheduler_state_dict( | 
|  | lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min), | 
|  | lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2), | 
|  | epochs=epochs) | 
|  |  | 
|  | def test_reduce_lr_on_plateau_state_dict(self): | 
|  | scheduler = ReduceLROnPlateau(self.opt, mode='min', factor=0.1, patience=2) | 
|  | for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]: | 
|  | scheduler.step(score) | 
|  | scheduler_copy = ReduceLROnPlateau(self.opt, mode='max', factor=0.5, patience=10) | 
|  | scheduler_copy.load_state_dict(scheduler.state_dict()) | 
|  | for key in scheduler.__dict__.keys(): | 
|  | if key not in {'optimizer', 'is_better'}: | 
|  | self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) | 
|  |  | 
|  | def test_lambda_lr_state_dict_fn(self): | 
|  | scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x) | 
|  | state = scheduler.state_dict() | 
|  | self.assertIsNone(state['lr_lambdas'][0]) | 
|  |  | 
|  | scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x) | 
|  | scheduler_copy.load_state_dict(state) | 
|  | for key in scheduler.__dict__.keys(): | 
|  | if key not in {'optimizer', 'lr_lambdas'}: | 
|  | self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) | 
|  |  | 
|  | def test_lambda_lr_state_dict_obj(self): | 
|  | scheduler = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(10)) | 
|  | state = scheduler.state_dict() | 
|  | self.assertIsNotNone(state['lr_lambdas'][0]) | 
|  |  | 
|  | scheduler_copy = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(-1)) | 
|  | scheduler_copy.load_state_dict(state) | 
|  | for key in scheduler.__dict__.keys(): | 
|  | if key not in {'optimizer'}: | 
|  | self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) | 
|  |  | 
|  | def test_CosineAnnealingWarmRestarts_lr_state_dict(self): | 
|  | self._check_scheduler_state_dict( | 
|  | lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2), | 
|  | lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100)) | 
|  |  | 
|  | def test_swa_lr_state_dict(self): | 
|  | self._check_scheduler_state_dict( | 
|  | lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5), | 
|  | lambda: SWALR(self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.)) | 
|  |  | 
|  | def _check_scheduler_state_dict(self, constr, constr2, epochs=10): | 
|  | scheduler = constr() | 
|  | for _ in range(epochs): | 
|  | scheduler.optimizer.step() | 
|  | scheduler.step() | 
|  | scheduler_copy = constr2() | 
|  | scheduler_copy.load_state_dict(scheduler.state_dict()) | 
|  | for key in scheduler.__dict__.keys(): | 
|  | if key != 'optimizer': | 
|  | self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) | 
|  | self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) | 
|  |  | 
|  | def _test_get_last_lr(self, schedulers, targets, epochs=10): | 
|  | if isinstance(schedulers, LRScheduler): | 
|  | schedulers = [schedulers] | 
|  | optimizers = {scheduler.optimizer for scheduler in schedulers} | 
|  | for epoch in range(epochs): | 
|  | result = [scheduler.get_last_lr() for scheduler in schedulers] | 
|  | [optimizer.step() for optimizer in optimizers] | 
|  | [scheduler.step() for scheduler in schedulers] | 
|  | target = [[t[epoch] for t in targets]] * len(schedulers) | 
|  | for t, r in zip(target, result): | 
|  | self.assertEqual(target, result, | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, t, r), atol=1e-5, rtol=0) | 
|  |  | 
|  | def _test_with_epoch(self, schedulers, targets, epochs=10): | 
|  | if isinstance(schedulers, LRScheduler): | 
|  | schedulers = [schedulers] | 
|  | optimizers = {scheduler.optimizer for scheduler in schedulers} | 
|  | for epoch in range(epochs): | 
|  | [optimizer.step() for optimizer in optimizers] | 
|  | with warnings.catch_warnings(record=True) as w: | 
|  | [scheduler.step(epoch) for scheduler in schedulers]  # step before assert: skip initial lr | 
|  | self._check_warning_is_epoch_deprecation_warning(w, num_warnings=len(schedulers)) | 
|  | for param_group, target in zip(self.opt.param_groups, targets): | 
|  | self.assertEqual(target[epoch], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) | 
|  |  | 
|  | def _test(self, schedulers, targets, epochs=10): | 
|  | if isinstance(schedulers, LRScheduler): | 
|  | schedulers = [schedulers] | 
|  | for epoch in range(epochs): | 
|  | for param_group, target in zip(self.opt.param_groups, targets): | 
|  | self.assertEqual(target[epoch], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) | 
|  | [scheduler.step() for scheduler in schedulers] | 
|  |  | 
|  | def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): | 
|  | for index, epoch in enumerate(torch.arange(0, epochs, 0.1)): | 
|  | epoch = round(epoch.item(), 1) | 
|  | scheduler.step(epoch) | 
|  | for param_group, target in zip(self.opt.param_groups, targets): | 
|  | self.assertEqual(target[index], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, target[index], param_group['lr']), atol=1e-5, rtol=0) | 
|  |  | 
|  | def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs): | 
|  | for index, epoch in enumerate(epochs): | 
|  | scheduler.step(epoch) | 
|  | for param_group, target in zip(self.opt.param_groups, targets): | 
|  | self.assertEqual(target[index], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, target[index], param_group['lr']), atol=1e-5, rtol=0) | 
|  |  | 
|  | def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10): | 
|  | self.setUp() | 
|  | targets = [] | 
|  | for epoch in range(epochs): | 
|  | closed_form_scheduler.optimizer.step() | 
|  | with warnings.catch_warnings(record=True) as w: | 
|  | closed_form_scheduler.step(epoch) | 
|  | self._check_warning_is_epoch_deprecation_warning(w) | 
|  | targets.append([group['lr'] for group in self.opt.param_groups]) | 
|  | self.setUp() | 
|  | for epoch in range(epochs): | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  | for i, param_group in enumerate(self.opt.param_groups): | 
|  | self.assertEqual(targets[epoch][i], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, targets[epoch][i], param_group['lr']), atol=1e-5, rtol=0) | 
|  |  | 
|  | def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, verbose=False): | 
|  | if isinstance(schedulers, LRScheduler) or isinstance(schedulers, ReduceLROnPlateau): | 
|  | schedulers = [schedulers] | 
|  | for epoch in range(epochs): | 
|  | self.opt.step() | 
|  | for scheduler in schedulers: | 
|  | if isinstance(scheduler, ReduceLROnPlateau): | 
|  | scheduler.step(metrics[epoch]) | 
|  | else: | 
|  | scheduler.step() | 
|  | if verbose: | 
|  | print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr'])) | 
|  | for param_group, target in zip(self.opt.param_groups, targets): | 
|  | self.assertEqual(target[epoch], param_group['lr'], | 
|  | msg='LR is wrong in epoch {}: expected {}, got {}'.format( | 
|  | epoch, target[epoch], param_group['lr']), atol=1e-5, rtol=0) | 
|  |  | 
|  | def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False): | 
|  | for batch_num in range(batch_iterations): | 
|  | if verbose: | 
|  | if 'momentum' in self.opt.param_groups[0].keys(): | 
|  | print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'], | 
|  | self.opt.param_groups[0]['momentum'])) | 
|  | elif use_beta1 and 'betas' in self.opt.param_groups[0].keys(): | 
|  | print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'], | 
|  | self.opt.param_groups[0]['betas'][0])) | 
|  | else: | 
|  | print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr'])) | 
|  |  | 
|  | for param_group, lr_target, momentum_target in zip(self.opt.param_groups, lr_targets, momentum_targets): | 
|  | self.assertEqual( | 
|  | lr_target[batch_num], param_group['lr'], | 
|  | msg='LR is wrong in batch_num {}: expected {}, got {}'.format( | 
|  | batch_num, lr_target[batch_num], param_group['lr']), atol=1e-5, rtol=0) | 
|  |  | 
|  | if use_beta1 and 'betas' in param_group.keys(): | 
|  | self.assertEqual( | 
|  | momentum_target[batch_num], param_group['betas'][0], | 
|  | msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format( | 
|  | batch_num, momentum_target[batch_num], param_group['betas'][0]), atol=1e-5, rtol=0) | 
|  | elif 'momentum' in param_group.keys(): | 
|  | self.assertEqual( | 
|  | momentum_target[batch_num], param_group['momentum'], | 
|  | msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format( | 
|  | batch_num, momentum_target[batch_num], param_group['momentum']), atol=1e-5, rtol=0) | 
|  | self.opt.step() | 
|  | scheduler.step() | 
|  |  | 
|  | def test_cosine_then_cyclic(self): | 
|  | # https://github.com/pytorch/pytorch/issues/21965 | 
|  |  | 
|  | max_lr = 0.3 | 
|  | base_lr = 0.1 | 
|  | optim_lr = 0.5 | 
|  |  | 
|  | model = torch.nn.Linear(2, 1) | 
|  | optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr) | 
|  | lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.1) | 
|  | lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR( | 
|  | optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3 | 
|  | ) | 
|  |  | 
|  | for i in range(40): | 
|  | optimizer.step() | 
|  | if i <= lr_scheduler_1.T_max: | 
|  | lr_scheduler_1.step() | 
|  | else: | 
|  | lr_scheduler_2.step() | 
|  | last_lr = optimizer.param_groups[0]["lr"] | 
|  |  | 
|  | self.assertLessEqual(last_lr, max_lr) | 
|  |  | 
|  |  | 
|  | class SWATestDNN(torch.nn.Module): | 
|  | def __init__(self, input_features): | 
|  | super(SWATestDNN, self).__init__() | 
|  | self.n_features = 100 | 
|  | self.fc1 = torch.nn.Linear(input_features, self.n_features) | 
|  | self.bn = torch.nn.BatchNorm1d(self.n_features) | 
|  |  | 
|  | def compute_preactivation(self, x): | 
|  | return self.fc1(x) | 
|  |  | 
|  | def forward(self, x): | 
|  | x = self.fc1(x) | 
|  | x = self.bn(x) | 
|  | return x | 
|  |  | 
|  |  | 
|  | class SWATestCNN(torch.nn.Module): | 
|  | def __init__(self, input_channels): | 
|  | super(SWATestCNN, self).__init__() | 
|  | self.n_features = 10 | 
|  | self.conv1 = torch.nn.Conv2d(input_channels, self.n_features, kernel_size=3, padding=1) | 
|  | self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3) | 
|  |  | 
|  | def compute_preactivation(self, x): | 
|  | return self.conv1(x) | 
|  |  | 
|  | def forward(self, x): | 
|  | x = self.conv1(x) | 
|  | x = self.bn(x) | 
|  | return x | 
|  |  | 
|  |  | 
|  | class TestSWAUtils(TestCase): | 
|  |  | 
|  | def _test_averaged_model(self, net_device, swa_device): | 
|  | dnn = torch.nn.Sequential( | 
|  | torch.nn.Conv2d(1, 5, kernel_size=3), | 
|  | torch.nn.ReLU(), | 
|  | torch.nn.MaxPool2d(kernel_size=2), | 
|  | torch.nn.BatchNorm2d(5, momentum=0.3), | 
|  | torch.nn.Conv2d(5, 2, kernel_size=3), | 
|  | torch.nn.ReLU(), | 
|  | torch.nn.Linear(5, 5), | 
|  | torch.nn.ReLU(), | 
|  | torch.nn.Linear(5, 10) | 
|  | ).to(net_device) | 
|  |  | 
|  | averaged_dnn = AveragedModel(dnn, device=swa_device) | 
|  | averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] | 
|  | n_updates = 10 | 
|  | for i in range(n_updates): | 
|  | for p, p_avg in zip(dnn.parameters(), averaged_params): | 
|  | p.detach().add_(torch.randn_like(p)) | 
|  | p_avg += p.detach() / n_updates | 
|  | averaged_dnn.update_parameters(dnn) | 
|  |  | 
|  | for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): | 
|  | self.assertEqual(p_avg, p_swa) | 
|  | # Check that AveragedModel is on the correct device | 
|  | self.assertTrue(p_swa.device == swa_device) | 
|  | self.assertTrue(p.device == net_device) | 
|  | self.assertTrue(averaged_dnn.n_averaged.device == swa_device) | 
|  |  | 
|  | def test_averaged_model_all_devices(self): | 
|  | cpu = torch.device("cpu") | 
|  | self._test_averaged_model(cpu, cpu) | 
|  | if torch.cuda.is_available(): | 
|  | cuda = torch.device(0) | 
|  | self._test_averaged_model(cuda, cpu) | 
|  | self._test_averaged_model(cpu, cuda) | 
|  | self._test_averaged_model(cuda, cuda) | 
|  |  | 
|  | def test_averaged_model_mixed_device(self): | 
|  | if not torch.cuda.is_available(): | 
|  | return | 
|  | dnn = torch.nn.Sequential( | 
|  | torch.nn.Conv2d(1, 5, kernel_size=3), | 
|  | torch.nn.Linear(5, 10) | 
|  | ) | 
|  | dnn[0].cuda() | 
|  | dnn[1].cpu() | 
|  | averaged_dnn = AveragedModel(dnn) | 
|  | averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] | 
|  | n_updates = 10 | 
|  | for i in range(n_updates): | 
|  | for p, p_avg in zip(dnn.parameters(), averaged_params): | 
|  | p.detach().add_(torch.randn_like(p)) | 
|  | p_avg += p.detach() / n_updates | 
|  | averaged_dnn.update_parameters(dnn) | 
|  |  | 
|  | for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): | 
|  | self.assertEqual(p_avg, p_swa) | 
|  | # Check that AveragedModel is on the correct device | 
|  | self.assertTrue(p_avg.device == p_swa.device) | 
|  |  | 
|  | def test_averaged_model_state_dict(self): | 
|  | dnn = torch.nn.Sequential( | 
|  | torch.nn.Conv2d(1, 5, kernel_size=3), | 
|  | torch.nn.Linear(5, 10) | 
|  | ) | 
|  | averaged_dnn = AveragedModel(dnn) | 
|  | averaged_dnn2 = AveragedModel(dnn) | 
|  | n_updates = 10 | 
|  | for i in range(n_updates): | 
|  | for p in dnn.parameters(): | 
|  | p.detach().add_(torch.randn_like(p)) | 
|  | averaged_dnn.update_parameters(dnn) | 
|  | averaged_dnn2.load_state_dict(averaged_dnn.state_dict()) | 
|  | for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()): | 
|  | self.assertEqual(p_swa, p_swa2) | 
|  | self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) | 
|  |  | 
|  | def test_averaged_model_exponential(self): | 
|  | # Test AveragedModel with EMA as avg_fn | 
|  | dnn = torch.nn.Sequential( | 
|  | torch.nn.Conv2d(1, 5, kernel_size=3), | 
|  | torch.nn.BatchNorm2d(5, momentum=0.3), | 
|  | torch.nn.Linear(5, 10) | 
|  | ) | 
|  | alpha = 0.9 | 
|  |  | 
|  | def avg_fn(p_avg, p, n_avg): | 
|  | return alpha * p_avg + (1 - alpha) * p | 
|  | averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn) | 
|  | averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] | 
|  | n_updates = 10 | 
|  | for i in range(n_updates): | 
|  | updated_averaged_params = [] | 
|  | for p, p_avg in zip(dnn.parameters(), averaged_params): | 
|  | p.detach().add_(torch.randn_like(p)) | 
|  | if i == 0: | 
|  | updated_averaged_params.append(p.clone()) | 
|  | else: | 
|  | updated_averaged_params.append((p_avg * alpha + | 
|  | p * (1 - alpha)).clone()) | 
|  | for b in dnn.buffers(): | 
|  | if b.size() != torch.Size([]): | 
|  | b.detach_().add_(torch.randn_like(b)) | 
|  |  | 
|  | averaged_dnn.update_parameters(dnn) | 
|  | averaged_params = updated_averaged_params | 
|  |  | 
|  | for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): | 
|  | self.assertEqual(p_avg, p_swa) | 
|  | for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): | 
|  | self.assertEqual(b_avg, b_swa) | 
|  |  | 
|  | def test_averaged_model_exponential_buffers(self): | 
|  | # Test AveragedModel with EMA as avg_fn and use_buffers as True. | 
|  | dnn = torch.nn.Sequential( | 
|  | torch.nn.Conv2d(1, 5, kernel_size=3), | 
|  | torch.nn.BatchNorm2d(5, momentum=0.3), | 
|  | torch.nn.Linear(5, 10) | 
|  | ) | 
|  | alpha = 0.9 | 
|  |  | 
|  | def avg_fn(p_avg, p, n_avg): | 
|  | return alpha * p_avg + (1 - alpha) * p | 
|  | averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True) | 
|  | dnn_params = itertools.chain(dnn.parameters(), dnn.buffers()) | 
|  | averaged_params = [torch.zeros_like(param) for param in dnn_params | 
|  | if param.size() != torch.Size([])] | 
|  | n_updates = 10 | 
|  | for i in range(n_updates): | 
|  | updated_averaged_params = [] | 
|  | for p, p_avg in zip(dnn_params, averaged_params): | 
|  | if p.size() == torch.Size([]): | 
|  | continue | 
|  | p.detach().add_(torch.randn_like(p)) | 
|  | if i == 0: | 
|  | updated_averaged_params.append(p.clone()) | 
|  | else: | 
|  | updated_averaged_params.append((p_avg * alpha + | 
|  | p * (1 - alpha)).clone()) | 
|  | averaged_dnn.update_parameters(dnn) | 
|  | averaged_params = updated_averaged_params | 
|  |  | 
|  | for p_avg, p_swa in zip( | 
|  | averaged_params, itertools.chain(averaged_dnn.module.parameters(), averaged_dnn.module.buffers())): | 
|  | self.assertEqual(p_avg, p_swa) | 
|  |  | 
|  | def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): | 
|  |  | 
|  | preactivation_sum = torch.zeros(dnn.n_features) | 
|  | preactivation_squared_sum = torch.zeros(dnn.n_features) | 
|  | if cuda: | 
|  | preactivation_sum = preactivation_sum.cuda() | 
|  | preactivation_squared_sum = preactivation_squared_sum.cuda() | 
|  | total_num = 0 | 
|  | for x in dl_x: | 
|  | x = x[0] | 
|  | if cuda: | 
|  | x = x.cuda() | 
|  |  | 
|  | dnn.forward(x) | 
|  | preactivations = dnn.compute_preactivation(x) | 
|  | if len(preactivations.shape) == 4: | 
|  | preactivations = preactivations.transpose(1, 3) | 
|  | preactivations = preactivations.contiguous().view(-1, dnn.n_features) | 
|  | total_num += preactivations.shape[0] | 
|  |  | 
|  | preactivation_sum += torch.sum(preactivations, dim=0) | 
|  | preactivation_squared_sum += torch.sum(preactivations**2, dim=0) | 
|  |  | 
|  | preactivation_mean = preactivation_sum / total_num | 
|  | preactivation_var = preactivation_squared_sum / total_num | 
|  | preactivation_var = preactivation_var - preactivation_mean**2 | 
|  |  | 
|  | update_bn(dl_xy, dnn, device=x.device) | 
|  | self.assertEqual(preactivation_mean, dnn.bn.running_mean) | 
|  | self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) | 
|  |  | 
|  | def _reset_bn(module): | 
|  | if issubclass(module.__class__, | 
|  | torch.nn.modules.batchnorm._BatchNorm): | 
|  | module.running_mean = torch.zeros_like(module.running_mean) | 
|  | module.running_var = torch.ones_like(module.running_var) | 
|  | # reset batch norm and run update_bn again | 
|  | dnn.apply(_reset_bn) | 
|  | update_bn(dl_xy, dnn, device=x.device) | 
|  | self.assertEqual(preactivation_mean, dnn.bn.running_mean) | 
|  | self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) | 
|  | # using the dl_x loader instead of dl_xy | 
|  | dnn.apply(_reset_bn) | 
|  | update_bn(dl_x, dnn, device=x.device) | 
|  | self.assertEqual(preactivation_mean, dnn.bn.running_mean) | 
|  | self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) | 
|  |  | 
|  | def test_update_bn_dnn(self): | 
|  | # Test update_bn for a fully-connected network with BatchNorm1d | 
|  | objects, input_features = 100, 5 | 
|  | x = torch.rand(objects, input_features) | 
|  | y = torch.rand(objects) | 
|  | ds_x = torch.utils.data.TensorDataset(x) | 
|  | ds_xy = torch.utils.data.TensorDataset(x, y) | 
|  | dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) | 
|  | dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) | 
|  | dnn = SWATestDNN(input_features=input_features) | 
|  | dnn.train() | 
|  | self._test_update_bn(dnn, dl_x, dl_xy, False) | 
|  | if torch.cuda.is_available(): | 
|  | dnn = SWATestDNN(input_features=input_features) | 
|  | dnn.train() | 
|  | self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) | 
|  | self.assertTrue(dnn.training) | 
|  |  | 
|  | def test_update_bn_cnn(self): | 
|  | # Test update_bn for convolutional network and BatchNorm2d | 
|  | objects = 100 | 
|  | input_channels = 3 | 
|  | height, width = 5, 5 | 
|  | x = torch.rand(objects, input_channels, height, width) | 
|  | y = torch.rand(objects) | 
|  | ds_x = torch.utils.data.TensorDataset(x) | 
|  | ds_xy = torch.utils.data.TensorDataset(x, y) | 
|  | dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) | 
|  | dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) | 
|  | dnn = SWATestCNN(input_channels=input_channels) | 
|  | dnn.train() | 
|  | self._test_update_bn(dnn, dl_x, dl_xy, False) | 
|  | if torch.cuda.is_available(): | 
|  | dnn = SWATestCNN(input_channels=input_channels) | 
|  | dnn.train() | 
|  | self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) | 
|  | self.assertTrue(dnn.training) | 
|  |  | 
|  | def test_bn_update_eval_momentum(self): | 
|  | # check that update_bn preserves eval mode | 
|  | objects = 100 | 
|  | input_channels = 3 | 
|  | height, width = 5, 5 | 
|  | x = torch.rand(objects, input_channels, height, width) | 
|  | ds_x = torch.utils.data.TensorDataset(x) | 
|  | dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) | 
|  | dnn = SWATestCNN(input_channels=input_channels) | 
|  | dnn.eval() | 
|  | update_bn(dl_x, dnn) | 
|  | self.assertFalse(dnn.training) | 
|  |  | 
|  | # check that momentum is preserved | 
|  | self.assertEqual(dnn.bn.momentum, 0.3) | 
|  |  | 
|  |  | 
|  | instantiate_parametrized_tests(TestLRScheduler) | 
|  |  | 
|  |  | 
|  | def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): | 
|  | # Ignored is the list of values in `opt_differentiable_state`, we do this | 
|  | # for `gradcheck` to correctly track the state tensors as function inputs | 
|  | # because otherwise it can't unpack the values in the `opt_differentiable_state` | 
|  | # dict | 
|  | p = p.clone() | 
|  | p.grad = grad | 
|  | opt_differentiable_state = { | 
|  | k: v.clone() if isinstance(v, torch.Tensor) else v | 
|  | for k, v in opt_differentiable_state.items() | 
|  | } | 
|  | opt = opt_class([p], **kwargs) | 
|  | opt.state[p].update(opt_differentiable_state) | 
|  | opt.step() | 
|  | return (p,) + tuple( | 
|  | v for v in opt.state[p].values() if isinstance(v, torch.Tensor) and v.requires_grad) | 
|  |  | 
|  |  | 
|  | class TestDifferentiableOptimizer(TestCase): | 
|  |  | 
|  | def test_sgd(self): | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state = {'momentum_buffer': mbuff} | 
|  | gradcheck(_diff_fn, (p, grad, state, torch.optim.SGD, {'lr': 0.9, 'differentiable': True}, *state.values())) | 
|  |  | 
|  | def test_adam(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['max_exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  |  | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.Adam, | 
|  | {'lr': 0.9, 'differentiable': True, 'amsgrad': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_rmsprop(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['step'] = 0 | 
|  | state['square_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['momentum_buffer'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # This can cause issues with large values and nan due to sqrt ops | 
|  | state['grad_avg'] = 1e-2 * torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.RMSprop, | 
|  | {'lr': 0.9, 'maximize': True, 'momentum': 0.9, 'differentiable': True, 'centered': True, 'weight_decay': 0.1}, | 
|  | *state.values())) | 
|  |  | 
|  | def test_adadelta(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['square_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['acc_delta'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.Adadelta, | 
|  | {'lr': 0.9, 'weight_decay': 0.1, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_adagrad(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['sum'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.Adagrad, | 
|  | {'lr': 0.9, 'weight_decay': 0.1, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_adamax(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['exp_inf'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.Adamax, | 
|  | {'lr': 0.9, 'weight_decay': 0.1, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_asgd(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` `eta` & `mu` are not continuous variables (even though we define them as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['eta'] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) | 
|  | state['mu'] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) | 
|  | state['ax'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  |  | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.ASGD, | 
|  | {'lr': 0.9, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_rprop(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['prev'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['step_size'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  |  | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.Rprop, | 
|  | {'lr': 0.9, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  |  | 
|  | def test_adamw(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['max_exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  |  | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.AdamW, | 
|  | {'lr': 0.9, 'differentiable': True, 'amsgrad': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_nadam(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['mu_product'] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) | 
|  |  | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.NAdam, | 
|  | {'lr': 0.9, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  | def test_radam(self): | 
|  | state = {} | 
|  | p = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | grad = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | # `step` is not a continuous variable (even though we define it as a float) | 
|  | # and so it shouldn't require gradients. | 
|  | state['step'] = torch.tensor(10., requires_grad=False, dtype=torch.float64) | 
|  | state['exp_avg'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  | state['exp_avg_sq'] = torch.rand(10, requires_grad=True, dtype=torch.float64) | 
|  |  | 
|  | gradcheck( | 
|  | _diff_fn, | 
|  | (p, grad, state, torch.optim.RAdam, | 
|  | {'lr': 0.9, 'differentiable': True}, *state.values()) | 
|  | ) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |