Reformat (#62073)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62073
as title
ghstack-source-id: 134159445
Test Plan: N/A
Reviewed By: rohan-varma
Differential Revision: D29869185
fbshipit-source-id: 17a32d56860e9469bd26c4eb4ca2d483827d946e
diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py
index 520c009..54f87ec 100644
--- a/test/distributed/test_c10d_common.py
+++ b/test/distributed/test_c10d_common.py
@@ -55,7 +55,7 @@
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
- visible_devices[rank * gpus_per_process: (rank + 1) * gpus_per_process]
+ visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process]
)
return gpus_for_rank
@@ -248,12 +248,12 @@
return 2
def _prepare_single_device_module(
- self,
- process_group,
- devices,
- device_ids,
- global_batch_size,
- gradient_as_bucket_view=False,
+ self,
+ process_group,
+ devices,
+ device_ids,
+ global_batch_size,
+ gradient_as_bucket_view=False,
):
model = Net()
device = devices[0] if devices else torch.device("cuda:%d" % self.rank)
@@ -273,12 +273,12 @@
return model, ddp_model, input, target
def _prepare_multi_device_module(
- self,
- process_group,
- devices,
- device_ids,
- global_batch_size,
- gradient_as_bucket_view=False,
+ self,
+ process_group,
+ devices,
+ device_ids,
+ global_batch_size,
+ gradient_as_bucket_view=False,
):
self.assertTrue(
len(devices) == 2 or len(devices) == 4,
@@ -303,12 +303,12 @@
return model, ddp_model, input, target
def _test_ddp_with_process_group(
- self,
- process_group,
- devices,
- device_ids,
- multi_device=False,
- gradient_as_bucket_view=False,
+ self,
+ process_group,
+ devices,
+ device_ids,
+ multi_device=False,
+ gradient_as_bucket_view=False,
):
"""
Note: we pass down `device_ids` all the way to DistributedDataParallel
@@ -362,10 +362,10 @@
step_model(
ddp_model,
input[
- self.rank * local_batch_size: (self.rank + 1) * local_batch_size
+ self.rank * local_batch_size : (self.rank + 1) * local_batch_size
],
target[
- self.rank * local_batch_size: (self.rank + 1) * local_batch_size
+ self.rank * local_batch_size : (self.rank + 1) * local_batch_size
],
)
@@ -383,7 +383,7 @@
input = input[torch.randperm(global_batch_size)]
def _gpu_model_with_ddp_comm_hook(
- self, process_group, hook=None, gradient_as_bucket_view=False, state=None
+ self, process_group, hook=None, gradient_as_bucket_view=False, state=None
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
@@ -400,7 +400,7 @@
return gpu_model
def _gpu_model_with_builtin_ddp_comm_hook(
- self, process_group, hook=None, gradient_as_bucket_view=False
+ self, process_group, hook=None, gradient_as_bucket_view=False
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
@@ -426,7 +426,7 @@
[self.assertEqual(p.grad, expected_grad) for p in model.parameters()]
def _simple_hook(
- self, state: object, bucket: dist.GradBucket
+ self, state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
fut = torch.futures.Future()
fut.set_result([torch.ones_like(bucket.get_tensor())])
@@ -442,8 +442,9 @@
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
-class DistributedDataParallelTest(AbstractDistributedDataParallelTest, MultiProcessTestCase):
-
+class DistributedDataParallelTest(
+ AbstractDistributedDataParallelTest, MultiProcessTestCase
+):
def setUp(self):
super(DistributedDataParallelTest, self).setUp()
if sys.platform == "win32":
@@ -453,14 +454,14 @@
def test_invalid_powerSGD_state(self):
for start_powerSGD_iter, use_error_feedback, warm_start in product(
- [0, 1], [True, False], [True, False]
+ [0, 1], [True, False], [True, False]
):
if not use_error_feedback and not warm_start:
continue
with self.assertRaisesRegex(
- ValueError,
- "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
- "because PowerSGD can only be applied after the first two iterations in DDP.",
+ ValueError,
+ "Expect `start_powerSGD_iter` > 1 if `use_error_feedback` or `warm_start` is enabled, "
+ "because PowerSGD can only be applied after the first two iterations in DDP.",
):
state = powerSGD.PowerSGDState(
process_group=None,
@@ -518,7 +519,6 @@
class AbstractCommTest(object):
-
@property
def op_timeout_sec(self):
return 1
@@ -651,7 +651,6 @@
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class CommTest(AbstractCommTest, MultiProcessTestCase):
-
def setUp(self):
super(CommTest, self).setUp()
if sys.platform == "win32":
diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py
index 11b91ed..c16f5c6 100644
--- a/test/distributed/test_c10d_gloo.py
+++ b/test/distributed/test_c10d_gloo.py
@@ -17,9 +17,17 @@
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)
+import test_c10d_common
import torch.distributed as dist
import torch.nn.functional as F
import torch.testing._internal.common_utils as common
+from test_c10d_common import (
+ LOOPBACK,
+ gpus_for_rank,
+ Task,
+ ModuleForDdpCommHook,
+ SparseGradientModule,
+)
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
@@ -38,14 +46,6 @@
retry_on_connect_failures,
TEST_WITH_TSAN,
)
-import test_c10d_common
-from test_c10d_common import (
- LOOPBACK,
- gpus_for_rank,
- Task,
- ModuleForDdpCommHook,
- SparseGradientModule,
-)
def simple_reduce_tests(rank, world_size):
@@ -202,6 +202,7 @@
def test_default_store_timeout_gloo(self):
self._test_default_store_timeout("gloo")
+
@requires_gloo()
@unittest.skipIf(
TEST_WITH_TSAN,
@@ -213,7 +214,6 @@
dist.barrier(group=pg)
return pg
-
def setUp(self):
super(ProcessGroupGlooTest, self).setUp()
@@ -246,7 +246,9 @@
def test_empty_tensors(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
xs = [torch.FloatTensor([])]
fut = pg.broadcast(xs).get_future()
@@ -257,7 +259,9 @@
def test_broadcast_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1], dtype=torch.float32)
t2 = torch.zeros([1], dtype=torch.float64)
@@ -307,7 +311,9 @@
def _test_broadcast_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
def broadcast(xs, rootRank, rootTensor):
opts = c10d.BroadcastOptions()
@@ -383,7 +389,9 @@
def test_allreduce_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1], dtype=torch.float32)
t2 = torch.zeros([1], dtype=torch.float64)
@@ -403,7 +411,9 @@
def _test_allreduce_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Single input tests
tests = simple_reduce_tests(self.rank, self.world_size)
@@ -436,7 +446,8 @@
fut.wait()
result = fut.value()
self.assertEqual(
- torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), result[0]
+ torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
+ result[0],
)
def test_allreduce_basics(self):
@@ -450,7 +461,9 @@
# This should go away as we deprecate it.
def _test_allreduce_basics_using_work_api(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Single input tests
tests = simple_reduce_tests(self.rank, self.world_size)
@@ -483,7 +496,8 @@
work.wait()
result = work.result()
self.assertEqual(
- torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), result[0]
+ torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
+ result[0],
)
def test_allreduce_basics_using_work_api(self):
@@ -498,7 +512,9 @@
pg = self._create_process_group_gloo(
store, self.rank, self.world_size, self.opts(threads=8)
)
- future_handles = [pg.allreduce(inputs[i]).get_future() for i in range(len(inputs))]
+ future_handles = [
+ pg.allreduce(inputs[i]).get_future() for i in range(len(inputs))
+ ]
for i, future_handle in enumerate(future_handles):
future_handle.wait()
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
@@ -525,7 +541,9 @@
def test_allreduce_coalesced_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros(1, dtype=torch.float32)
t2 = torch.zeros(1, dtype=torch.float64)
@@ -550,7 +568,9 @@
@skip_if_lt_x_gpu(1)
def test_allreduce_coalesced_checks_cuda(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros(1, dtype=torch.float32)
@@ -560,7 +580,9 @@
def _test_allreduce_coalesced_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
test_cases = simple_coalesced_reduce_tests(self.rank, self.world_size)
for op, inputs, outputs in test_cases:
@@ -582,7 +604,9 @@
pg = self._create_process_group_gloo(
store, self.rank, self.world_size, self.opts(threads=8)
)
- future_handles = [pg.allreduce_coalesced(input).get_future() for input in inputs]
+ future_handles = [
+ pg.allreduce_coalesced(input).get_future() for input in inputs
+ ]
for i, future_handle in enumerate(future_handles):
future_handle.wait()
result = future_handle.value()
@@ -607,7 +631,9 @@
def test_sparse_allreduce_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1])
t2 = torch.sparse_coo_tensor([[0]], [1], size=(2,))
@@ -634,7 +660,9 @@
def _test_sparse_allreduce_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
for num_inputs_per_rank in [1, 2]:
tests = simple_sparse_reduce_tests(
@@ -658,7 +686,9 @@
def test_scatter_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1], dtype=torch.float32)
t2 = torch.zeros([1], dtype=torch.float64)
@@ -733,7 +763,9 @@
def _test_scatter_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Preallocate tensors for input/output
input = [fn(torch.tensor([self.rank])) for _ in range(self.world_size)]
@@ -814,7 +846,9 @@
def test_gather_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1], dtype=torch.float32)
t2 = torch.zeros([1], dtype=torch.float64)
@@ -893,7 +927,9 @@
def _test_gather_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Preallocate tensors for input/output
input = [fn(torch.tensor([self.rank]))]
@@ -972,7 +1008,9 @@
def test_allgather_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1], dtype=torch.float32)
t2 = torch.zeros([1], dtype=torch.float64)
@@ -1015,7 +1053,9 @@
def _test_allgather_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Run with N input tensor per rank
for n in [1, 2, 3]:
@@ -1081,7 +1121,9 @@
def test_allgather_coalesced_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
dummy_input = [torch.zeros([1], dtype=torch.float32)]
dummy_output_lists = [
[torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size)
@@ -1117,7 +1159,9 @@
def test_reduce_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
t1 = torch.zeros([1], dtype=torch.float32)
@@ -1149,7 +1193,9 @@
def _test_reduce_basics(self, fn):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
for (op, input, output) in simple_reduce_tests(self.rank, self.world_size):
for root in range(self.world_size):
opts = c10d.ReduceOptions()
@@ -1216,7 +1262,9 @@
def test_send_recv_all_to_all(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Preallocate tensors for input/output
inputs = [torch.tensor([self.rank]) for _ in range(self.world_size)]
@@ -1254,7 +1302,9 @@
def test_barrier_implies_wait(self):
store = c10d.FileStore(self.file_name, self.world_size)
- pg = self._create_process_group_gloo(store, self.rank, self.world_size, self.opts())
+ pg = self._create_process_group_gloo(
+ store, self.rank, self.world_size, self.opts()
+ )
# Kick off allreduce operations
size = (100, 100)
@@ -1277,7 +1327,10 @@
pg = c10d._round_robin_process_groups(
[
c10d.ProcessGroupGloo(
- c10d.PrefixStore(str(i), store), self.rank, self.world_size, self.opts()
+ c10d.PrefixStore(str(i), store),
+ self.rank,
+ self.world_size,
+ self.opts(),
)
for i in range(num_process_groups)
]
@@ -1300,7 +1353,7 @@
c10d.PrefixStore("%s/%d" % (prefix, i), store),
self.rank,
self.world_size,
- self.opts()
+ self.opts(),
)
for i in range(num)
]
@@ -1321,8 +1374,9 @@
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
-class DistributedDataParallelTest(test_c10d_common.AbstractDistributedDataParallelTest, MultiProcessTestCase):
-
+class DistributedDataParallelTest(
+ test_c10d_common.AbstractDistributedDataParallelTest, MultiProcessTestCase
+):
def setUp(self):
super(DistributedDataParallelTest, self).setUp()
if sys.platform == "win32":
@@ -1331,7 +1385,7 @@
self._fork_processes()
def _test_gloo_backend(
- self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
+ self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
):
store = c10d.FileStore(self.file_name, self.world_size)
options = c10d.ProcessGroupGloo._Options()
@@ -1349,7 +1403,9 @@
@requires_gloo()
def test_gloo_backend_cpu_module_grad_is_view(self):
- self._test_gloo_backend([torch.device("cpu")], None, gradient_as_bucket_view=True)
+ self._test_gloo_backend(
+ [torch.device("cpu")], None, gradient_as_bucket_view=True
+ )
@requires_gloo()
@skip_if_lt_x_gpu(2)
@@ -1379,7 +1435,9 @@
devices = [torch.device("cuda:" + str(i)) for i in int_devices]
self._test_gloo_backend(devices, None, multi_device=True)
- def _test_global_local_unused_params_grad(self, gradient_as_bucket_view=False, static_graph=False):
+ def _test_global_local_unused_params_grad(
+ self, gradient_as_bucket_view=False, static_graph=False
+ ):
"""
By simulating a multi-task training, this test is to make sure:
1) DDP does not touch the grad of globally unused parameters.
@@ -1683,7 +1741,9 @@
criterion = nn.CrossEntropyLoss()
optimizer_withload = torch.optim.SGD(ddp_withload.parameters(), lr=0.001)
- optimizer_non_ddp_withload = torch.optim.SGD(model_withload.parameters(), lr=0.001)
+ optimizer_non_ddp_withload = torch.optim.SGD(
+ model_withload.parameters(), lr=0.001
+ )
optimizer_withoutload = torch.optim.SGD(ddp_withoutload.parameters(), lr=0.001)
input = torch.rand([batch_size, 2], dtype=torch.float).to(device_id)
@@ -1709,7 +1769,9 @@
p.zero_()
ddp_withload.load_state_dict(ddp_state_dict)
# the non-DDP model needs to first remove the prefix of "module." from the DDP state dict
- torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(ddp_state_dict, "module.")
+ torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
+ ddp_state_dict, "module."
+ )
model_withload.load_state_dict(ddp_state_dict)
train_loop(ddp_withload, optimizer_withload, 3)
@@ -1719,7 +1781,9 @@
train_loop(ddp_withoutload, optimizer_withoutload, 6)
for p_withload, p_withoutload, p_non_ddp_withload in zip(
- ddp_withload.parameters(), ddp_withoutload.parameters(), model_withload.parameters()
+ ddp_withload.parameters(),
+ ddp_withoutload.parameters(),
+ model_withload.parameters(),
):
self.assertEqual(p_withload, p_withoutload)
self.assertEqual(p_non_ddp_withload, p_withoutload)
@@ -1770,7 +1834,7 @@
self._run_and_verify_hook(cpu_model, 8, 2 * torch.ones(2, 2))
def _gpu_model_with_ddp_comm_hook(
- self, process_group, hook=None, gradient_as_bucket_view=False, state=None
+ self, process_group, hook=None, gradient_as_bucket_view=False, state=None
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
@@ -1821,8 +1885,9 @@
model.register_comm_hook(state=None, hook=1)
with self.assertRaisesRegex(
- ValueError, "bucket annotation should be dist.GradBucket."
+ ValueError, "bucket annotation should be dist.GradBucket."
):
+
def comm_hook(state: object, bucket: int) -> torch.futures.Future:
return torch.futures.Future()
@@ -1844,9 +1909,10 @@
expected_err = "Communication hook: return annotation should be torch.futures.Future or torch._C.Future."
with self.assertRaisesRegex(
- ValueError,
- expected_err,
+ ValueError,
+ expected_err,
):
+
def comm_hook(state: object, bucket: dist.GradBucket) -> int:
return torch.futures.Future()
@@ -1855,9 +1921,10 @@
verify_ddp_error_logged(model, expected_err)
with self.assertRaisesRegex(
- RuntimeError,
- "callback must return a torch.futures.Future or torch._C.Future object, but got",
+ RuntimeError,
+ "callback must return a torch.futures.Future or torch._C.Future object, but got",
):
+
def comm_hook(state: object, bucket: dist.GradBucket):
return 1
@@ -1890,8 +1957,8 @@
model.register_comm_hook(None, dummy_hook)
with self.assertRaisesRegex(
- RuntimeError,
- "register_comm_hook or register_builtin_comm_hook can only be called once.",
+ RuntimeError,
+ "register_comm_hook or register_builtin_comm_hook can only be called once.",
):
model.register_comm_hook(None, dummy_hook)
@@ -1914,7 +1981,7 @@
)
def allreduce_hook_gloo(
- state: object, bucket: dist.GradBucket
+ state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
def div_by_world_size(fut):
# Divide the result by 2 * world_size.
@@ -2000,8 +2067,8 @@
num_replicas = 2
models = [self._create_mixed_precision_model() for _ in range(num_replicas)]
with self.assertRaisesRegex(
- RuntimeError,
- "Expected exactly one model replica.",
+ RuntimeError,
+ "Expected exactly one model replica.",
):
reducer = self._create_reducer_for_models(models)
@@ -2067,7 +2134,6 @@
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
-
def setUp(self):
super(CommTest, self).setUp()
if sys.platform == "win32":
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index ea476ad..bcad92d 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -20,14 +20,15 @@
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)
+import test_c10d_common
import torch.distributed as dist
import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
import torch.nn.functional as F
import torch.testing._internal.common_utils as common
+from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
from torch import nn
from torch.nn.parallel import DistributedDataParallel
-from torch.utils.checkpoint import checkpoint
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
@@ -44,8 +45,7 @@
retry_on_connect_failures,
TEST_WITH_TSAN,
)
-import test_c10d_common
-from test_c10d_common import gpus_for_rank, DoubleGpuNet, ConvNet, ModuleForDdpCommHook
+from torch.utils.checkpoint import checkpoint
class RendezvousEnvTest(TestCase):
@@ -158,6 +158,7 @@
raise unittest.SkipTest("No GPUs available, skipping test")
self._test_default_store_timeout("nccl")
+
class ProcessGroupNCCLNoGPUTest(TestCase):
MAIN_PROCESS_RANK = 0
@@ -176,7 +177,7 @@
def test_init_no_gpus(self):
store = c10d.FileStore(self.file.name, self.world_size)
with self.assertRaisesRegex(
- RuntimeError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"
+ RuntimeError, "ProcessGroupNCCL is only supported with GPUs, no GPUs found!"
):
c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
@@ -312,7 +313,7 @@
for op in (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR):
with self.assertRaisesRegex(
- RuntimeError, "Cannot use " + str(op) + " with NCCL"
+ RuntimeError, "Cannot use " + str(op) + " with NCCL"
):
allreduce(tensors, op)
@@ -346,7 +347,7 @@
for op in (c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR):
with self.assertRaisesRegex(
- RuntimeError, "Cannot use " + str(op) + " with NCCL"
+ RuntimeError, "Cannot use " + str(op) + " with NCCL"
):
reduce(tensors, self.rank, rt, op)
@@ -407,16 +408,25 @@
device_id = self.rank % self.num_gpus
# anticpate an error
- with self.assertRaisesRegex(RuntimeError, "output tensor size must be equal to world_size times input tensor size"):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "output tensor size must be equal to world_size times input tensor size",
+ ):
tensor = torch.tensor([self.rank]).cuda(device_id)
- output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).cuda(device_id)
+ output_t = torch.empty((self.world_size + 1), dtype=tensor.dtype).cuda(
+ device_id
+ )
# fails the check because output_t is not correctly sized
allgather_base(output_t, tensor)
# anticpate an error
- with self.assertRaisesRegex(RuntimeError, "output tensor must have the same type as input tensor"):
+ with self.assertRaisesRegex(
+ RuntimeError, "output tensor must have the same type as input tensor"
+ ):
tensor = torch.tensor([self.rank], dtype=torch.float).cuda(device_id)
- output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(device_id)
+ output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
+ device_id
+ )
# fails the check because the dtype is different
allgather_base(output_t, tensor)
@@ -431,16 +441,25 @@
device_id = self.rank % self.num_gpus
# anticpate an error
- with self.assertRaisesRegex(RuntimeError, "input tensor must be the same size as output size times world size"):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "input tensor must be the same size as output size times world size",
+ ):
input_t = torch.tensor([self.rank]).cuda(device_id)
- output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).cuda(device_id)
+ output_t = torch.empty((self.world_size + 1), dtype=input_t.dtype).cuda(
+ device_id
+ )
# fails the check because output_t is not correctly sized
reduce_scatter_base(output_t, input_t)
# anticpate an error
- with self.assertRaisesRegex(RuntimeError, "input tensor must be the same type as the outut tensor."):
+ with self.assertRaisesRegex(
+ RuntimeError, "input tensor must be the same type as the outut tensor."
+ ):
tensor = torch.tensor([self.rank], dtype=torch.float).cuda(device_id)
- output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(device_id)
+ output_t = torch.empty((self.world_size + 1), dtype=torch.long).cuda(
+ device_id
+ )
# fails the check because the dtype is different
reduce_scatter_base(output_t, tensor)
@@ -578,8 +597,9 @@
TEST_WITH_TSAN,
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
-class DistributedDataParallelTest(test_c10d_common.AbstractDistributedDataParallelTest, MultiProcessTestCase):
-
+class DistributedDataParallelTest(
+ test_c10d_common.AbstractDistributedDataParallelTest, MultiProcessTestCase
+):
def setUp(self):
super(DistributedDataParallelTest, self).setUp()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
@@ -591,7 +611,7 @@
self._fork_processes()
def _test_nccl_backend(
- self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
+ self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
):
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
@@ -604,7 +624,9 @@
def test_nccl_backend_multi_device_ids_not_allowed(self):
int_devices = list(range(torch.cuda.device_count()))
devices = [torch.device("cuda:" + str(i)) for i in int_devices]
- with self.assertRaisesRegex(ValueError, "device_ids can only be None or contain a single element."):
+ with self.assertRaisesRegex(
+ ValueError, "device_ids can only be None or contain a single element."
+ ):
self._test_nccl_backend(devices, int_devices)
@requires_nccl()
@@ -669,27 +691,31 @@
model = DoubleGpuNet(gpus)
with self.assertRaisesRegex(
- ValueError,
- "DistributedDataParallel device_ids and output_device arguments only work with "
- "single-device/multiple-device GPU modules or CPU modules",
+ ValueError,
+ "DistributedDataParallel device_ids and output_device arguments only work with "
+ "single-device/multiple-device GPU modules or CPU modules",
):
ddp_model = DistributedDataParallel(
model, output_device=gpus[1], process_group=process_group
)
- with self.assertRaisesRegex(ValueError, "device_ids can only be None or contain a single element."):
+ with self.assertRaisesRegex(
+ ValueError, "device_ids can only be None or contain a single element."
+ ):
ddp_model = DistributedDataParallel(
model, device_ids=gpus, process_group=process_group
)
with self.assertRaisesRegex(
- ValueError, "input module must be on the same type of devices"
+ ValueError, "input module must be on the same type of devices"
):
model.fc1 = model.fc1.cpu()
ddp_model = DistributedDataParallel(model, process_group=process_group)
model = model.cpu()
- with self.assertRaisesRegex(ValueError, "device_ids can only be None or contain a single element."):
+ with self.assertRaisesRegex(
+ ValueError, "device_ids can only be None or contain a single element."
+ ):
ddp_model = DistributedDataParallel(
model, device_ids=gpus, process_group=process_group
)
@@ -839,7 +865,7 @@
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
with self.assertRaisesRegex(
- RuntimeError, "Modules with uninitialized parameters"
+ RuntimeError, "Modules with uninitialized parameters"
):
DistributedDataParallel(
torch.nn.LazyLinear(10), process_group=process_group
@@ -855,7 +881,7 @@
backend="nccl",
world_size=self.world_size,
rank=self.rank,
- init_method=f"file://{self.file_name}"
+ init_method=f"file://{self.file_name}",
)
process_group = c10d.distributed_c10d._get_default_group()
@@ -886,7 +912,7 @@
ddp_model = None
def test_find_unused_parameters(
- find_unused_parameters, test_default=False, gradient_as_bucket_view=False
+ find_unused_parameters, test_default=False, gradient_as_bucket_view=False
):
if test_default:
model = DistributedDataParallel(
@@ -929,9 +955,7 @@
model = ddp_model.module
for module_name, module in model.named_modules():
if module == model.fc3:
- for parameter_name, _ in module.named_parameters(
- recurse=False
- ):
+ for parameter_name, _ in module.named_parameters(recurse=False):
unused_fqn = f"{module_name}.{parameter_name}"
# Only one such parameter in model.fc3, since bias=False
break
@@ -1144,16 +1168,16 @@
# Skip gradients sync without calling prepare_for_backward
step_model(
ddp_model.module,
- input[self.rank: (self.rank + 1)],
- target[self.rank: (self.rank + 1)],
+ input[self.rank : (self.rank + 1)],
+ target[self.rank : (self.rank + 1)],
)
for i, j in zip(model.parameters(), ddp_model.parameters()):
self.assertNotEqual(i.grad, j.grad)
else:
step_model(
ddp_model,
- input[self.rank: (self.rank + 1)],
- target[self.rank: (self.rank + 1)],
+ input[self.rank : (self.rank + 1)],
+ target[self.rank : (self.rank + 1)],
)
for i, j in zip(model.parameters(), ddp_model.parameters()):
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
@@ -1299,10 +1323,10 @@
dist._DEFAULT_FIRST_BUCKET_BYTES = old_DEFAULT_FIRST_BUCKET_BYTES
with torch.backends.cudnn.flags(
- enabled=True, deterministic=True, benchmark=False
+ enabled=True, deterministic=True, benchmark=False
):
for formats, dtypes, bucketsize in product(
- layer_formats, layer_dtypes, bucketsizes
+ layer_formats, layer_dtypes, bucketsizes
):
with first_bucket_size(bucketsize):
model_msg = (
@@ -1341,7 +1365,7 @@
target[local_batch_start:local_batch_end],
).backward()
for i, ((layer_name, m_child), m_ddp_child) in enumerate(
- zip(m.named_children(), m_ddp.module.children())
+ zip(m.named_children(), m_ddp.module.children())
):
named_msg = layer_name + ".weight" + " " + iter_msg
self.assertTrue(
@@ -1357,10 +1381,10 @@
named_msg,
)
for j, ((param_name, p), p_ddp) in enumerate(
- zip(
- m_child.named_parameters(),
- m_ddp_child.parameters(),
- )
+ zip(
+ m_child.named_parameters(),
+ m_ddp_child.parameters(),
+ )
):
named_msg = (
layer_name + "." + param_name + " " + iter_msg
@@ -1433,15 +1457,20 @@
)
else:
with self.assertRaisesRegex(
- RuntimeError,
- ".* appears not to match strides of the same param in process 0",
+ RuntimeError,
+ ".* appears not to match strides of the same param in process 0",
):
m_ddp = DistributedDataParallel(
m, device_ids=[dev0], process_group=process_group
)
def _gpu_model_with_ddp_comm_hook(
- self, process_group, hook=None, gradient_as_bucket_view=False, state=None, static_graph=False
+ self,
+ process_group,
+ hook=None,
+ gradient_as_bucket_view=False,
+ state=None,
+ static_graph=False,
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
gpu_model = DistributedDataParallel(
@@ -1477,7 +1506,9 @@
# without the comm_hook, result would be 0.25 * torch.ones(2, 2).
self._run_and_verify_hook(gpu_model, 8, 2 * torch.ones(2, 2))
- def _test_ddp_comm_hook_allreduce_hook_nccl(self, gradient_as_bucket_view=False, static_graph=False):
+ def _test_ddp_comm_hook_allreduce_hook_nccl(
+ self, gradient_as_bucket_view=False, static_graph=False
+ ):
"""
This unit test verifies whether a DDP communication hook that just calls
allreduce gives the same result with the case of no hook registered.
@@ -1529,14 +1560,17 @@
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
powerSGD_state = powerSGD.PowerSGDState(process_group=process_group)
- hook_args = [(powerSGD.powerSGD_hook, powerSGD_state), (default.allreduce_hook, process_group)]
+ hook_args = [
+ (powerSGD.powerSGD_hook, powerSGD_state),
+ (default.allreduce_hook, process_group),
+ ]
for hook, state in hook_args:
gpu_model = self._gpu_model_with_ddp_comm_hook(
process_group,
default.fp16_compress_wrapper(hook),
gradient_as_bucket_view,
- state
+ state,
)
# check whether the grads are equal to what DDP without hook would return.
@@ -1653,7 +1687,7 @@
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
def allreduce_with_then_hook(
- state: object, bucket: dist.GradBucket
+ state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
tensors = [bucket.get_tensor() / self.world_size]
fut = process_group.allreduce(tensors).get_future()
@@ -1699,14 +1733,17 @@
p = torch.nn.Parameter(torch.randn(size, requires_grad=True))
for try_set_to_none, use_bucket_view in product((False, True), (False, True)):
- m = torch.nn.Sequential(self.AcceptsParam(p, dev + 1),
- self.AcceptsParam(p, dev + 1)).cuda(dev)
+ m = torch.nn.Sequential(
+ self.AcceptsParam(p, dev + 1), self.AcceptsParam(p, dev + 1)
+ ).cuda(dev)
- m = torch.nn.parallel.DistributedDataParallel(m,
- bucket_cap_mb=1,
- gradient_as_bucket_view=use_bucket_view,
- device_ids=[dev],
- process_group=process_group)
+ m = torch.nn.parallel.DistributedDataParallel(
+ m,
+ bucket_cap_mb=1,
+ gradient_as_bucket_view=use_bucket_view,
+ device_ids=[dev],
+ process_group=process_group,
+ )
for i in range(3):
m.zero_grad(set_to_none=try_set_to_none)
@@ -1715,11 +1752,20 @@
# Each param value is multiplied by "rank + 1" twice in forward, so the grad
# values produced by a particular rank should be 2. * (rank + 1).
# Summing these over ranks and dividing by world size gives the expected result:
- analytic = torch.full_like(p, 2. * (world * (world + 1.) / 2.) / world, device=dev)
+ analytic = torch.full_like(
+ p, 2.0 * (world * (world + 1.0) / 2.0) / world, device=dev
+ )
for name, p in m.named_parameters():
- self.assertEqual(p.grad, analytic, "mismatch at " + name + ".grad for " +
- "set_to_none = {}, use_bucket_view = {}".format(try_set_to_none,
- use_bucket_view))
+ self.assertEqual(
+ p.grad,
+ analytic,
+ "mismatch at "
+ + name
+ + ".grad for "
+ + "set_to_none = {}, use_bucket_view = {}".format(
+ try_set_to_none, use_bucket_view
+ ),
+ )
# A list of tests for ddp with activation checkpointing
# when gradient_as_bucket_view=True, False.
@@ -1752,8 +1798,8 @@
input = torch.rand((bs, 20), device="cuda", requires_grad=True)
target = torch.randn((bs, 20), device="cuda")
offset = self.rank * ddp_bs
- ddp_input = input[offset: offset + ddp_bs]
- ddp_target = target[offset: offset + ddp_bs]
+ ddp_input = input[offset : offset + ddp_bs]
+ ddp_target = target[offset : offset + ddp_bs]
return input, ddp_input, target, ddp_target
def _train_model(self, model, input_var, target, loss, run_checkpoint=False):
@@ -1772,7 +1818,7 @@
use_bucket_view,
find_unused_parameters=False,
static_graph=False,
- run_checkpoint=False
+ run_checkpoint=False,
):
# to reprodce the same training results
torch.cuda.set_device(self.rank)
@@ -1785,18 +1831,22 @@
gradient_as_bucket_view=use_bucket_view,
device_ids=[self.rank],
process_group=process_group,
- find_unused_parameters=find_unused_parameters
+ find_unused_parameters=find_unused_parameters,
)
if static_graph:
ddp_model._set_static_graph()
- self.assertEqual(ddp_model._get_ddp_logging_data().get("static_graph", 0), static_graph)
+ self.assertEqual(
+ ddp_model._get_ddp_logging_data().get("static_graph", 0), static_graph
+ )
input, ddp_input, target, ddp_target = self._prepare_dummy_data()
loss = nn.MSELoss()
for i in range(5):
model.zero_grad(set_to_none=False)
ddp_model.zero_grad(set_to_none=False)
self._train_model(model, input, target, loss, run_checkpoint=run_checkpoint)
- self._train_model(ddp_model, ddp_input, ddp_target, loss, run_checkpoint=run_checkpoint)
+ self._train_model(
+ ddp_model, ddp_input, ddp_target, loss, run_checkpoint=run_checkpoint
+ )
for i, j in zip(model.parameters(), ddp_model.parameters()):
self.assertTrue(i.grad is not None)
self.assertTrue(j.grad is not None)
@@ -1809,10 +1859,12 @@
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
for use_bucket_view, static_graph in product((False, True), (False, True)):
- self._test_ddp_checkpointing(self.CheckpointOnceModule(),
- process_group=process_group,
- use_bucket_view=use_bucket_view,
- static_graph=static_graph)
+ self._test_ddp_checkpointing(
+ self.CheckpointOnceModule(),
+ process_group=process_group,
+ use_bucket_view=use_bucket_view,
+ static_graph=static_graph,
+ )
# DDP will fail when there are unused_parameters in the model
@requires_nccl()
@@ -1825,17 +1877,21 @@
RuntimeError,
"Expected to mark a variable ready only once.",
):
- model = self._test_ddp_checkpointing(self.CheckpointOnceModule(),
- process_group=process_group,
- use_bucket_view=use_bucket_view,
- find_unused_parameters=True,
- static_graph=False)
+ model = self._test_ddp_checkpointing(
+ self.CheckpointOnceModule(),
+ process_group=process_group,
+ use_bucket_view=use_bucket_view,
+ find_unused_parameters=True,
+ static_graph=False,
+ )
# test passes when static_graph is true
- model = self._test_ddp_checkpointing(self.CheckpointOnceModule(),
- process_group=process_group,
- use_bucket_view=use_bucket_view,
- find_unused_parameters=True,
- static_graph=True)
+ model = self._test_ddp_checkpointing(
+ self.CheckpointOnceModule(),
+ process_group=process_group,
+ use_bucket_view=use_bucket_view,
+ find_unused_parameters=True,
+ static_graph=True,
+ )
# DDP will fail when the same layer is checkponted twice
@requires_nccl()
@@ -1848,14 +1904,18 @@
RuntimeError,
"Expected to mark a variable ready only once.",
):
- model = self._test_ddp_checkpointing(self.CheckpointTwiceModule(),
- process_group=process_group,
- use_bucket_view=use_bucket_view,
- static_graph=False)
- model = self._test_ddp_checkpointing(self.CheckpointTwiceModule(),
- process_group=process_group,
- use_bucket_view=use_bucket_view,
- static_graph=True)
+ model = self._test_ddp_checkpointing(
+ self.CheckpointTwiceModule(),
+ process_group=process_group,
+ use_bucket_view=use_bucket_view,
+ static_graph=False,
+ )
+ model = self._test_ddp_checkpointing(
+ self.CheckpointTwiceModule(),
+ process_group=process_group,
+ use_bucket_view=use_bucket_view,
+ static_graph=True,
+ )
# DDP works as expected if there is weight sharing among layers
@requires_nccl()
@@ -1870,11 +1930,14 @@
l2 = nn.Linear(20, 20)
l1.weight = l2.weight
model = nn.Sequential(l1, l2)
- self._test_ddp_checkpointing(model,
- process_group=process_group,
- use_bucket_view=use_bucket_view,
- static_graph=static_graph,
- run_checkpoint=True)
+ self._test_ddp_checkpointing(
+ model,
+ process_group=process_group,
+ use_bucket_view=use_bucket_view,
+ static_graph=static_graph,
+ run_checkpoint=True,
+ )
+
@unittest.skipIf(
TEST_WITH_TSAN,
@@ -1926,7 +1989,9 @@
# Note: we unset and restore NCCL_ASYNC_ERROR_HANDLING for this test
# since test_c10d_common runs with async error handling by default, but this
# tests behavior when it is not enabled.
- prev_nccl_async_error_handling = os.environ.get("NCCL_ASYNC_ERROR_HANDLING", None)
+ prev_nccl_async_error_handling = os.environ.get(
+ "NCCL_ASYNC_ERROR_HANDLING", None
+ )
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
@@ -1999,7 +2064,9 @@
@requires_nccl_version(2400, "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm
- @unittest.skip("Frequently times out see https://github.com/pytorch/pytorch/issues/58920")
+ @unittest.skip(
+ "Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
+ )
def test_nccl_errors_blocking_abort(self):
self._test_nccl_errors_blocking(lambda: os.abort())
@@ -2096,7 +2163,6 @@
"TSAN is not fork-safe since we're forking in a multi-threaded environment",
)
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
-
def setUp(self):
super(CommTest, self).setUp()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
@@ -2192,7 +2258,7 @@
world_size=self.world_size,
rank=self.rank,
store=store,
- pg_options=pg_opts
+ pg_options=pg_opts,
)
# Test with new_group
@@ -2244,7 +2310,7 @@
store = c10d.FileStore(self.file_name, self.world_size)
if self.rank == 0:
with self.assertRaisesRegex(
- RuntimeError, "Timed out initializing process group"
+ RuntimeError, "Timed out initializing process group"
):
c10d.init_process_group(
backend="nccl",
@@ -2268,12 +2334,12 @@
if self.rank == 0:
with self.assertRaisesRegex(
- RuntimeError, "Timed out initializing process group"
+ RuntimeError, "Timed out initializing process group"
):
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
with self.assertRaisesRegex(
- RuntimeError, "Timed out initializing process group"
+ RuntimeError, "Timed out initializing process group"
):
c10d.new_group([0], timeout=timedelta(seconds=1))
@@ -2291,12 +2357,12 @@
if self.rank == 1:
with self.assertRaisesRegex(
- RuntimeError, "Timed out initializing process group"
+ RuntimeError, "Timed out initializing process group"
):
c10d.new_group([0, 1], timeout=timedelta(seconds=1))
with self.assertRaisesRegex(
- RuntimeError, "Timed out initializing process group"
+ RuntimeError, "Timed out initializing process group"
):
c10d.new_group([0], timeout=timedelta(seconds=1))
diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp
index 676ade2..eb0d914 100644
--- a/torch/csrc/distributed/c10d/reducer.cpp
+++ b/torch/csrc/distributed/c10d/reducer.cpp
@@ -38,7 +38,12 @@
} // namespace
-C10_DEFINE_TYPED_REGISTRY(TimerRegistry, c10::DeviceType, Timer, std::unique_ptr, c10::Device);
+C10_DEFINE_TYPED_REGISTRY( // NOLINT
+ TimerRegistry,
+ c10::DeviceType,
+ Timer,
+ std::unique_ptr,
+ c10::Device);
namespace {
@@ -391,8 +396,10 @@
// previous iterations, no copy is needed.
if (!grad.is_alias_of(bucket_view)) {
if (comm_hook_ == nullptr) {
- auto wrapped = at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
- // Divides while copying into the bucket view to save one scan over all the input parameters.
+ auto wrapped =
+ at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
+ // Divides while copying into the bucket view to save one scan over
+ // all the input parameters.
at::mul_out(bucket_view, grad, wrapped);
} else {
bucket_view.copy_(grad);
@@ -440,7 +447,8 @@
// Directly assign the sparse tensor to the `contents` field.
replica.contents = grad;
// If no DDP comm hook is registered,
- // the allreduce only sums up the value, and a separate division is required.
+ // the allreduce only sums up the value, and a separate division is
+ // required.
if (comm_hook_ == nullptr) {
replica.contents.div_(div_factor_);
}
@@ -458,14 +466,13 @@
auto& bucket = buckets_[i];
auto variables_for_bucket = get_variables_for_bucket(i, bucket);
gradBuckets.emplace_back(
- i,
- return_zero_tensors ? at::zeros_like(bucket.replicas[0].contents)
+ i,
+ return_zero_tensors ? at::zeros_like(bucket.replicas[0].contents)
: bucket.replicas[0].contents,
- bucket.replicas[0].offsets,
- bucket.replicas[0].lengths,
- bucket.replicas[0].sizes_vec,
- variables_for_bucket
- );
+ bucket.replicas[0].offsets,
+ bucket.replicas[0].lengths,
+ bucket.replicas[0].sizes_vec,
+ variables_for_bucket);
}
return gradBuckets;
}
@@ -874,7 +881,7 @@
if (has_rebuilt_bucket_ &&
cached_variables_for_bucket_.find(bucket_index) !=
cached_variables_for_bucket_.end()) {
- return cached_variables_for_bucket_[bucket_index];
+ return cached_variables_for_bucket_[bucket_index];
}
std::vector<at::Tensor> variables_for_bucket;
variables_for_bucket.reserve(bucket.variable_indices.size());
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
index d0f583e..fd10c57 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -15,6 +15,7 @@
return dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future()
+
def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future: