blob: 744b3d0736fae0dfc7aced45cd931df2dd4a8178 [file] [log] [blame]
# Owner(s): ["module: optimizer"]
import math
import unittest
import functools
import itertools
from copy import deepcopy
import torch
from torch.nn import Parameter
from torch.optim import (
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, LBFGS, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam, Optimizer
)
from torch.optim.lr_scheduler import (
StepLR,
ConstantLR,
LinearLR,
ExponentialLR,
ReduceLROnPlateau,
PolynomialLR,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
gradcheck,
skipIfRocm,
skipIfTorchDynamo
)
from torch.testing._internal.common_cuda import TEST_CUDA
from typing import Dict, Any, Tuple
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
from unittest.mock import patch
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
def rosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
def drosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
@skipIfTorchDynamo("This is a TEMPORARY stopgap, see https://github.com/pytorch/pytorch/issues/103322")
class TestOptim(TestCase):
exact_dtype = True
def _test_rosenbrock_sparse(
self,
constructor,
scheduler_constructors=None,
sparse_only=False,
maximize=False,
multi_tensor=False
):
if scheduler_constructors is None:
scheduler_constructors = []
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
if multi_tensor:
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=torch.float64)]
else:
params_t = [torch.tensor([1.5, 1.5])]
params = [Parameter(param_t) for param_t in params_t]
optimizer = constructor(params)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
if not sparse_only:
params_c = [Parameter(param_t.clone()) for param_t in params_t]
optimizer_c = constructor(params_c)
solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = sum([param.dist(solution) for param in params])
def get_grad(param, sparse_grad):
grad = drosenbrock(param)
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
# Depending on w, provide only the x or y gradient
if sparse_grad:
if w:
i = torch.LongTensor([[0, 0]])
x = grad[0]
v = torch.tensor([x / 4.0, x - x / 4.0])
else:
i = torch.LongTensor([[1, 1]])
y = grad[1]
v = torch.tensor([y - y / 4.0, y / 4.0])
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
else:
if w:
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
else:
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
return grad_out
def eval(params, sparse_grad, w):
optimizer.zero_grad()
if multi_tensor:
loss = sum(rosenbrock(param) for param in params)
else:
loss = rosenbrock(params[0])
loss.backward()
grads_out = [get_grad(param, sparse_grad) for param in params]
with torch.no_grad():
params[0].grad = grads_out[0]
if multi_tensor:
params[1].grad = grads_out[1].to(dtype=torch.float64)
return loss
for i in range(2000):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params[0]))
else:
scheduler.step()
if not sparse_only:
optimizer_c.step(functools.partial(eval, params_c, False, w))
# Tolerance is increased due to floating point error from different
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
if not maximize:
self.assertLessEqual(
sum([param.dist(solution) for param in params]),
initial_dist
)
else:
self.assertGreaterEqual(
sum([rosenbrock(param) for param in params]),
sum([rosenbrock(param_t) for param_t in params_t]),
)
def _test_basic_cases_template(
self,
weight_tensor,
bias_tensor,
input_tensor,
constructor,
scheduler_constructors,
constructor_accepts_maximize=True,
constructor_accepts_foreach=False,
):
maximize_options = {False, constructor_accepts_maximize}
foreach_options = {False, constructor_accepts_foreach}
four_arg_constructor = constructor
if constructor_accepts_maximize and constructor_accepts_foreach:
pass
elif constructor_accepts_maximize:
def four_arg_constructor(weight, bias, maximize, foreach): # noqa: F811
self.assertFalse(foreach)
return constructor(weight, bias, maximize)
elif constructor_accepts_foreach:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize)
return constructor(weight, bias, foreach)
else:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize or foreach)
return constructor(weight, bias)
for maximize, foreach in itertools.product(maximize_options, foreach_options):
with torch.no_grad():
weight = Parameter(weight_tensor.clone().detach())
bias = Parameter(bias_tensor.clone().detach())
input = input_tensor.clone().detach().requires_grad_()
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
optimizer.step(fn)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
val_loss = fn()
scheduler.step(val_loss)
else:
scheduler.step()
if maximize:
self.assertGreater(fn().item(), initial_value)
else:
self.assertLess(fn().item(), initial_value)
def _test_basic_cases(
self,
constructor,
scheduler_constructors=None,
constructor_accepts_maximize=False,
constructor_accepts_foreach=False,
):
if scheduler_constructors is None:
scheduler_constructors = []
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# non-contiguous parameters
self._test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0],
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# CUDA
if not torch.cuda.is_available():
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# Multi-GPU
if not torch.cuda.device_count() > 1:
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
def _test_complex_2d(self, optimizer_constructor):
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
a1_real = a1.real.clone().detach()
a1_imag = a1.imag.clone().detach()
a1_real.requires_grad_()
a1_imag.requires_grad_()
optim1 = optimizer_constructor([a1])
optim2 = optimizer_constructor([a1_real, a1_imag])
for _ in range(10):
optim1.zero_grad()
optim2.zero_grad()
a2 = torch.complex(a1_real, a1_imag)
rosenbrock(a1).abs().backward()
rosenbrock(a2).abs().backward()
self.assertEqual(a1.grad.real, a1_real.grad)
self.assertEqual(a1.grad.imag, a1_imag.grad)
optim1.step()
optim2.step()
self.assertEqual(a1.real, a1_real)
self.assertEqual(a1.imag, a1_imag)
def _build_params_dict(self, weight, bias, **kwargs):
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
def test_sgd(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
lambda opt: ExponentialLR(opt, gamma=0.99),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_sgd_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
lambda params: SGD(params, lr=4.8e-3, foreach=foreach),
multi_tensor=foreach,
)
self._test_rosenbrock_sparse(
lambda params: SGD(params, lr=0.0048, foreach=foreach),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
multi_tensor=foreach,
)
def test_adam(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=torch.tensor(1e-3),
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adam_complex(self):
for foreach in (False, True):
self._test_complex_2d(functools.partial(Adam, foreach=foreach))
self._test_complex_2d(functools.partial(Adam, foreach=foreach, amsgrad=True))
self._test_complex_2d(functools.partial(Adam, foreach=foreach, weight_decay=0.2))
self._test_complex_2d(functools.partial(Adam, foreach=foreach, weight_decay=0.2, amsgrad=True))
self._test_complex_2d(Adam)
self._test_complex_2d(functools.partial(
Adam, lr=torch.tensor(.001), weight_decay=0.2, amsgrad=True,
))
def test_adamw(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: AdamW(
[weight, bias],
lr=torch.tensor(1e-3),
weight_decay=1,
amsgrad=True,
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adamw_complex(self):
self._test_complex_2d(AdamW)
self._test_complex_2d(functools.partial(
AdamW, lr=torch.tensor(.001), weight_decay=0.2, amsgrad=True,
))
for foreach in (False, True):
self._test_complex_2d(functools.partial(AdamW, foreach=foreach))
self._test_complex_2d(functools.partial(AdamW, foreach=foreach, amsgrad=True))
self._test_complex_2d(functools.partial(AdamW, foreach=foreach, weight_decay=0.2))
self._test_complex_2d(functools.partial(AdamW, foreach=foreach, weight_decay=0.2, amsgrad=True))
def test_sparse_adam(self):
self._test_rosenbrock_sparse(
lambda params: SparseAdam(params, lr=4e-2), [], True
)
self._test_rosenbrock_sparse(
lambda params: SparseAdam(params, lr=4e-2, maximize=True),
scheduler_constructors=[],
sparse_only=True,
maximize=True,
)
# ROCm precision is too low to pass this test
def test_adadelta(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 4e-3
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_nadam(self):
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
# NAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
decoupled_weight_decay=True,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ReduceLROnPlateau(opt),
lambda opt: ExponentialLR(opt, gamma=0.99),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
def test_adagrad_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
lambda params: Adagrad(params, lr=1e-1, foreach=foreach),
multi_tensor=foreach,
)
self._test_rosenbrock_sparse(
lambda params: Adagrad(params, lr=0.1, foreach=foreach),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
],
multi_tensor=foreach,
)
def test_adamax(self):
self._test_complex_2d(Adamax)
self._test_complex_2d(functools.partial(Adamax, foreach=True))
def test_radam(self):
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
# RAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
def test_rmsprop(self):
for foreach in (False, True):
self._test_complex_2d(lambda param: RMSprop(param, foreach=foreach))
self._test_complex_2d(
lambda param: RMSprop(param, centered=True, foreach=foreach)
)
self._test_complex_2d(
lambda param: RMSprop(param, momentum=0.1, foreach=foreach)
)
self._test_complex_2d(
lambda param: RMSprop(param, maximize=True, foreach=foreach)
)
@skipIfRocm
@skipIfTorchDynamo()
def test_rprop(self):
for foreach in (False, True):
self._test_complex_2d(lambda param: Rprop(param, foreach=foreach))
def test_lbfgs_returns_consistent_type(self):
params = [torch.randn(10, 5), torch.randn(10)]
opt1 = LBFGS(params, 0.01, tolerance_grad=math.inf)
opt2 = LBFGS(params, 0.01, tolerance_grad=-math.inf)
def closure():
return torch.tensor([10])
res1 = opt1.step(closure)
res2 = opt2.step(closure)
self.assertEqual(type(res1), type(res2))
def test_fused_optimizer_does_not_step_if_foundinf(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is required.")
from torch.optim import adam, adamw, sgd
num_tensors = 5
for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)):
params, grads, exp_avgs, exp_avg_sqs = (
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4))
prev_params = [t.clone().detach() for t in params]
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
functional_optim(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=True,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-2,
weight_decay=0.0,
eps=1e-8,
maximize=False,
grad_scale=grad_scale,
found_inf=found_inf,
)
self.assertEqual(
state_steps,
[
torch.ones((), dtype=torch.float32, device="cuda")
for _ in range(num_tensors)
],
)
self.assertEqual(params, prev_params)
else:
for momentum in (0.0, 0.1):
params, d_p_list, momentum_buffer_list = (
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(3))
if momentum == 0.0:
momentum_buffer_list = [None for _ in range(num_tensors)]
prev_params = [t.clone().detach() for t in params]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
sgd.sgd(
params,
d_p_list,
momentum_buffer_list,
has_sparse_grad=False,
foreach=False,
fused=True,
grad_scale=grad_scale,
found_inf=found_inf,
weight_decay=0.0,
momentum=momentum,
lr=0.01,
dampening=0.0,
nesterov=False,
maximize=False,
)
@skipIfTorchDynamo()
def test_post_hook(self):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.Tensor([1, 1])]
opt = SGD(params, lr=0.001)
data = 2
hook_handle = opt.register_step_post_hook(post_hook)
opt.step()
opt.step()
# check if pre hooks were registered
self.assertEqual(data, 6)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
opt.step()
self.assertEqual(data, 6)
@skipIfTorchDynamo()
def test_pre_hook(self):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.Tensor([1, 1])]
opt = SGD(params, lr=0.001)
data = 5
hook_handle = opt.register_step_pre_hook(pre_hook)
opt.step()
opt.step()
# check if pre hooks were registered
self.assertEqual(data, 9)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
opt.step()
self.assertEqual(data, 9)
@skipIfTorchDynamo()
def test_pre_and_post_hook(self):
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(0)
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(5)
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(1)
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(2)
params = [torch.Tensor([1, 1])]
opt1 = SGD(params, lr=0.001)
opt2 = Adam(params, lr=0.01)
data = []
# register global hooks to both optimizers
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
# register local hooks
first_pre_handle = opt1.register_step_pre_hook(local_pre_hook)
first_post_handle = opt1.register_step_post_hook(local_post_hook)
second_pre_handle = opt2.register_step_pre_hook(local_pre_hook)
second_post_handle = opt2.register_step_post_hook(local_post_hook)
opt1.step()
self.assertListEqual(data, [0, 1, 2, 5])
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
opt1.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
# remove all hooks
global_pre_handle.remove()
global_post_handle.remove()
first_pre_handle.remove()
first_post_handle.remove()
second_pre_handle.remove()
second_post_handle.remove()
opt1.step()
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1
@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict
@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict
@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True
def test_state_dict_pre_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)
def test_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)
def test_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)
def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
state_dict = opt.state_dict()
# usually one would have a new opt instance here, but it's all the same here
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
opt.load_state_dict(state_dict)
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
opt.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
def test_load_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def test_load_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
# for `gradcheck` to correctly track the state tensors as function inputs
# because otherwise it can't unpack the values in the `opt_differentiable_state`
# dict
p = p.clone()
p.grad = grad
opt_differentiable_state = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in opt_differentiable_state.items()
}
opt = opt_class([p], **kwargs)
opt.state[p].update(opt_differentiable_state)
opt.step()
return (p,) + tuple(
v
for v in opt.state[p].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
@skipIfTorchDynamo("Differentiable optimizers not supported")
class TestDifferentiableOptimizer(TestCase):
def test_sgd(self):
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
gradcheck(
_diff_fn,
(
p,
grad,
state,
SGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adam,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_rmsprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step"] = 0
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["momentum_buffer"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# This can cause issues with large values and nan due to sqrt ops
state["grad_avg"] = 1e-2 * torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RMSprop,
{
"lr": 0.9,
"maximize": True,
"momentum": 0.9,
"differentiable": True,
"centered": True,
"weight_decay": 0.1,
},
*state.values(),
),
)
def test_adadelta(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adadelta,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adagrad(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adagrad,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adamax(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Adamax,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
"since this is only happening when differentiable is enabled, skipping for now")
def test_asgd(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
# and so they shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
ASGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_rprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
Rprop,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adamw(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
AdamW,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_nadam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
NAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
NAdam,
{"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
def test_radam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
RAdam,
{"lr": 0.9, "weight_decay": 0.1, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_defaults_changed_to_foreach(self):
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
asgd, adamax, adadelta, adagrad)
multi_optims = ((Adam, adam, "_multi_tensor_adam"),
(AdamW, adamw, "_multi_tensor_adamw"),
(NAdam, nadam, "_multi_tensor_nadam"),
(SGD, sgd, "_multi_tensor_sgd"),
(RAdam, radam, "_multi_tensor_radam"),
(RMSprop, rmsprop, "_multi_tensor_rmsprop"),
(Rprop, rprop, "_multi_tensor_rprop"),
(ASGD, asgd, "_multi_tensor_asgd"),
(Adamax, adamax, "_multi_tensor_adamax"),
(Adadelta, adadelta, "_multi_tensor_adadelta"),
(Adagrad, adagrad, "_multi_tensor_adagrad"),)
model = torch.nn.Linear(5, 5)
model.to(dtype=torch.float64, device="cuda")
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
for opt, mod, func in multi_optims:
defaults = {}
if opt == SGD:
defaults["lr"] = 1e-2
optimizer = opt(model.parameters(), **defaults)
optimizer.zero_grad()
output = model(input)
loss = output.sum()
loss.backward()
with patch.object(mod, func) as mocked_foreach_impl:
optimizer.step()
self.assertTrue(mocked_foreach_impl.called)
if __name__ == "__main__":
print("These tests should be run through test/test_optim.py instead")