blob: 7d30f6d1f7aaf4ad33de1b2fdf9d7fad99b52c94 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
import torch.nn as nn
from torch.distributed.optim import _NamedOptimizer
def _run_model_training(model_optim_lists):
for _ in range(2):
x = torch.rand(5, 8)
for model_optim_list in model_optim_lists:
model = model_optim_list[0]
optim_list = model_optim_list[1]
y = model(x)
y.sum().backward()
for optim in optim_list:
optim.step()
class TestDummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
class NamedOptimizerTest(unittest.TestCase):
def _compare_state_dict_group(self, group, named_group, assert_equal=True):
for key, val in group.items():
if key != "params":
self.assertTrue(
key in named_group, f"{key} not in named optimizer state dict"
)
err_msg = (
f"{key} state not equal" if assert_equal else f"{key} state equal"
)
if isinstance(val, torch.Tensor):
fn = self.assertTrue if assert_equal else self.assertFalse
fn(torch.allclose(val, named_group[key]), err_msg)
else:
fn = self.assertEqual if assert_equal else self.assertNotEqual
fn(val, named_group[key], err_msg)
def _compare_param_groups(self, param_groups_1, param_groups_2):
self.assertTrue(isinstance(param_groups_1, list))
self.assertTrue(isinstance(param_groups_2, list))
for groups in zip(param_groups_1, param_groups_2):
self._compare_param_group(groups[0], groups[1])
def _compare_param_group(self, group_1, group_2):
self.assertTrue(isinstance(group_1, dict))
self.assertTrue(isinstance(group_2, dict))
for key, val in group_1.items():
self.assertTrue(key in group_2)
if key != "params":
self.assertEqual(val, group_2[key])
else:
for tensors in zip(val, group_2[key]):
self.assertTrue(torch.allclose(tensors[0], tensors[1]))
def test_state_dict(self):
"""Check that NamedOptimizer exposes the expected state dict
interface."""
m = TestDummyModel()
m_dup = TestDummyModel()
optim = torch.optim.SGD(
m.parameters(),
lr=1e-2,
momentum=0.9,
)
named_optim = _NamedOptimizer(
m_dup.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.9,
)
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
sd = optim.state_dict()
named_sd = named_optim.state_dict()
# Compare "state" in optim state dict
self._compare_state_dict_group(
sd["state"][0],
named_sd["state"]["net1.0.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
sd["state"][3],
named_sd["state"]["net2.0.bias"],
assert_equal=True,
)
self._compare_state_dict_group(
sd["state"][4],
named_sd["state"]["net3.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
sd["state"][7],
named_sd["state"]["net4.1.bias"],
assert_equal=True,
)
def test_state_dict_multi_param_group(self):
"""Check that NamedOptimizer exposes the expected state dict
interface when multiple param groups are specified."""
m = TestDummyModel()
m_dup = TestDummyModel()
optim_1 = torch.optim.SGD(
[
{"params": m.net1.parameters()},
{"params": m.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
optim_2 = torch.optim.Adam(
[
{"params": m.net2.parameters()},
{"params": m.net4.parameters(), "lr": 1e-5},
]
)
named_optim_1 = _NamedOptimizer(
m_dup.named_parameters(),
torch.optim.SGD,
[
{"params": m_dup.net1.parameters()},
{"params": m_dup.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
named_optim_2 = _NamedOptimizer(
m_dup.named_parameters(),
torch.optim.Adam,
[
{"params": m_dup.net2.parameters()},
{"params": m_dup.net4.parameters(), "lr": 1e-5},
],
)
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
_run_model_training(
[(m, [optim_1, optim_2]), (m_dup, [named_optim_1, named_optim_2])]
)
self._compare_param_groups(optim_1.param_groups, named_optim_1.param_groups)
self._compare_param_groups(optim_2.param_groups, named_optim_2.param_groups)
sd_1 = optim_1.state_dict()
sd_2 = optim_2.state_dict()
named_sd_1 = named_optim_1.state_dict()
named_sd_2 = named_optim_2.state_dict()
# Compare "state" in optim state dict
self._compare_state_dict_group(
sd_1["state"][0],
named_sd_1["state"]["net1.0.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
sd_2["state"][1],
named_sd_2["state"]["net2.0.bias"],
assert_equal=True,
)
self._compare_state_dict_group(
sd_1["state"][2],
named_sd_1["state"]["net3.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
sd_2["state"][3],
named_sd_2["state"]["net4.1.bias"],
assert_equal=True,
)
# Compare "param_groups" in optim state dict
self._compare_state_dict_group(
sd_1["param_groups"][0],
named_sd_1["param_groups"][0],
assert_equal=True,
)
self._compare_state_dict_group(
sd_2["param_groups"][1], named_sd_2["param_groups"][1], assert_equal=True
)
def test_load_state_dict(self):
"""Check that NamedOptimizer's load_state_dict works as expected."""
m = TestDummyModel()
named_optim_1 = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.9,
)
_run_model_training([(m, [named_optim_1])])
state_dict_to_load = named_optim_1.state_dict()
named_optim_2 = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.6,
)
_run_model_training([(m, [named_optim_2])])
state_dict_before_load = named_optim_2.state_dict()
# Compare "state" in optim state dict
self._compare_state_dict_group(
state_dict_to_load["state"]["net1.0.weight"],
state_dict_before_load["state"]["net1.0.weight"],
assert_equal=False,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net2.0.bias"],
state_dict_before_load["state"]["net2.0.bias"],
assert_equal=False,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net3.weight"],
state_dict_before_load["state"]["net3.weight"],
assert_equal=False,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net4.1.bias"],
state_dict_before_load["state"]["net4.1.bias"],
assert_equal=False,
)
named_optim_2.load_state_dict(state_dict_to_load)
state_dict_after_load = named_optim_2.state_dict()
# Compare "state" in optim state dict
self._compare_state_dict_group(
state_dict_to_load["state"]["net1.0.weight"],
state_dict_after_load["state"]["net1.0.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net2.0.bias"],
state_dict_after_load["state"]["net2.0.bias"],
assert_equal=True,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net3.weight"],
state_dict_after_load["state"]["net3.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net4.1.bias"],
state_dict_after_load["state"]["net4.1.bias"],
assert_equal=True,
)
def test_load_state_dict_conditional_training(self):
"""Check that NamedOptimizer load_state_dict works under conditional training case."""
m = TestDummyModel()
named_optim_1 = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
[
{"params": m.net1.parameters()},
{"params": m.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
_run_model_training([(m, [named_optim_1])])
state_dict_to_load = named_optim_1.state_dict()
named_optim_2 = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.6,
)
_run_model_training([(m, [named_optim_2])])
named_optim_2.load_state_dict(state_dict_to_load)
state_dict_after_load = named_optim_2.state_dict()
# Compare "state" in optim state dict
self._compare_state_dict_group(
state_dict_to_load["state"]["net1.0.weight"],
state_dict_after_load["state"]["net1.0.weight"],
assert_equal=True,
)
self._compare_state_dict_group(
state_dict_to_load["state"]["net3.weight"],
state_dict_after_load["state"]["net3.weight"],
assert_equal=True,
)
def test_load_state_dict_error(self):
m = TestDummyModel()
named_optim_1 = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.9,
)
_run_model_training([(m, [named_optim_1])])
state_dict_to_load = named_optim_1.state_dict()
named_optim_2 = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
lr=1e-2,
momentum=0.6,
)
err_msg = (
"Expects the optim to be initialized before load but found not initialized"
)
with self.assertRaisesRegex(ValueError, err_msg):
named_optim_2.load_state_dict(state_dict_to_load)
def test_add_param_group(self):
m = TestDummyModel()
m_dup = TestDummyModel()
optim = torch.optim.SGD(
[
{"params": m.net1.parameters()},
{"params": m.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
named_optim = _NamedOptimizer(
m_dup.named_parameters(),
torch.optim.SGD,
[
{"params": m_dup.net1.parameters()},
{"params": m_dup.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
optim.add_param_group({"params": m.net2.parameters(), "lr": 1e-5})
named_optim.add_param_group({"params": m_dup.net2.parameters(), "lr": 1e-5})
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
optim.add_param_group({"params": m.net4[1].weight, "lr": 1e-3})
named_optim.add_param_group({"params": m_dup.net4[1].weight, "lr": 1e-3})
_run_model_training([(m, [optim]), (m_dup, [named_optim])])
self._compare_param_groups(optim.param_groups, named_optim.param_groups)
def test_add_param_group_error(self):
m = TestDummyModel()
named_optim = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
[
{"params": m.net1.parameters()},
{"params": m.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
err_msg = "some parameters are not in the module"
with self.assertRaisesRegex(ValueError, err_msg):
named_optim.add_param_group({"params": [torch.ones(8, 1)], "lr": 1e-5})
def test_init_state(self):
m = TestDummyModel()
named_optim = _NamedOptimizer(
m.named_parameters(),
torch.optim.SGD,
[
{"params": m.net1.parameters()},
{"params": m.net3.parameters(), "lr": 1e-3},
],
lr=1e-2,
momentum=0.9,
)
named_sd = named_optim.state_dict()
self.assertTrue(m.net1[0].weight.grad is None)
self.assertTrue(len(named_sd["state"]) == 0)
named_optim.init_state()
named_sd = named_optim.state_dict()
self.assertTrue(m.net1[0].weight.grad is not None)
self.assertTrue("momentum_buffer" in named_sd["state"]["net1.0.weight"])
self.assertFalse(
torch.all(named_sd["state"]["net1.0.weight"]["momentum_buffer"]).item()
)
self.assertFalse(
torch.all(named_sd["state"]["net1.0.bias"]["momentum_buffer"]).item()
)
self.assertTrue(m.net3.bias.grad is not None)
self.assertTrue("momentum_buffer" in named_sd["state"]["net3.bias"])
self.assertFalse(
torch.all(named_sd["state"]["net3.bias"]["momentum_buffer"]).item()
)
self.assertFalse(
torch.all(named_sd["state"]["net3.weight"]["momentum_buffer"]).item()
)