blob: 520c00927d71be4de2fc6ee75db554e35c23b3b2 [file] [log] [blame]
import copy
import os
import sys
import tempfile
import threading
import time
import unittest
from datetime import timedelta
from itertools import product
from sys import platform
import torch
import torch.distributed as c10d
if not c10d.is_available():
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
import torch.nn.functional as F
import torch.testing._internal.common_utils as common
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
run_tests,
TEST_WITH_TSAN,
)
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
if platform == "darwin":
LOOPBACK = "lo0"
else:
LOOPBACK = "lo"
torch.backends.cuda.matmul.allow_tf32 = False
def gpus_for_rank(world_size):
"""Multigpu tests are designed to simulate the multi nodes with multi
GPUs on each node. Nccl backend requires equal #GPUs in each process.
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]
)
return gpus_for_rank
class AbstractTimeoutTest(object):
def _test_store_timeout(self, backend, init_method, c2p):
try:
c10d.distributed_c10d.init_process_group(
backend=backend,
init_method=init_method,
world_size=1,
rank=0,
timeout=timedelta(seconds=1),
)
default_store = c10d.distributed_c10d._get_default_store()
tik = time.time()
with self.assertRaisesRegex(RuntimeError, "Timeout"):
default_store.get("nonexistent key")
tok = time.time()
c10d.destroy_process_group()
c2p.append(float(tok - tik))
except RuntimeError as e:
# catch "Address already in use" error and report it to the main
# thread
c2p.append(e)
def _init_methods(self):
f = tempfile.NamedTemporaryFile(delete=False)
if sys.platform == "win32":
yield "file:///%s" % f.name.replace("\\", "/")
f.close()
else:
yield "file://%s" % f.name
f.close()
yield "tcp://127.0.0.1:%d" % common.find_free_port()
def _test_default_store_timeout(self, backend):
for init_method in self._init_methods():
c2p = []
t = threading.Thread(
target=self._test_store_timeout, args=(backend, init_method, c2p)
)
t.daemon = True
t.start()
t.join(5)
self.assertEqual(1, len(c2p))
if isinstance(c2p[0], float):
# waiting time should be 1s, use 3s to rule out false alarm
self.assertGreater(3, c2p[0])
elif isinstance(c2p[0], RuntimeError):
# let @retry_on_connect_failures handle the error
raise c2p[0]
else:
raise RuntimeError("Unexpected type {}".format(type(c2p[0])))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 50, bias=False)
self.fc3 = nn.Linear(50, 4, bias=False)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return F.softmax(x, dim=1)
class DoubleGpuNet(nn.Module):
def __init__(self, gpus):
super(DoubleGpuNet, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[1])
self.relu = nn.ReLU()
self.no_grad_param = nn.Parameter(
torch.tensor([2, 2]).long(), requires_grad=False
).to(gpus[0])
def forward(self, x):
dev0 = self.fc1.weight.device
dev1 = self.fc2.weight.device
x = self.relu(self.fc1(x.to(dev0)))
x = self.relu(self.fc2(x.to(dev1)))
x = self.fc3(x)
return F.softmax(x, dim=1).to(dev0)
class QuadraGpuNet(nn.Module):
def __init__(self, gpus):
super(QuadraGpuNet, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False).to(gpus[0])
self.fc2 = nn.Linear(10, 50, bias=False).to(gpus[1])
self.fc3 = nn.Linear(50, 4, bias=False).to(gpus[2])
self.fc4 = nn.Linear(4, 4, bias=False).to(gpus[3])
self.relu = nn.ReLU()
self.no_grad_param = nn.Parameter(
torch.tensor([2, 2]).long(), requires_grad=False
).to(gpus[0])
def forward(self, x):
dev0 = self.fc1.weight.device
dev1 = self.fc2.weight.device
dev2 = self.fc3.weight.device
dev3 = self.fc4.weight.device
x = self.relu(self.fc1(x.to(dev0)))
x = self.relu(self.fc2(x.to(dev1)))
x = self.relu(self.fc3(x.to(dev2)))
x = self.fc4(x.to(dev3))
return F.softmax(x, dim=1).to(dev0)
class ConvNet(nn.Module):
def __init__(self, gpus, layouts, dtypes):
super(ConvNet, self).__init__()
self.dtypes = dtypes
if isinstance(gpus, list):
self.layer_gpus = gpus
else:
gpus = [gpus] * 4
self.conv0 = torch.nn.Conv2d(8, 16, (2, 2)).to(
device=gpus[0], memory_format=layouts[0], dtype=dtypes[0]
)
self.conv1 = torch.nn.Conv2d(16, 32, (2, 2)).to(
device=gpus[1], memory_format=layouts[1], dtype=dtypes[1]
)
self.conv2 = torch.nn.Conv2d(32, 16, (2, 2)).to(
device=gpus[2], memory_format=layouts[2], dtype=dtypes[2]
)
self.conv3 = torch.nn.Conv2d(16, 8, (2, 2)).to(
device=gpus[3], memory_format=layouts[3], dtype=dtypes[3]
)
def forward(self, x):
x = x.to(self.dtypes[0])
# Could say
# x = self.conv0(x).to(device=self.conv1.weight.device, dtype=self.dtypes[1])
# etc. But I don't want to appeal to the weights' devices directly, because part of this test's purpose
# is to verify weights are where expected if the model gets replicated.
gpus = self.layer_gpus if hasattr(self, "layer_gpus") else [x.device] * 4
x = self.conv0(x).to(device=gpus[1], dtype=self.dtypes[1])
x = self.conv1(x).to(device=gpus[2], dtype=self.dtypes[2])
x = self.conv2(x).to(device=gpus[3], dtype=self.dtypes[3])
return self.conv3(x)
class Task(nn.Module):
def __init__(self):
super().__init__()
self.p = nn.Parameter(torch.ones(2, 2))
def forward(self, x):
return self.p + x
class ModuleForDdpCommHook(nn.Module):
def __init__(self):
super().__init__()
self.t0 = Task()
def forward(self, x, rank):
return self.t0(x + rank)
class SparseGradientModule(nn.Module):
def __init__(self):
super(SparseGradientModule, self).__init__()
self.embedding = nn.EmbeddingBag(10, 10, sparse=True)
def forward(self, x):
return F.softmax(self.embedding(x), dim=1)
class AbstractDistributedDataParallelTest(object):
def tearDown(self):
# DistributedDataParallel test doesn't seem to call FileStore destructor
# TODO: investigate this test and the test is known to have issues
# Use this hack to remove files for that test
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self):
return 2
def _prepare_single_device_module(
self,
process_group,
devices,
device_ids,
global_batch_size,
gradient_as_bucket_view=False,
):
model = Net()
device = devices[0] if devices else torch.device("cuda:%d" % self.rank)
ddp_model = DistributedDataParallel(
copy.deepcopy(model).to(device),
device_ids=device_ids,
process_group=process_group,
bucket_cap_mb=0.001,
gradient_as_bucket_view=gradient_as_bucket_view,
)
model.to(device)
input = torch.randn(global_batch_size, 2).to(device)
target = torch.randn(global_batch_size, 4).to(device)
return model, ddp_model, input, target
def _prepare_multi_device_module(
self,
process_group,
devices,
device_ids,
global_batch_size,
gradient_as_bucket_view=False,
):
self.assertTrue(
len(devices) == 2 or len(devices) == 4,
"unexpected devices for ddp tests {}".format(devices),
)
if len(devices) == 2:
model = DoubleGpuNet(devices)
elif len(devices) == 4:
model = QuadraGpuNet(devices)
ddp_model = DistributedDataParallel(
copy.deepcopy(model),
device_ids=device_ids,
process_group=process_group,
bucket_cap_mb=0.001,
gradient_as_bucket_view=gradient_as_bucket_view,
)
input = torch.randn(global_batch_size, 2).cuda(devices[0])
target = torch.randn(global_batch_size, 4)
return model, ddp_model, input, target
def _test_ddp_with_process_group(
self,
process_group,
devices,
device_ids,
multi_device=False,
gradient_as_bucket_view=False,
):
"""
Note: we pass down `device_ids` all the way to DistributedDataParallel
as part of the test. Below you find tests that either use a list of
integers, a list of `torch.Device` instances, or an empty list.
The `devices` argument is used to control placement of the model and
must always be specified as list of `torch.Device` instances.
"""
local_batch_size = 1 if devices is None else len(devices)
global_batch_size = self.world_size * local_batch_size
if multi_device:
model, ddp_model, input, target = self._prepare_multi_device_module(
process_group,
devices,
device_ids,
global_batch_size,
gradient_as_bucket_view,
)
ddp_logging_data = ddp_model._get_ddp_logging_data()
self.assertTrue(ddp_logging_data.get("is_multi_device_module"))
else:
model, ddp_model, input, target = self._prepare_single_device_module(
process_group,
devices,
device_ids,
global_batch_size,
gradient_as_bucket_view,
)
ddp_logging_data = ddp_model._get_ddp_logging_data()
self.assertFalse(ddp_logging_data.get("is_multi_device_module"))
def step_model(model, input, target):
model.train()
output = model(input)
loss = F.mse_loss(output, target.to(output.device))
loss.backward()
def update_parameters(model):
for param in model.parameters():
with torch.no_grad():
param -= param.grad
param.grad = None
# check two model parameters over 2 iterations
for iteration in range(2):
# single cpu/gpu training
step_model(model, input, target)
# DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
step_model(
ddp_model,
input[
self.rank * local_batch_size: (self.rank + 1) * local_batch_size
],
target[
self.rank * local_batch_size: (self.rank + 1) * local_batch_size
],
)
# Update weights and run a second iteration to shake out errors
update_parameters(model)
update_parameters(ddp_model)
self.assertEqual(
len(list(model.parameters())), len(list(ddp_model.parameters()))
)
for i, j in zip(model.parameters(), ddp_model.parameters()):
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
# Shuffle the input so that DDP input is different
torch.manual_seed(1337 + iteration)
input = input[torch.randperm(global_batch_size)]
def _gpu_model_with_ddp_comm_hook(
self, process_group, hook=None, gradient_as_bucket_view=False, state=None
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
process_group=process_group,
gradient_as_bucket_view=gradient_as_bucket_view,
)
# Register a DDP communication hook if any.
if hook is not None:
gpu_model.register_comm_hook(state, hook)
return gpu_model
def _gpu_model_with_builtin_ddp_comm_hook(
self, process_group, hook=None, gradient_as_bucket_view=False
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
process_group=process_group,
gradient_as_bucket_view=gradient_as_bucket_view,
)
# Register a built-in DDP communication hook if defined
if hook is not None:
gpu_model._register_builtin_comm_hook(hook)
return gpu_model
def _run_and_verify_hook(self, model, input, expected_grad):
# Run forward
output = model(input, self.rank)
# Run backward
output.mean().backward()
[self.assertEqual(p.grad, expected_grad) for p in model.parameters()]
def _simple_hook(
self, state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
fut = torch.futures.Future()
fut.set_result([torch.ones_like(bucket.get_tensor())])
def fut_then(fut):
# Add ones to fut's result.
return [t + torch.ones_like(t) for t in fut.value()]
return fut.then(fut_then)
@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class DistributedDataParallelTest(AbstractDistributedDataParallelTest, MultiProcessTestCase):
def setUp(self):
super(DistributedDataParallelTest, self).setUp()
if sys.platform == "win32":
self._spawn_processes()
else:
self._fork_processes()
def test_invalid_powerSGD_state(self):
for start_powerSGD_iter, use_error_feedback, warm_start in product(
[0, 1], [True, False], [True, False]
):
if not use_error_feedback and not warm_start:
continue
with self.assertRaisesRegex(
ValueError,
"Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
"because PowerSGD can only be applied after the first two iterations in DDP.",
):
state = powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=start_powerSGD_iter,
use_error_feedback=use_error_feedback,
warm_start=warm_start,
)
class ComputeBucketAssignmentTest(TestCase):
def test_single_limit_single_dtype(self):
tensors = [
torch.empty([100], dtype=torch.float),
torch.empty([200], dtype=torch.float),
torch.empty([100], dtype=torch.float),
torch.empty([50], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0], [1], [2], [3]], result)
def test_single_limit_multi_dtype(self):
tensors = [
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [400])
self.assertEqual([[0, 2], [1, 3], [4], [5]], result)
def test_multi_limit_single_dtype(self):
tensors = [
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
torch.empty([10], dtype=torch.float),
]
result = dist._compute_bucket_assignment_by_size(tensors, [40, 80])
self.assertEqual([[0], [1, 2], [3]], result)
def test_multi_limit_multi_dtype(self):
tensors = [
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
torch.empty([50], dtype=torch.float),
torch.empty([25], dtype=torch.double),
]
result = dist._compute_bucket_assignment_by_size(tensors, [200, 400])
self.assertEqual([[0], [1], [2, 4], [3, 5]], result)
class AbstractCommTest(object):
@property
def op_timeout_sec(self):
return 1
@property
def world_size(self):
return 2
def _verify_sequence_number_across_pg(self, pg, verify_pg):
seq_num = pg._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
# We use a separate pg to verify the sequence numbers, otherwise these
# collectives will themselves increment the sequence number.
dist.all_gather_object(obj_list, seq_num, group=verify_pg)
self.assertEqual(len(set(obj_list)), 1)
return obj_list[0]
def _test_sequence_num_incremented(self, process_group, ranks):
# verify initial sequence numbers. Use a distinct process group for
# verification to keep counts as expected with respect to process_group.
verify_pg = dist.new_group(
ranks=ranks,
backend="gloo",
)
assert dist.get_world_size(process_group) == dist.get_world_size(verify_pg)
initial_num = (
self._verify_sequence_number_across_pg(
pg=process_group, verify_pg=verify_pg
)
if not c10d.distributed_c10d._rank_not_in_group(process_group)
else -1
)
# Verify sequence numbers are appropriately incremented
for i in range(10):
t = torch.ones(1, device=torch.cuda.current_device())
dist.all_reduce(t, group=process_group)
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = self._verify_sequence_number_across_pg(
pg=process_group,
verify_pg=verify_pg,
)
self.assertEqual(initial_num + i + 1, seq_num)
if dist.get_world_size(process_group) > 2:
# Test when certain ranks don't call collectives
if dist.get_rank(process_group) not in [0, 2]:
dist.all_reduce(t, group=process_group, async_op=True)
# Now ranks 0 and 2 should be lagging by 1.
if not c10d.distributed_c10d._rank_not_in_group(process_group):
seq_num = process_group._get_sequence_number_for_group()
rank = dist.get_rank(process_group)
obj_list = [None for _ in range(dist.get_world_size(verify_pg))]
dist.all_gather_object(obj_list, (rank, seq_num), group=verify_pg)
rank_to_seq_num = {rank: num for (rank, num) in obj_list}
self.assertEqual(len(set(rank_to_seq_num.values())), 2)
self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2])
expected_same = {
rank_to_seq_num[i]
for i in rank_to_seq_num.keys()
if i not in [0, 2]
}
self.assertEqual(len(expected_same), 1)
self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1])
def _test_sequence_num_incremented_default_group(self, backend_name):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
world_size=self.world_size,
rank=self.rank,
store=store,
)
self._test_sequence_num_incremented(
c10d.distributed_c10d._get_default_group(),
ranks=list(i for i in range(dist.get_world_size())),
)
def _test_sequence_num_incremented_subgroup(self, backend_name):
torch.cuda.set_device(self.rank)
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend_name,
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup_ranks = [0, 1, 2]
subgroup = dist.new_group(subgroup_ranks)
self._test_sequence_num_incremented(subgroup, subgroup_ranks)
def _test_sequence_num_set_default_pg(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
default_pg = c10d.distributed_c10d._get_default_group()
seq_num = default_pg._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, seq_num)
self.assertEqual(len(set(obj_list)), 1)
def _test_sequence_num_set_new_group(self, backend):
store = c10d.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
subgroup = dist.new_group([0, 1])
if not c10d.distributed_c10d._rank_not_in_group(subgroup):
subgroup_seq = subgroup._get_sequence_number_for_group()
obj_list = [None for _ in range(dist.get_world_size(subgroup))]
dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
self.assertEqual(len(set(obj_list)), 1)
@unittest.skipIf(
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
super(CommTest, self).setUp()
if sys.platform == "win32":
self._spawn_processes()
else:
self._fork_processes()
def tearDown(self):
super(CommTest, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def test_distributed_debug_mode(self):
# Default should be off
default_debug_mode = dist._get_debug_mode()
self.assertEqual(default_debug_mode, dist._DistributedDebugLevel.OFF)
mapping = {
"OFF": dist._DistributedDebugLevel.OFF,
"INFO": dist._DistributedDebugLevel.INFO,
"DETAIL": dist._DistributedDebugLevel.DETAIL,
}
invalid_debug_modes = ["foo", 0, 1, -1]
for mode in mapping.keys():
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
set_debug_mode = dist._get_debug_mode()
self.assertEqual(
set_debug_mode,
mapping[mode],
f"Expected {mode} to map to {mapping[mode]} but got {set_debug_mode}",
)
for mode in invalid_debug_modes:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode)
with self.assertRaisesRegex(RuntimeError, "to be one of"):
dist._get_debug_mode()
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_distributed must not have initialized CUDA context on main process"
run_tests()