| import math |
| import multiprocessing |
| import sys |
| import tempfile |
| import unittest |
| from functools import wraps |
| |
| import torch |
| import torch.distributed.c10d as c10d |
| |
| from common import TestCase |
| |
| |
| TCP_ADDR = '127.0.0.1' |
| TCP_PORT = 29500 |
| |
| TIMEOUT_DEFAULT = 5 |
| TIMEOUT_OVERRIDE = {} |
| |
| |
| def get_timeout(test_id): |
| return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT) |
| |
| |
| if not c10d.is_available(): |
| print('c10d not available, skipping tests') |
| sys.exit(0) |
| |
| |
| class StoreTestBase(object): |
| def _create_store(self, i): |
| raise RuntimeError("not implemented") |
| |
| def _test_set_get(self, fs): |
| fs.set("key0", "value0") |
| fs.set("key1", "value1") |
| fs.set("key2", "value2") |
| self.assertEqual(b"value0", fs.get("key0")) |
| self.assertEqual(b"value1", fs.get("key1")) |
| self.assertEqual(b"value2", fs.get("key2")) |
| |
| def test_set_get(self): |
| self._test_set_get(self._create_store()) |
| |
| |
| class FileStoreTest(TestCase, StoreTestBase): |
| def setUp(self): |
| self.file = tempfile.NamedTemporaryFile() |
| |
| def tearDown(self): |
| self.file.close() |
| |
| def _create_store(self): |
| return c10d.FileStore(self.file.name) |
| |
| |
| class TCPStoreTest(TestCase, StoreTestBase): |
| def _create_store(self): |
| return c10d.TCPStore(TCP_ADDR, TCP_PORT, True) |
| |
| |
| class RendezvousTest(TestCase): |
| def test_unknown_handler(self): |
| with self.assertRaisesRegex(RuntimeError, "^No rendezvous handler"): |
| c10d.rendezvous('invalid://') |
| |
| |
| class RendezvousFileTest(TestCase): |
| def test_common_errors(self): |
| with self.assertRaisesRegex(ValueError, 'path missing'): |
| gen = c10d.rendezvous('file://?rank=0&size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'rank parameter missing'): |
| gen = c10d.rendezvous('file:///tmp/foo?size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'size parameter missing'): |
| gen = c10d.rendezvous('file:///tmp/foo?rank=0') |
| next(gen) |
| |
| def test_nominal(self): |
| with tempfile.NamedTemporaryFile() as file: |
| url = 'file://%s?size=%d' % (file.name, 2) |
| gen0 = c10d.rendezvous(url + "&rank=0") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(2, size0) |
| gen1 = c10d.rendezvous(url + "&rank=1") |
| store1, rank1, size1 = next(gen1) |
| self.assertEqual(1, rank1) |
| self.assertEqual(2, size1) |
| |
| # Set value on both stores |
| store0.set("key0", "value0") |
| store1.set("key1", "value1") |
| |
| # Cross check with get |
| self.assertEqual(b"value0", store1.get("key0")) |
| self.assertEqual(b"value1", store0.get("key1")) |
| |
| |
| class RendezvousTCPTest(TestCase): |
| def test_common_errors(self): |
| with self.assertRaisesRegex(ValueError, 'port number missing'): |
| gen = c10d.rendezvous('tcp://127.0.0.1?rank=0&size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'rank parameter missing'): |
| gen = c10d.rendezvous('tcp://127.0.0.1:23456?size=1') |
| next(gen) |
| with self.assertRaisesRegex(ValueError, 'size parameter missing'): |
| gen = c10d.rendezvous('tcp://127.0.0.1:23456?rank=0') |
| next(gen) |
| |
| def test_nominal(self): |
| url = 'tcp://127.0.0.1:23456?size=%d' % 2 |
| gen0 = c10d.rendezvous(url + "&rank=0") |
| store0, rank0, size0 = next(gen0) |
| self.assertEqual(0, rank0) |
| self.assertEqual(2, size0) |
| gen1 = c10d.rendezvous(url + "&rank=1") |
| store1, rank1, size1 = next(gen1) |
| self.assertEqual(1, rank1) |
| self.assertEqual(2, size1) |
| |
| # Set value on both stores |
| store0.set("key0", "value0") |
| store1.set("key1", "value1") |
| |
| # Cross check with get |
| self.assertEqual(b"value0", store1.get("key0")) |
| self.assertEqual(b"value1", store0.get("key1")) |
| |
| |
| class ProcessGroupGlooTest(TestCase): |
| MAIN_PROCESS_RANK = -1 |
| |
| @staticmethod |
| def join_or_run(fn): |
| @wraps(fn) |
| def wrapper(self): |
| if self.rank == self.MAIN_PROCESS_RANK: |
| self._join_processes(fn) |
| else: |
| fn(self) |
| return wrapper |
| |
| # The main process spawns N subprocesses that run the test. |
| # This function patches overwrites every test function to either |
| # assume the role of the main process and join its subprocesses, |
| # or run the underlying test function. |
| @classmethod |
| def setUpClass(cls): |
| for attr in dir(cls): |
| if attr.startswith('test'): |
| fn = getattr(cls, attr) |
| setattr(cls, attr, cls.join_or_run(fn)) |
| |
| def setUp(self): |
| self.rank = self.MAIN_PROCESS_RANK |
| self.size = 4 |
| self.file = tempfile.NamedTemporaryFile() |
| self.processes = [self._spawn_process(rank) for rank in range(int(self.size))] |
| |
| def tearDown(self): |
| for p in self.processes: |
| p.terminate() |
| self.file.close() |
| |
| def _spawn_process(self, rank): |
| name = 'process ' + str(rank) |
| process = multiprocessing.Process(target=self._run, name=name, args=(rank,)) |
| process.start() |
| return process |
| |
| def _run(self, rank): |
| self.rank = rank |
| |
| # self.id() == e.g. '__main__.TestDistributed.test_get_rank' |
| # We're retreiving a corresponding test and executing it. |
| getattr(self, self.id().split(".")[2])() |
| sys.exit(0) |
| |
| def _join_processes(self, fn): |
| timeout = get_timeout(self.id()) |
| for p in self.processes: |
| p.join(timeout) |
| |
| def test_broadcast_ops(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.size) |
| |
| def broadcast(xs, rootRank, rootTensor): |
| opts = c10d.BroadcastOptions() |
| opts.rootRank = rootRank |
| opts.rootTensor = rootTensor |
| work = pg.broadcast(xs, opts) |
| work.wait() |
| |
| # Every rank is root once, every tensor index is root once |
| for i in range(self.size): |
| for j in range(2): |
| xs = [ |
| torch.Tensor([self.rank * self.size + 0.0]), |
| torch.Tensor([self.rank * self.size + 1.0]), |
| ] |
| |
| broadcast(xs, i, j) |
| self.assertEqual(torch.Tensor([i * self.size + j]), xs[0]) |
| self.assertEqual(torch.Tensor([i * self.size + j]), xs[1]) |
| |
| # Test overloaded convenience function |
| x = torch.Tensor([self.rank + 1.0]) |
| work = pg.broadcast(x, root=0) |
| work.wait() |
| self.assertEqual(torch.Tensor([1.0]), x) |
| |
| def test_allreduce_ops(self): |
| store = c10d.FileStore(self.file.name) |
| pg = c10d.ProcessGroupGloo(store, self.rank, self.size) |
| |
| def allreduce(x, op): |
| opts = c10d.AllreduceOptions() |
| opts.reduceOp = op |
| work = pg.allreduce([x], opts) |
| work.wait() |
| |
| # Sum |
| x = torch.Tensor([self.rank + 1.0]) |
| allreduce(x, c10d.ReduceOp.SUM) |
| self.assertEqual(torch.Tensor([float(self.size * (self.size + 1) / 2)]), x) |
| |
| # Product |
| x = torch.Tensor([self.rank + 1.0]) |
| allreduce(x, c10d.ReduceOp.PRODUCT) |
| self.assertEqual(torch.Tensor([float(math.factorial(self.size))]), x) |
| |
| # Min |
| x = torch.Tensor([self.rank + 1.0]) |
| allreduce(x, c10d.ReduceOp.MIN) |
| self.assertEqual(torch.Tensor([1.0]), x) |
| |
| # Max |
| x = torch.Tensor([self.rank + 1.0]) |
| allreduce(x, c10d.ReduceOp.MAX) |
| self.assertEqual(torch.Tensor([self.size]), x) |
| |
| # Test overloaded convenience function (defaults to using sum) |
| x = torch.Tensor([self.rank + 1.0]) |
| work = pg.allreduce(x) |
| work.wait() |
| self.assertEqual(torch.Tensor([float(self.size * (self.size + 1) / 2)]), x) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |