| # Owner(s): ["module: optimizer"] |
| |
| import warnings |
| import math |
| import unittest |
| import functools |
| import itertools |
| from copy import deepcopy |
| |
| import torch |
| import torch.optim as optim |
| import torch.nn.functional as F |
| from torch.nn import Parameter |
| from torch.optim import Adam, SGD, Optimizer |
| from torch import sparse |
| from torch.optim.lr_scheduler import ( |
| LambdaLR, |
| MultiplicativeLR, |
| SequentialLR, |
| StepLR, |
| MultiStepLR, |
| ConstantLR, |
| LinearLR, |
| ExponentialLR, |
| CosineAnnealingLR, |
| ReduceLROnPlateau, |
| LRScheduler, |
| CyclicLR, |
| CosineAnnealingWarmRestarts, |
| OneCycleLR, |
| ChainedScheduler, |
| PolynomialLR, |
| EPOCH_DEPRECATION_WARNING, |
| ) |
| from torch.optim.swa_utils import AveragedModel, SWALR, update_bn |
| from torch.testing._internal.common_utils import ( |
| TestCase, |
| run_tests, |
| TEST_WITH_UBSAN, |
| load_tests, |
| parametrize, |
| instantiate_parametrized_tests, |
| gradcheck, |
| skipIfRocm, |
| skipIfTorchDynamo |
| ) |
| from typing import Dict, Any, Tuple |
| from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook |
| |
| # load_tests from common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| |
| def rosenbrock(tensor): |
| x, y = tensor |
| return (1 - x) ** 2 + 100 * (y - x**2) ** 2 |
| |
| |
| def drosenbrock(tensor): |
| x, y = tensor |
| return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) |
| |
| |
| class TestOptim(TestCase): |
| exact_dtype = True |
| |
| def _test_rosenbrock_sparse( |
| self, |
| constructor, |
| scheduler_constructors=None, |
| sparse_only=False, |
| maximize=False, |
| ): |
| if scheduler_constructors is None: |
| scheduler_constructors = [] |
| params_t = torch.tensor([1.5, 1.5]) |
| |
| params = Parameter(params_t) |
| optimizer = constructor([params]) |
| schedulers = [] |
| for scheduler_constructor in scheduler_constructors: |
| schedulers.append(scheduler_constructor(optimizer)) |
| |
| if not sparse_only: |
| params_c = Parameter(params_t.clone()) |
| optimizer_c = constructor([params_c]) |
| |
| solution = torch.tensor([1, 1]) |
| with torch.no_grad(): |
| initial_dist = params.dist(solution) |
| |
| def eval(params, sparse_grad, w): |
| # Depending on w, provide only the x or y gradient |
| optimizer.zero_grad() |
| loss = rosenbrock(params) |
| loss.backward() |
| grad = drosenbrock(params.data) |
| # NB: We torture test the optimizer by returning an |
| # uncoalesced sparse tensor |
| if w: |
| i = torch.LongTensor([[0, 0]]) |
| x = grad[0] |
| v = torch.tensor([x / 4.0, x - x / 4.0]) |
| else: |
| i = torch.LongTensor([[1, 1]]) |
| y = grad[1] |
| v = torch.tensor([y - y / 4.0, y / 4.0]) |
| x = sparse.DoubleTensor(i, v, torch.Size([2])).to(dtype=v.dtype) |
| with torch.no_grad(): |
| if sparse_grad: |
| params.grad = x |
| else: |
| params.grad = x.to_dense() |
| return loss |
| |
| for i in range(2000): |
| # Do cyclic coordinate descent |
| w = i % 2 |
| optimizer.step(functools.partial(eval, params, True, w)) |
| for scheduler in schedulers: |
| if isinstance(scheduler, ReduceLROnPlateau): |
| scheduler.step(rosenbrock(params)) |
| else: |
| scheduler.step() |
| if not sparse_only: |
| optimizer_c.step(functools.partial(eval, params_c, False, w)) |
| self.assertEqual(params, params_c) |
| |
| if not maximize: |
| self.assertLessEqual(params.data.dist(solution), initial_dist) |
| else: |
| self.assertGreaterEqual(rosenbrock(params), rosenbrock(params_t)) |
| |
| def _test_basic_cases_template( |
| self, |
| weight_tensor, |
| bias_tensor, |
| input_tensor, |
| constructor, |
| scheduler_constructors, |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=False, |
| ): |
| maximize_options = set([False, constructor_accepts_maximize]) |
| foreach_options = set([False, constructor_accepts_foreach]) |
| |
| four_arg_constructor = constructor |
| if constructor_accepts_maximize and constructor_accepts_foreach: |
| pass |
| elif constructor_accepts_maximize: |
| |
| def four_arg_constructor(weight, bias, maximize, foreach): |
| self.assertFalse(foreach) |
| return constructor(weight, bias, maximize) |
| |
| elif constructor_accepts_foreach: |
| |
| def four_arg_constructor(weight, bias, maximize, foreach): |
| self.assertFalse(maximize) |
| return constructor(weight, bias, foreach) |
| |
| else: |
| |
| def four_arg_constructor(weight, bias, maximize, foreach): |
| self.assertFalse(maximize or foreach) |
| return constructor(weight, bias) |
| |
| for maximize, foreach in itertools.product(maximize_options, foreach_options): |
| with torch.no_grad(): |
| weight = Parameter(weight_tensor.clone().detach()) |
| bias = Parameter(bias_tensor.clone().detach()) |
| input = input_tensor.clone().detach().requires_grad_() |
| optimizer = four_arg_constructor(weight, bias, maximize, foreach) |
| schedulers = [] |
| for scheduler_constructor in scheduler_constructors: |
| schedulers.append(scheduler_constructor(optimizer)) |
| |
| # to check if the optimizer can be printed as a string |
| optimizer.__repr__() |
| |
| def fn(): |
| optimizer.zero_grad() |
| y = weight.mv(input) |
| if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device(): |
| y = y.cuda(bias.get_device()) |
| loss = (y + bias).pow(2).sum() |
| loss.backward() |
| return loss |
| |
| initial_value = fn().item() |
| for _ in range(200): |
| for scheduler in schedulers: |
| if isinstance(scheduler, ReduceLROnPlateau): |
| val_loss = fn() |
| scheduler.step(val_loss) |
| else: |
| scheduler.step() |
| optimizer.step(fn) |
| if maximize: |
| self.assertGreater(fn().item(), initial_value) |
| else: |
| self.assertLess(fn().item(), initial_value) |
| |
| def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None): |
| weight = Parameter(weight) |
| bias = Parameter(bias) |
| with torch.no_grad(): |
| input = input.clone().detach().requires_grad_() |
| |
| def fn_base(optimizer, weight, bias): |
| optimizer.zero_grad() |
| i = input_cuda if weight.is_cuda else input |
| loss = (weight.mv(i) + bias).pow(2).sum() |
| loss.backward() |
| return loss |
| |
| optimizer = constructor(weight, bias) |
| fn = functools.partial(fn_base, optimizer, weight, bias) |
| |
| # Prime the optimizer |
| for _i in range(20): |
| optimizer.step(fn) |
| # Clone the weights and construct new optimizer for them |
| with torch.no_grad(): |
| weight_c = Parameter(weight.clone().detach()) |
| bias_c = Parameter(bias.clone().detach()) |
| optimizer_c = constructor(weight_c, bias_c) |
| fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c) |
| # Load state dict |
| state_dict = deepcopy(optimizer.state_dict()) |
| state_dict_c = deepcopy(optimizer.state_dict()) |
| optimizer_c.load_state_dict(state_dict_c) |
| # Run both optimizations in parallel |
| for _ in range(20): |
| optimizer.step(fn) |
| optimizer_c.step(fn_c) |
| self.assertEqual(weight, weight_c) |
| self.assertEqual(bias, bias_c) |
| # Make sure state dict wasn't modified |
| self.assertEqual(state_dict, state_dict_c) |
| # Make sure state dict is deterministic with equal but not identical parameters |
| self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) |
| # Make sure repeated parameters have identical representation in state dict |
| optimizer_c.param_groups.extend(optimizer_c.param_groups) |
| self.assertEqual( |
| optimizer.state_dict()["param_groups"][-1], |
| optimizer_c.state_dict()["param_groups"][-1], |
| ) |
| |
| # Make sure that optimizers that support maximize can load older models |
| state_dict = optimizer.state_dict() |
| if "maximize" in state_dict["param_groups"][0]: |
| for group in state_dict["param_groups"]: |
| del group["maximize"] |
| optimizer.load_state_dict(state_dict) |
| # Make sure we can still step |
| optimizer.step() |
| # Make sure that optimizers that support foreach can load older models |
| state_dict = optimizer.state_dict() |
| if "foreach" in state_dict["param_groups"][0]: |
| for group in state_dict["param_groups"]: |
| del group["foreach"] |
| optimizer.load_state_dict(state_dict) |
| # Make sure we can still step |
| optimizer.step() |
| |
| # Make sure that loading optimizers with step not wrapped in tensor can work |
| state_dict = optimizer.state_dict() |
| if "step" in state_dict["state"][0] and torch.is_tensor( |
| state_dict["state"][0]["step"] |
| ): |
| for state in state_dict["state"].values(): |
| state["step"] = state["step"].item() |
| optimizer.load_state_dict(state_dict) |
| optimizer.step() |
| |
| # Check that state dict can be loaded even when we cast parameters |
| # to a different type and move to a different device. |
| if not torch.cuda.is_available(): |
| return |
| |
| with torch.no_grad(): |
| input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda") |
| weight_cuda = Parameter( |
| weight.clone().detach().to(dtype=torch.float32, device="cuda") |
| ) |
| bias_cuda = Parameter( |
| bias.clone().detach().to(dtype=torch.float32, device="cuda") |
| ) |
| optimizer_cuda = constructor(weight_cuda, bias_cuda) |
| fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda) |
| |
| state_dict = deepcopy(optimizer.state_dict()) |
| state_dict_c = deepcopy(optimizer.state_dict()) |
| optimizer_cuda.load_state_dict(state_dict_c) |
| |
| # Make sure state dict wasn't modified |
| self.assertEqual(state_dict, state_dict_c) |
| |
| # Make sure that device of state['step'] is still CPU |
| new_state_dict = optimizer_cuda.state_dict() |
| if "step" in state_dict["state"][0] and torch.is_tensor( |
| state_dict["state"][0]["step"] |
| ): |
| for state in new_state_dict["state"].values(): |
| self.assertEqual(state["step"].device.type, "cpu") |
| |
| for _i in range(20): |
| optimizer.step(fn) |
| optimizer_cuda.step(fn_cuda) |
| self.assertEqual(weight, weight_cuda) |
| self.assertEqual(bias, bias_cuda, atol=atol, rtol=rtol) |
| |
| # validate deepcopy() copies all public attributes |
| def getPublicAttr(obj): |
| return set(k for k in obj.__dict__ if not k.startswith("_")) |
| |
| self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer))) |
| |
| def _test_basic_cases( |
| self, |
| constructor, |
| scheduler_constructors=None, |
| ignore_multidevice=False, |
| constructor_accepts_maximize=False, |
| constructor_accepts_foreach=False, |
| atol=None, |
| rtol=None, |
| ): |
| if scheduler_constructors is None: |
| scheduler_constructors = [] |
| |
| def make_two_arg_constructor( |
| constructor, maximize: bool = False, foreach: bool = False |
| ): |
| if constructor_accepts_maximize and constructor_accepts_foreach: |
| return lambda weight, bias: constructor(weight, bias, maximize, foreach) |
| if constructor_accepts_maximize: |
| return lambda weight, bias: constructor(weight, bias, maximize) |
| if constructor_accepts_foreach: |
| return lambda weight, bias: constructor(weight, bias, foreach) |
| return constructor |
| |
| for maximize, foreach in itertools.product( |
| set([False, constructor_accepts_maximize]), |
| set([False, constructor_accepts_foreach]), |
| ): |
| self._test_state_dict( |
| torch.randn(10, 5), |
| torch.randn(10), |
| torch.randn(5), |
| make_two_arg_constructor(constructor, maximize, foreach), |
| atol=atol, |
| rtol=rtol, |
| ) |
| self._test_basic_cases_template( |
| torch.randn(10, 5), |
| torch.randn(10), |
| torch.randn(5), |
| constructor, |
| scheduler_constructors, |
| constructor_accepts_maximize, |
| constructor_accepts_foreach, |
| ) |
| # non-contiguous parameters |
| self._test_basic_cases_template( |
| torch.randn(10, 5, 2)[..., 0], |
| torch.randn(10, 2)[..., 0], |
| torch.randn(5), |
| constructor, |
| scheduler_constructors, |
| constructor_accepts_maximize, |
| constructor_accepts_foreach, |
| ) |
| # CUDA |
| if not torch.cuda.is_available(): |
| return |
| self._test_basic_cases_template( |
| torch.randn(10, 5).cuda(), |
| torch.randn(10).cuda(), |
| torch.randn(5).cuda(), |
| constructor, |
| scheduler_constructors, |
| constructor_accepts_maximize, |
| constructor_accepts_foreach, |
| ) |
| # Multi-GPU |
| if not torch.cuda.device_count() > 1 or ignore_multidevice: |
| return |
| self._test_basic_cases_template( |
| torch.randn(10, 5).cuda(0), |
| torch.randn(10).cuda(1), |
| torch.randn(5).cuda(0), |
| constructor, |
| scheduler_constructors, |
| constructor_accepts_maximize, |
| constructor_accepts_foreach, |
| ) |
| |
| def _test_complex_optimizer(self, optimizer_constructor): |
| complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True) |
| real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_() |
| complex_opt = optimizer_constructor(complex_param) |
| real_opt = optimizer_constructor(real_param) |
| |
| for _ in range(3): |
| complex_param.grad = torch.randn_like(complex_param) |
| real_param.grad = torch.view_as_real(complex_param.grad) |
| complex_opt.step() |
| real_opt.step() |
| |
| self.assertEqual(torch.view_as_real(complex_param), real_param) |
| |
| def _test_complex_2d(self, optimizer_constructor, f=None): |
| if f is None: |
| f = rosenbrock |
| a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True) |
| a1_real = a1.real.clone().detach() |
| a1_imag = a1.imag.clone().detach() |
| a1_real.requires_grad_() |
| a1_imag.requires_grad_() |
| optim1 = optimizer_constructor([a1]) |
| optim2 = optimizer_constructor([a1_real, a1_imag]) |
| |
| for _ in range(10): |
| optim1.zero_grad() |
| optim2.zero_grad() |
| a2 = torch.complex(a1_real, a1_imag) |
| f(a1).backward() |
| f(a2).backward() |
| |
| self.assertEqual(a1.grad.real, a1_real.grad) |
| self.assertEqual(a1.grad.imag, a1_imag.grad) |
| |
| optim1.step() |
| optim2.step() |
| self.assertEqual(a1.real, a1_real) |
| self.assertEqual(a1.imag, a1_imag) |
| |
| def _build_params_dict(self, weight, bias, **kwargs): |
| return [{"params": [weight]}, dict(params=[bias], **kwargs)] |
| |
| def _build_params_dict_single(self, weight, bias, **kwargs): |
| return [dict(params=bias, **kwargs)] |
| |
| def test_sgd(self): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| self._build_params_dict_single(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| self._build_params_dict_single(weight, bias, lr=1e-2), |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [lambda opt: StepLR(opt, gamma=0.9, step_size=10)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [ |
| lambda opt: LinearLR( |
| opt, start_factor=0.4, end_factor=0.8, total_iters=4 |
| ) |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [ |
| lambda opt: StepLR(opt, gamma=0.9, step_size=10), |
| lambda opt: LinearLR( |
| opt, start_factor=0.4, end_factor=0.6, total_iters=4 |
| ), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [ |
| lambda opt: StepLR(opt, gamma=0.9, step_size=10), |
| lambda opt: ReduceLROnPlateau(opt), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [ |
| lambda opt: StepLR(opt, gamma=0.99, step_size=10), |
| lambda opt: ExponentialLR(opt, gamma=0.99), |
| lambda opt: ReduceLROnPlateau(opt), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], |
| lr=1e-3, |
| momentum=0.5, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], |
| lr=1e-3, |
| momentum=0.5, |
| weight_decay=1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], |
| nesterov=True, |
| lr=1e-3, |
| momentum=0.5, |
| weight_decay=1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.SGD( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| [lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"): |
| optim.SGD(None, lr=1e-2, momentum=-0.5) |
| |
| def test_sgd_sparse(self): |
| for foreach in (False, True): |
| self._test_rosenbrock_sparse( |
| lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach) |
| ) |
| self._test_rosenbrock_sparse( |
| lambda params: optim.SGD(params, lr=0.0048, foreach=foreach), |
| [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)], |
| ) |
| |
| def test_sgd_complex(self): |
| for foreach in (False, True): |
| self._test_complex_optimizer( |
| lambda param: optim.SGD([param], lr=0.001, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.SGD( |
| [param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach |
| ) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.SGD( |
| [param], |
| lr=0.001, |
| nesterov=True, |
| momentum=1, |
| weight_decay=1, |
| foreach=foreach, |
| ) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.SGD( |
| [param], |
| lr=0.001, |
| momentum=1, |
| dampening=0.5, |
| weight_decay=1, |
| foreach=foreach, |
| ) |
| ) |
| |
| def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag): |
| if not torch.cuda.is_available(): |
| return |
| assert flag in ("foreach", "fused") |
| |
| kIterations = 4 |
| device = "cuda" |
| for optimizer_constructor, params in optimizer_pairs_with_flags: |
| res, state = [], [] |
| for foreach in (False, True): |
| input = torch.tensor( |
| [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device |
| ).reshape(3, 2) |
| |
| torch.manual_seed(1) |
| model = torch.nn.Sequential( |
| torch.nn.Linear(2, 3), |
| torch.nn.Sigmoid(), |
| torch.nn.Linear(3, 1), |
| torch.nn.Sigmoid(), |
| ) |
| model.to(dtype=torch.float64, device=device) |
| params_with_foreach = deepcopy(params) |
| params_with_foreach["foreach"] = foreach |
| optimizer = optimizer_constructor( |
| model.parameters(), **params_with_foreach |
| ) |
| |
| for _ in range(kIterations): |
| optimizer.zero_grad() |
| output = model(input) |
| loss = output.sum() |
| loss.backward() |
| |
| # Test that step behaves as expected (a no-op) when grads are set to None |
| if iter == 0: |
| optimizer.zero_grad(set_to_none=True) |
| |
| optimizer.step() |
| |
| state.append(optimizer.state) |
| res.append(model.parameters()) |
| |
| st_state = state[0] |
| mt_state = state[1] |
| for st_p, mt_p in zip(res[0], res[1]): |
| self.assertEqual(st_p, mt_p, atol=5e-5, rtol=0) |
| |
| # check that optimizer states are the same |
| st_p_state = st_state[st_p] |
| mt_p_state = mt_state[mt_p] |
| |
| for k in st_p_state: |
| actual = mt_p_state[k] |
| # If `torch.optim.Adam` is `__init__`ed with either `fused=True` or `capturable=True`, |
| # `step` Tensor is 1D while usually it's 0D. |
| if ( |
| k == "step" |
| and isinstance(actual, torch.Tensor) |
| and actual.ndim == 1 |
| ): |
| actual = actual[0] |
| self.assertEqual(st_p_state[k], actual, atol=5e-5, rtol=0) |
| |
| def test_multi_tensor_optimizers(self): |
| optimizer_pairs_with_flags = [ |
| (optim.Adam, dict(weight_decay=1.0, amsgrad=True)), |
| (optim.Adam, dict(weight_decay=1.0, amsgrad=False)), |
| (optim.Adam, dict(weight_decay=0.0, amsgrad=True)), |
| (optim.Adam, dict(weight_decay=0.0, amsgrad=False)), |
| (optim.AdamW, dict(weight_decay=1.0, amsgrad=True)), |
| (optim.AdamW, dict(weight_decay=1.0, amsgrad=False)), |
| (optim.AdamW, dict(weight_decay=0.0, amsgrad=True)), |
| (optim.AdamW, dict(weight_decay=0.0, amsgrad=False)), |
| (optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)), |
| (optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)), |
| (optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)), |
| (optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)), |
| ( |
| optim.SGD, |
| dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True), |
| ), |
| ( |
| optim.SGD, |
| dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False), |
| ), |
| (optim.RAdam, dict(weight_decay=0)), |
| (optim.RAdam, dict(weight_decay=1)), |
| (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)), |
| (optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)), |
| (optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)), |
| (optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)), |
| (optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))), |
| (optim.ASGD, dict(weight_decay=0)), |
| (optim.ASGD, dict(weight_decay=1)), |
| (optim.Adamax, dict(weight_decay=0)), |
| (optim.Adamax, dict(weight_decay=1)), |
| (optim.Adadelta, dict(weight_decay=0)), |
| (optim.Adadelta, dict(weight_decay=1)), |
| (optim.Adagrad, dict(weight_decay=0)), |
| (optim.Adagrad, dict(weight_decay=1)), |
| ] |
| self._test_derived_optimizers(optimizer_pairs_with_flags, "foreach") |
| |
| def test_fused_optimizers(self): |
| optimizer_pairs_with_flags = [ |
| (optim.Adam, dict(weight_decay=1.0, amsgrad=False)), |
| (optim.Adam, dict(weight_decay=1.0, amsgrad=True)), |
| (optim.Adam, dict(weight_decay=0.0, amsgrad=False)), |
| (optim.Adam, dict(weight_decay=0.0, amsgrad=True)), |
| ] |
| self._test_derived_optimizers(optimizer_pairs_with_flags, "fused") |
| |
| def test_adam(self): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| [weight, bias], |
| lr=1e-3, |
| amsgrad=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| [weight, bias], |
| lr=1e-3, |
| weight_decay=0.1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| amsgrad=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [lambda opt: ExponentialLR(opt, gamma=0.9)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| [weight, bias], |
| lr=1e-3, |
| amsgrad=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [ |
| lambda opt: ConstantLR(opt, factor=0.4, total_iters=4), |
| lambda opt: ExponentialLR(opt, gamma=0.9), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| [weight, bias], |
| lr=1e-3, |
| amsgrad=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [ |
| lambda opt: ExponentialLR(opt, gamma=0.9), |
| lambda opt: ReduceLROnPlateau(opt), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| amsgrad=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [ |
| lambda opt: StepLR(opt, gamma=0.9, step_size=10), |
| lambda opt: ReduceLROnPlateau(opt), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adam( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_complex_2d(optim.Adam) |
| self._test_complex_2d(functools.partial(optim.Adam, foreach=True)) |
| |
| with self.assertRaisesRegex( |
| ValueError, "Invalid beta parameter at index 0: 1.0" |
| ): |
| optim.Adam(None, lr=1e-2, betas=(1.0, 0.0)) |
| |
| with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): |
| optim.Adam(None, lr=1e-2, weight_decay=-1) |
| |
| def test_adamw(self): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.AdamW( |
| [weight, bias], lr=1e-3, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.AdamW( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.AdamW( |
| [weight, bias], |
| lr=1e-3, |
| weight_decay=1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.AdamW( |
| [weight, bias], |
| lr=1e-3, |
| weight_decay=1, |
| amsgrad=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_complex_2d(optim.AdamW) |
| self._test_complex_2d(functools.partial(optim.AdamW, foreach=True)) |
| with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): |
| optim.AdamW(None, lr=1e-2, weight_decay=-1) |
| |
| def test_sparse_adam(self): |
| self._test_rosenbrock_sparse( |
| lambda params: optim.SparseAdam(params, lr=4e-2), [], True |
| ) |
| self._test_rosenbrock_sparse( |
| lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True), |
| [], |
| True, |
| True, |
| ) |
| with self.assertRaisesRegex( |
| ValueError, "Invalid beta parameter at index 0: 1.0" |
| ): |
| optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0)) |
| with self.assertRaisesRegex( |
| ValueError, "SparseAdam requires dense parameter tensors" |
| ): |
| optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)]) |
| with self.assertRaisesRegex( |
| ValueError, "SparseAdam requires dense parameter tensors" |
| ): |
| optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}]) |
| |
| # ROCm precision is too low to pass this test |
| def test_adadelta(self): |
| # Handles https://github.com/pytorch/pytorch/issues/69698 |
| self.rel_tol = 4e-3 |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adadelta( |
| [weight, bias], maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adadelta( |
| self._build_params_dict(weight, bias, rho=0.95), |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adadelta( |
| self._build_params_dict(weight, bias, rho=0.95), |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [ |
| lambda opt: StepLR(opt, gamma=0.9, step_size=10), |
| lambda opt: ReduceLROnPlateau(opt), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adadelta( |
| [weight, bias], weight_decay=1, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"): |
| optim.Adadelta(None, lr=1e-2, rho=1.1) |
| |
| def test_adadelta_complex(self): |
| # Handles https://github.com/pytorch/pytorch/issues/69698 |
| self.rel_tol = 2e-2 |
| for optimizer in [optim.Adadelta]: |
| self._test_complex_optimizer(lambda weight: optimizer([weight])) |
| self._test_complex_optimizer(lambda weight: optimizer([weight], rho=0.95)) |
| self._test_complex_optimizer( |
| lambda weight: optimizer([weight], rho=0.95, weight_decay=1) |
| ) |
| |
| def test_nadam(self): |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.NAdam( |
| [weight, bias], lr=1e-3, foreach=foreach |
| ), |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.NAdam( |
| self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach |
| ), |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.NAdam( |
| [weight, bias], |
| lr=1e-3, |
| weight_decay=0.1, |
| momentum_decay=6e-3, |
| foreach=foreach, |
| ), |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.NAdam( |
| [weight, bias], |
| lr=1e-3, |
| weight_decay=0.1, |
| momentum_decay=6e-3, |
| foreach=foreach, |
| ), |
| [lambda opt: ExponentialLR(opt, gamma=0.9)], |
| constructor_accepts_foreach=True, |
| ) |
| with self.assertRaisesRegex( |
| ValueError, "Invalid beta parameter at index 0: 1.0" |
| ): |
| optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0)) |
| with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"): |
| optim.NAdam(None, lr=1e-2, momentum_decay=-0.2) |
| |
| def test_adagrad(self): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adagrad( |
| [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adagrad( |
| [weight, bias], |
| lr=1e-1, |
| initial_accumulator_value=0.1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adagrad( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adagrad( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [lambda opt: ReduceLROnPlateau(opt)], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adagrad( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| [ |
| lambda opt: ReduceLROnPlateau(opt), |
| lambda opt: ExponentialLR(opt, gamma=0.99), |
| ], |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"): |
| optim.Adagrad(None, lr=1e-2, lr_decay=-0.5) |
| |
| def test_adagrad_sparse(self): |
| for foreach in (False, True): |
| self._test_rosenbrock_sparse( |
| lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach) |
| ) |
| self._test_rosenbrock_sparse( |
| lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach), |
| [ |
| lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500), |
| lambda opt: ReduceLROnPlateau(opt, threshold=1e-4), |
| ], |
| ) |
| |
| def test_adagrad_complex(self): |
| for foreach in (False, True): |
| self._test_complex_optimizer( |
| lambda param: optim.Adagrad([param], lr=1e-1, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.Adagrad( |
| [param], |
| lr=1e-1, |
| initial_accumulator_value=0.1, |
| foreach=foreach, |
| ) |
| ) |
| |
| def test_adamax(self): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adamax( |
| [weight, bias], lr=1e-1, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adamax( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Adamax( |
| [weight, bias], |
| lr=1e-1, |
| weight_decay=1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_complex_2d(optim.Adamax) |
| self._test_complex_2d(functools.partial(optim.Adamax, foreach=True)) |
| with self.assertRaisesRegex( |
| ValueError, "Invalid beta parameter at index 1: 1.0" |
| ): |
| optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0)) |
| |
| def test_radam(self): |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.RAdam( |
| [weight, bias], lr=1e-3, foreach=foreach |
| ), |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.RAdam( |
| self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach |
| ), |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.RAdam( |
| [weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach |
| ), |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, foreach: optim.RAdam( |
| [weight, bias], lr=1e-3, foreach=foreach |
| ), |
| [ |
| lambda opt: ExponentialLR(opt, gamma=0.9), |
| lambda opt: ReduceLROnPlateau(opt), |
| ], |
| constructor_accepts_foreach=True, |
| ) |
| with self.assertRaisesRegex( |
| ValueError, "Invalid beta parameter at index 0: 1.0" |
| ): |
| optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0)) |
| |
| with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"): |
| optim.RAdam(None, lr=1e-2, weight_decay=-1) |
| |
| def test_rmsprop(self): |
| for foreach in (False, True): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.RMSprop( |
| [weight, bias], lr=1e-2, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.RMSprop( |
| self._build_params_dict(weight, bias, lr=1e-3), |
| lr=1e-2, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.RMSprop( |
| self._build_params_dict(weight, bias, lr=1e-3), |
| lr=1e-2, |
| centered=True, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.RMSprop( |
| self._build_params_dict(weight, bias, lr=1e-3), |
| lr=1e-2, |
| centered=True, |
| momentum=0.1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.RMSprop( |
| self._build_params_dict(weight, bias, lr=1e-3), |
| lr=1e-2, |
| momentum=0.1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.RMSprop( |
| self._build_params_dict(weight, bias, lr=1e-3), |
| lr=1e-2, |
| momentum=0.1, |
| weight_decay=1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach)) |
| self._test_complex_2d( |
| lambda param: optim.RMSprop(param, centered=True, foreach=foreach) |
| ) |
| self._test_complex_2d( |
| lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach) |
| ) |
| self._test_complex_2d( |
| lambda param: optim.RMSprop(param, maximize=True, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.RMSprop([param], foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.RMSprop([param], centered=True, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.RMSprop([param], maximize=True, foreach=foreach) |
| ) |
| with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): |
| optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach) |
| |
| def test_asgd(self): |
| for foreach in (False, True): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.ASGD( |
| [weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.ASGD( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| t0=100, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.ASGD( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=1e-3, |
| weight_decay=1, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| # Ref: https://github.com/pytorch/pytorch/issues/84560 |
| # self._test_complex_2d(optimizer) |
| self._test_complex_optimizer( |
| lambda params: optim.ASGD([params], foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda params: optim.ASGD([params], maximize=True, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda params: optim.ASGD( |
| [params], maximize=True, weight_decay=0.9, foreach=foreach |
| ) |
| ) |
| self._test_complex_optimizer( |
| lambda params: optim.ASGD( |
| [params], maximize=False, weight_decay=0.9, foreach=foreach |
| ) |
| ) |
| self._test_complex_optimizer( |
| lambda params: optim.ASGD([params], weight_decay=0.9, foreach=foreach) |
| ) |
| with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"): |
| optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach) |
| |
| @skipIfRocm |
| def test_rprop(self): |
| is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability( |
| 0 |
| ) == (8, 6) |
| for foreach in (False, True): |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Rprop( |
| [weight, bias], lr=2e-4, maximize=maximize, foreach=foreach |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| ) |
| self._test_basic_cases( |
| lambda weight, bias, maximize, foreach: optim.Rprop( |
| self._build_params_dict(weight, bias, lr=1e-2), |
| lr=2e-4, |
| maximize=maximize, |
| foreach=foreach, |
| ), |
| constructor_accepts_maximize=True, |
| constructor_accepts_foreach=True, |
| atol=4e-5 if is_cuda_sm86 else None, |
| rtol=3e-5 if is_cuda_sm86 else None, |
| ) |
| self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach)) |
| self._test_complex_optimizer( |
| lambda param: optim.Rprop([param], lr=0.001, foreach=foreach) |
| ) |
| self._test_complex_optimizer( |
| lambda param: optim.Rprop( |
| [param], lr=0.001, maximize=True, foreach=foreach |
| ) |
| ) |
| with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"): |
| optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach) |
| |
| def test_lbfgs(self): |
| self._test_basic_cases( |
| lambda weight, bias: optim.LBFGS([weight, bias]), ignore_multidevice=True |
| ) |
| self._test_basic_cases( |
| lambda weight, bias: optim.LBFGS( |
| [weight, bias], line_search_fn="strong_wolfe" |
| ), |
| ignore_multidevice=True, |
| ) |
| |
| @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN") |
| def test_lbfgs_return_type(self): |
| params = [torch.randn(10, 5), torch.randn(10)] |
| opt1 = optim.LBFGS(params, 0.01, tolerance_grad=math.inf) |
| opt2 = optim.LBFGS(params, 0.01, tolerance_grad=-math.inf) |
| |
| def closure(): |
| return torch.tensor([10]) |
| |
| res1 = opt1.step(closure) |
| res2 = opt2.step(closure) |
| self.assertEqual(type(res1), type(res2)) |
| |
| def test_invalid_param_type(self): |
| with self.assertRaises(TypeError): |
| optim.SGD(Parameter(torch.randn(5, 5)), lr=3) |
| |
| def test_duplicate_params_in_param_group(self): |
| param = Parameter(torch.randn(5, 5)) |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter("always") |
| optim.SGD([param, param], lr=0.1) |
| self.assertEqual(len(w), 1) |
| self.assertIn( |
| "a parameter group with duplicate parameters", str(w[0].message) |
| ) |
| |
| def test_no_grad_for_all_params(self): |
| params = [torch.randn(5, 5, requires_grad=False) for _ in range(2)] |
| |
| optimizer_list = [ |
| optim.Adadelta, |
| optim.AdamW, |
| optim.Adam, |
| optim.Adagrad, |
| optim.Adamax, |
| optim.RMSprop, |
| optim.SGD, |
| optim.SparseAdam, |
| optim.ASGD, |
| ] |
| for optim_ctr in optimizer_list: |
| opt = optim_ctr(params, lr=0.1) |
| # make sure step can still run even if |
| # all params have no grad |
| opt.step() |
| |
| # make sure that `state_steps` is correctly either updated or not updated when `found_inf`. |
| def test_functional_fused_adam_with_foundinf(self): |
| if not torch.cuda.is_available(): |
| self.skipTest("CUDA is required.") |
| |
| from torch.optim import adam |
| |
| num_tensors = 5 |
| for amsgrad in (False, True): |
| params, grads, exp_avgs, exp_avg_sqs = [ |
| [torch.ones((1,), device="cuda") for _ in range(num_tensors)] |
| for _ in range(4) |
| ] |
| max_exp_avg_sqs = ( |
| [torch.ones((1,), device="cuda") for _ in range(num_tensors)] |
| if amsgrad |
| else [] |
| ) |
| state_steps = [ |
| torch.ones((1,), dtype=torch.float32, device="cuda") |
| for _ in range(num_tensors) |
| ] |
| grad_scale = torch.cuda.amp.grad_scaler._MultiDeviceReplicator( |
| torch.ones((1,), dtype=torch.float32, device="cuda") |
| ) |
| found_inf = torch.cuda.amp.grad_scaler._MultiDeviceReplicator( |
| torch.ones((1,), dtype=torch.float32, device="cuda") |
| ) |
| |
| adam.adam( |
| params, |
| grads, |
| exp_avgs, |
| exp_avg_sqs, |
| max_exp_avg_sqs, |
| state_steps, |
| foreach=False, |
| capturable=False, |
| fused=True, |
| amsgrad=amsgrad, |
| beta1=0.9, |
| beta2=0.99, |
| lr=1e-2, |
| weight_decay=0.0, |
| eps=1e-8, |
| maximize=False, |
| grad_scale=grad_scale, |
| found_inf=found_inf, |
| ) |
| |
| self.assertEqual( |
| state_steps, |
| [ |
| torch.ones((1,), dtype=torch.float32, device="cuda") |
| for _ in range(num_tensors) |
| ], |
| ) |
| |
| def test_empty_grad(self): |
| optimizers = [ |
| torch.optim.Adadelta, |
| torch.optim.Adagrad, |
| torch.optim.Adam, |
| torch.optim.AdamW, |
| torch.optim.Adamax, |
| torch.optim.ASGD, |
| torch.optim.NAdam, |
| torch.optim.RAdam, |
| torch.optim.RMSprop, |
| torch.optim.Rprop, |
| torch.optim.SGD, |
| torch.optim.SparseAdam, |
| ] |
| |
| for optimizer in optimizers: |
| net = torch.nn.Embedding( |
| 5, 1, padding_idx=0, sparse=optimizer is torch.optim.SparseAdam |
| ) |
| original_params = (param.detach().clone() for param in net.parameters()) |
| # Simulate a batch that only indexes the embedding at padding_idx |
| x = torch.tensor([[0, 0]]).int() |
| y = torch.tensor([[[3.0], [4.0]]]) |
| opt = optimizer(net.parameters(), lr=1e-5) |
| torch.nn.MSELoss()(net.forward(x), y).backward() |
| |
| opt.step() |
| |
| for original_param, param in zip(original_params, net.parameters()): |
| # assert that the parameters have not changed |
| self.assertEqual(original_param, param) |
| |
| @skipIfTorchDynamo() |
| def test_post_hook(self): |
| def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): |
| nonlocal data |
| data += 2 |
| |
| params = [torch.Tensor([1, 1])] |
| opt = SGD(params, lr=0.001) |
| data = 2 |
| hook_handle = opt.register_step_post_hook(post_hook) |
| |
| opt.step() |
| opt.step() |
| # check if pre hooks were registered |
| self.assertEqual(data, 6) |
| |
| # remove handles, take step and verify that hook is no longer registered |
| hook_handle.remove() |
| |
| opt.step() |
| self.assertEqual(data, 6) |
| |
| @skipIfTorchDynamo() |
| def test_pre_hook(self): |
| def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): |
| nonlocal data |
| data += 2 |
| |
| params = [torch.Tensor([1, 1])] |
| opt = SGD(params, lr=0.001) |
| data = 5 |
| hook_handle = opt.register_step_pre_hook(pre_hook) |
| |
| opt.step() |
| opt.step() |
| # check if pre hooks were registered |
| self.assertEqual(data, 9) |
| |
| # remove handles, take step and verify that hook is no longer registered |
| hook_handle.remove() |
| |
| opt.step() |
| self.assertEqual(data, 9) |
| |
| @skipIfTorchDynamo() |
| def test_pre_and_post_hook(self): |
| def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): |
| nonlocal data |
| data.append(0) |
| |
| def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): |
| nonlocal data |
| data.append(5) |
| |
| def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): |
| nonlocal data |
| data.append(1) |
| |
| def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): |
| nonlocal data |
| data.append(2) |
| |
| params = [torch.Tensor([1, 1])] |
| opt1 = SGD(params, lr=0.001) |
| opt2 = Adam(params, lr=0.01) |
| data = [] |
| |
| # register global hooks to both optimizers |
| global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook) |
| global_post_handle = register_optimizer_step_post_hook(global_post_hook) |
| |
| # register local hooks |
| first_pre_handle = opt1.register_step_pre_hook(local_pre_hook) |
| first_post_handle = opt1.register_step_post_hook(local_post_hook) |
| second_pre_handle = opt2.register_step_pre_hook(local_pre_hook) |
| second_post_handle = opt2.register_step_post_hook(local_post_hook) |
| |
| opt1.step() |
| self.assertListEqual(data, [0, 1, 2, 5]) |
| opt2.step() |
| self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5]) |
| opt1.step() |
| self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) |
| |
| # remove all hooks |
| global_pre_handle.remove() |
| global_post_handle.remove() |
| first_pre_handle.remove() |
| first_post_handle.remove() |
| second_pre_handle.remove() |
| second_post_handle.remove() |
| |
| opt1.step() |
| opt2.step() |
| self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5]) |
| |
| |
| class SchedulerTestNet(torch.nn.Module): |
| def __init__(self): |
| super(SchedulerTestNet, self).__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| |
| def forward(self, x): |
| return self.conv2(F.relu(self.conv1(x))) |
| |
| |
| class LambdaLRTestObject: |
| def __init__(self, value): |
| self.value = value |
| |
| def __call__(self, epoch): |
| return self.value * epoch |
| |
| def __eq__(self, other): |
| if isinstance(other, self.__class__): |
| return self.__dict__ == other.__dict__ |
| else: |
| return False |
| |
| |
| class TestLRScheduler(TestCase): |
| exact_dtype = True |
| |
| def setUp(self): |
| super(TestLRScheduler, self).setUp() |
| self.net = SchedulerTestNet() |
| self.opt = SGD( |
| [ |
| {"params": self.net.conv1.parameters()}, |
| {"params": self.net.conv2.parameters(), "lr": 0.5}, |
| ], |
| lr=0.05, |
| ) |
| |
| def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1): |
| """This function swallows the epoch deprecation warning which is produced when we |
| call `scheduler.step(epoch)` with some not `None` value of `epoch`. |
| this is deprecated, and this function will need to be removed/updated when |
| the schedulers no longer accept the parameter at all. |
| """ |
| self.assertEqual(len(w), num_warnings) |
| for warning in w: |
| self.assertEqual(len(warning.message.args), 1) |
| self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING) |
| |
| def test_error_when_getlr_has_epoch(self): |
| class MultiStepLR(torch.optim.lr_scheduler.LRScheduler): |
| def __init__(self, optimizer, gamma, milestones, last_epoch=-1): |
| self.init_lr = [group["lr"] for group in optimizer.param_groups] |
| self.gamma = gamma |
| self.milestones = milestones |
| super().__init__(optimizer, last_epoch) |
| |
| def get_lr(self, step): |
| global_step = self.last_epoch |
| gamma_power = ( |
| [0] |
| + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m] |
| )[-1] |
| return [ |
| init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr |
| ] |
| |
| optimizer = torch.optim.SGD([torch.rand(1)], lr=1) |
| |
| with self.assertRaises(TypeError): |
| scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20]) |
| |
| @skipIfTorchDynamo("Torchdynamo keeps references to optim in the guards and the stack of the graph break frames") |
| def test_no_cyclic_references(self): |
| import gc |
| |
| param = Parameter(torch.empty(10)) |
| optim = SGD([param], lr=0.5) |
| scheduler = LambdaLR(optim, lambda epoch: 1.0) |
| del scheduler |
| |
| # Prior to Python 3.7, local variables in a function will be referred by the current frame. |
| import sys |
| |
| if sys.version_info < (3, 7): |
| import inspect |
| |
| referrers = gc.get_referrers(optim) |
| self.assertTrue( |
| len(referrers) == 1 and referrers[0] is inspect.currentframe(), |
| "Optimizer should contain no cyclic references (except current frame)", |
| ) |
| del referrers |
| else: |
| self.assertTrue( |
| len(gc.get_referrers(optim)) == 0, |
| "Optimizer should contain no cyclic references", |
| ) |
| |
| gc.collect() |
| del optim |
| self.assertEqual( |
| gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__" |
| ) |
| |
| @skipIfTorchDynamo("Torchdynamo keeps references to optim in the guards and the stack of the graph break frames") |
| def test_no_cyclic_references_in_step(self): |
| import gc |
| import weakref |
| |
| def run(): |
| param = torch.empty(10, requires_grad=True) |
| optim = SGD(params=[param], lr=0.5) |
| scheduler = LambdaLR(optim, lambda epoch: 1.0) |
| param.sum().backward() |
| optim.step() |
| scheduler.step() |
| |
| return weakref.ref(scheduler) |
| |
| # To ensure that there are no reference cycles in scheduler, |
| # we need to turn off the garbage collector. Since gc will |
| # automatically collect unreachable objects. |
| gc.disable() |
| ref = run() |
| |
| assert ref() is None |
| gc.enable() # restore |
| |
| def test_old_pattern_warning(self): |
| epochs = 35 |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| def old_pattern(): |
| for _ in range(epochs): |
| scheduler.step() |
| self.opt.step() |
| |
| self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) |
| |
| def test_old_pattern_warning_with_arg(self): |
| epochs = 35 |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| def old_pattern2(): |
| for _ in range(epochs): |
| scheduler.step() |
| self.opt.step() |
| |
| self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) |
| |
| def test_old_pattern_warning_resuming(self): |
| epochs = 35 |
| for i, group in enumerate(self.opt.param_groups): |
| group["initial_lr"] = 0.01 |
| |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| def old_pattern(): |
| for _ in range(epochs): |
| scheduler.step() |
| self.opt.step() |
| |
| self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern) |
| |
| def test_old_pattern_warning_resuming_with_arg(self): |
| epochs = 35 |
| for i, group in enumerate(self.opt.param_groups): |
| group["initial_lr"] = 0.01 |
| |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| def old_pattern2(): |
| for _ in range(epochs): |
| scheduler.step() |
| self.opt.step() |
| |
| self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) |
| |
| def test_old_pattern_warning_with_overridden_optim_step(self): |
| epochs = 35 |
| for i, group in enumerate(self.opt.param_groups): |
| group["initial_lr"] = 0.01 |
| |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| # emulate use-case with optimizer.step overridden |
| import types |
| |
| old_step = self.opt.step |
| |
| def new_step(o, *args, **kwargs): |
| retval = old_step(*args, **kwargs) |
| return retval |
| |
| self.opt.step = types.MethodType(new_step, self.opt) |
| |
| def old_pattern2(): |
| for _ in range(epochs): |
| scheduler.step() |
| self.opt.step() |
| |
| self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2) |
| |
| def test_new_pattern_no_warning(self): |
| epochs = 35 |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| for _ in range(epochs): |
| self.opt.step() |
| scheduler.step() |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| def test_new_pattern_no_warning_with_arg(self): |
| epochs = 35 |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| for _ in range(epochs): |
| self.opt.step() |
| scheduler.step() |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| def test_new_pattern_no_warning_with_overridden_optim_step(self): |
| epochs = 35 |
| with warnings.catch_warnings(record=True) as ws: |
| warnings.simplefilter("always") # allow any warning to be raised |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self.assertTrue(len(ws) == 0, "No warning should be raised") |
| |
| # emulate use-case with optimizer.step overridden |
| import types |
| |
| old_step = self.opt.step |
| |
| def new_step(o, *args, **kwargs): |
| retval = old_step(*args, **kwargs) |
| return retval |
| |
| self.opt.step = types.MethodType(new_step, self.opt) |
| |
| def new_pattern(): |
| for e in range(epochs): |
| self.opt.step() |
| scheduler.step() |
| |
| self.assertWarnsRegex( |
| UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern |
| ) |
| |
| def _test_lr_is_constant_for_constant_epoch(self, scheduler): |
| l = [] |
| |
| for _ in range(10): |
| scheduler.optimizer.step() |
| with warnings.catch_warnings(record=True) as w: |
| scheduler.step(2) |
| self._check_warning_is_epoch_deprecation_warning(w) |
| |
| l.append(self.opt.param_groups[0]["lr"]) |
| self.assertEqual(min(l), max(l)) |
| |
| def test_step_lr_is_constant_for_constant_epoch(self): |
| scheduler = StepLR(self.opt, 2) |
| self._test_lr_is_constant_for_constant_epoch(scheduler) |
| |
| def test_exponential_lr_is_constant_for_constant_epoch(self): |
| scheduler = ExponentialLR(self.opt, gamma=0.9) |
| self._test_lr_is_constant_for_constant_epoch(scheduler) |
| |
| def test_constantlr_is_constant_for_constant_epoch(self): |
| scheduler = ConstantLR(self.opt) |
| self._test_lr_is_constant_for_constant_epoch(scheduler) |
| |
| def test_linear_linearlr_is_constant_for_constant_epoch(self): |
| scheduler = LinearLR(self.opt) |
| self._test_lr_is_constant_for_constant_epoch(scheduler) |
| |
| def test_polynomial_lr_is_constant_for_constant_epoch(self): |
| scheduler = PolynomialLR(self.opt, power=0.9) |
| self._test_lr_is_constant_for_constant_epoch(scheduler) |
| |
| def test_step_lr(self): |
| # lr = 0.05 if epoch < 3 |
| # lr = 0.005 if 30 <= epoch < 6 |
| # lr = 0.0005 if epoch >= 9 |
| epochs = 10 |
| single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self._test(scheduler, targets, epochs) |
| |
| def test_get_last_lr_step_lr(self): |
| from torch.nn import Parameter |
| |
| epochs = 10 |
| optimizer = torch.optim.SGD( |
| [Parameter(torch.randn(2, 2, requires_grad=True))], 0.1 |
| ) |
| targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]] |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1) |
| self._test_get_last_lr(scheduler, targets, epochs) |
| |
| def test_get_last_lr_multi_step_lr(self): |
| # lr = 0.05 if epoch < 2 |
| # lr = 0.005 if 2 <= epoch < 5 |
| # lr = 0.0005 if 5 <= epoch < 9 |
| # lr = 0.00005 if 9 <= epoch |
| epochs = 10 |
| single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| self._test_get_last_lr(scheduler, targets, epochs) |
| |
| def test_multi_step_lr(self): |
| # lr = 0.05 if epoch < 2 |
| # lr = 0.005 if 2 <= epoch < 5 |
| # lr = 0.0005 if epoch < 9 |
| # lr = 0.00005 if epoch >= 9 |
| epochs = 10 |
| single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| self._test(scheduler, targets, epochs) |
| |
| def test_multi_step_lr_with_epoch(self): |
| # lr = 0.05 if epoch < 2 |
| # lr = 0.005 if 2 <= epoch < 5 |
| # lr = 0.0005 if epoch < 9 |
| # lr = 0.00005 if epoch >= 9 |
| epochs = 10 |
| single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| self._test_with_epoch(scheduler, targets, epochs) |
| |
| def test_get_last_lr_constantlr(self): |
| # lr = 0.025 if epoch < 5 |
| # lr = 0.005 if 5 <= epoch |
| epochs = 10 |
| single_targets = [0.025] * 5 + [0.05] * 5 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) |
| self._test_get_last_lr(scheduler, targets, epochs) |
| |
| def test_get_last_lr_linearlr(self): |
| # lr = 0.025 if epoch == 0 |
| # lr = 0.03125 if epoch == 1 |
| # lr = 0.0375 if epoch == 2 |
| # lr = 0.04375 if epoch == 3 |
| # lr = 0.005 if 4 <= epoch |
| epochs = 10 |
| start_factor = 1.0 / 4 |
| end_factor = 3.0 / 5 |
| iters = 4 |
| interpolation = [ |
| start_factor + i * (end_factor - start_factor) / iters for i in range(iters) |
| ] |
| single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * ( |
| epochs - iters |
| ) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = LinearLR( |
| self.opt, |
| start_factor=start_factor, |
| end_factor=end_factor, |
| total_iters=iters, |
| ) |
| self._test_get_last_lr(scheduler, targets, epochs) |
| |
| def test_constantlr(self): |
| # lr = 0.025 if epoch < 5 |
| # lr = 0.005 if 5 <= epoch |
| epochs = 10 |
| single_targets = [0.025] * 5 + [0.05] * 5 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) |
| self._test(scheduler, targets, epochs) |
| |
| def test_linearlr(self): |
| # lr = 0.025 if epoch == 0 |
| # lr = 0.03125 if epoch == 1 |
| # lr = 0.0375 if epoch == 2 |
| # lr = 0.04375 if epoch == 3 |
| # lr = 0.005 if 4 <= epoch |
| epochs = 10 |
| start_factor = 1.0 / 2 |
| iters = 4 |
| interpolation = [ |
| start_factor + i * (1 - start_factor) / iters for i in range(iters) |
| ] |
| single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| self._test(scheduler, targets, epochs) |
| |
| def test_linearlr_start_factor_limits1(self): |
| start_factor = 0.0 |
| iters = 4 |
| with self.assertRaises(ValueError): |
| LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| |
| def test_linearlr_start_factor_limits2(self): |
| start_factor = 1.1 |
| iters = 4 |
| with self.assertRaises(ValueError): |
| LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| |
| def test_constantlr_with_epoch(self): |
| # lr = 0.025 if epoch < 5 |
| # lr = 0.005 if 5 <= epoch |
| epochs = 10 |
| single_targets = [0.025] * 5 + [0.05] * 5 |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5) |
| self._test_with_epoch(scheduler, targets, epochs) |
| |
| def test_linearlr_with_epoch(self): |
| # lr = 0.025 if epoch == 0 |
| # lr = 0.03125 if epoch == 1 |
| # lr = 0.0375 if epoch == 2 |
| # lr = 0.04375 if epoch == 3 |
| # lr = 0.005 if 4 <= epoch |
| epochs = 10 |
| start_factor = 1.0 / 2 |
| end_factor = 1.0 |
| iters = 4 |
| interpolation = [ |
| start_factor + i * (end_factor - start_factor) / iters for i in range(iters) |
| ] |
| single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| self._test_with_epoch(scheduler, targets, epochs) |
| |
| def test_exp_lr(self): |
| epochs = 10 |
| single_targets = [0.05 * (0.9**x) for x in range(epochs)] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = ExponentialLR(self.opt, gamma=0.9) |
| self._test(scheduler, targets, epochs) |
| |
| def test_poly_lr(self): |
| epochs = 10 |
| power = 0.9 |
| total_iters = 5 |
| single_targets = [ |
| (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters) |
| ] + [0.0] * (epochs - total_iters) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters) |
| self._test(scheduler, targets, epochs) |
| |
| def test_cos_anneal_lr(self): |
| epochs = 10 |
| eta_min = 1e-10 |
| single_targets = [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 |
| for x in range(epochs) |
| ] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) |
| self._test(scheduler, targets, epochs) |
| |
| def test_closed_form_step_lr(self): |
| scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, 20) |
| |
| def test_closed_form_linearlr(self): |
| scheduler = LinearLR( |
| self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 |
| ) |
| closed_form_scheduler = LinearLR( |
| self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4 |
| ) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, 20) |
| |
| def test_closed_form_constantlr(self): |
| scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) |
| closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, 20) |
| |
| def test_closed_form_multi_step_lr(self): |
| scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, 20) |
| |
| def test_closed_form_exp_lr(self): |
| scheduler = ExponentialLR(self.opt, gamma=0.9) |
| closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, 20) |
| |
| def test_closed_form_poly_lr(self): |
| scheduler = PolynomialLR(self.opt, power=0.9) |
| closed_form_scheduler = PolynomialLR(self.opt, power=0.9) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, 20) |
| |
| def test_closed_form_cos_anneal_lr(self): |
| eta_min = 1e-10 |
| epochs = 20 |
| T_max = 5 |
| scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) |
| closed_form_scheduler = CosineAnnealingLR( |
| self.opt, T_max=T_max, eta_min=eta_min |
| ) |
| self._test_against_closed_form(scheduler, closed_form_scheduler, epochs) |
| |
| def test_cos_anneal_lr_continue(self): |
| eta_min = 0.1 |
| T_max = 5 |
| scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min) |
| self.opt.step() |
| scheduler.step() |
| original_lrs = scheduler._last_lr |
| new_scheduler = CosineAnnealingLR( |
| self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0 |
| ) |
| new_lrs = new_scheduler._last_lr |
| torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5) |
| |
| def test_reduce_lr_on_plateau1(self): |
| epochs = 10 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 20] |
| metrics = [10 - i * 0.0167 for i in range(20)] |
| scheduler = ReduceLROnPlateau( |
| self.opt, |
| threshold_mode="abs", |
| mode="min", |
| threshold=0.01, |
| patience=5, |
| cooldown=5, |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau2(self): |
| epochs = 22 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2] |
| metrics = [10 - i * 0.0165 for i in range(22)] |
| scheduler = ReduceLROnPlateau( |
| self.opt, |
| patience=5, |
| cooldown=0, |
| threshold_mode="abs", |
| mode="min", |
| threshold=0.1, |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau3(self): |
| epochs = 22 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4] |
| metrics = [-0.8] * 2 + [-0.234] * 20 |
| scheduler = ReduceLROnPlateau( |
| self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau4(self): |
| epochs = 20 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 20] |
| metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 |
| scheduler = ReduceLROnPlateau( |
| self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau5(self): |
| epochs = 20 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] |
| metrics = [1.5 * (1.005**i) for i in range(20)] |
| scheduler = ReduceLROnPlateau( |
| self.opt, |
| mode="max", |
| threshold_mode="rel", |
| threshold=0.1, |
| patience=5, |
| cooldown=5, |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau6(self): |
| epochs = 20 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 20] |
| metrics = [1.5 * (0.85**i) for i in range(20)] |
| scheduler = ReduceLROnPlateau( |
| self.opt, mode="min", threshold_mode="rel", threshold=0.1 |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau7(self): |
| epochs = 20 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] |
| metrics = [1] * 7 + [0.6] + [0.5] * 12 |
| scheduler = ReduceLROnPlateau( |
| self.opt, |
| mode="min", |
| threshold_mode="rel", |
| threshold=0.1, |
| patience=5, |
| cooldown=5, |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_reduce_lr_on_plateau8(self): |
| epochs = 20 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14] |
| metrics = [1.5 * (1.005**i) for i in range(20)] |
| scheduler = ReduceLROnPlateau( |
| self.opt, |
| mode="max", |
| threshold_mode="rel", |
| min_lr=[0.4, 0.3], |
| threshold=0.1, |
| patience=5, |
| cooldown=5, |
| ) |
| self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs) |
| |
| def test_sequentiallr1(self): |
| epochs = 19 |
| schedulers = [None] * 2 |
| targets = [ |
| [0.05, 0.04, 0.032] |
| + [0.05 for x in range(4)] |
| + [0.05 * 0.1 for x in range(4)] |
| + [0.05 * 0.01 for x in range(4)] |
| + [0.05 * 0.001 for x in range(4)] |
| ] |
| milestones = [3] |
| schedulers[0] = ExponentialLR(self.opt, gamma=0.8) |
| schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4) |
| scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) |
| self._test(scheduler, targets, epochs) |
| |
| def test_sequentiallr2(self): |
| epochs = 13 |
| schedulers = [None] * 2 |
| targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]] |
| milestones = [3] |
| schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.9) |
| scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) |
| self._test(scheduler, targets, epochs) |
| |
| def test_sequentiallr3(self): |
| epochs = 12 |
| schedulers = [None] * 3 |
| targets = [ |
| [0.005, 0.005, 0.005] |
| + [0.05, 0.04, 0.032] |
| + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] |
| ] |
| milestones = [3, 6] |
| schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.8) |
| schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) |
| scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) |
| self._test(scheduler, targets, epochs) |
| |
| def test_sequentiallr4(self): |
| optimizer = torch.optim.SGD([torch.tensor(0.5)], lr=0.1) |
| prev_lr = optimizer.param_groups[0]["lr"] |
| |
| schedulers = [ |
| torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1), |
| torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1), |
| ] |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, schedulers, milestones=[10] |
| ) |
| |
| new_lr = optimizer.param_groups[0]["lr"] |
| |
| # Ensure that multiple schedulers does not affect the initial learning rate |
| self.assertEqual(prev_lr, new_lr) |
| |
| def test_get_last_lr_sequentiallr(self): |
| epochs = 12 |
| milestones = [3, 6] |
| schedulers = [None] * 3 |
| schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.8) |
| schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2) |
| scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) |
| constant_lr_target = [0.005] * 3 |
| exponential_lr_target = [0.05, 0.04, 0.032] |
| step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005] |
| single_targets = constant_lr_target + exponential_lr_target + step_lr_target |
| targets = [single_targets, [x * 10 for x in single_targets]] |
| self._test_get_last_lr(scheduler, targets, epochs) |
| |
| def test_chained_lr2_get_last_lr_before_step(self): |
| schedulers = [ |
| LinearLR(self.opt, start_factor=0.4, total_iters=3), |
| MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1), |
| ] |
| scheduler = ChainedScheduler(schedulers) |
| self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) |
| |
| def test_chained_lr1(self): |
| epochs = 10 |
| schedulers = [None] * 1 |
| targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3] |
| schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) |
| scheduler = ChainedScheduler(schedulers) |
| self._test([scheduler], targets, epochs) |
| self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) |
| |
| def test_chained_lr2(self): |
| epochs = 10 |
| schedulers = [None] * 1 |
| targets = [[0.02, 0.03, 0.04] + [0.05] * 9] |
| schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) |
| scheduler = ChainedScheduler(schedulers) |
| self._test([scheduler], targets, epochs) |
| self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) |
| |
| def test_chained_lr3(self): |
| epochs = 10 |
| schedulers = [None] * 2 |
| targets = [ |
| [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3 |
| ] |
| schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3) |
| schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1) |
| scheduler = ChainedScheduler(schedulers) |
| self._test([scheduler], targets, epochs) |
| self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) |
| |
| def test_chained_lr4(self): |
| epochs = 9 |
| schedulers = [None] * 3 |
| targets = [ |
| [0.05 * 0.2 * 0.9**x for x in range(3)] |
| + [0.05 * 0.2 * 0.9**3 * 0.1] |
| + [0.05 * 0.9**x * 0.1 for x in range(4, 6)] |
| + [0.05 * 0.9**x * 0.01 for x in range(6, 9)] |
| ] |
| schedulers[0] = ExponentialLR(self.opt, gamma=0.9) |
| schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4) |
| schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3) |
| scheduler = ChainedScheduler(schedulers) |
| self._test([scheduler], targets, epochs) |
| self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) |
| |
| def test_chained_lr5(self): |
| def poly_lr(lr: float): |
| return [ |
| (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters) |
| ] + [0.0] * (epochs - total_iters) |
| |
| schedulers = [None] * 2 |
| epochs = 10 |
| power = 0.9 |
| total_iters = 5 |
| const_factor = 0.1 |
| single_targets = [x * const_factor for x in poly_lr(lr=0.05)] |
| targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]] |
| schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters) |
| schedulers[1] = ConstantLR(self.opt, factor=const_factor) |
| scheduler = ChainedScheduler(schedulers) |
| self._test(scheduler, targets, epochs) |
| self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr()) |
| |
| def test_compound_step_and_multistep_lr(self): |
| epochs = 10 |
| schedulers = [None] * 2 |
| schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) |
| schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]] |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_step_and_exp_lr(self): |
| epochs = 10 |
| schedulers = [None] * 2 |
| single_targets = [0.05 * (0.9**x) for x in range(3)] |
| single_targets += [0.005 * (0.9**x) for x in range(3, 6)] |
| single_targets += [0.0005 * (0.9**x) for x in range(6, 9)] |
| single_targets += [0.00005 * (0.9**x) for x in range(9, 12)] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.9) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_exp_and_multistep_lr(self): |
| epochs = 10 |
| schedulers = [None] * 2 |
| single_targets = [0.05 * (0.9**x) for x in range(2)] |
| single_targets += [0.005 * (0.9**x) for x in range(2, 5)] |
| single_targets += [0.0005 * (0.9**x) for x in range(5, 9)] |
| single_targets += [0.00005 * (0.9**x) for x in range(9, 11)] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.9) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_exp_and_linearlr(self): |
| epochs = 10 |
| iters = 4 |
| start_factor = 0.4 |
| end_factor = 0.9 |
| schedulers = [None] * 2 |
| single_targets = [0.05 * (0.9**x) for x in range(11)] |
| for i in range(iters): |
| single_targets[i] *= start_factor + i / iters * (end_factor - start_factor) |
| for i in range(iters, 11): |
| single_targets[i] *= end_factor |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers[0] = LinearLR( |
| self.opt, |
| start_factor=start_factor, |
| end_factor=end_factor, |
| total_iters=iters, |
| ) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.9) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_step_and_constantlr(self): |
| epochs = 10 |
| iters = 4 |
| factor = 0.4 |
| schedulers = [None] * 2 |
| single_targets = ( |
| [0.05 * 0.4] * 3 |
| + [0.005 * 0.4] |
| + [0.005] * 2 |
| + [0.0005] * 3 |
| + [0.00005] * 3 |
| ) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) |
| schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_linearlr_and_multistep_lr(self): |
| epochs = 10 |
| iters = 4 |
| start_factor = 0.4 |
| schedulers = [None] * 2 |
| single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2 |
| for i in range(iters): |
| single_targets[i] *= start_factor + i / iters * (1 - start_factor) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_cosanneal_and_step_lr(self): |
| epochs = 10 |
| eta_min = 1e-10 |
| single_targets = [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 |
| for x in range(epochs) |
| ] |
| single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers = [None] * 2 |
| schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) |
| schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_cosanneal_and_multistep_lr(self): |
| epochs = 10 |
| eta_min = 1e-10 |
| single_targets = [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 |
| for x in range(epochs) |
| ] |
| multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001] |
| single_targets = [x * y for x, y in zip(single_targets, multipliers)] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers = [None] * 2 |
| schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) |
| schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_cosanneal_and_linearlr(self): |
| epochs = 10 |
| iters = 4 |
| start_factor = 0.4 |
| eta_min = 1e-10 |
| schedulers = [None] * 2 |
| single_targets = [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 |
| for x in range(epochs) |
| ] |
| for i in range(iters): |
| single_targets[i] *= start_factor + i / iters * (1 - start_factor) |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_cosanneal_and_exp_lr(self): |
| epochs = 10 |
| eta_min = 1e-10 |
| single_targets = [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 |
| for x in range(epochs) |
| ] |
| multipliers = [0.1**i for i in range(epochs)] |
| single_targets = [x * y for x, y in zip(single_targets, multipliers)] |
| targets = [single_targets, [x * epochs for x in single_targets]] |
| schedulers = [None] * 2 |
| schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.1) |
| self._test(schedulers, targets, epochs) |
| |
| def test_compound_reduce_lr_on_plateau1(self): |
| epochs = 10 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| single_targets = [0.5] * 20 |
| multipliers = [0.1 ** (i // 3) for i in range(20)] |
| single_targets = [x * y for x, y in zip(multipliers, single_targets)] |
| targets = [single_targets] |
| targets = targets[1:] # test runs step before checking lr |
| metrics = [10 - i * 0.0167 for i in range(20)] |
| schedulers = [None, None] |
| schedulers[0] = ReduceLROnPlateau( |
| self.opt, |
| threshold_mode="abs", |
| mode="min", |
| threshold=0.01, |
| patience=5, |
| cooldown=5, |
| ) |
| schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3) |
| self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) |
| |
| def test_compound_reduce_lr_on_plateau2(self): |
| epochs = 22 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 |
| multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10 |
| single_targets = [x * y for x, y in zip(single_targets, multipliers)] |
| targets = [single_targets] |
| targets = targets[1:] # test runs step before checking lr |
| metrics = [10 - i * 0.0165 for i in range(22)] |
| schedulers = [None] * 2 |
| schedulers[0] = ReduceLROnPlateau( |
| self.opt, |
| patience=5, |
| cooldown=0, |
| threshold_mode="abs", |
| mode="min", |
| threshold=0.1, |
| ) |
| schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12]) |
| self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) |
| |
| def test_compound_reduce_lr_on_plateau3(self): |
| epochs = 22 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4 |
| multipliers = [0.1**i for i in range(epochs)] |
| single_targets = [x * y for x, y in zip(multipliers, single_targets)] |
| targets = [single_targets] |
| targets = targets[1:] # test runs step before checking lr |
| metrics = [-0.8] * 2 + [-0.234] * 20 |
| schedulers = [None, None] |
| schedulers[0] = ReduceLROnPlateau( |
| self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs" |
| ) |
| schedulers[1] = ExponentialLR(self.opt, gamma=0.1) |
| self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) |
| |
| def test_compound_reduce_lr_on_plateau4(self): |
| epochs = 20 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.05 |
| epochs = 10 |
| eta_min = 1e-10 |
| single_targets = [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2 |
| for x in range(epochs) |
| ] |
| targets = [single_targets] |
| targets = targets[1:] # test runs step before checking lr |
| metrics = [1.5 * (1.025**i) for i in range(20)] # 1.025 > 1.1**0.25 |
| schedulers = [None, None] |
| schedulers[0] = ReduceLROnPlateau( |
| self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1 |
| ) |
| schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) |
| self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) |
| |
| def test_compound_reduce_lr_on_plateau5(self): |
| iters = 4 |
| start_factor = 0.4 |
| epochs = 22 |
| for param_group in self.opt.param_groups: |
| param_group["lr"] = 0.5 |
| single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 |
| multipliers = [1] * 22 |
| for i in range(iters): |
| multipliers[i] *= start_factor + i / iters * (1 - start_factor) |
| single_targets = [x * y for x, y in zip(single_targets, multipliers)] |
| targets = [single_targets] |
| targets = targets[1:] # test runs step before checking lr |
| metrics = [10 - i * 0.0165 for i in range(22)] |
| schedulers = [None] * 2 |
| schedulers[0] = ReduceLROnPlateau( |
| self.opt, |
| patience=5, |
| cooldown=0, |
| threshold_mode="abs", |
| mode="min", |
| threshold=0.1, |
| ) |
| schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters) |
| self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) |
| |
| def test_cycle_lr_invalid_mode(self): |
| with self.assertRaises(ValueError): |
| scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") |
| |
| def test_cycle_lr_triangular_mode_one_lr(self): |
| lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] |
| momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=1, |
| max_lr=5, |
| step_size_up=4, |
| cycle_momentum=True, |
| base_momentum=1, |
| max_momentum=5, |
| mode="triangular", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| def test_cycle_lr_triangular_mode_one_lr_no_momentum(self): |
| lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] |
| lr_targets = [lr_target, lr_target] |
| momentum_target = [self.opt.defaults["momentum"]] * len(lr_target) |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=1, |
| max_lr=5, |
| step_size_up=4, |
| cycle_momentum=False, |
| mode="triangular", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| def test_cycle_lr_triangular2_mode_one_lr(self): |
| lr_target = [ |
| 1, |
| 2, |
| 3, |
| 4, |
| 5, |
| 4, |
| 3, |
| 2, |
| 1, |
| 1.5, |
| 2.0, |
| 2.5, |
| 3.0, |
| 2.5, |
| 2.0, |
| 1.5, |
| 1, |
| 1.25, |
| 1.50, |
| 1.75, |
| 2.00, |
| 1.75, |
| ] |
| momentum_target = [ |
| 5.0, |
| 4.0, |
| 3.0, |
| 2.0, |
| 1.0, |
| 2.0, |
| 3.0, |
| 4.0, |
| 5.0, |
| 4.5, |
| 4.0, |
| 3.5, |
| 3.0, |
| 3.5, |
| 4.0, |
| 4.5, |
| 5.0, |
| 4.75, |
| 4.5, |
| 4.25, |
| 4.0, |
| 4.25, |
| ] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=1, |
| max_lr=5, |
| step_size_up=4, |
| cycle_momentum=True, |
| base_momentum=1, |
| max_momentum=5, |
| mode="triangular2", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| def test_cycle_lr_exp_range_mode_one_lr(self): |
| base_lr, max_lr = 1, 5 |
| diff_lr = max_lr - base_lr |
| gamma = 0.9 |
| xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] |
| lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] |
| momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=base_lr, |
| max_lr=max_lr, |
| step_size_up=4, |
| cycle_momentum=True, |
| base_momentum=base_lr, |
| max_momentum=max_lr, |
| mode="exp_range", |
| gamma=gamma, |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| def test_cycle_lr_triangular_mode(self): |
| lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] |
| lr_target_2 = [x + 1 for x in lr_target_1] |
| lr_targets = [lr_target_1, lr_target_2] |
| momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3] |
| momentum_target_2 = [x + 1 for x in momentum_target_1] |
| momentum_targets = [momentum_target_1, momentum_target_2] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=[1, 2], |
| max_lr=[5, 6], |
| step_size_up=4, |
| cycle_momentum=True, |
| base_momentum=[1, 2], |
| max_momentum=[5, 6], |
| mode="triangular", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) |
| |
| def test_cycle_lr_triangular2_mode(self): |
| lr_target_1 = [ |
| 1, |
| 2, |
| 3, |
| 4, |
| 5, |
| 4, |
| 3, |
| 2, |
| 1, |
| 1.5, |
| 2.0, |
| 2.5, |
| 3.0, |
| 2.5, |
| 2.0, |
| 1.5, |
| 1, |
| 1.25, |
| 1.50, |
| 1.75, |
| 2.00, |
| 1.75, |
| ] |
| lr_target_2 = [x + 2 for x in lr_target_1] |
| lr_targets = [lr_target_1, lr_target_2] |
| momentum_target_1 = [ |
| 5.0, |
| 4.0, |
| 3.0, |
| 2.0, |
| 1.0, |
| 2.0, |
| 3.0, |
| 4.0, |
| 5.0, |
| 4.5, |
| 4.0, |
| 3.5, |
| 3.0, |
| 3.5, |
| 4.0, |
| 4.5, |
| 5.0, |
| 4.75, |
| 4.5, |
| 4.25, |
| 4.0, |
| 4.25, |
| ] |
| momentum_target_2 = [x + 2 for x in momentum_target_1] |
| momentum_targets = [momentum_target_1, momentum_target_2] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=[1, 3], |
| max_lr=[5, 7], |
| step_size_up=4, |
| cycle_momentum=True, |
| base_momentum=[1, 3], |
| max_momentum=[5, 7], |
| mode="triangular2", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) |
| |
| def test_cycle_lr_exp_range_mode(self): |
| base_lr_1, max_lr_1 = 1, 5 |
| base_lr_2, max_lr_2 = 5, 12 |
| |
| diff_lr_1 = max_lr_1 - base_lr_1 |
| diff_lr_2 = max_lr_2 - base_lr_2 |
| |
| gamma = 0.9 |
| xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1] |
| lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)] |
| lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)] |
| lr_targets = [lr_target_1, lr_target_2] |
| momentum_target_1 = [ |
| max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs) |
| ] |
| momentum_target_2 = [ |
| max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs) |
| ] |
| momentum_targets = [momentum_target_1, momentum_target_2] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=[base_lr_1, base_lr_2], |
| max_lr=[max_lr_1, max_lr_2], |
| step_size_up=4, |
| cycle_momentum=True, |
| base_momentum=[base_lr_1, base_lr_2], |
| max_momentum=[max_lr_1, max_lr_2], |
| mode="exp_range", |
| gamma=gamma, |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1)) |
| |
| def test_cycle_lr_triangular_mode_step_size_up_down(self): |
| lr_target = [ |
| 1.0, |
| 2.0, |
| 3.0, |
| 4.0, |
| 5.0, |
| 13.0 / 3, |
| 11.0 / 3, |
| 9.0 / 3, |
| 7.0 / 3, |
| 5.0 / 3, |
| 1.0, |
| ] |
| lr_targets = [lr_target, lr_target] |
| momentum_target = [ |
| 5.0, |
| 4.0, |
| 3.0, |
| 2.0, |
| 1.0, |
| 5.0 / 3, |
| 7.0 / 3, |
| 3.0, |
| 11.0 / 3, |
| 13.0 / 3, |
| 5.0, |
| ] |
| momentum_targets = [momentum_target, momentum_target] |
| |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=1, |
| max_lr=5, |
| step_size_up=4, |
| step_size_down=6, |
| cycle_momentum=True, |
| base_momentum=1, |
| max_momentum=5, |
| mode="triangular", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| def test_cycle_lr_triangular2_mode_step_size_up_down(self): |
| lr_base_target = [ |
| 1.0, |
| 3.0, |
| 5.0, |
| 13.0 / 3, |
| 11.0 / 3, |
| 9.0 / 3, |
| 7.0 / 3, |
| 5.0 / 3, |
| 1.0, |
| 2.0, |
| 3.0, |
| 8.0 / 3, |
| 7.0 / 3, |
| 6.0 / 3, |
| 5.0 / 3, |
| 4.0 / 3, |
| 1.0, |
| 3.0 / 2, |
| 2.0, |
| 11.0 / 6, |
| 10.0 / 6, |
| 9.0 / 6, |
| 8.0 / 6, |
| 7.0 / 6, |
| ] |
| momentum_base_target = [ |
| 5.0, |
| 3.0, |
| 1.0, |
| 5.0 / 3, |
| 7.0 / 3, |
| 3.0, |
| 11.0 / 3, |
| 13.0 / 3, |
| 5.0, |
| 4.0, |
| 3.0, |
| 10.0 / 3, |
| 11.0 / 3, |
| 4.0, |
| 13.0 / 3, |
| 14.0 / 3, |
| 5.0, |
| 4.5, |
| 4.0, |
| 25.0 / 6, |
| 13.0 / 3, |
| 4.5, |
| 14.0 / 3, |
| 29.0 / 6, |
| ] |
| deltas = [2 * i for i in range(0, 2)] |
| base_lrs = [1 + delta for delta in deltas] |
| max_lrs = [5 + delta for delta in deltas] |
| lr_targets = [[x + delta for x in lr_base_target] for delta in deltas] |
| momentum_targets = [ |
| [x + delta for x in momentum_base_target] for delta in deltas |
| ] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=base_lrs, |
| max_lr=max_lrs, |
| step_size_up=2, |
| step_size_down=6, |
| cycle_momentum=True, |
| base_momentum=base_lrs, |
| max_momentum=max_lrs, |
| mode="triangular2", |
| ) |
| self._test_cycle_lr( |
| scheduler, lr_targets, momentum_targets, len(lr_base_target) |
| ) |
| |
| def test_cycle_lr_exp_range_mode_step_size_up_down(self): |
| base_lr, max_lr = 1, 5 |
| diff_lr = max_lr - base_lr |
| gamma = 0.9 |
| xs = [ |
| 0.0, |
| 0.5, |
| 1.0, |
| 5.0 / 6, |
| 4.0 / 6, |
| 3.0 / 6, |
| 2.0 / 6, |
| 1.0 / 6, |
| 0.0, |
| 0.5, |
| 1.0, |
| 5.0 / 6, |
| 4.0 / 6, |
| ] |
| lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)] |
| lr_targets = [lr_target, lr_target] |
| momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=base_lr, |
| max_lr=max_lr, |
| step_size_up=2, |
| step_size_down=6, |
| cycle_momentum=True, |
| base_momentum=base_lr, |
| max_momentum=max_lr, |
| mode="exp_range", |
| gamma=gamma, |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| def test_cycle_lr_with_momentumless_optimizer(self): |
| # Note [Temporarily set optimizer to Adam] |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # The TestLRScheduler object carries around an SGD optimizer to avoid having to |
| # instantiate one for every test. This gets in the way for our very specific case |
| # in which we need to use Adam (or really any optimizer that doesn't use momentum) |
| # in order to test that the momentum bug in CyclicLR is fixed (the bug is described |
| # in more detail in https://github.com/pytorch/pytorch/issues/19003 ). |
| old_opt = self.opt |
| self.opt = optim.Adam( |
| [ |
| {"params": self.net.conv1.parameters()}, |
| {"params": self.net.conv2.parameters(), "lr": 0.5}, |
| ], |
| lr=0.05, |
| ) |
| |
| lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3] |
| lr_targets = [lr_target, lr_target] |
| momentum_target = [None] * len(lr_target) |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = CyclicLR( |
| self.opt, |
| base_lr=1, |
| max_lr=5, |
| step_size_up=4, |
| cycle_momentum=False, |
| mode="triangular", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target)) |
| |
| self.opt = old_opt # set optimizer back to SGD |
| |
| def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): |
| with self.assertRaises(ValueError): |
| adam_opt = optim.Adam(self.net.parameters()) |
| scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True) |
| |
| def test_cycle_lr_removed_after_out_of_scope(self): |
| import gc |
| import weakref |
| |
| gc.disable() |
| |
| def test(): |
| adam_opt = optim.Adam(self.net.parameters()) |
| scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False) |
| return weakref.ref(scheduler) |
| |
| ref = test() |
| assert ref() is None |
| gc.enable() |
| |
| def test_onecycle_lr_invalid_anneal_strategy(self): |
| with self.assertRaises(ValueError): |
| scheduler = OneCycleLR( |
| self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS" |
| ) |
| |
| def test_onecycle_lr_invalid_pct_start(self): |
| with self.assertRaises(ValueError): |
| scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1) |
| |
| def test_onecycle_lr_cannot_calculate_total_steps(self): |
| with self.assertRaises(ValueError): |
| scheduler = OneCycleLR(self.opt, max_lr=1e-3) |
| |
| def test_onecycle_lr_linear_annealing(self): |
| lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] |
| momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = OneCycleLR( |
| self.opt, |
| max_lr=25, |
| final_div_factor=2, |
| base_momentum=1, |
| max_momentum=22, |
| total_steps=10, |
| anneal_strategy="linear", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) |
| |
| def test_onecycle_lr_linear_annealing_three_phases(self): |
| lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25] |
| momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = OneCycleLR( |
| self.opt, |
| max_lr=25, |
| div_factor=25, |
| base_momentum=1, |
| max_momentum=22, |
| total_steps=10, |
| anneal_strategy="linear", |
| pct_start=0.4, |
| final_div_factor=4, |
| three_phase=True, |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) |
| |
| def test_onecycle_lr_cosine_annealing(self): |
| def annealing_cos(start, end, pct): |
| cos_out = math.cos(math.pi * pct) + 1 |
| return end + (start - end) / 2.0 * cos_out |
| |
| lr_target = [ |
| 1, |
| 13, |
| 25, |
| annealing_cos(25, 0.5, 1 / 7.0), |
| annealing_cos(25, 0.5, 2 / 7.0), |
| annealing_cos(25, 0.5, 3 / 7.0), |
| annealing_cos(25, 0.5, 4 / 7.0), |
| annealing_cos(25, 0.5, 5 / 7.0), |
| annealing_cos(25, 0.5, 6 / 7.0), |
| 0.5, |
| ] |
| momentum_target = [ |
| 22, |
| 11.5, |
| 1, |
| annealing_cos(1, 22, 1 / 7.0), |
| annealing_cos(1, 22, 2 / 7.0), |
| annealing_cos(1, 22, 3 / 7.0), |
| annealing_cos(1, 22, 4 / 7.0), |
| annealing_cos(1, 22, 5 / 7.0), |
| annealing_cos(1, 22, 6 / 7.0), |
| 22, |
| ] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = OneCycleLR( |
| self.opt, |
| max_lr=25, |
| final_div_factor=2, |
| base_momentum=1, |
| max_momentum=22, |
| total_steps=10, |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) |
| |
| def test_cycle_lr_with_adam(self): |
| old_opt = self.opt |
| self.opt = optim.Adam( |
| [ |
| {"params": self.net.conv1.parameters()}, |
| {"params": self.net.conv2.parameters(), "lr": 0.5}, |
| ], |
| lr=0.05, |
| ) |
| |
| lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] |
| momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] |
| lr_targets = [lr_target, lr_target] |
| momentum_targets = [momentum_target, momentum_target] |
| scheduler = OneCycleLR( |
| self.opt, |
| max_lr=25, |
| final_div_factor=2, |
| base_momentum=1, |
| max_momentum=22, |
| total_steps=10, |
| anneal_strategy="linear", |
| ) |
| self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True) |
| self.opt = old_opt # set optimizer back to SGD |
| |
| def test_lambda_lr(self): |
| epochs = 10 |
| self.opt.param_groups[0]["lr"] = 0.05 |
| self.opt.param_groups[1]["lr"] = 0.4 |
| targets = [ |
| [0.05 * (0.9**x) for x in range(epochs)], |
| [0.4 * (0.8**x) for x in range(epochs)], |
| ] |
| scheduler = LambdaLR( |
| self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2] |
| ) |
| self._test(scheduler, targets, epochs) |
| |
| def test_multiplicative_lr(self): |
| epochs = 10 |
| self.opt.param_groups[0]["lr"] = 0.05 |
| self.opt.param_groups[1]["lr"] = 0.4 |
| targets = [ |
| [0.05 * (0.9**x) for x in range(epochs)], |
| [0.4 * (0.8**x) for x in range(epochs)], |
| ] |
| scheduler = MultiplicativeLR( |
| self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8] |
| ) |
| self._test(scheduler, targets, epochs) |
| |
| @parametrize("T_mult", [1, 2, 4]) |
| def test_CosineAnnealingWarmRestarts_lr1(self, T_mult): |
| iters = 100 |
| eta_min = 1e-10 |
| T_i = 10 |
| T_cur = 0 |
| targets = [[0.05], [0.5]] |
| scheduler = CosineAnnealingWarmRestarts( |
| self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min |
| ) |
| for _ in range(1, iters, 1): |
| T_cur += 1 |
| if T_cur >= T_i: |
| T_cur = T_cur - T_i |
| T_i = int(T_mult) * T_i |
| targets[0] += [ |
| eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 |
| ] |
| targets[1] += [ |
| eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 |
| ] |
| self._test(scheduler, targets, iters) |
| |
| def test_CosineAnnealingWarmRestarts_lr2(self): |
| iters = 30 |
| eta_min = 1e-10 |
| T_mults = [1, 2, 4] |
| for T_mult in T_mults: |
| T_i = 10 |
| T_cur = 0 |
| targets = [[0.05], [0.5]] |
| scheduler = CosineAnnealingWarmRestarts( |
| self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min |
| ) |
| for _ in torch.arange(0.1, iters, 0.1): |
| T_cur = round(T_cur + 0.1, 1) |
| if T_cur >= T_i: |
| T_cur = T_cur - T_i |
| T_i = int(T_mult) * T_i |
| targets[0] += [ |
| eta_min |
| + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 |
| ] |
| targets[1] += [ |
| eta_min |
| + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 |
| ] |
| self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters) |
| |
| def test_CosineAnnealingWarmRestarts_lr3(self): |
| epochs_for_T_mults = [ |
| [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13], |
| [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3], |
| [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50], |
| ] |
| T_curs_for_T_mults = [ |
| [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3], |
| [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3], |
| [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10], |
| ] |
| T_is_for_T_mults = [ |
| [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10], |
| [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10], |
| [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90], |
| ] |
| eta_min = 1e-10 |
| T_mults = [1, 2, 3] |
| for epochs, T_mult, T_curs, T_is in zip( |
| epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults |
| ): |
| targets = [[0.05], [0.5]] |
| scheduler = CosineAnnealingWarmRestarts( |
| self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min |
| ) |
| for T_cur, T_i in zip(T_curs, T_is): |
| targets[0] += [ |
| eta_min |
| + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 |
| ] |
| targets[1] += [ |
| eta_min |
| + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 |
| ] |
| self._test_interleaved_CosineAnnealingWarmRestarts( |
| scheduler, targets, epochs |
| ) |
| |
| def test_swalr_no_anneal(self): |
| epochs, swa_start, swa_lr = 10, 5, 0.01 |
| initial_lrs = [group["lr"] for group in self.opt.param_groups] |
| targets = [ |
| [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1) |
| for lr in initial_lrs |
| ] |
| swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr) |
| self._test_swalr(swa_scheduler, None, targets, swa_start, epochs) |
| |
| def test_swalr_cosine_anneal_after_multiplicative(self): |
| # same swa_lr for different param_groups |
| epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5 |
| mult_factor = 0.9 |
| scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) |
| swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr) |
| |
| def anneal_coef(t): |
| if t + 1 >= anneal_epochs: |
| return 0.0 |
| return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2 |
| |
| initial_lrs = [group["lr"] for group in self.opt.param_groups] |
| targets_before_swa = [ |
| [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs |
| ] |
| swa_epochs = epochs - swa_start - 1 |
| targets = [ |
| lrs |
| + [ |
| lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) |
| for t in range(swa_epochs) |
| ] |
| for lrs in targets_before_swa |
| ] |
| |
| self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) |
| |
| def test_swalr_linear_anneal_after_multiplicative(self): |
| # separate swa_lr for different param_groups |
| epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4 |
| mult_factor = 0.9 |
| scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor) |
| swa_scheduler = SWALR( |
| self.opt, |
| anneal_epochs=anneal_epochs, |
| anneal_strategy="linear", |
| swa_lr=swa_lrs, |
| ) |
| |
| def anneal_coef(t): |
| if t + 1 >= anneal_epochs: |
| return 0.0 |
| return 1 - (t + 1) / anneal_epochs |
| |
| initial_lrs = [group["lr"] for group in self.opt.param_groups] |
| targets_before_swa = [ |
| [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs |
| ] |
| swa_epochs = epochs - swa_start - 1 |
| targets = [ |
| lrs |
| + [ |
| lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t)) |
| for t in range(swa_epochs) |
| ] |
| for lrs, swa_lr in zip(targets_before_swa, swa_lrs) |
| ] |
| |
| self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs) |
| |
| def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs): |
| for epoch in range(epochs): |
| for param_group, target in zip(self.opt.param_groups, targets): |
| self.assertEqual( |
| target[epoch], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, target[epoch], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| if epoch >= swa_start: |
| self.opt.step() |
| swa_scheduler.step() |
| elif scheduler is not None: |
| self.opt.step() |
| scheduler.step() |
| |
| def test_swalr_hypers(self): |
| # Test that SWALR raises errors for incorrect hyper-parameters |
| with self.assertRaisesRegex(ValueError, "anneal_strategy must"): |
| swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0) |
| |
| with self.assertRaisesRegex(ValueError, "anneal_epochs must"): |
| swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0) |
| with self.assertRaisesRegex(ValueError, "anneal_epochs must"): |
| swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0) |
| with self.assertRaisesRegex(ValueError, "swa_lr must"): |
| swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01]) |
| |
| def test_step_lr_state_dict(self): |
| self._check_scheduler_state_dict( |
| lambda: StepLR(self.opt, gamma=0.1, step_size=3), |
| lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1), |
| ) |
| |
| def test_multi_step_lr_state_dict(self): |
| self._check_scheduler_state_dict( |
| lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]), |
| lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]), |
| ) |
| |
| def test_exp_step_lr_state_dict(self): |
| self._check_scheduler_state_dict( |
| lambda: ExponentialLR(self.opt, gamma=0.1), |
| lambda: ExponentialLR(self.opt, gamma=0.01), |
| ) |
| |
| def test_cosine_lr_state_dict(self): |
| epochs = 10 |
| eta_min = 1e-10 |
| self._check_scheduler_state_dict( |
| lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min), |
| lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2), |
| epochs=epochs, |
| ) |
| |
| def test_reduce_lr_on_plateau_state_dict(self): |
| scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2) |
| for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]: |
| scheduler.step(score) |
| scheduler_copy = ReduceLROnPlateau( |
| self.opt, mode="max", factor=0.5, patience=10 |
| ) |
| scheduler_copy.load_state_dict(scheduler.state_dict()) |
| for key in scheduler.__dict__.keys(): |
| if key not in {"optimizer", "is_better"}: |
| self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) |
| |
| def test_lambda_lr_state_dict_fn(self): |
| scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x) |
| state = scheduler.state_dict() |
| self.assertIsNone(state["lr_lambdas"][0]) |
| |
| scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x) |
| scheduler_copy.load_state_dict(state) |
| for key in scheduler.__dict__.keys(): |
| if key not in {"optimizer", "lr_lambdas"}: |
| self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) |
| |
| def test_lambda_lr_state_dict_obj(self): |
| scheduler = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(10)) |
| state = scheduler.state_dict() |
| self.assertIsNotNone(state["lr_lambdas"][0]) |
| |
| scheduler_copy = LambdaLR(self.opt, lr_lambda=LambdaLRTestObject(-1)) |
| scheduler_copy.load_state_dict(state) |
| for key in scheduler.__dict__.keys(): |
| if key not in {"optimizer"}: |
| self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) |
| |
| def test_CosineAnnealingWarmRestarts_lr_state_dict(self): |
| self._check_scheduler_state_dict( |
| lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2), |
| lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100), |
| ) |
| |
| def test_swa_lr_state_dict(self): |
| self._check_scheduler_state_dict( |
| lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5), |
| lambda: SWALR( |
| self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0 |
| ), |
| ) |
| |
| def _check_scheduler_state_dict(self, constr, constr2, epochs=10): |
| scheduler = constr() |
| for _ in range(epochs): |
| scheduler.optimizer.step() |
| scheduler.step() |
| scheduler_copy = constr2() |
| scheduler_copy.load_state_dict(scheduler.state_dict()) |
| for key in scheduler.__dict__.keys(): |
| if key != "optimizer": |
| self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key]) |
| self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr()) |
| |
| def _test_get_last_lr(self, schedulers, targets, epochs=10): |
| if isinstance(schedulers, LRScheduler): |
| schedulers = [schedulers] |
| optimizers = {scheduler.optimizer for scheduler in schedulers} |
| for epoch in range(epochs): |
| result = [scheduler.get_last_lr() for scheduler in schedulers] |
| [optimizer.step() for optimizer in optimizers] |
| [scheduler.step() for scheduler in schedulers] |
| target = [[t[epoch] for t in targets]] * len(schedulers) |
| for t, r in zip(target, result): |
| self.assertEqual( |
| target, |
| result, |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, t, r |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| def _test_with_epoch(self, schedulers, targets, epochs=10): |
| if isinstance(schedulers, LRScheduler): |
| schedulers = [schedulers] |
| optimizers = {scheduler.optimizer for scheduler in schedulers} |
| for epoch in range(epochs): |
| [optimizer.step() for optimizer in optimizers] |
| with warnings.catch_warnings(record=True) as w: |
| [ |
| scheduler.step(epoch) for scheduler in schedulers |
| ] # step before assert: skip initial lr |
| self._check_warning_is_epoch_deprecation_warning( |
| w, num_warnings=len(schedulers) |
| ) |
| for param_group, target in zip(self.opt.param_groups, targets): |
| self.assertEqual( |
| target[epoch], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, target[epoch], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| def _test(self, schedulers, targets, epochs=10): |
| if isinstance(schedulers, LRScheduler): |
| schedulers = [schedulers] |
| for epoch in range(epochs): |
| for param_group, target in zip(self.opt.param_groups, targets): |
| self.assertEqual( |
| target[epoch], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, target[epoch], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| [scheduler.step() for scheduler in schedulers] |
| |
| def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10): |
| for index, epoch in enumerate(torch.arange(0, epochs, 0.1)): |
| epoch = round(epoch.item(), 1) |
| scheduler.step(epoch) |
| for param_group, target in zip(self.opt.param_groups, targets): |
| self.assertEqual( |
| target[index], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, target[index], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs): |
| for index, epoch in enumerate(epochs): |
| scheduler.step(epoch) |
| for param_group, target in zip(self.opt.param_groups, targets): |
| self.assertEqual( |
| target[index], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, target[index], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10): |
| self.setUp() |
| targets = [] |
| for epoch in range(epochs): |
| closed_form_scheduler.optimizer.step() |
| with warnings.catch_warnings(record=True) as w: |
| closed_form_scheduler.step(epoch) |
| self._check_warning_is_epoch_deprecation_warning(w) |
| targets.append([group["lr"] for group in self.opt.param_groups]) |
| self.setUp() |
| for epoch in range(epochs): |
| self.opt.step() |
| scheduler.step() |
| for i, param_group in enumerate(self.opt.param_groups): |
| self.assertEqual( |
| targets[epoch][i], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, targets[epoch][i], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| def _test_reduce_lr_on_plateau( |
| self, schedulers, targets, metrics, epochs=10, verbose=False |
| ): |
| if isinstance(schedulers, LRScheduler) or isinstance( |
| schedulers, ReduceLROnPlateau |
| ): |
| schedulers = [schedulers] |
| for epoch in range(epochs): |
| self.opt.step() |
| for scheduler in schedulers: |
| if isinstance(scheduler, ReduceLROnPlateau): |
| scheduler.step(metrics[epoch]) |
| else: |
| scheduler.step() |
| if verbose: |
| print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"])) |
| for param_group, target in zip(self.opt.param_groups, targets): |
| self.assertEqual( |
| target[epoch], |
| param_group["lr"], |
| msg="LR is wrong in epoch {}: expected {}, got {}".format( |
| epoch, target[epoch], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| def _test_cycle_lr( |
| self, |
| scheduler, |
| lr_targets, |
| momentum_targets, |
| batch_iterations, |
| verbose=False, |
| use_beta1=False, |
| ): |
| for batch_num in range(batch_iterations): |
| if verbose: |
| if "momentum" in self.opt.param_groups[0].keys(): |
| print( |
| "batch{}:\tlr={},momentum={}".format( |
| batch_num, |
| self.opt.param_groups[0]["lr"], |
| self.opt.param_groups[0]["momentum"], |
| ) |
| ) |
| elif use_beta1 and "betas" in self.opt.param_groups[0].keys(): |
| print( |
| "batch{}:\tlr={},beta1={}".format( |
| batch_num, |
| self.opt.param_groups[0]["lr"], |
| self.opt.param_groups[0]["betas"][0], |
| ) |
| ) |
| else: |
| print( |
| "batch{}:\tlr={}".format( |
| batch_num, self.opt.param_groups[0]["lr"] |
| ) |
| ) |
| |
| for param_group, lr_target, momentum_target in zip( |
| self.opt.param_groups, lr_targets, momentum_targets |
| ): |
| self.assertEqual( |
| lr_target[batch_num], |
| param_group["lr"], |
| msg="LR is wrong in batch_num {}: expected {}, got {}".format( |
| batch_num, lr_target[batch_num], param_group["lr"] |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| |
| if use_beta1 and "betas" in param_group.keys(): |
| self.assertEqual( |
| momentum_target[batch_num], |
| param_group["betas"][0], |
| msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format( |
| batch_num, |
| momentum_target[batch_num], |
| param_group["betas"][0], |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| elif "momentum" in param_group.keys(): |
| self.assertEqual( |
| momentum_target[batch_num], |
| param_group["momentum"], |
| msg="Momentum is wrong in batch_num {}: expected {}, got {}".format( |
| batch_num, |
| momentum_target[batch_num], |
| param_group["momentum"], |
| ), |
| atol=1e-5, |
| rtol=0, |
| ) |
| self.opt.step() |
| scheduler.step() |
| |
| def test_cosine_then_cyclic(self): |
| # https://github.com/pytorch/pytorch/issues/21965 |
| |
| max_lr = 0.3 |
| base_lr = 0.1 |
| optim_lr = 0.5 |
| |
| model = torch.nn.Linear(2, 1) |
| optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr) |
| lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=20, eta_min=0.1 |
| ) |
| lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR( |
| optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3 |
| ) |
| |
| for i in range(40): |
| optimizer.step() |
| if i <= lr_scheduler_1.T_max: |
| lr_scheduler_1.step() |
| else: |
| lr_scheduler_2.step() |
| last_lr = optimizer.param_groups[0]["lr"] |
| |
| self.assertLessEqual(last_lr, max_lr) |
| |
| |
| class SWATestDNN(torch.nn.Module): |
| def __init__(self, input_features): |
| super(SWATestDNN, self).__init__() |
| self.n_features = 100 |
| self.fc1 = torch.nn.Linear(input_features, self.n_features) |
| self.bn = torch.nn.BatchNorm1d(self.n_features) |
| |
| def compute_preactivation(self, x): |
| return self.fc1(x) |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.bn(x) |
| return x |
| |
| |
| class SWATestCNN(torch.nn.Module): |
| def __init__(self, input_channels): |
| super(SWATestCNN, self).__init__() |
| self.n_features = 10 |
| self.conv1 = torch.nn.Conv2d( |
| input_channels, self.n_features, kernel_size=3, padding=1 |
| ) |
| self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3) |
| |
| def compute_preactivation(self, x): |
| return self.conv1(x) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn(x) |
| return x |
| |
| |
| class TestSWAUtils(TestCase): |
| def _test_averaged_model(self, net_device, swa_device): |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), |
| torch.nn.ReLU(), |
| torch.nn.MaxPool2d(kernel_size=2), |
| torch.nn.BatchNorm2d(5, momentum=0.3), |
| torch.nn.Conv2d(5, 2, kernel_size=3), |
| torch.nn.ReLU(), |
| torch.nn.Linear(5, 5), |
| torch.nn.ReLU(), |
| torch.nn.Linear(5, 10), |
| ).to(net_device) |
| |
| averaged_dnn = AveragedModel(dnn, device=swa_device) |
| averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] |
| n_updates = 10 |
| for i in range(n_updates): |
| for p, p_avg in zip(dnn.parameters(), averaged_params): |
| p.detach().add_(torch.randn_like(p)) |
| p_avg += p.detach() / n_updates |
| averaged_dnn.update_parameters(dnn) |
| |
| for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): |
| self.assertEqual(p_avg, p_swa) |
| # Check that AveragedModel is on the correct device |
| self.assertTrue(p_swa.device == swa_device) |
| self.assertTrue(p.device == net_device) |
| self.assertTrue(averaged_dnn.n_averaged.device == swa_device) |
| |
| def test_averaged_model_all_devices(self): |
| cpu = torch.device("cpu") |
| self._test_averaged_model(cpu, cpu) |
| if torch.cuda.is_available(): |
| cuda = torch.device(0) |
| self._test_averaged_model(cuda, cpu) |
| self._test_averaged_model(cpu, cuda) |
| self._test_averaged_model(cuda, cuda) |
| |
| def test_averaged_model_mixed_device(self): |
| if not torch.cuda.is_available(): |
| return |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) |
| ) |
| dnn[0].cuda() |
| dnn[1].cpu() |
| averaged_dnn = AveragedModel(dnn) |
| averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] |
| n_updates = 10 |
| for i in range(n_updates): |
| for p, p_avg in zip(dnn.parameters(), averaged_params): |
| p.detach().add_(torch.randn_like(p)) |
| p_avg += p.detach() / n_updates |
| averaged_dnn.update_parameters(dnn) |
| |
| for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): |
| self.assertEqual(p_avg, p_swa) |
| # Check that AveragedModel is on the correct device |
| self.assertTrue(p_avg.device == p_swa.device) |
| |
| def test_averaged_model_state_dict(self): |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10) |
| ) |
| averaged_dnn = AveragedModel(dnn) |
| averaged_dnn2 = AveragedModel(dnn) |
| n_updates = 10 |
| for i in range(n_updates): |
| for p in dnn.parameters(): |
| p.detach().add_(torch.randn_like(p)) |
| averaged_dnn.update_parameters(dnn) |
| averaged_dnn2.load_state_dict(averaged_dnn.state_dict()) |
| for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()): |
| self.assertEqual(p_swa, p_swa2) |
| self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged) |
| |
| def test_averaged_model_exponential(self): |
| # Test AveragedModel with EMA as avg_fn |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), |
| torch.nn.BatchNorm2d(5, momentum=0.3), |
| torch.nn.Linear(5, 10), |
| ) |
| alpha = 0.9 |
| |
| def avg_fn(p_avg, p, n_avg): |
| return alpha * p_avg + (1 - alpha) * p |
| |
| averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn) |
| averaged_params = [torch.zeros_like(param) for param in dnn.parameters()] |
| n_updates = 10 |
| for i in range(n_updates): |
| updated_averaged_params = [] |
| for p, p_avg in zip(dnn.parameters(), averaged_params): |
| p.detach().add_(torch.randn_like(p)) |
| if i == 0: |
| updated_averaged_params.append(p.clone()) |
| else: |
| updated_averaged_params.append( |
| (p_avg * alpha + p * (1 - alpha)).clone() |
| ) |
| for b in dnn.buffers(): |
| if b.size() != torch.Size([]): |
| b.detach_().add_(torch.randn_like(b)) |
| |
| averaged_dnn.update_parameters(dnn) |
| averaged_params = updated_averaged_params |
| |
| for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()): |
| self.assertEqual(p_avg, p_swa) |
| for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()): |
| self.assertEqual(b_avg, b_swa) |
| |
| def test_averaged_model_exponential_buffers(self): |
| # Test AveragedModel with EMA as avg_fn and use_buffers as True. |
| dnn = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 5, kernel_size=3), |
| torch.nn.BatchNorm2d(5, momentum=0.3), |
| torch.nn.Linear(5, 10), |
| ) |
| alpha = 0.9 |
| |
| def avg_fn(p_avg, p, n_avg): |
| return alpha * p_avg + (1 - alpha) * p |
| |
| averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=True) |
| dnn_params = itertools.chain(dnn.parameters(), dnn.buffers()) |
| averaged_params = [ |
| torch.zeros_like(param) |
| for param in dnn_params |
| if param.size() != torch.Size([]) |
| ] |
| n_updates = 10 |
| for i in range(n_updates): |
| updated_averaged_params = [] |
| for p, p_avg in zip(dnn_params, averaged_params): |
| if p.size() == torch.Size([]): |
| continue |
| p.detach().add_(torch.randn_like(p)) |
| if i == 0: |
| updated_averaged_params.append(p.clone()) |
| else: |
| updated_averaged_params.append( |
| (p_avg * alpha + p * (1 - alpha)).clone() |
| ) |
| averaged_dnn.update_parameters(dnn) |
| averaged_params = updated_averaged_params |
| |
| for p_avg, p_swa in zip( |
| averaged_params, |
| itertools.chain( |
| averaged_dnn.module.parameters(), averaged_dnn.module.buffers() |
| ), |
| ): |
| self.assertEqual(p_avg, p_swa) |
| |
| def _test_update_bn(self, dnn, dl_x, dl_xy, cuda): |
| |
| preactivation_sum = torch.zeros(dnn.n_features) |
| preactivation_squared_sum = torch.zeros(dnn.n_features) |
| if cuda: |
| preactivation_sum = preactivation_sum.cuda() |
| preactivation_squared_sum = preactivation_squared_sum.cuda() |
| total_num = 0 |
| for x in dl_x: |
| x = x[0] |
| if cuda: |
| x = x.cuda() |
| |
| dnn.forward(x) |
| preactivations = dnn.compute_preactivation(x) |
| if len(preactivations.shape) == 4: |
| preactivations = preactivations.transpose(1, 3) |
| preactivations = preactivations.contiguous().view(-1, dnn.n_features) |
| total_num += preactivations.shape[0] |
| |
| preactivation_sum += torch.sum(preactivations, dim=0) |
| preactivation_squared_sum += torch.sum(preactivations**2, dim=0) |
| |
| preactivation_mean = preactivation_sum / total_num |
| preactivation_var = preactivation_squared_sum / total_num |
| preactivation_var = preactivation_var - preactivation_mean**2 |
| |
| update_bn(dl_xy, dnn, device=x.device) |
| self.assertEqual(preactivation_mean, dnn.bn.running_mean) |
| self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) |
| |
| def _reset_bn(module): |
| if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): |
| module.running_mean = torch.zeros_like(module.running_mean) |
| module.running_var = torch.ones_like(module.running_var) |
| |
| # reset batch norm and run update_bn again |
| dnn.apply(_reset_bn) |
| update_bn(dl_xy, dnn, device=x.device) |
| self.assertEqual(preactivation_mean, dnn.bn.running_mean) |
| self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) |
| # using the dl_x loader instead of dl_xy |
| dnn.apply(_reset_bn) |
| update_bn(dl_x, dnn, device=x.device) |
| self.assertEqual(preactivation_mean, dnn.bn.running_mean) |
| self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0) |
| |
| def test_update_bn_dnn(self): |
| # Test update_bn for a fully-connected network with BatchNorm1d |
| objects, input_features = 100, 5 |
| x = torch.rand(objects, input_features) |
| y = torch.rand(objects) |
| ds_x = torch.utils.data.TensorDataset(x) |
| ds_xy = torch.utils.data.TensorDataset(x, y) |
| dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) |
| dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) |
| dnn = SWATestDNN(input_features=input_features) |
| dnn.train() |
| self._test_update_bn(dnn, dl_x, dl_xy, False) |
| if torch.cuda.is_available(): |
| dnn = SWATestDNN(input_features=input_features) |
| dnn.train() |
| self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) |
| self.assertTrue(dnn.training) |
| |
| def test_update_bn_cnn(self): |
| # Test update_bn for convolutional network and BatchNorm2d |
| objects = 100 |
| input_channels = 3 |
| height, width = 5, 5 |
| x = torch.rand(objects, input_channels, height, width) |
| y = torch.rand(objects) |
| ds_x = torch.utils.data.TensorDataset(x) |
| ds_xy = torch.utils.data.TensorDataset(x, y) |
| dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) |
| dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True) |
| dnn = SWATestCNN(input_channels=input_channels) |
| dnn.train() |
| self._test_update_bn(dnn, dl_x, dl_xy, False) |
| if torch.cuda.is_available(): |
| dnn = SWATestCNN(input_channels=input_channels) |
| dnn.train() |
| self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True) |
| self.assertTrue(dnn.training) |
| |
| def test_bn_update_eval_momentum(self): |
| # check that update_bn preserves eval mode |
| objects = 100 |
| input_channels = 3 |
| height, width = 5, 5 |
| x = torch.rand(objects, input_channels, height, width) |
| ds_x = torch.utils.data.TensorDataset(x) |
| dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True) |
| dnn = SWATestCNN(input_channels=input_channels) |
| dnn.eval() |
| update_bn(dl_x, dnn) |
| self.assertFalse(dnn.training) |
| |
| # check that momentum is preserved |
| self.assertEqual(dnn.bn.momentum, 0.3) |
| |
| |
| instantiate_parametrized_tests(TestLRScheduler) |
| |
| |
| def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): |
| # Ignored is the list of values in `opt_differentiable_state`, we do this |
| # for `gradcheck` to correctly track the state tensors as function inputs |
| # because otherwise it can't unpack the values in the `opt_differentiable_state` |
| # dict |
| p = p.clone() |
| p.grad = grad |
| opt_differentiable_state = { |
| k: v.clone() if isinstance(v, torch.Tensor) else v |
| for k, v in opt_differentiable_state.items() |
| } |
| opt = opt_class([p], **kwargs) |
| opt.state[p].update(opt_differentiable_state) |
| opt.step() |
| return (p,) + tuple( |
| v |
| for v in opt.state[p].values() |
| if isinstance(v, torch.Tensor) and v.requires_grad |
| ) |
| |
| |
| class TestDifferentiableOptimizer(TestCase): |
| def test_sgd(self): |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state = {"momentum_buffer": mbuff} |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.SGD, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adam(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["max_exp_avg_sq"] = torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.Adam, |
| {"lr": 0.9, "differentiable": True, "amsgrad": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_rmsprop(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["step"] = 0 |
| state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["momentum_buffer"] = torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| # This can cause issues with large values and nan due to sqrt ops |
| state["grad_avg"] = 1e-2 * torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.RMSprop, |
| { |
| "lr": 0.9, |
| "maximize": True, |
| "momentum": 0.9, |
| "differentiable": True, |
| "centered": True, |
| "weight_decay": 0.1, |
| }, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adadelta(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.Adadelta, |
| {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adagrad(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.Adagrad, |
| {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adamax(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.Adamax, |
| {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_asgd(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` `eta` & `mu` are not continuous variables (even though we define them as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) |
| state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) |
| state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.ASGD, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_rprop(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.Rprop, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adamw(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["max_exp_avg_sq"] = torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.AdamW, |
| {"lr": 0.9, "differentiable": True, "amsgrad": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_nadam(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.NAdam, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_radam(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| torch.optim.RAdam, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |