blob: 22120aa0841d8584918612f6d9086c02180af6e6 [file] [log] [blame]
# Owner(s): ["module: optimizer"]
import math
import unittest
import functools
import itertools
from copy import deepcopy
import torch
import torch.optim as optim
from torch.nn import Parameter
from torch.optim import Adam, SGD, Optimizer
from torch.optim.lr_scheduler import (
StepLR,
ConstantLR,
LinearLR,
ExponentialLR,
ReduceLROnPlateau,
PolynomialLR,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
gradcheck,
skipIfRocm,
skipIfTorchDynamo
)
from torch._dynamo import disable as disable_dynamo
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_device_type import largeTensorTest
from typing import Dict, Any, Tuple
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
from unittest.mock import patch
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
def rosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return (1 - x) ** 2 + 100 * (y - x**2) ** 2
def drosenbrock(tensor):
assert tensor.size() == torch.Size([2]), f"Requires tensor with 2 scalars but got {tensor.size()}"
x, y = tensor
return torch.tensor((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2)))
@skipIfTorchDynamo("This is a TEMPORARY stopgap, see https://github.com/pytorch/pytorch/issues/103322")
class TestOptim(TestCase):
exact_dtype = True
def _test_rosenbrock_sparse(
self,
constructor,
scheduler_constructors=None,
sparse_only=False,
maximize=False,
multi_tensor=False
):
if scheduler_constructors is None:
scheduler_constructors = []
# For rosenbrock tests, it is mandated that the param is a tensor with 2 numbers
if multi_tensor:
params_t = [torch.tensor([1.5, 1.5]), torch.tensor([1.5, 1.5], dtype=torch.float64)]
else:
params_t = [torch.tensor([1.5, 1.5])]
params = [Parameter(param_t) for param_t in params_t]
optimizer = constructor(params)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
if not sparse_only:
params_c = [Parameter(param_t.clone()) for param_t in params_t]
optimizer_c = constructor(params_c)
solution = torch.tensor([1, 1])
with torch.no_grad():
initial_dist = sum([param.dist(solution) for param in params])
def get_grad(param, sparse_grad):
grad = drosenbrock(param)
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
# Depending on w, provide only the x or y gradient
if sparse_grad:
if w:
i = torch.LongTensor([[0, 0]])
x = grad[0]
v = torch.tensor([x / 4.0, x - x / 4.0])
else:
i = torch.LongTensor([[1, 1]])
y = grad[1]
v = torch.tensor([y - y / 4.0, y / 4.0])
grad_out = torch.sparse_coo_tensor(i, v, (2,), dtype=v.dtype)
else:
if w:
grad_out = torch.tensor([grad[0], 0], dtype=param.dtype)
else:
grad_out = torch.tensor([0, grad[1]], dtype=param.dtype)
return grad_out
def eval(params, sparse_grad, w):
optimizer.zero_grad()
if multi_tensor:
loss = sum(rosenbrock(param) for param in params)
else:
loss = rosenbrock(params[0])
loss.backward()
grads_out = [get_grad(param, sparse_grad) for param in params]
with torch.no_grad():
params[0].grad = grads_out[0]
if multi_tensor:
params[1].grad = grads_out[1].to(dtype=torch.float64)
return loss
for i in range(2000):
# Do cyclic coordinate descent
w = i % 2
optimizer.step(functools.partial(eval, params, True, w))
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(rosenbrock(params[0]))
else:
scheduler.step()
if not sparse_only:
optimizer_c.step(functools.partial(eval, params_c, False, w))
# Tolerance is increased due to floating point error from different
# code path for dense case: x v.s. x - x / 4.0 + x / 4.0
self.assertEqual(params, params_c, atol=5e-6, rtol=5e-6)
if not maximize:
self.assertLessEqual(
sum([param.dist(solution) for param in params]),
initial_dist
)
else:
self.assertGreaterEqual(
sum([rosenbrock(param) for param in params]),
sum([rosenbrock(param_t) for param_t in params_t]),
)
def _test_basic_cases_template(
self,
weight_tensor,
bias_tensor,
input_tensor,
constructor,
scheduler_constructors,
constructor_accepts_maximize=True,
constructor_accepts_foreach=False,
):
maximize_options = {False, constructor_accepts_maximize}
foreach_options = {False, constructor_accepts_foreach}
four_arg_constructor = constructor
if constructor_accepts_maximize and constructor_accepts_foreach:
pass
elif constructor_accepts_maximize:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(foreach)
return constructor(weight, bias, maximize)
elif constructor_accepts_foreach:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize)
return constructor(weight, bias, foreach)
else:
def four_arg_constructor(weight, bias, maximize, foreach):
self.assertFalse(maximize or foreach)
return constructor(weight, bias)
for maximize, foreach in itertools.product(maximize_options, foreach_options):
with torch.no_grad():
weight = Parameter(weight_tensor.clone().detach())
bias = Parameter(bias_tensor.clone().detach())
input = input_tensor.clone().detach().requires_grad_()
optimizer = four_arg_constructor(weight, bias, maximize, foreach)
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
# to check if the optimizer can be printed as a string
optimizer.__repr__()
def fn():
optimizer.zero_grad()
y = weight.mv(input)
if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
y = y.cuda(bias.get_device())
loss = (y + bias).pow(2).sum()
loss.backward()
return loss
initial_value = fn().item()
for _ in range(200):
optimizer.step(fn)
for scheduler in schedulers:
if isinstance(scheduler, ReduceLROnPlateau):
val_loss = fn()
scheduler.step(val_loss)
else:
scheduler.step()
if maximize:
self.assertGreater(fn().item(), initial_value)
else:
self.assertLess(fn().item(), initial_value)
# Note: disable dynamo on this function
# This allows us to continue running actual logic of the optimizer
# tests in dynamo without tracing this test code which has a lot of unsupported
# behavior
@disable_dynamo(recursive=False)
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
weight = Parameter(weight)
bias = Parameter(bias)
with torch.no_grad():
input = input.clone().detach().requires_grad_()
# Note: Disable dynamo on this function
# This avoids a bug where input_cuda is not detected in the environment
# because it currently is not defined in the local environmet. Unable to repro
# anywhere else however and this is test code that we don't need to spend
# time getting dynamo to trace unless the issue repros in real models.
@disable_dynamo(recursive=False)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_cuda if weight.is_cuda else input
loss = (weight.mv(i) + bias).pow(2).sum()
loss.backward()
return loss
optimizer = constructor(weight, bias)
fn = functools.partial(fn_base, optimizer, weight, bias)
# Prime the optimizer
for _i in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
with torch.no_grad():
weight_c = Parameter(weight.clone().detach())
bias_c = Parameter(bias.clone().detach())
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizers in parallel
for _ in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
# Make sure state dict is deterministic with equal but not identical parameters
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
# Make sure repeated parameters have identical representation in state dict
optimizer_c.param_groups.extend(optimizer_c.param_groups)
self.assertEqual(
optimizer.state_dict()["param_groups"][-1],
optimizer_c.state_dict()["param_groups"][-1],
)
# Make sure that optimizers that support maximize can load older models
old_state_dict = deepcopy(optimizer.state_dict())
state_dict_no_maximize = deepcopy(optimizer.state_dict())
if "maximize" in state_dict_no_maximize["param_groups"][0]:
for group in state_dict_no_maximize["param_groups"]:
del group["maximize"]
optimizer.load_state_dict(state_dict_no_maximize)
# Make sure we can still step
optimizer.step()
# Undo these changes before proceeding!
optimizer.load_state_dict(old_state_dict)
# Make sure that optimizers that support foreach can load older models
state_dict_no_foreach = deepcopy(optimizer.state_dict())
if "foreach" in state_dict_no_foreach["param_groups"][0]:
for group in state_dict_no_foreach["param_groups"]:
del group["foreach"]
optimizer.load_state_dict(state_dict_no_foreach)
# Make sure we can still step
optimizer.step()
# Undo these changes before proceeding!
optimizer.load_state_dict(old_state_dict)
# Make sure that loading optimizers with step not wrapped in tensor can work
state_dict = optimizer.state_dict()
if "step" in state_dict["state"][0] and torch.is_tensor(
state_dict["state"][0]["step"]
):
for state in state_dict["state"].values():
state["step"] = state["step"].item()
optimizer.load_state_dict(state_dict)
optimizer.step()
# Check that state dict can be loaded even when we cast parameters
# to a different type and move to a different device.
if not torch.cuda.is_available():
return
with torch.no_grad():
input_cuda = input.clone().detach().to(dtype=torch.float32, device="cuda")
weight_cuda = Parameter(
weight.clone().detach().to(dtype=torch.float32, device="cuda")
)
bias_cuda = Parameter(
bias.clone().detach().to(dtype=torch.float32, device="cuda")
)
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)
# Make sure state_dict_c isn't modified by merely calling load_state_dict
self.assertEqual(state_dict, state_dict_c)
# Make sure that device of state['step'] is still CPU
new_state_dict = optimizer_cuda.state_dict()
if "step" in state_dict["state"][0] and torch.is_tensor(
state_dict["state"][0]["step"]
):
for state in new_state_dict["state"].values():
self.assertEqual(state["step"].device.type, "cpu")
for _ in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
self.assertEqual(weight, weight_cuda)
self.assertEqual(bias, bias_cuda, atol=atol, rtol=rtol)
# validate deepcopy() copies all public attributes
def getPublicAttr(obj):
return {k for k in obj.__dict__ if not k.startswith("_")}
self.assertEqual(getPublicAttr(optimizer), getPublicAttr(deepcopy(optimizer)))
def _test_basic_cases(
self,
constructor,
scheduler_constructors=None,
ignore_multidevice=False,
constructor_accepts_maximize=False,
constructor_accepts_foreach=False,
atol=None,
rtol=None,
):
if scheduler_constructors is None:
scheduler_constructors = []
def make_two_arg_constructor(
constructor, maximize: bool, foreach: bool
):
if constructor_accepts_maximize and constructor_accepts_foreach:
return lambda weight, bias: constructor(weight, bias, maximize, foreach)
if constructor_accepts_maximize:
return lambda weight, bias: constructor(weight, bias, maximize)
if constructor_accepts_foreach:
return lambda weight, bias: constructor(weight, bias, foreach)
return constructor
for maximize, foreach in itertools.product(
{False, constructor_accepts_maximize},
{False, constructor_accepts_foreach},
):
self._test_state_dict(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
make_two_arg_constructor(constructor, maximize, foreach),
atol=atol,
rtol=rtol,
)
self._test_basic_cases_template(
torch.randn(10, 5),
torch.randn(10),
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# non-contiguous parameters
self._test_basic_cases_template(
torch.randn(10, 5, 2)[..., 0],
torch.randn(10, 2)[..., 0],
torch.randn(5),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# CUDA
if not torch.cuda.is_available():
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(),
torch.randn(10).cuda(),
torch.randn(5).cuda(),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
# Multi-GPU
if not torch.cuda.device_count() > 1 or ignore_multidevice:
return
self._test_basic_cases_template(
torch.randn(10, 5).cuda(0),
torch.randn(10).cuda(1),
torch.randn(5).cuda(0),
constructor,
scheduler_constructors,
constructor_accepts_maximize,
constructor_accepts_foreach,
)
def _test_complex_optimizer(self, optimizer_constructor):
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
complex_opt = optimizer_constructor(complex_param)
real_opt = optimizer_constructor(real_param)
for _ in range(3):
complex_param.grad = torch.randn_like(complex_param)
real_param.grad = torch.view_as_real(complex_param.grad)
complex_opt.step()
real_opt.step()
self.assertEqual(torch.view_as_real(complex_param), real_param)
def _test_complex_2d(self, optimizer_constructor):
a1 = torch.randn(2, dtype=torch.complex64, requires_grad=True)
a1_real = a1.real.clone().detach()
a1_imag = a1.imag.clone().detach()
a1_real.requires_grad_()
a1_imag.requires_grad_()
optim1 = optimizer_constructor([a1])
optim2 = optimizer_constructor([a1_real, a1_imag])
for _ in range(10):
optim1.zero_grad()
optim2.zero_grad()
a2 = torch.complex(a1_real, a1_imag)
rosenbrock(a1).abs().backward()
rosenbrock(a2).abs().backward()
self.assertEqual(a1.grad.real, a1_real.grad)
self.assertEqual(a1.grad.imag, a1_imag.grad)
optim1.step()
optim2.step()
self.assertEqual(a1.real, a1_real)
self.assertEqual(a1.imag, a1_imag)
def _build_params_dict(self, weight, bias, **kwargs):
return [{"params": [weight]}, dict(params=[bias], **kwargs)]
def _build_params_dict_single(self, weight, bias, **kwargs):
return [dict(params=bias, **kwargs)]
def test_sgd(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
self._build_params_dict_single(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
self._build_params_dict_single(weight, bias, lr=1e-2),
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.9, step_size=10)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.8, total_iters=4
)
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[lambda opt: PolynomialLR(opt, power=0.9, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: LinearLR(
opt, start_factor=0.4, end_factor=0.6, total_iters=4
),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
[
lambda opt: StepLR(opt, gamma=0.99, step_size=10),
lambda opt: ExponentialLR(opt, gamma=0.99),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias],
lr=1e-3,
momentum=0.5,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias],
lr=1e-3,
momentum=0.5,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.SGD(
[weight, bias],
nesterov=True,
lr=1e-3,
momentum=0.5,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
optim.SGD(None, lr=1e-2, momentum=-0.5)
def test_sgd_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
lambda params: optim.SGD(params, lr=4.8e-3, foreach=foreach),
multi_tensor=foreach,
)
self._test_rosenbrock_sparse(
lambda params: optim.SGD(params, lr=0.0048, foreach=foreach),
scheduler_constructors=[lambda opt: StepLR(opt, gamma=0.99999, step_size=300)],
multi_tensor=foreach,
)
def test_sgd_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: optim.SGD([param], lr=0.001, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.SGD([param], lr=0.001, momentum=1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.SGD(
[param], lr=0.001, momentum=1, weight_decay=1, foreach=foreach
)
)
self._test_complex_optimizer(
lambda param: optim.SGD(
[param],
lr=0.001,
nesterov=True,
momentum=1,
weight_decay=1,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: optim.SGD(
[param],
lr=0.001,
momentum=1,
dampening=0.5,
weight_decay=1,
foreach=foreach,
)
)
def _test_derived_optimizers_varying_tensors(self, optimizer_with_kwargs, kwarg):
if not torch.cuda.is_available():
return
assert kwarg in ("foreach", "fused")
# Specifically test that inputting params of different dtypes and devices
# is handled equivalently on the foreach and fused implementations as the
# single tensor implementations. We need multiple GPUs (vs just a CPU and
# GPU) because fused adam only works on GPUs. (Thus we only run the tests
# that call into this helper when TEST_MULTIGPU.)
params = [
torch.rand(2, 3, dtype=torch.float64, device='cuda:0', requires_grad=True),
torch.rand(2, 3, dtype=torch.float32, device='cuda:0', requires_grad=True),
torch.rand(2, 3, dtype=torch.float16, device='cuda:0', requires_grad=True),
torch.rand(2, 3, dtype=torch.bfloat16, device='cuda:0', requires_grad=True),
torch.rand(2, 3, dtype=torch.float64, device='cuda:1', requires_grad=True),
torch.rand(2, 3, dtype=torch.float32, device='cuda:1', requires_grad=True),
torch.rand(2, 3, dtype=torch.float16, device='cuda:1', requires_grad=True),
torch.rand(2, 3, dtype=torch.bfloat16, device='cuda:1', requires_grad=True),
torch.randint(1024, (2, 3), dtype=torch.int64, device='cuda:1', requires_grad=False),
]
for p in params:
if p.requires_grad:
p.grad = torch.rand_like(p, device=p.device, dtype=p.dtype)
kIterations = 7 if kwarg == "foreach" else 1
for optimizer_constructor, kwargs in optimizer_with_kwargs:
res, state = [], []
for enabled in (False, True):
kwargs_clone = deepcopy(kwargs)
if optimizer_constructor.__name__ == "ASGD" and kwarg == "foreach" and not enabled:
# single tensor ASGD does not support capturable
kwargs_clone["capturable"] = False
kwargs_clone[kwarg] = enabled
params_clone = []
for p in params:
p_clone = p.clone().detach()
if p.requires_grad:
p_clone.requires_grad = True
p_clone.grad = p.grad.clone().detach()
params_clone.append(p_clone)
optimizer = optimizer_constructor(params_clone, **kwargs_clone)
for _ in range(kIterations):
optimizer.step()
state.append(optimizer.state)
res.append(params_clone)
st_state = state[0]
mt_state = state[1]
for st_p, mt_p in zip(res[0], res[1]):
# Increasing the tolerance as we are collating lots of ops together for optimizers and
# the designated tolerances are for single op only.
single_rtol, single_atol = torch.testing._comparison.get_tolerances(mt_p.dtype, rtol=None, atol=None)
rtol = 5 * single_rtol
atol = 5 * single_atol
self.assertEqual(st_p, mt_p, rtol=rtol, atol=atol)
# check that optimizer states are the same
st_p_state = st_state[st_p]
mt_p_state = mt_state[mt_p]
for k in st_p_state:
actual = mt_p_state[k]
self.assertEqual(st_p_state[k], actual, rtol=rtol, atol=atol)
def _test_derived_optimizers(self, optimizer_pairs_with_flags, flag):
if not torch.cuda.is_available():
return
assert flag in ("foreach", "fused")
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
kIterations = 7
device = "cuda"
for optimizer_constructor, params in optimizer_pairs_with_flags:
res, state = [], []
for flag_value in (False, True):
input = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=torch.float64, device=device
).reshape(3, 2)
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model.to(dtype=torch.float64, device=device)
params_with_flags = deepcopy(params)
if optimizer_constructor.__name__ == "ASGD" and flag == "foreach" and not flag_value:
# single tensor ASGD does not support capturable
params_with_flags["capturable"] = False
params_with_flags[flag] = flag_value
# foreach/fused optimizers should be tested with a param_groups['params'] with
# zero_size tensor as its last param.
# ref: https://github.com/pytorch/pytorch/issues/100701
empty_params = [torch.empty((), device=device, dtype=torch.float64)]
optimizer = optimizer_constructor(
list(model.parameters()) + empty_params, **params_with_flags
)
for i in range(kIterations):
optimizer.zero_grad()
output = model(input)
loss = output.sum()
loss.backward()
# Test that step behaves as expected (a no-op) when grads are set to None
if i == 0:
optimizer.zero_grad(set_to_none=True)
optimizer.step()
state.append(optimizer.state)
res.append(model.parameters())
st_state = state[0]
mt_state = state[1]
for st_p, mt_p in zip(res[0], res[1]):
self.assertEqual(st_p, mt_p)
# check that optimizer states are the same
st_p_state = st_state[st_p]
mt_p_state = mt_state[mt_p]
for k in st_p_state:
self.assertEqual(st_p_state[k], mt_p_state[k])
def _test_foreach_memory(self, optimizer_pairs_with_flags):
if not torch.cuda.is_available():
return
device = "cuda"
nparams = 10
for optimizer_constructor, kwargs in optimizer_pairs_with_flags:
max_mems = []
for flag_value in (False, True):
kwargs_with_flags = deepcopy(kwargs)
if optimizer_constructor.__name__ == "ASGD" and kwargs_with_flags.get("capturable", False) and not flag_value:
# single tensor ASGD does not support capturable
kwargs_with_flags["capturable"] = False
kwargs_with_flags["foreach"] = flag_value
# The 128 is critical here! Our CUDACachingAllocator allocates in blocks of 512,
# meaning any tensor that occupies <512 bytes of memory will allocate a whole
# 512 bytes anyway. We use 128 (since datasize would be 4 bytes) so that param
# is size 512 exactly, making our later calculations for intermediate_size easy.
param = torch.rand(128, device=device)
params = [torch.rand_like(param) for _ in range(nparams)]
optimizer = optimizer_constructor(
params, **kwargs_with_flags
)
for p in params:
p.grad = torch.rand_like(p)
optimizer.step()
import gc
gc.collect()
torch.cuda.reset_peak_memory_stats()
optimizer.step()
gc.collect()
max_mems.append(torch.cuda.max_memory_allocated())
st_max_mem, mt_max_mem = max_mems
intermediate_size = nparams * param.nelement() * param.element_size()
nintermediates = 1 # we expect a budget of 1 intermediate most of the time
if (('capturable' in kwargs_with_flags and kwargs_with_flags['capturable']) or
optimizer_constructor.__name__ in ["Adadelta", "ASGD"]):
# with capturable in Adam(W), we have 2 extra intermediates for the bias_corrections
# with Adadelta, we have 2 extra for (acc_delta + eps) and (square_avg + eps)
# ASGD allocates axs, 2x mus, 2x etas, and grads at the same time
nintermediates = 3
if optimizer_constructor.__name__ == "NAdam":
# with capturable in NAdam, we have 3 extra intermediates for the
# bias_correction, mus, and mu_nexts
nintermediates = 5
elif optimizer_constructor.__name__ in ["NAdam", "Adagrad", "RMSprop"]:
# NAdam uses two intermediates at the same time (grads & exp_avg_sq_sqrt)
# Adagrad uses std and grads at the same time
# RMSprop uses avg and grads
nintermediates = 2
self.assertLessEqual(mt_max_mem, st_max_mem + intermediate_size * nintermediates)
@property
def _multi_tensor_optimizer_configs(self):
return [
(optim.Adam, dict(weight_decay=1.0, amsgrad=False)),
(optim.Adam, dict(weight_decay=0.0, amsgrad=True)),
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, maximize=True)),
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, maximize=True)),
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, capturable=True, maximize=True)),
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, capturable=True, maximize=True)),
(
optim.Adam,
dict(lr=torch.tensor(.001), weight_decay=1.0, amsgrad=True,
capturable=True, maximize=True)
),
(optim.AdamW, dict(weight_decay=1.0, amsgrad=False)),
(optim.AdamW, dict(weight_decay=0.0, amsgrad=True)),
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True, maximize=True)),
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False, maximize=True)),
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True, capturable=True, maximize=True)),
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False, capturable=True, maximize=True)),
(
optim.AdamW,
dict(lr=torch.tensor(.001), weight_decay=0.0, amsgrad=False,
capturable=True, maximize=True)
),
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)),
(optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)),
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)),
(optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)),
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3, capturable=True)),
(optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3, capturable=True)),
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3, decoupled_weight_decay=True)),
(
optim.NAdam,
dict(weight_decay=0.01, momentum_decay=4e-3, decoupled_weight_decay=True),
),
(
optim.NAdam,
dict(weight_decay=0.01, momentum_decay=4e-3,
decoupled_weight_decay=True, capturable=True),
),
(
optim.SGD,
dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True),
),
(
optim.SGD,
dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False),
),
(
optim.SGD,
dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True, maximize=True),
),
(
optim.SGD,
dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False, maximize=True),
),
(optim.RAdam, dict(weight_decay=0, eps=1e-6)),
(optim.RAdam, dict(weight_decay=0)),
(optim.RAdam, dict(weight_decay=1, eps=1e-6)),
(optim.RAdam, dict(weight_decay=1)),
(optim.RAdam, dict(weight_decay=0, decoupled_weight_decay=True)),
(optim.RAdam, dict(weight_decay=1, decoupled_weight_decay=True)),
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
(optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
(optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)),
(optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))),
(optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50), maximize=True)),
(optim.ASGD, dict(weight_decay=0)),
(optim.ASGD, dict(weight_decay=1)),
(optim.ASGD, dict(weight_decay=0, maximize=True)),
(optim.ASGD, dict(weight_decay=1, maximize=True)),
(optim.ASGD, dict(weight_decay=0, capturable=True)),
(optim.ASGD, dict(weight_decay=1, capturable=True)),
(optim.ASGD, dict(weight_decay=0, maximize=True, capturable=True)),
(optim.ASGD, dict(weight_decay=1, maximize=True, capturable=True)),
(optim.Adamax, dict(weight_decay=0)),
(optim.Adamax, dict(weight_decay=1)),
(optim.Adamax, dict(weight_decay=0, maximize=True)),
(optim.Adamax, dict(weight_decay=1, maximize=True)),
(optim.Adadelta, dict(weight_decay=0)),
(optim.Adadelta, dict(weight_decay=1)),
(optim.Adadelta, dict(weight_decay=0, maximize=True)),
(optim.Adadelta, dict(weight_decay=1, maximize=True)),
(optim.Adagrad, dict(weight_decay=0)),
(optim.Adagrad, dict(weight_decay=1)),
(optim.Adagrad, dict(weight_decay=0, maximize=True)),
(optim.Adagrad, dict(weight_decay=1, maximize=True)),
]
def test_multi_tensor_optimizers(self):
self._test_derived_optimizers(self._multi_tensor_optimizer_configs, "foreach")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_multi_tensor_optimizers_with_varying_tensors(self):
self._test_derived_optimizers_varying_tensors(self._multi_tensor_optimizer_configs, "foreach")
@unittest.skipIf(not torch.cuda.is_available(), "Requires a GPU")
@largeTensorTest("72GB", "cuda")
@skipIfRocm
def test_multi_tensor_optimizers_with_large_tensors(self):
for optimizer_ctor, optimizer_params in self._multi_tensor_optimizer_configs:
# note(crcrpar): H100 wasn't sufficient for Adamax, surprisingly
if optimizer_ctor == optim.Adamax:
continue
params = [torch.ones(2 ** 32, device="cuda", dtype=torch.float16)]
params[0].grad = torch.zeros_like(params[0])
optimizer = optimizer_ctor(params, foreach=True, **optimizer_params)
optimizer.step()
def test_peak_mem_multi_tensor_optimizers(self):
configs = [
(o, d) for (o, d) in self._multi_tensor_optimizer_configs if o.__name__ in [
"Adadelta", "Adagrad", "Adamax", "Adam", "AdamW", "ASGD", "NAdam",
"RAdam", "RMSprop", "RProp", "SGD"
]
]
self._test_foreach_memory(configs)
@property
def _fused_optimizer_configs(self):
return tuple(itertools.product(
(optim.Adam, optim.AdamW),
(
dict(weight_decay=1., lr=torch.tensor(0.001), amsgrad=False, capturable=True, maximize=True),
dict(weight_decay=1., amsgrad=False, capturable=True, maximize=True),
dict(weight_decay=1., amsgrad=False, maximize=True),
dict(weight_decay=1., amsgrad=True),
dict(weight_decay=0., amsgrad=False),
dict(weight_decay=0., amsgrad=True, capturable=True, maximize=True),
dict(weight_decay=0., amsgrad=True, maximize=True),
),
))
def test_fused_optimizers(self):
self._test_derived_optimizers(self._fused_optimizer_configs, "fused")
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_fused_optimizers_with_varying_tensors(self):
self._test_derived_optimizers_varying_tensors(self._fused_optimizer_configs, "fused")
@unittest.skipIf(not torch.cuda.is_available(), "Requires a GPU")
@largeTensorTest("64GB", "cuda")
@skipIfRocm
def test_fused_optimizers_with_large_tensors(self):
for optimizer_ctor, optimizer_params in self._fused_optimizer_configs:
params = [torch.ones(2 ** 32, device="cuda", dtype=torch.float16)]
params[0].grad = torch.zeros_like(params[0])
optimizer = optimizer_ctor(params, fused=True, **optimizer_params)
optimizer.step()
def test_adam(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
lambda opt: ExponentialLR(opt, gamma=0.9),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
[weight, bias],
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adam(
self._build_params_dict(weight, bias, lr=1e-2),
lr=torch.tensor(1e-3),
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
[lambda opt: PolynomialLR(opt, total_iters=4, power=0.9)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(optim.Adam)
self._test_complex_2d(functools.partial(optim.Adam, foreach=False))
self._test_complex_2d(functools.partial(optim.Adam, foreach=False, amsgrad=True))
self._test_complex_2d(functools.partial(optim.Adam, weight_decay=0.2))
self._test_complex_2d(functools.partial(optim.Adam, weight_decay=0.2, amsgrad=True))
self._test_complex_2d(functools.partial(
optim.Adam, lr=torch.tensor(.001), weight_decay=0.2, amsgrad=True,
))
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
):
optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.Adam(None, lr=1e-2, weight_decay=-1)
with self.assertRaisesRegex(
ValueError, "lr as a Tensor is not supported for capturable=False and foreach=True"
):
optim.Adam(None, lr=torch.tensor(0.001), foreach=True)
def test_adamw(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.AdamW(
[weight, bias], lr=1e-3, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.AdamW(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.AdamW(
[weight, bias],
lr=1e-3,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.AdamW(
[weight, bias],
lr=1e-3,
weight_decay=1,
amsgrad=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.AdamW(
[weight, bias],
lr=torch.tensor(1e-3),
weight_decay=1,
amsgrad=True,
maximize=maximize,
foreach=False, # foreach for lr tensors tested in multi configs
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(optim.AdamW)
self._test_complex_2d(functools.partial(optim.AdamW, foreach=False))
self._test_complex_2d(functools.partial(optim.AdamW, foreach=False, amsgrad=True))
self._test_complex_2d(functools.partial(optim.AdamW, weight_decay=0.2))
self._test_complex_2d(functools.partial(optim.AdamW, weight_decay=0.2, amsgrad=True))
self._test_complex_2d(functools.partial(
optim.AdamW, lr=torch.tensor(.001), weight_decay=0.2, amsgrad=True,
))
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.AdamW(None, lr=1e-2, weight_decay=-1)
with self.assertRaisesRegex(
ValueError, "lr as a Tensor is not supported for capturable=False and foreach=True"
):
optim.AdamW(None, lr=torch.tensor(0.001), foreach=True)
def test_sparse_adam(self):
self._test_rosenbrock_sparse(
lambda params: optim.SparseAdam(params, lr=4e-2), [], True
)
self._test_rosenbrock_sparse(
lambda params: optim.SparseAdam(params, lr=4e-2, maximize=True),
scheduler_constructors=[],
sparse_only=True,
maximize=True,
)
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
):
optim.SparseAdam(None, lr=1e-2, betas=(1.0, 0.0))
with self.assertRaisesRegex(
ValueError, "SparseAdam requires dense parameter tensors"
):
optim.SparseAdam([torch.zeros(3, layout=torch.sparse_coo)])
with self.assertRaisesRegex(
ValueError, "SparseAdam requires dense parameter tensors"
):
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])
# ROCm precision is too low to pass this test
def test_adadelta(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 4e-3
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adadelta(
[weight, bias], maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adadelta(
self._build_params_dict(weight, bias, rho=0.95),
maximize=maximize,
foreach=foreach,
),
[
lambda opt: StepLR(opt, gamma=0.9, step_size=10),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adadelta(
[weight, bias], weight_decay=1, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
optim.Adadelta(None, lr=1e-2, rho=1.1)
def test_adadelta_complex(self):
# Handles https://github.com/pytorch/pytorch/issues/110606
self.rel_tol = 2e-2
for foreach in (False, True):
self._test_complex_optimizer(lambda weight: optim.Adadelta([weight], foreach=foreach))
self._test_complex_optimizer(lambda weight: optim.Adadelta([weight], rho=0.95, foreach=foreach))
self._test_complex_optimizer(
lambda weight: optim.Adadelta([weight], rho=0.95, weight_decay=1, foreach=foreach)
)
def test_nadam(self):
self._test_basic_cases(
lambda weight, bias, foreach: optim.NAdam(
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.NAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
foreach=foreach,
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
# NAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: optim.NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
decoupled_weight_decay=True,
foreach=foreach,
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.NAdam(
[weight, bias],
lr=1e-3,
weight_decay=0.1,
momentum_decay=6e-3,
decoupled_weight_decay=True,
foreach=foreach,
),
[lambda opt: ExponentialLR(opt, gamma=0.9)],
constructor_accepts_foreach=True,
)
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
):
optim.NAdam(None, lr=1e-2, betas=(1.0, 0.0))
with self.assertRaisesRegex(ValueError, "Invalid momentum_decay value: -0.2"):
optim.NAdam(None, lr=1e-2, momentum_decay=-0.2)
def test_nadam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: optim.NAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.NAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: optim.NAdam(
[param],
lr=1e-1,
momentum_decay=0.01,
foreach=foreach,
)
)
def test_adagrad(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adagrad(
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adagrad(
[weight, bias],
lr=1e-1,
initial_accumulator_value=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[lambda opt: ReduceLROnPlateau(opt)],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adagrad(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
[
lambda opt: ReduceLROnPlateau(opt),
lambda opt: ExponentialLR(opt, gamma=0.99),
],
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"):
optim.Adagrad(None, lr=1e-2, lr_decay=-0.5)
def test_adagrad_sparse(self):
for foreach in (False, True):
self._test_rosenbrock_sparse(
lambda params: optim.Adagrad(params, lr=1e-1, foreach=foreach),
multi_tensor=foreach,
)
self._test_rosenbrock_sparse(
lambda params: optim.Adagrad(params, lr=0.1, foreach=foreach),
scheduler_constructors=[
lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
lambda opt: ReduceLROnPlateau(opt, threshold=1e-4),
],
multi_tensor=foreach,
)
def test_adagrad_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: optim.Adagrad([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.Adagrad(
[param],
lr=1e-1,
initial_accumulator_value=0.1,
foreach=foreach,
)
)
def test_adamax(self):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adamax(
[weight, bias], lr=1e-1, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adamax(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Adamax(
[weight, bias],
lr=1e-1,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(optim.Adamax)
self._test_complex_2d(functools.partial(optim.Adamax, foreach=True))
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 1: 1.0"
):
optim.Adamax(None, lr=1e-2, betas=(0.0, 1.0))
def test_radam(self):
self._test_basic_cases(
lambda weight, bias, foreach: optim.RAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.RAdam(
self._build_params_dict(weight, bias, lr=1e-2), lr=1e-3, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.RAdam(
[weight, bias], lr=1e-3, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
# RAdamW tests
self._test_basic_cases(
lambda weight, bias, foreach: optim.RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
),
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, foreach: optim.RAdam(
[weight, bias], lr=1e-3, weight_decay=0.1, decoupled_weight_decay=True, foreach=foreach
),
[
lambda opt: ExponentialLR(opt, gamma=0.9),
lambda opt: ReduceLROnPlateau(opt),
],
constructor_accepts_foreach=True,
)
with self.assertRaisesRegex(
ValueError, "Invalid beta parameter at index 0: 1.0"
):
optim.RAdam(None, lr=1e-2, betas=(1.0, 0.0))
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.RAdam(None, lr=1e-2, weight_decay=-1)
def test_radam_complex(self):
for foreach in (False, True):
self._test_complex_optimizer(
lambda param: optim.RAdam([param], lr=1e-1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.RAdam(
[param],
lr=1e-1,
weight_decay=0.01,
foreach=foreach,
)
)
self._test_complex_optimizer(
lambda param: optim.RAdam(
[param],
lr=1e-1,
weight_decay=0.01,
decoupled_weight_decay=True,
foreach=foreach,
)
)
def test_rmsprop(self):
for foreach in (False, True):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.RMSprop(
[weight, bias], lr=1e-2, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
centered=True,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
centered=True,
momentum=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
momentum=0.1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.RMSprop(
self._build_params_dict(weight, bias, lr=1e-3),
lr=1e-2,
momentum=0.1,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_complex_2d(lambda param: optim.RMSprop(param, foreach=foreach))
self._test_complex_2d(
lambda param: optim.RMSprop(param, centered=True, foreach=foreach)
)
self._test_complex_2d(
lambda param: optim.RMSprop(param, momentum=0.1, foreach=foreach)
)
self._test_complex_2d(
lambda param: optim.RMSprop(param, maximize=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.RMSprop([param], foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.RMSprop([param], centered=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.RMSprop([param], momentum=0.1, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.RMSprop([param], maximize=True, foreach=foreach)
)
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"):
optim.RMSprop(None, lr=1e-2, momentum=-1.0, foreach=foreach)
def test_asgd(self):
for foreach in (False, True):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.ASGD(
[weight, bias], lr=1e-3, t0=100, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.ASGD(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
t0=100,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.ASGD(
self._build_params_dict(weight, bias, lr=1e-2),
lr=1e-3,
weight_decay=1,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
# Ref: https://github.com/pytorch/pytorch/issues/84560
# self._test_complex_2d(optimizer)
self._test_complex_optimizer(
lambda params: optim.ASGD([params], foreach=foreach)
)
self._test_complex_optimizer(
lambda params: optim.ASGD([params], maximize=True, foreach=foreach)
)
self._test_complex_optimizer(
lambda params: optim.ASGD(
[params], maximize=True, weight_decay=0.9, foreach=foreach
)
)
self._test_complex_optimizer(
lambda params: optim.ASGD(
[params], maximize=False, weight_decay=0.9, foreach=foreach
)
)
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -0.5"):
optim.ASGD(None, lr=1e-2, weight_decay=-0.5, foreach=foreach)
@skipIfRocm
@skipIfTorchDynamo()
def test_rprop(self):
is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(
0
) == (8, 6)
for foreach in (False, True):
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Rprop(
[weight, bias], lr=2e-4, maximize=maximize, foreach=foreach
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
)
self._test_basic_cases(
lambda weight, bias, maximize, foreach: optim.Rprop(
self._build_params_dict(weight, bias, lr=1e-2),
lr=2e-4,
maximize=maximize,
foreach=foreach,
),
constructor_accepts_maximize=True,
constructor_accepts_foreach=True,
atol=4e-5 if is_cuda_sm86 else None,
rtol=3e-5 if is_cuda_sm86 else None,
)
self._test_complex_2d(lambda param: optim.Rprop(param, foreach=foreach))
self._test_complex_optimizer(
lambda param: optim.Rprop([param], lr=0.001, foreach=foreach)
)
self._test_complex_optimizer(
lambda param: optim.Rprop(
[param], lr=0.001, maximize=True, foreach=foreach
)
)
with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"):
optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5), foreach=foreach)
def test_lbfgs(self):
self._test_basic_cases(
lambda weight, bias: optim.LBFGS([weight, bias]), ignore_multidevice=True
)
self._test_basic_cases(
lambda weight, bias: optim.LBFGS(
[weight, bias], line_search_fn="strong_wolfe"
),
ignore_multidevice=True,
)
def test_lbfgs_returns_consistent_type(self):
params = [torch.randn(10, 5), torch.randn(10)]
opt1 = optim.LBFGS(params, 0.01, tolerance_grad=math.inf)
opt2 = optim.LBFGS(params, 0.01, tolerance_grad=-math.inf)
def closure():
return torch.tensor([10])
res1 = opt1.step(closure)
res2 = opt2.step(closure)
self.assertEqual(type(res1), type(res2))
def test_invalid_param_type(self):
self.assertRaisesRegex(
TypeError,
'params argument given to the optimizer should be an iterable of Tensors or dicts',
lambda: optim.LBFGS(Parameter(torch.randn(5, 5)))
)
def test_duplicate_params_in_one_param_group(self):
param = Parameter(torch.randn(1))
with self.assertWarnsOnceRegex(UserWarning, '.*a parameter group with duplicate parameters.*'):
optim.Adamax([param, param], lr=0.01)
def test_duplicate_params_across_param_groups(self):
param = Parameter(torch.randn(1))
self.assertRaisesRegex(
ValueError,
'some parameters appear in more than one parameter group',
lambda: optim.Adadelta([{'params': param}, {'params': param}])
)
def test_step_is_noop_when_params_have_no_grad(self):
params = [torch.randn(2, 3, requires_grad=False) for _ in range(2)]
old_params = [p.clone().detach() for p in params]
def closure():
return torch.tensor([1])
optimizer_list = [
optim.Adadelta,
optim.AdamW,
optim.Adam,
optim.RAdam,
optim.NAdam,
optim.Adagrad,
optim.Adamax,
optim.RMSprop,
optim.SGD,
optim.SparseAdam,
optim.ASGD,
optim.LBFGS
]
for optim_ctr in optimizer_list:
opt = optim_ctr(params, lr=0.1)
opt.step(closure)
self.assertEqual(old_params, params)
def test_step_is_noop_for_empty_grads(self):
optimizers = [
optim.Adadelta,
optim.AdamW,
optim.Adam,
optim.RAdam,
optim.NAdam,
optim.Adagrad,
optim.Adamax,
optim.RMSprop,
optim.SGD,
optim.SparseAdam,
optim.ASGD,
optim.LBFGS
]
param = torch.randn(5, 1, requires_grad=True)
old_param = param.clone().detach()
def closure():
return torch.tensor([1])
for optimizer in optimizers:
opt = optimizer([param], lr=1e-5)
param.grad = torch.zeros_like(param)
if optimizer is optim.SparseAdam:
# Intentionally construct a multidimensional empty v for the sparse grad
# Single dim v passes the test while multidim correctly repros the issue
# https://github.com/pytorch/pytorch/issues/82486
i = torch.empty(1, 0)
v = torch.empty(0, 1)
param.grad = torch.sparse_coo_tensor(i, v, (5, 1))
opt.step(closure)
self.assertEqual(old_param, param)
def test_fused_optimizer_does_not_step_if_foundinf(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is required.")
from torch.optim import adam, adamw
num_tensors = 5
for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)):
params, grads, exp_avgs, exp_avg_sqs = (
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(4))
prev_params = [t.clone().detach() for t in params]
max_exp_avg_sqs = [torch.ones((1,), device="cuda") for _ in range(num_tensors)] if amsgrad else []
state_steps = [torch.ones((), dtype=torch.float32, device="cuda") for _ in range(num_tensors)]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
functional_optim(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
foreach=False,
capturable=False,
fused=True,
amsgrad=amsgrad,
beta1=0.9,
beta2=0.99,
lr=1e-2,
weight_decay=0.0,
eps=1e-8,
maximize=False,
grad_scale=grad_scale,
found_inf=found_inf,
)
self.assertEqual(
state_steps,
[
torch.ones((), dtype=torch.float32, device="cuda")
for _ in range(num_tensors)
],
)
self.assertEqual(params, prev_params)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required.")
def test_fused_optimizer_load_state_dict(self):
# NOTE: This SIMULATES a fused/capturable optimizer with state moved to CPU, issue 103256
# How do we get there? Users typically create CUDA models on fused optimizers and then
# store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
# Since this is a unit test, it is more expedient to simulate what the state_dict
# would look like, which is basically CPU tensors with fused/capturable flag = True.
for optimC, kwarg in itertools.product((Adam, optim.AdamW), ("fused", "capturable")):
input = torch.tensor([0.1, 0.2], dtype=torch.float32, device="cpu")
optimizer = optimC([input])
optimizer.zero_grad()
input.grad = torch.rand_like(input)
optimizer.step()
optim_state_dict_cpu = deepcopy(optimizer.state_dict())
optim_state_dict_cpu["param_groups"][0][kwarg] = True
# load
input_cuda = input.clone().detach().to(device="cuda")
defaults = {kwarg: True}
optimizer_cuda = optimC([input_cuda], **defaults)
optimizer_cuda.load_state_dict(optim_state_dict_cpu)
optimizer_cuda.zero_grad()
input_cuda.grad = torch.rand_like(input_cuda)
optimizer_cuda.step()
@skipIfTorchDynamo()
def test_post_hook(self):
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.Tensor([1, 1])]
opt = SGD(params, lr=0.001)
data = 2
hook_handle = opt.register_step_post_hook(post_hook)
opt.step()
opt.step()
# check if pre hooks were registered
self.assertEqual(data, 6)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
opt.step()
self.assertEqual(data, 6)
@skipIfTorchDynamo()
def test_pre_hook(self):
def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data += 2
params = [torch.Tensor([1, 1])]
opt = SGD(params, lr=0.001)
data = 5
hook_handle = opt.register_step_pre_hook(pre_hook)
opt.step()
opt.step()
# check if pre hooks were registered
self.assertEqual(data, 9)
# remove handles, take step and verify that hook is no longer registered
hook_handle.remove()
opt.step()
self.assertEqual(data, 9)
@skipIfTorchDynamo()
def test_pre_and_post_hook(self):
def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(0)
def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(5)
def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(1)
def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
nonlocal data
data.append(2)
params = [torch.Tensor([1, 1])]
opt1 = SGD(params, lr=0.001)
opt2 = Adam(params, lr=0.01)
data = []
# register global hooks to both optimizers
global_pre_handle = register_optimizer_step_pre_hook(global_pre_hook)
global_post_handle = register_optimizer_step_post_hook(global_post_hook)
# register local hooks
first_pre_handle = opt1.register_step_pre_hook(local_pre_hook)
first_post_handle = opt1.register_step_post_hook(local_post_hook)
second_pre_handle = opt2.register_step_pre_hook(local_pre_hook)
second_post_handle = opt2.register_step_post_hook(local_post_hook)
opt1.step()
self.assertListEqual(data, [0, 1, 2, 5])
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5])
opt1.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
# remove all hooks
global_pre_handle.remove()
global_post_handle.remove()
first_pre_handle.remove()
first_post_handle.remove()
second_pre_handle.remove()
second_post_handle.remove()
opt1.step()
opt2.step()
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
def test_fused_optimizer_raises(self):
if not torch.cuda.is_available():
self.skipTest("Requires CUDA devices")
for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
with self.assertRaisesRegex(RuntimeError, "`fused` and `foreach` cannot be `True` together."):
optimizer_ctor([torch.empty((), device="cuda")], foreach=True, fused=True)
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)
@staticmethod
def _state_dict_pre_hook(optimizer: Optimizer) -> None:
optimizer.state["test"] = 1
@staticmethod
def _state_dict_post_hook(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
if "test" in state_dict["state"]:
state_dict["state"].pop("test")
state_dict["ran_state_dict_pre_hook"] = True
else:
state_dict["ran_state_dict_pre_hook"] = False
return state_dict
@staticmethod
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
state_dict["param_groups"][0]["lr"] = 0.002
@staticmethod
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
# The typical use case for returning a state dict is to drastically modify the state dict.
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
my_state_dict = deepcopy(state_dict)
my_state_dict["param_groups"][0]["lr"] = 0.003
return my_state_dict
@staticmethod
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
optimizer.state["ran_load_state_dict_post_hook"] = True
def test_state_dict_pre_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["state"]["test"], 1)
def test_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertEqual(state_dict["ran_state_dict_pre_hook"], False)
def test_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_state_dict_pre_hook(self._state_dict_pre_hook)
opt.register_state_dict_post_hook(self._state_dict_post_hook)
state_dict = opt.state_dict()
self.assertFalse("test" in state_dict["state"])
self.assertEqual(state_dict["ran_state_dict_pre_hook"], True)
def test_load_state_dict_pre_hook_and_prepend(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
state_dict = opt.state_dict()
# usually one would have a new opt instance here, but it's all the same here
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
opt.load_state_dict(state_dict)
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
opt.load_state_dict(state_dict)
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
def test_load_state_dict_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def test_load_state_dict_pre_post_hook(self):
param = torch.rand(2, 3, requires_grad=True)
param.grad = torch.rand(2, 3, requires_grad=True)
opt = SGD([param], lr=0.001)
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
opt.load_state_dict(opt.state_dict())
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
# Ignored is the list of values in `opt_differentiable_state`, we do this
# for `gradcheck` to correctly track the state tensors as function inputs
# because otherwise it can't unpack the values in the `opt_differentiable_state`
# dict
p = p.clone()
p.grad = grad
opt_differentiable_state = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in opt_differentiable_state.items()
}
opt = opt_class([p], **kwargs)
opt.state[p].update(opt_differentiable_state)
opt.step()
return (p,) + tuple(
v
for v in opt.state[p].values()
if isinstance(v, torch.Tensor) and v.requires_grad
)
@skipIfTorchDynamo("Differentiable optimizers not supported")
class TestDifferentiableOptimizer(TestCase):
def test_sgd(self):
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
state = {"momentum_buffer": mbuff}
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.SGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.Adam,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_rmsprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step"] = 0
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["momentum_buffer"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
# This can cause issues with large values and nan due to sqrt ops
state["grad_avg"] = 1e-2 * torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.RMSprop,
{
"lr": 0.9,
"maximize": True,
"momentum": 0.9,
"differentiable": True,
"centered": True,
"weight_decay": 0.1,
},
*state.values(),
),
)
def test_adadelta(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.Adadelta,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adagrad(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.Adagrad,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
def test_adamax(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.Adamax,
{"lr": 0.9, "weight_decay": 0.1, "differentiable": True},
*state.values(),
),
)
@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
"since this is only happening when differentiable is enabled, skipping for now")
def test_asgd(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` `eta` & `mu` are not continuous variables (even though we define them as floats)
# and so they shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64)
state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64)
state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.ASGD,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_rprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.Rprop,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
def test_adamw(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["max_exp_avg_sq"] = torch.rand(
10, requires_grad=True, dtype=torch.float64
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.AdamW,
{"lr": 0.9, "differentiable": True, "amsgrad": True},
*state.values(),
),
)
def test_nadam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.NAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.NAdam,
{"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
def test_radam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
# `step` is not a continuous variable (even though we define it as a float)
# and so it shouldn't require gradients.
state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64)
state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.RAdam,
{"lr": 0.9, "differentiable": True},
*state.values(),
),
)
gradcheck(
_diff_fn,
(
p,
grad,
state,
torch.optim.RAdam,
{"lr": 0.9, "weight_decay": 0.1, "decoupled_weight_decay": True, "differentiable": True},
*state.values(),
),
)
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_defaults_changed_to_foreach(self):
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
asgd, adamax, adadelta, adagrad)
multi_optims = ((optim.Adam, adam, "_multi_tensor_adam"),
(optim.AdamW, adamw, "_multi_tensor_adamw"),
(optim.NAdam, nadam, "_multi_tensor_nadam"),
(optim.SGD, sgd, "_multi_tensor_sgd"),
(optim.RAdam, radam, "_multi_tensor_radam"),
(optim.RMSprop, rmsprop, "_multi_tensor_rmsprop"),
(optim.Rprop, rprop, "_multi_tensor_rprop"),
(optim.ASGD, asgd, "_multi_tensor_asgd"),
(optim.Adamax, adamax, "_multi_tensor_adamax"),
(optim.Adadelta, adadelta, "_multi_tensor_adadelta"),
(optim.Adagrad, adagrad, "_multi_tensor_adagrad"),)
model = torch.nn.Linear(5, 5)
model.to(dtype=torch.float64, device="cuda")
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
for opt, mod, func in multi_optims:
defaults = {}
if opt == optim.SGD:
defaults["lr"] = 1e-2
optimizer = opt(model.parameters(), **defaults)
optimizer.zero_grad()
output = model(input)
loss = output.sum()
loss.backward()
with patch.object(mod, func) as mocked_foreach_impl:
optimizer.step()
self.assertTrue(mocked_foreach_impl.called)
if __name__ == "__main__":
print("These tests should be run through test/test_optim.py instead")