blob: 4a1f1cddc593386fd087c35a585d7c6a347fc0b9 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import os
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from torch.distributed._composable.replicate import replicate
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2, 2)
self.fc2 = nn.Linear(2, 2)
self.fc3 = nn.Linear(2, 2)
def forward(self, x):
return self.fc3(self.fc2(self.fc1(x)))
class ReplicateStateDictTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _check_state_dict_parity(self, sd_1, sd_2):
for k1, k2 in zip(sd_1.keys(), sd_2.keys()):
self.assertEqual(k1, k2)
for v1, v2 in zip(sd_1.values(), sd_2.values()):
self.assertEqual(v1, v2)
def test_replicate_single_module_save_load(self):
"""
Tests that replicate() on a single module state_dict
matches local module state_dict.
"""
model = Net()
replicate_model = replicate(deepcopy(model))
local_sd = model.state_dict()
ddp_sd = replicate_model.state_dict()
self._check_state_dict_parity(local_sd, ddp_sd)
def test_replicate_non_root_multiple_save_load(self):
"""
Tests tha replicate() on multiple submodules matches
local module state_dict.
"""
model = Net()
replicate_model = deepcopy(model)
replicate(replicate_model.fc1)
replicate(replicate_model.fc2)
replicate(replicate_model.fc3)
local_sd = model.state_dict()
ddp_sd = replicate_model.state_dict()
self._check_state_dict_parity(local_sd, ddp_sd)
class ReplicateTest(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 2
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _compare_module(self, mod, replicate_mod):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
local_batch_size = 1
global_batch_size = self.world_size * local_batch_size
input = torch.randn(global_batch_size, 2)
target = torch.randn(global_batch_size, 2)
def step_model(model, input, target):
model.train()
output = model(input)
loss = F.mse_loss(output, target.to(output.device))
loss.backward()
for param in model.parameters():
with torch.no_grad():
param -= param.grad
param.grad = None
for iteration in range(2):
step_model(mod, input, target)
step_model(
replicate_mod,
input[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
],
target[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
],
)
self.assertEqual(
len(list(mod.parameters())),
len(list(replicate_mod.parameters())),
)
for i, j in zip(mod.parameters(), replicate_mod.parameters()):
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
# Shuffle the input so that DDP input is different
torch.manual_seed(iteration)
input = input[torch.randperm(global_batch_size)]
def test_replicate_single_module(self):
model = Net()
replicate_model = replicate(deepcopy(model))
self._compare_module(model, replicate_model)
@skip_if_lt_x_gpu(2)
def test_replicate_move_args_kwargs_to_device(self):
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(2, 2)
def forward(self, inp, *, kwarg=None):
if kwarg is not None:
inp = inp @ kwarg
return self.a(inp)
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
torch.cuda.set_device(self.rank)
model = MyNet().cuda()
replicate(model, device_id=torch.cuda.current_device())
# CPU input ensures replicate can move arg and kwargs to device.
a, b = torch.randn(2, 2), torch.randn(2, 2)
model(a, kwarg=b).sum().backward()
@skip_if_lt_x_gpu(2)
def test_replicate_ignore_module(self):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
torch.cuda.set_device(self.rank)
# Seed ensures diff input and thus different local grads across ranks.
torch.manual_seed(self.rank)
torch.cuda.manual_seed(self.rank)
model = Net().cuda()
replicate(model, ignored_modules=[model.fc1])
# CPU input ensures that replicate can move input to GPU as DDP does.
inp = torch.randn(5, 2, device="cuda") * (self.rank + 1)
out = model(inp) * 10
out.sum().backward()
# FC1 grads should not be synchronized, FC2 and 3 should be.
fc1_grad = model.fc1.weight.grad
tensor_list = [torch.zeros_like(fc1_grad) for _ in range(dist.get_world_size())]
dist.all_gather(tensor_list, fc1_grad)
grad, rest = tensor_list[0], tensor_list[1:]
for g in rest:
self.assertNotEqual(grad, g)
for dp_grad in [model.fc2.weight.grad, model.fc3.weight.grad]:
tensor_list = [
torch.zeros_like(dp_grad) for _ in range(dist.get_world_size())
]
dist.all_gather(tensor_list, dp_grad)
grad, rest = tensor_list[0], tensor_list[1:]
for g in rest:
self.assertEqual(grad, g)
def test_replicate_multi_module(self):
model = Net()
replicate_model = deepcopy(model)
replicate(replicate_model.fc1)
replicate(replicate_model.fc2)
replicate(replicate_model.fc3)
self._compare_module(model, replicate_model)
def test_replicate_with_kwargs(self):
model = Net()
replicate_model = replicate(
deepcopy(model), bucket_cap_mb=1, gradient_as_bucket_view=True
)
self._compare_module(model, replicate_model)
@skip_if_lt_x_gpu(2)
def test_replicate_device_id(self):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
model = Net()
replicate(model, device_id=torch.device("cpu"))
# DDP instance is attached in first pre forward
model(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model)._ddp_weakref()
# Should be None for CPU training
self.assertEqual(None, replicate_ddp_weakref.device_ids)
model.cuda()
model_cuda = deepcopy(model)
replicate(model_cuda, device_id=torch.device(torch.cuda.current_device()))
# DDP instance is attached in first pre forward
model_cuda(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
self.assertEqual([0], replicate_ddp_weakref.device_ids)
# Pass in int as device_id
model_cuda = deepcopy(model_cuda)
replicate(model_cuda, device_id=int(torch.cuda.current_device()))
# DDP instance is attached in first pre forward
model_cuda(torch.randn(2, 2))
replicate_ddp_weakref = replicate.state(model_cuda)._ddp_weakref()
self.assertEqual([0], replicate_ddp_weakref.device_ids)
def test_replicate_wrong_device_id_type(self):
dist.init_process_group(
backend="gloo",
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
model = Net()
with self.assertRaisesRegex(
RuntimeError, "Expected device_id to be int or torch.device"
):
replicate(model, device_id=[torch.device("cpu")])
if __name__ == "__main__":
run_tests()