| import copy |
| import math |
| import multiprocessing |
| import os |
| import sys |
| import tempfile |
| import time |
| import unittest |
| from datetime import timedelta |
| |
| from functools import wraps |
| from collections import namedtuple |
| |
| import torch |
| import common_utils as common |
| from torch import nn |
| import torch.nn.functional as F |
| import torch.distributed as c10d |
| from torch.nn.parallel import DistributedDataParallel |
| |
| from common_utils import TestCase, load_tests, run_tests |
| |
| # load_tests from common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| if not c10d.is_available(): |
| print('c10d not available, skipping tests') |
| sys.exit(0) |
| |
| |
| TIMEOUT_DEFAULT = 15 |
| TIMEOUT_OVERRIDE = {} |
| |
| TestSkip = namedtuple('TestSkip', 'exit_code, message') |
| |
| TEST_SKIPS = { |
| "multi-gpu": TestSkip(75, "Need at least 2 CUDA devices"), |
| "nccl": TestSkip(76, "c10d not compiled with NCCL support"), |
| "known_issues": TestSkip(77, "Test skipped due to known issues") |
| } |
| |
| |
| def skip_if_not_multigpu(func): |
| """Multi-GPU tests requires at least 2 GPUS. Skip if this is not met.""" |
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| if torch.cuda.is_available() and torch.cuda.device_count() >= 2: |
| return func(*args, **kwargs) |
| sys.exit(TEST_SKIPS['multi-gpu'].exit_code) |
| |
| return wrapper |
| |
| |
| def skip_if_not_nccl(func): |
| """Skips a test if NCCL is not available (for c10d).""" |
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| if hasattr(c10d, "ProcessGroupNCCL"): |
| return func(*args, **kwargs) |
| sys.exit(TEST_SKIPS['nccl'].exit_code) |
| |
| return wrapper |
| |
| |
| def skip_for_known_issues(func): |
| """Skips a test due to known issues (for c10d).""" |
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| sys.exit(TEST_SKIPS['known_issues'].exit_code) |
| |
| return wrapper |
| |
| |
| def get_timeout(test_id): |
| return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT) |
| |
| |
| def gpus_for_rank(world_size): |
| """Multigpu tests are designed to simulate the multi nodes with multi |
| GPUs on each node. Nccl backend requires equal #GPUs in each process. |
| On a single node, all visible GPUs are evenly |
| divided to subsets, each process only uses a subset. |
| """ |
| visible_devices = list(range(torch.cuda.device_count())) |
| gpus_per_process = torch.cuda.device_count() // world_size |
| gpus_for_rank = [] |
| for rank in range(world_size): |
| gpus_for_rank.append(visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]) |
| return gpus_for_rank |
| |
| |
| def simple_reduce_tests(rank, world_size): |
| return [ |
| ( |
| c10d.ReduceOp.SUM, |
| torch.Tensor([rank + 1.0]), |
| torch.Tensor([float(world_size * (world_size + 1) / 2)]), |
| ), |
| ( |
| c10d.ReduceOp.PRODUCT, |
| torch.Tensor([rank + 1.0]), |
| torch.Tensor([float(math.factorial(world_size))]), |
| ), |
| ( |
| c10d.ReduceOp.MIN, |
| torch.Tensor([rank + 1.0]), |
| torch.Tensor([1.0]), |
| ), |
| ( |
| c10d.ReduceOp.MAX, |
| torch.Tensor([rank + 1.0]), |
| torch.Tensor([world_size]), |
| ), |
| ] |
| |
| |
| class StoreTestBase(object): |
| def _create_store(self, i): |
| raise RuntimeError("not implemented") |
| |
| def _test_set_get(self, fs): |
| fs.set("key0", "value0") |
| fs.set("key1", "value1") |
| fs.set("key2", "value2") |
| self.assertEqual(b"value0", fs.get("key0")) |
| self.assertEqual(b"value1", fs.get("key1")) |
| self.assertEqual(b"value2", fs.get("key2")) |
| |
| def test_set_get(self): |
| self._test_set_get(self._create_store()) |
| |
| |
| class FileStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| self.file = tempfile.NamedTemporaryFile() |
| |
| def tearDown(self): |
| self.file.close() |
| |
| def _create_store(self): |
| store = c10d.FileStore(self.file.name) |
| store.set_timeout(timedelta(seconds=300)) |
| return store |
| |
| |
| class PrefixFileStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| self.file = tempfile.NamedTemporaryFile() |
| self.filestore = c10d.FileStore(self.file.name) |
| self.prefix = "test_prefix" |
| self.filestore.set_timeout(timedelta(seconds=300)) |
| |
| def tearDown(self): |
| self.file.close() |
| |
| def _create_store(self): |
| return c10d.PrefixStore(self.prefix, self.filestore) |
| |
| |
| def create_tcp_store(addr): |
| """ |
| Creates a TCP store. Retries if the chosen port is already in use. |
| """ |
| while True: |
| try: |
| port = common.find_free_port() |
| return c10d.TCPStore(addr, port, True) |
| except RuntimeError as error: |
| if str(error) == "Address already in use": |
| continue |
| raise |
| |
| |
| class TCPStoreTest(TestCase, StoreTestBase): |
| def _create_store(self): |
| store = create_tcp_store('localhost') |
| store.set_timeout(timedelta(seconds=300)) |
| return store |
| |
| def test_address_already_in_use(self): |
| with self.assertRaisesRegex(RuntimeError, "^Address already in use$"): |
| addr = 'localhost' |
| port = common.find_free_port() |
| |
| # Use noqa to silence flake8. |
| # Need to store in an unused variable here to ensure the first |
| # object is not destroyed before the second object is created. |
| store1 = c10d.TCPStore(addr, port, True) # noqa: F841 |
| store2 = c10d.TCPStore(addr, port, True) # noqa: F841 |
| |
| |
| class PrefixTCPStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| self.tcpstore = create_tcp_store('localhost') |
| self.prefix = "test_prefix" |
| self.tcpstore.set_timeout(timedelta(seconds=300)) |
| |
| def _create_store(self): |
| return c10d.PrefixStore(self.prefix, self.tcpstore) |
| |
| |
| class RendezvousTest(TestCase): |
| def test_unknown_handler(self): |
| with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): |
| c10d.rendezvous('invalid://') |
| |
| |
| class RendezvousEnvTest(TestCase): |
| def test_common_errors(self): |
| vars = { |
| "WORLD_SIZE": "2", |
| "RANK": "0", |
| "MASTER_ADDR": "127.0.0.1", |
| "MASTER_PORT": common.find_free_port(), |
| } |
| |
| class Env(object): |
| def __init__(self, vars): |
| self.vars = vars |
| |
| def __enter__(self): |
| for key, value in self.vars.items(): |
| os.environ[key] = str(value) |
| |
| def __exit__(self, type, value, traceback): |
| for key in self.vars.keys(): |
| del os.environ[key] |
| |
| def without(d, key): |
| d = d.copy() |
| d.pop(key) |
| return d |
| |
| with Env(without(vars, 'WORLD_SIZE')): |
| with self.assertRaisesRegex(ValueError, 'WORLD_SIZE expected'): |
| gen = c10d.rendezvous('env://') |
| next(gen) |
| with Env(without(vars, 'RANK')): |
| with self.assertRaisesRegex(ValueError, 'RANK expected'): |
| gen = c10d.rendezvous('env://') |
| next(gen) |
| with Env(without(vars, 'MASTER_ADDR')): |
| with self.assertRaisesRegex(ValueError, 'MASTER_ADDR expected'): |
| gen = c10d.rendezvous('env://') |
| next(gen) |
| with Env(without(vars, 'MASTER_PORT')): |
| with self.assertRaisesRegex(ValueError, 'MASTER_PORT expected'): |
| gen = c10d.rendezvous('env://') |
| next(gen) |
| |
| def test_nominal(self): |
| os.environ['WORLD_SIZE'] = '2' |
| os.environ['MASTER_ADDR'] = '127.0.0.1' |
| os.environ['MASTER_PORT'] = str(common.find_free_port()) |
| |
| # First rank |
| os.environ['RANK'] = '0' |
| gen0 = c10d.rendezvous('env://') |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(2, size0) |
| |
| # Second rank |
| os.environ['RANK'] = '1' |
| gen1 = c10d.rendezvous('env://') |
| store1, rank1, size1 = next(gen1) |
| self.assertEqual(1, rank1) |
| self.assertEqual(2, size1) |
| |
| # Set value on both stores |
| store0.set("key0", "value0") |
| store1.set("key1", "value1") |
| |
| # Cross check with get |
| self.assertEqual(b"value0", store1.get("key0")) |
| self.assertEqual(b"value1", store0.get("key1")) |
| |
| |
| class RendezvousFileTest(TestCase): |
| def test_common_errors(self): |
| with self.assertRaisesRegex(ValueError, 'path missing'): |
| gen = c10d.rendezvous('file://?rank=0&world_size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'rank parameter missing'): |
| gen = c10d.rendezvous('file:///tmp/foo?world_size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'size parameter missing'): |
| gen = c10d.rendezvous('file:///tmp/foo?rank=0') |
| next(gen) |
| |
| def test_nominal(self): |
| with tempfile.NamedTemporaryFile() as file: |
| url = 'file://%s?world_size=%d' % (file.name, 2) |
| gen0 = c10d.rendezvous(url + "&rank=0") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(2, size0) |
| gen1 = c10d.rendezvous(url + "&rank=1") |
| store1, rank1, size1 = next(gen1) |
| self.assertEqual(1, rank1) |
| self.assertEqual(2, size1) |
| |
| # Set value on both stores |
| store0.set("key0", "value0") |
| store1.set("key1", "value1") |
| |
| # Cross check with get |
| self.assertEqual(b"value0", store1.get("key0")) |
| self.assertEqual(b"value1", store0.get("key1")) |
| |
| |
| class RendezvousTCPTest(TestCase): |
| def test_common_errors(self): |
| with self.assertRaisesRegex(ValueError, 'port number missing'): |
| gen = c10d.rendezvous('tcp://127.0.0.1?rank=0&world_size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'rank parameter missing'): |
| gen = c10d.rendezvous('tcp://127.0.0.1:23456?world_size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'size parameter missing'): |
| gen = c10d.rendezvous('tcp://127.0.0.1:23456?rank=0') |
| next(gen) |
| |
| def test_nominal(self): |
| addr = 'localhost' |
| port = common.find_free_port() |
| url = 'tcp://%s:%d?world_size=%d' % (addr, port, 2) |
| gen0 = c10d.rendezvous(url + "&rank=0") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(2, size0) |
| gen1 = c10d.rendezvous(url + "&rank=1") |
| store1, rank1, size1 = next(gen1) |
| self.assertEqual(1, rank1) |
| self.assertEqual(2, size1) |
| |
| # Set value on both stores |
| store0.set("key0", "value0") |
| store1.set("key1", "value1") |
| |
| # Cross check with get |
| self.assertEqual(b"value0", store1.get("key0")) |
| self.assertEqual(b"value1", store0.get("key1")) |
| |
| |
| class MultiProcessTestCase(TestCase): |
| MAIN_PROCESS_RANK = -1 |
| |
| @property |
| def world_size(self): |
| return 4 |
| |
| @staticmethod |
| def join_or_run(fn): |
| @wraps(fn) |
| def wrapper(self): |
| if self.rank == self.MAIN_PROCESS_RANK: |
| self._join_processes(fn) |
| else: |
| fn(self) |
| return wrapper |
| |
| # The main process spawns N subprocesses that run the test. |
| # This function patches overwrites every test function to either |
| # assume the role of the main process and join its subprocesses, |
| # or run the underlying test function. |
| @classmethod |
| def setUpClass(cls): |
| for attr in dir(cls): |
| if attr.startswith('test'): |
| fn = getattr(cls, attr) |
| setattr(cls, attr, cls.join_or_run(fn)) |
| |
| def setUp(self): |
| self.rank = self.MAIN_PROCESS_RANK |
| self.file = tempfile.NamedTemporaryFile() |
| self.processes = [self._spawn_process(rank) for rank in range(int(self.world_size))] |
| |
| def tearDown(self): |
| for p in self.processes: |
| p.terminate() |
| self.file.close() |
| |
| def _spawn_process(self, rank): |
| name = 'process ' + str(rank) |
| process = multiprocessing.Process(target=self._run, name=name, args=(rank,)) |
| process.start() |
| return process |
| |
| def _run(self, rank): |
| self.rank = rank |
| |
| # self.id() == e.g. '__main__.TestDistributed.test_get_rank' |
| # We're retreiving a corresponding test and executing it. |
| getattr(self, self.id().split(".")[2])() |
| sys.exit(0) |
| |
| def _join_processes(self, fn): |
| timeout = get_timeout(self.id()) |
| start_time = time.time() |
| for p in self.processes: |
| p.join(timeout) |
| elapsed_time = time.time() - start_time |
| self._check_return_codes(elapsed_time) |
| |
| def _check_return_codes(self, elapsed_time): |
| """ |
| Checks that the return codes of all spawned processes match, and skips |
| tests if they returned a return code indicating a skipping condition. |
| """ |
| first_process = self.processes[0] |
| for i, p in enumerate(self.processes): |
| if p.exitcode is None: |
| raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time)) |
| self.assertEqual(p.exitcode, first_process.exitcode) |
| for skip in TEST_SKIPS.values(): |
| if first_process.exitcode == skip.exit_code: |
| raise unittest.SkipTest(skip.message) |
| self.assertEqual(first_process.exitcode, 0) |
| |
| @property |
| def is_master(self): |
| return self.rank == 0 |
| |
| |
| class ProcessGroupGlooTest(MultiProcessTestCase): |
| def opts(self, threads=2): |
| opts = c10d.ProcessGroupGloo.Options() |
| opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] |
| opts.timeout = 1.0 |
| opts.threads = threads |
| return opts |
| |
| def test_broadcast_checks(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| t1 = torch.zeros([1], dtype=torch.float32) |
| t2 = torch.zeros([1], dtype=torch.float64) |
| t3 = torch.zeros([2], dtype=torch.float32) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = -1 |
| opts.rootTensor = 0 |
| pg.broadcast([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = self.world_size |
| opts.rootTensor = 0 |
| pg.broadcast([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root tensor"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = -1 |
| pg.broadcast([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root tensor"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = 1 |
| pg.broadcast([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root tensor"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = 0 |
| pg.broadcast([], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor type"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = 0 |
| pg.broadcast([t1, t2], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor size"): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = 0 |
| pg.broadcast([t1, t3], opts) |
| |
| def _test_broadcast_basics(self, fn): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| def broadcast(xs, rootRank, rootTensor): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = rootRank |
| opts.rootTensor = rootTensor |
| work = pg.broadcast(xs, opts) |
| work.wait() |
| |
| # Every rank is root once |
| for i in range(self.world_size): |
| # Run with 1 input tensor |
| x = fn(torch.Tensor([self.rank])) |
| broadcast([x], i, 0) |
| self.assertEqual(torch.Tensor([i]), x) |
| |
| # Run with 2 input tensors |
| num = 2 |
| for j in range(num): |
| xs = [ |
| fn(torch.Tensor([self.rank * num + 0.0])), |
| fn(torch.Tensor([self.rank * num + 1.0])), |
| ] |
| |
| broadcast(xs, i, j) |
| self.assertEqual(torch.Tensor([i * num + j]), xs[0]) |
| self.assertEqual(torch.Tensor([i * num + j]), xs[1]) |
| |
| # Test overloaded convenience function |
| x = torch.Tensor([self.rank + 1.0]) |
| work = pg.broadcast(x, root=0) |
| work.wait() |
| self.assertEqual(torch.Tensor([1.0]), x) |
| |
| def test_broadcast_basics(self): |
| self._test_broadcast_basics(lambda t: t.clone()) |
| |
| @skip_if_not_multigpu |
| def test_broadcast_basics_cuda(self): |
| self._test_broadcast_basics(lambda t: t.clone().cuda()) |
| |
| def _test_broadcast_stress(self, inputs): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) |
| work_handles = [ |
| pg.broadcast(inputs[i], root=(i % self.world_size)) |
| for i in range(len(inputs)) |
| ] |
| for i, work_handle in enumerate(work_handles): |
| work_handle.wait() |
| self.assertEqual( |
| torch.Tensor([ |
| (i * self.world_size) + (i % self.world_size) |
| ]), |
| inputs[i], |
| "Mismatch in iteration %d" % i, |
| ) |
| |
| def test_broadcast_stress(self): |
| inputs = [torch.Tensor([i * self.world_size + self.rank]) for i in range(1000)] |
| self._test_broadcast_stress(inputs) |
| |
| @skip_if_not_multigpu |
| def test_broadcast_stress_cuda(self): |
| inputs = [torch.Tensor([i * self.world_size + self.rank]).cuda() for i in range(1000)] |
| self._test_broadcast_stress(inputs) |
| |
| def test_allreduce_checks(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| t1 = torch.zeros([1], dtype=torch.float32) |
| t2 = torch.zeros([1], dtype=torch.float64) |
| t3 = torch.zeros([2], dtype=torch.float32) |
| |
| with self.assertRaisesRegex(ValueError, "requires non-empty tensor list"): |
| opts = c10d.AllreduceOptions() |
| pg.allreduce([], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor type"): |
| opts = c10d.AllreduceOptions() |
| pg.allreduce([t1, t2], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor size"): |
| opts = c10d.AllreduceOptions() |
| pg.allreduce([t1, t3], opts) |
| |
| def _test_allreduce_basics(self, fn): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| for (op, input, output) in simple_reduce_tests(self.rank, self.world_size): |
| opts = c10d.AllreduceOptions() |
| opts.reduceOp = op |
| tmp = fn(input) |
| work = pg.allreduce([tmp], opts) |
| work.wait() |
| self.assertEqual(output, tmp) |
| |
| # Test overloaded convenience function (defaults to using sum) |
| x = fn(torch.Tensor([self.rank + 1.0])) |
| work = pg.allreduce(x) |
| work.wait() |
| self.assertEqual(torch.Tensor([float(self.world_size * (self.world_size + 1) / 2)]), x) |
| |
| def test_allreduce_basics(self): |
| self._test_allreduce_basics(lambda t: t.clone()) |
| |
| @skip_if_not_multigpu |
| def test_allreduce_basics_cuda(self): |
| self._test_allreduce_basics(lambda t: t.clone().cuda()) |
| |
| def _test_allreduce_stress(self, inputs): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts(threads=8)) |
| work_handles = [pg.allreduce(inputs[i]) for i in range(len(inputs))] |
| for i, work_handle in enumerate(work_handles): |
| work_handle.wait() |
| self.assertEqual( |
| torch.Tensor([ |
| (i * self.world_size) + |
| (self.world_size * (self.world_size - 1) / 2) |
| ]), |
| inputs[i], |
| "Mismatch in iteration %d" % i, |
| ) |
| |
| def test_allreduce_stress(self): |
| inputs = [torch.Tensor([i + self.rank]) for i in range(1000)] |
| self._test_allreduce_stress(inputs) |
| |
| @skip_if_not_multigpu |
| def test_allreduce_stress_cuda(self): |
| inputs = [torch.Tensor([i + self.rank]).cuda() for i in range(1000)] |
| self._test_allreduce_stress(inputs) |
| |
| def test_scatter_checks(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| t1 = torch.zeros([1], dtype=torch.float32) |
| t2 = torch.zeros([1], dtype=torch.float64) |
| t3 = torch.zeros([2], dtype=torch.float32) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = -1 |
| pg.scatter([t1], [], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.world_size |
| pg.scatter([t1], [], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element output tensor list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = 0 |
| pg.scatter([], [], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element output tensor list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = 0 |
| pg.scatter([t1, t1], [], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [[t1] * self.world_size, [t1] * self.world_size], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [[t1] * (self.world_size - 1)], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [[t1] * (self.world_size + 1)], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input list"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [[t1] * (self.world_size + 1)], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor type"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [[t2] * self.world_size], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor size"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = self.rank |
| pg.scatter([t1], [[t3] * self.world_size], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires empty input on non-root"): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = (self.rank + 1) % self.world_size |
| pg.scatter([t1], [[t1] * self.world_size], opts) |
| |
| def test_scatter_basics(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| # Preallocate tensors for input/output |
| input = [torch.Tensor([self.rank]) for _ in range(self.world_size)] |
| outputs = [torch.Tensor([-1]) for _ in range(self.world_size)] |
| |
| # Take turns being the scatter root and accumulate work items |
| work = [] |
| for i in range(self.world_size): |
| opts = c10d.ScatterOptions() |
| opts.rootRank = i |
| if i == self.rank: |
| work.append(pg.scatter([outputs[i]], [input], opts)) |
| else: |
| work.append(pg.scatter([outputs[i]], [], opts)) |
| |
| # Wait for work to complete |
| for i in range(self.world_size): |
| work[i].wait() |
| self.assertEqual(torch.Tensor([i]), outputs[i]) |
| |
| def test_gather_checks(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| t1 = torch.zeros([1], dtype=torch.float32) |
| t2 = torch.zeros([1], dtype=torch.float64) |
| t3 = torch.zeros([2], dtype=torch.float32) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = -1 |
| pg.gather([], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.world_size |
| pg.gather([], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input tensor list"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = 0 |
| pg.gather([], [], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element input tensor list"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = 0 |
| pg.gather([], [t1, t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element output list"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.rank |
| pg.gather([], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element output list"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.rank |
| pg.gather([[t1] * self.world_size, [t1] * self.world_size], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element output list"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.rank |
| pg.gather([[t1] * (self.world_size - 1)], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element output list"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.rank |
| pg.gather([[t1] * (self.world_size + 1)], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor type"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.rank |
| pg.gather([[t2] * self.world_size], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor size"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = self.rank |
| pg.gather([[t3] * self.world_size], [t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires empty output on non-root"): |
| opts = c10d.GatherOptions() |
| opts.rootRank = (self.rank + 1) % self.world_size |
| pg.gather([[t1] * self.world_size], [t1], opts) |
| |
| def test_gather_basics(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| # Preallocate tensors for input/output |
| input = [torch.Tensor([self.rank])] |
| outputs = [torch.Tensor([-1]) for _ in range(self.world_size)] |
| |
| # Take turns being the gather root and accumulate work items |
| work = [] |
| for i in range(self.world_size): |
| opts = c10d.GatherOptions() |
| opts.rootRank = i |
| if i == self.rank: |
| work.append(pg.gather([outputs], input, opts)) |
| else: |
| work.append(pg.gather([], input, opts)) |
| |
| # Wait for work to complete |
| expected = [torch.Tensor([rank]) for rank in range(self.world_size)] |
| for i in range(self.world_size): |
| work[i].wait() |
| if i == self.rank: |
| self.assertEqual(expected, outputs) |
| |
| def test_allgather_checks(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| t1 = torch.zeros([1], dtype=torch.float32) |
| t2 = torch.zeros([1], dtype=torch.float64) |
| t3 = torch.zeros([2], dtype=torch.float32) |
| |
| with self.assertRaisesRegex(ValueError, "requires non-empty input tensor list"): |
| pg.allgather([], []) |
| |
| with self.assertRaisesRegex(ValueError, "requires input/output tensor lists to have the same length"): |
| pg.allgather([], [t1]) |
| |
| with self.assertRaisesRegex(ValueError, "requires input/output tensor lists to have the same length"): |
| pg.allgather([[t1] * self.world_size, [t1] * self.world_size], [t1]) |
| |
| with self.assertRaisesRegex(ValueError, "invalid output tensor list"): |
| pg.allgather([[t1] * (self.world_size - 1)], [t1]) |
| |
| with self.assertRaisesRegex(ValueError, "invalid output tensor list"): |
| pg.allgather([[t1] * (self.world_size + 1)], [t1]) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor type"): |
| pg.allgather([[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t2]) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor size"): |
| pg.allgather([[t1, t1] * (self.world_size), [t1, t1] * (self.world_size)], [t1, t3]) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor type"): |
| pg.allgather([([t1, t2] * (self.world_size))[:self.world_size]], [t1]) |
| |
| with self.assertRaisesRegex(ValueError, "invalid tensor size"): |
| pg.allgather([([t1, t3] * (self.world_size))[:self.world_size]], [t1]) |
| |
| def test_allgather_basics(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| # Run with N input tensor per rank |
| for n in [1, 2, 3]: |
| input = [ |
| torch.Tensor([n * self.rank + i]) for i in range(n) |
| ] |
| output = [ |
| [ |
| torch.Tensor([-1]) for _ in range(n * self.world_size) |
| ] for _ in range(n) |
| ] |
| expected_output = [ |
| [ |
| torch.Tensor([i]) for i in range(n * self.world_size) |
| ] for _ in range(n) |
| ] |
| work = pg.allgather(output, input) |
| work.wait() |
| self.assertEqual(expected_output, output) |
| |
| def test_reduce_checks(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| t1 = torch.zeros([1], dtype=torch.float32) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.ReduceOptions() |
| opts.rootRank = -1 |
| opts.rootTensor = 0 |
| pg.reduce([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root rank"): |
| opts = c10d.ReduceOptions() |
| opts.rootRank = self.world_size |
| opts.rootTensor = 0 |
| pg.reduce([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "invalid root tensor"): |
| opts = c10d.ReduceOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = 1 |
| pg.reduce([t1], opts) |
| |
| with self.assertRaisesRegex(ValueError, "requires a single-element tensor list"): |
| opts = c10d.ReduceOptions() |
| opts.rootRank = self.rank |
| opts.rootTensor = 0 |
| pg.reduce([t1, t1], opts) |
| |
| def test_reduce_basics(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| for (op, input, output) in simple_reduce_tests(self.rank, self.world_size): |
| for root in range(self.world_size): |
| opts = c10d.ReduceOptions() |
| opts.reduceOp = op |
| opts.rootRank = root |
| tmp = input.clone() |
| work = pg.reduce([tmp], opts) |
| work.wait() |
| if root == self.rank: |
| self.assertEqual(output, tmp) |
| |
| def test_send_recv_all_to_all(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts()) |
| |
| # Preallocate tensors for input/output |
| inputs = [torch.Tensor([self.rank]) for _ in range(self.world_size)] |
| outputs = [torch.Tensor([-1]) for _ in range(self.world_size)] |
| |
| # Issue sends |
| send_work = [] |
| for i in range(self.world_size): |
| if i == self.rank: |
| continue |
| send_work.append(pg.send([inputs[i]], i, 0)) |
| |
| # Issue recvs |
| recv_work = [] |
| for i in range(self.world_size): |
| if i == self.rank: |
| continue |
| recv_work.append(pg.recv([outputs[i]], i, 0)) |
| |
| # Wait for sends to complete |
| for work in send_work: |
| work.wait() |
| |
| # Wait for recvs to complete |
| for work in recv_work: |
| work.wait() |
| |
| # Test that every output other than our own contains the respective rank |
| for i in range(self.world_size): |
| if i == self.rank: |
| continue |
| self.assertEqual(torch.Tensor([i]), outputs[i]) |
| |
| def test_timeout_kwarg(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo( |
| store, |
| self.rank, |
| self.world_size, |
| timeout=timedelta(seconds=0.5)) |
| |
| # Wait on barrier |
| self.assertTrue(pg.barrier().wait()) |
| |
| # Sleep on one of the processes to trigger barrier timeout |
| if self.rank == 0: |
| time.sleep(0.6) |
| |
| # The barrier will now time output |
| self.assertFalse(pg.barrier().wait()) |
| |
| |
| class ProcessGroupNCCLTest(TestCase): |
| MAIN_PROCESS_RANK = 0 |
| |
| def setUp(self): |
| if not hasattr(c10d, "ProcessGroupNCCL"): |
| raise unittest.SkipTest("C10D is not built with NCCL process group," |
| " skipping test") |
| |
| self.rank = self.MAIN_PROCESS_RANK |
| self.world_size = 1 |
| self.file = tempfile.NamedTemporaryFile() |
| self.num_gpus = torch.cuda.device_count() |
| if self.num_gpus < 2: |
| raise unittest.SkipTest("NCCL test requires 2+ GPUs") |
| |
| def tearDown(self): |
| self.file.close() |
| |
| def test_broadcast_ops(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| def broadcast(xs, rootRank, rootTensor): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = rootRank |
| opts.rootTensor = rootTensor |
| work = pg.broadcast(xs, opts) |
| work.wait() |
| |
| # for every root tensor |
| for rt in range(self.num_gpus): |
| tensors = [] |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i]).cuda(i)) |
| |
| broadcast(tensors, self.rank, rt) |
| |
| for i in range(self.num_gpus): |
| self.assertEqual(tensors[i], tensors[rt]) |
| |
| def test_allreduce_ops(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| def allreduce(tensors, op): |
| opts = c10d.AllreduceOptions() |
| opts.reduceOp = op |
| work = pg.allreduce(tensors, opts) |
| work.wait() |
| |
| # Sum |
| tensors = [] |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i + 1]).cuda(i)) |
| |
| allreduce(tensors, c10d.ReduceOp.SUM) |
| |
| for i in range(self.num_gpus): |
| self.assertEqual( |
| torch.Tensor([float(self.num_gpus * (self.num_gpus + 1) / 2)]), |
| tensors[i]) |
| |
| # Product |
| tensors = [] |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i + 1]).cuda(i)) |
| |
| allreduce(tensors, c10d.ReduceOp.PRODUCT) |
| |
| for i in range(self.num_gpus): |
| self.assertEqual( |
| torch.Tensor([float(math.factorial(self.num_gpus))]), |
| tensors[i]) |
| |
| # Min |
| tensors = [] |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i + 1]).cuda(i)) |
| |
| allreduce(tensors, c10d.ReduceOp.MIN) |
| |
| for i in range(self.num_gpus): |
| self.assertEqual(torch.Tensor([1.0]), tensors[i]) |
| |
| # Max |
| tensors = [] |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i + 1]).cuda(i)) |
| |
| allreduce(tensors, c10d.ReduceOp.MAX) |
| |
| for i in range(self.num_gpus): |
| self.assertEqual(torch.Tensor([self.num_gpus]), tensors[i]) |
| |
| def test_reduce_ops(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| def reduce(xs, rootRank, rootTensor): |
| opts = c10d.ReduceOptions() |
| opts.rootRank = rootRank |
| opts.rootTensor = rootTensor |
| work = pg.reduce(xs, opts) |
| work.wait() |
| |
| # for every root tensor |
| for rt in range(self.num_gpus): |
| tensors = [] |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i + 1]).cuda(i)) |
| |
| reduce(tensors, self.rank, rt) |
| |
| self.assertEqual( |
| torch.Tensor([float(self.num_gpus * (self.num_gpus + 1) / 2)]), |
| tensors[rt]) |
| |
| def test_allgather_ops(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| def allgather(output_ts, input_ts): |
| work = pg.allgather(output_ts, input_ts) |
| work.wait() |
| |
| tensors = [] |
| output_ts = [[] for _ in range(self.num_gpus)] |
| |
| for idx, ls in enumerate(output_ts): |
| for _ in range(self.world_size * self.num_gpus): |
| ls.append(torch.Tensor([0]).cuda(idx)) |
| |
| for i in range(self.num_gpus): |
| tensors.append(torch.Tensor([i]).cuda(i)) |
| |
| allgather(output_ts, tensors) |
| |
| # Verification |
| for device_ts in output_ts: |
| for s_idx, t in enumerate(device_ts): |
| self.assertEqual(torch.Tensor([s_idx]), t) |
| |
| |
| class Net(nn.Module): |
| def __init__(self): |
| super(Net, self).__init__() |
| self.fc1 = nn.Linear(2, 10, bias=False) |
| self.fc2 = nn.Linear(10, 50, bias=False) |
| self.fc3 = nn.Linear(50, 4, bias=False) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.relu(self.fc1(x)) |
| x = self.relu(self.fc2(x)) |
| x = self.fc3(x) |
| return F.softmax(x, dim=1) |
| |
| |
| class DistributedDataParallelTest(MultiProcessTestCase): |
| |
| @property |
| def world_size(self): |
| return 2 |
| |
| def _test_ddp_with_process_group(self, process_group, gpus): |
| model = Net() |
| ddp_model = DistributedDataParallel( |
| copy.deepcopy(model).cuda(gpus[0]), |
| device_ids=gpus, |
| process_group=process_group) |
| |
| model.cuda(gpus[0]) |
| |
| local_batch_size = len(gpus) |
| global_batch_size = self.world_size * local_batch_size |
| input = torch.randn(global_batch_size, 2).cuda(gpus[0]) |
| target = torch.randn(global_batch_size, 4).cuda(gpus[0]) |
| |
| def step_model(model, input, target): |
| model.train() |
| output = model(input) |
| loss = F.mse_loss(output, target) |
| loss.backward() |
| |
| def update_parameters(model): |
| for param in model.parameters(): |
| param.data -= param.grad |
| param.grad = None |
| |
| # check two model parameters over 2 iterations |
| for iteration in range(2): |
| # single cpu/gpu training |
| step_model(model, input, target) |
| |
| # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs |
| step_model(ddp_model, |
| input[self.rank * local_batch_size: (self.rank + 1) * local_batch_size], |
| target[self.rank * local_batch_size: (self.rank + 1) * local_batch_size]) |
| |
| # Update weights and run a second iteration to shake out errors |
| update_parameters(model) |
| update_parameters(ddp_model) |
| self.assertEqual(len(list(model.parameters())), len(list(ddp_model.parameters()))) |
| for i, j in zip(model.parameters(), ddp_model.parameters()): |
| self.assertEqual(i, j) |
| |
| # Shuffle the input so that DDP input is different |
| torch.manual_seed(1337 + iteration) |
| input = input[torch.randperm(global_batch_size)] |
| |
| @skip_if_not_multigpu |
| def test_gloo_backend(self): |
| store = c10d.FileStore(self.file.name) |
| options = c10d.ProcessGroupGloo.Options() |
| options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] |
| process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) |
| gpus = gpus_for_rank(self.world_size)[self.rank] |
| self._test_ddp_with_process_group(process_group, gpus) |
| self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus))) |
| |
| @skip_if_not_multigpu |
| @skip_if_not_nccl |
| def test_nccl_backend(self): |
| store = c10d.FileStore(self.file.name) |
| process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| gpus = gpus_for_rank(self.world_size)[self.rank] |
| self._test_ddp_with_process_group(process_group, gpus) |
| self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus))) |
| |
| @skip_if_not_multigpu |
| @skip_if_not_nccl |
| @skip_for_known_issues |
| def test_dist_broadcast_coalesced_nccl(self): |
| store = c10d.FileStore(self.file.name) |
| process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| device = torch.device('cuda') |
| |
| for fine_grained in [False, True]: |
| target = torch.arange(60, dtype=torch.float16, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float16, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float64, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float16, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) |
| |
| if self.is_master: |
| # All processes should have these tensors in the end. |
| tensors = target |
| else: |
| # Non-master processes start with empty tensors and should be |
| # filled with the tensors from the master. |
| tensors = torch.zeros(60, dtype=torch.float16, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float32, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float16, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float64, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float16, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float32, device=device).chunk(5) |
| |
| c10d._dist_broadcast_coalesced( |
| process_group, |
| tensors, |
| buffer_size=256, |
| fine_grained=fine_grained) |
| |
| self.assertEqual(tensors, target) |
| |
| @skip_if_not_multigpu |
| def test_dist_broadcast_coalesced_gloo(self): |
| store = c10d.FileStore(self.file.name) |
| options = c10d.ProcessGroupGloo.Options() |
| options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] |
| process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) |
| |
| device = torch.device('cuda') |
| |
| for fine_grained in [False, True]: |
| target = torch.arange(60, dtype=torch.float16, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float16, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float64, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float16, device=device).chunk(5) |
| target += torch.arange(60, dtype=torch.float32, device=device).chunk(5) |
| |
| if self.is_master: |
| # All processes should have these tensors in the end. |
| tensors = target |
| else: |
| # Non-master processes start with empty tensors and should be |
| # filled with the tensors from the master. |
| tensors = torch.zeros(60, dtype=torch.float16, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float32, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float16, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float64, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float16, device=device).chunk(5) |
| tensors += torch.zeros(60, dtype=torch.float32, device=device).chunk(5) |
| |
| c10d._dist_broadcast_coalesced( |
| process_group, |
| tensors, |
| buffer_size=128, |
| fine_grained=fine_grained) |
| |
| self.assertEqual(tensors, target) |
| |
| @skip_if_not_multigpu |
| def test_sync_params_no_buffers(self): |
| store = c10d.FileStore(self.file.name) |
| options = c10d.ProcessGroupGloo.Options() |
| options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] |
| process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) |
| |
| # Use all available devices on every process here (data is small, so should be fine). |
| devices = gpus_for_rank(self.world_size)[self.rank] |
| target = torch.arange(10, dtype=torch.float64, device='cuda:0').chunk(5) |
| parameter_data = [target] |
| parameter_data += [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:]] |
| buffer_data = [[]] * len(parameter_data) |
| |
| c10d._sync_params( |
| process_group, |
| parameter_data=parameter_data, |
| buffer_data=buffer_data, |
| devices=devices, |
| broadcast_bucket_size=10, |
| broadcast_buffers=False) |
| |
| for device_data in parameter_data: |
| for i, parameter in enumerate(device_data): |
| self.assertEqual(parameter, target[i]) |
| |
| @skip_if_not_multigpu |
| def test_sync_params_with_buffers(self): |
| store = c10d.FileStore(self.file.name) |
| options = c10d.ProcessGroupGloo.Options() |
| options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] |
| process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) |
| |
| devices = gpus_for_rank(self.world_size)[self.rank] |
| target = torch.arange(10, dtype=torch.float64, device='cuda:0').chunk(5) |
| parameter_data = [target] |
| parameter_data += [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:]] |
| |
| # sync_params should do a dist_broadcast for buffers, so we only populate the master buffers and |
| # then check that other processes' tensors end up matching. |
| |
| if self.is_master: |
| buffer_data = [target] |
| buffer_data += [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:]] |
| else: |
| buffer_data = [torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices] |
| |
| c10d._sync_params( |
| process_group, |
| parameter_data=parameter_data, |
| buffer_data=buffer_data, |
| devices=devices, |
| broadcast_bucket_size=10, |
| broadcast_buffers=True) |
| |
| for device_data in parameter_data: |
| for i, parameter in enumerate(device_data): |
| self.assertEqual(parameter, target[i]) |
| |
| for device_data in buffer_data: |
| for i, buffer in enumerate(device_data): |
| self.assertEqual(buffer, target[i]) |
| |
| @skip_if_not_multigpu |
| @skip_if_not_nccl |
| def test_fp16(self): |
| store = c10d.FileStore(self.file.name) |
| process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| gpus = gpus_for_rank(self.world_size)[self.rank] |
| model = nn.Linear(1, 1, bias=False).cuda(gpus[0]).half() |
| nn.init.constant_(model.weight, 1) |
| ddp_model = DistributedDataParallel( |
| model, |
| device_ids=[gpus[0]], |
| process_group=process_group, |
| bucket_cap_mb=1, |
| ) |
| |
| # Input 2**15, so that the gradients will overflow with a |
| # world_size of 2, unless we normalize the gradient by the |
| # world_size before the reduction |
| input = torch.Tensor([[2**15]]).cuda(gpus[0]).half() |
| |
| # Step model |
| ddp_model.train() |
| output = ddp_model(input) |
| loss = output.sum() |
| loss.backward() |
| |
| self.assertFalse( |
| any(torch.isinf(p.grad).any() for p in ddp_model.parameters()) |
| ) |
| |
| @skip_if_not_nccl |
| def test_queue_reduction(self): |
| # Set up process group. |
| store = c10d.FileStore(self.file.name) |
| process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| # Get this process' split of devices. |
| devices = gpus_for_rank(self.world_size)[self.rank] |
| grads_batch = [(torch.ones(10, device=torch.device('cuda', d)) * |
| (self.rank + 1)).chunk(5) |
| for d in devices] |
| |
| work, local_grad_sum = c10d._queue_reduction(process_group, |
| grads_batch, |
| devices) |
| # The first return value should be the allreduce work item. |
| self.assertTrue(isinstance(work, c10d.Work)) |
| # The second return value will be the finished allreduced gradients. |
| self.assertTrue(isinstance(local_grad_sum, torch.Tensor)) |
| |
| # Wait for the allreduce to finish. |
| work.wait() |
| |
| # The expected result of the allreduce should be the average |
| self.assertEqual(local_grad_sum, |
| torch.ones(10) * (self.world_size + 1) / 2.0) |
| |
| @skip_if_not_nccl |
| def test_sync_reduction(self): |
| # Set up process group. |
| store = c10d.FileStore(self.file.name) |
| process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) |
| |
| # Get this process' split of devices. |
| devices = gpus_for_rank(self.world_size)[self.rank] |
| grads_batch = [(torch.ones(10, device=torch.device('cuda', d)) * |
| (self.rank + 1)).chunk(5) |
| for d in devices] |
| work, local_grad_sum = c10d._queue_reduction(process_group, |
| grads_batch, |
| devices) |
| c10d._sync_reduction(work, grads_batch[0], local_grad_sum) |
| # The expected result of the allreduce should be the average |
| self.assertEqual(grads_batch[0], (torch.ones(10) * (self.world_size + 1) / 2.0).chunk(5)) |
| |
| |
| if __name__ == '__main__': |
| assert not torch.cuda._initialized, "test_distributed must not have initialized CUDA context on main process" |
| |
| run_tests() |