Revert D29701479: [pytorch][PR] Remove `_broadcast_object()` from `ZeroRedundancyOptimizer`
Test Plan: revert-hammer
Differential Revision:
D29701479 (https://github.com/pytorch/pytorch/commit/9b5d9b404927d4438d449a36749fbe62273fdde4)
Original commit changeset: c8d5f9057b32
fbshipit-source-id: 35ab1f399513fb9d1c4e73b1fa906e559d2a6994
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index 4e1864a..574314a 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -19,6 +19,7 @@
sys.exit(0)
from torch.distributed.algorithms.join import _Join, _Joinable, _JoinHook
from torch.distributed.optim import ZeroRedundancyOptimizer
+from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from torch.testing._internal import common_distributed, common_utils
@@ -457,15 +458,12 @@
else:
optimizer_state_dict = {}
- optimizer_state_dict_list = [optimizer_state_dict]
- with torch.cuda.device(self.device):
- dist.broadcast_object_list(
- optimizer_state_dict_list,
- src=RECIPIENT_RANK,
- group=dist.group.WORLD,
- map_location=self.device
- )
- optimizer_state_dict = optimizer_state_dict_list[0]
+ optimizer_state_dict = _broadcast_object(
+ optimizer_state_dict,
+ src_rank=RECIPIENT_RANK,
+ group=dist.group.WORLD,
+ device=self.device,
+ )
# Load the optimizer state dict, check that no exception is raised
optimizer.load_state_dict(optimizer_state_dict)
@@ -709,14 +707,12 @@
# Broadcast the saved gradients and parameters to all of the other
# ranks (which joined early)
grads_and_params = [grads_at_each_iter, params_at_each_iter]
- dist.broadcast_object_list(
- grads_and_params,
- src=world_size - 1,
- group=dist.group.WORLD,
- map_location=device
- )
+ grads_and_params = _broadcast_object(grads_and_params, src_rank=world_size - 1, group=dist.group.WORLD, device=device)
grads_at_each_iter = grads_and_params[0]
params_at_each_iter = grads_and_params[1]
+ # TODO: Replace this `_broadcast_object` with `broadcast_object_list`
+ # once the latter supports loading to the destination device instead
+ # of the source device
# A process must still set the remaining gradients after joining, so we
# define a join hook to do this before the ZeRO join hook
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 688e368..ec10ac4 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1494,16 +1494,16 @@
def _object_to_tensor(obj):
f = io.BytesIO()
- torch.save(obj, f)
+ _pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long)
return byte_tensor, local_size
-def _tensor_to_object(tensor, tensor_size, map_location=None):
+def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
- return torch.load(io.BytesIO(buf), map_location=map_location)
+ return _unpickler(io.BytesIO(buf)).load()
def all_gather_object(object_list, obj, group=None):
@@ -1611,9 +1611,8 @@
collective and will contain the output. Must be ``None`` on non-dst
ranks. (default is ``None``)
dst (int, optional): Destination rank. (default is 0)
- group: (ProcessGroup, optional): The process group to work on. If
- ``None``, the default process group will be used. Default is
- ``None``.
+ group: (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used. Default is ``None``.
Returns:
None. On the ``dst`` rank, ``object_gather_list`` will contain the
@@ -1701,7 +1700,7 @@
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
-def broadcast_object_list(object_list, src=0, group=None, device=None, map_location=None):
+def broadcast_object_list(object_list, src=0, group=None, device=None):
"""
Broadcasts picklable objects in ``object_list`` to the whole group. Similar
to :func:`broadcast`, but Python objects can be passed in.
@@ -1715,13 +1714,9 @@
src (int): Source rank from which to broadcast ``object_list``.
group: (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.
- device (torch.device, optional): If not ``None``, the objects are
+ device (``torch.device``, optional): If not None, the objects are
serialized and converted to tensors which are moved to the
``device`` before broadcasting. Default is ``None``.
- map_location (torch.device, optional): The device to load tensors
- contained in the received objects; this argument does not affect
- the source rank. If ``None``, the tensors are loaded to the device
- they were on when passed into this function. Default is ``None``.
Returns:
``None``. If rank is part of the group, ``object_list`` will contain the
@@ -1816,8 +1811,7 @@
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size
- # Deserialize contained tensors directly to `map_location`
- object_list[i] = _tensor_to_object(obj_view, obj_size, map_location)
+ object_list[i] = _tensor_to_object(obj_view, obj_size)
def scatter_object_list(
diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py
index 398068d..e55dc7e 100644
--- a/torch/distributed/optim/zero_redundancy_optimizer.py
+++ b/torch/distributed/optim/zero_redundancy_optimizer.py
@@ -4,8 +4,8 @@
# LICENSE file in the root directory of this source tree.
import collections
-import contextlib
import copy
+import io
import logging
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Type
@@ -55,6 +55,46 @@
return param.requires_grad
+def _broadcast_object(
+ obj: Any, src_rank: int,
+ group: object = dist.group.WORLD,
+ device: torch.device = torch.device("cpu")
+) -> Any:
+ r"""
+ Broadcasts an object to the given group, sending the object if called from
+ the source rank and receiving the object otherwise.
+
+ Arguments:
+ obj: object to broadcast; only used if called on the source rank.
+ src_rank (int): source rank.
+ group (``ProcessGroup``, optional): group used for the broadcast
+ (default: ``dist.group.WORLD``).
+ device (``torch.device``, optional): device to send from or receive
+ to (default: ``torch.device("cpu")``).
+
+ Returns:
+ The broadcasted object.
+ """
+ if dist.get_rank() == src_rank:
+ # Send the object
+ buffer = io.BytesIO()
+ torch.save(obj, buffer)
+ data = bytearray(buffer.getbuffer())
+ length_tensor = torch.LongTensor([len(data)]).to(device)
+ data_send_tensor = torch.ByteTensor(data).to(device)
+ dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
+ dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
+ else:
+ # Receive the object
+ length_tensor = torch.LongTensor([0]).to(device)
+ dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
+ data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device)
+ dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
+ buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
+ obj = torch.load(buffer, map_location=device)
+ return obj
+
+
def _get_global_rank(group: Any, rank: int) -> int:
r"""
Returns the global rank for the given group and rank.
@@ -259,56 +299,51 @@
self._sync_param_groups(self.param_groups, self.optim.param_groups)
# Pull the sharded state from all ranks and store them in rank order
+ empty_messenger = torch.tensor([0], dtype=torch.uint8, device=self._default_device)
+
# NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
# due to compatibility issues with NCCL backend; a possible follow-up
# is to move all sharded state management to RPC RRef
self._all_state_dicts = []
- # Set `map_location` to CPU to save one GPU -> CPU transfer
- map_location = torch.device("cpu")
- is_gpu_device = self._default_device.type == "cuda"
- # TODO: Once the minimum supported version is Python 3.7, replace
- # `contextlib.suppress()` with `contextlib.nullcontext()`
- with torch.cuda.device(self._default_device) if is_gpu_device else contextlib.suppress():
- for rank in range(self.world_size):
- global_rank = _get_global_rank(self.process_group, rank)
- if self.rank == to:
- # Consolidate all local `state_dict`s on this rank, storing on
- # CPU to save GPU memory
- if rank == self.rank:
- # Directly append own optimizer state
- self._all_state_dicts.append(
- _recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
- )
- else:
- # Receive the optimizer state from the source rank
- local_state_dict_list = [None]
- dist.broadcast_object_list(
- local_state_dict_list,
- src=global_rank,
- group=self.process_group,
- map_location=map_location
- )
- local_state_dict = local_state_dict_list[0]
- self._all_state_dicts.append(
- _recursive_copy_to_device(local_state_dict, non_blocking=True, device=torch.device("cpu"))
- )
+ for rank in range(self.world_size):
+ global_rank = _get_global_rank(self.process_group, rank)
+ if self.rank == to:
+ # Consolidate all local `state_dict`s on this rank, storing on
+ # CPU to save GPU memory
+ if rank == self.rank:
+ # Directly append own optimizer state
+ self._all_state_dicts.append(
+ _recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"),)
+ )
else:
- if rank == self.rank:
- # Send the optimizer state to the target rank
- dist.broadcast_object_list(
- [self.optim.state_dict()],
- src=self.global_rank,
- group=self.process_group
- )
- elif rank != to:
- # Discard the received object; `broadcast()` is used for
- # compatibility reasons
- dist.broadcast_object_list(
- [None],
- src=global_rank,
- group=self.process_group,
- map_location=map_location
- )
+ # Receive the optimizer state from the source rank
+ local_state_dict = _broadcast_object(
+ empty_messenger,
+ src_rank=global_rank,
+ group=self.process_group,
+ device=self._default_device,
+ )
+ self._all_state_dicts.append(
+ _recursive_copy_to_device(local_state_dict, non_blocking=True, device=torch.device("cpu"))
+ )
+ else:
+ if rank == self.rank:
+ # Send the optimizer state to the target rank
+ _ = _broadcast_object(
+ self.optim.state_dict(),
+ src_rank=self.global_rank,
+ group=self.process_group,
+ device=self._default_device,
+ )
+ elif rank != to:
+ # Discard the received object; `broadcast()` is used for
+ # compatibility reasons
+ _ = _broadcast_object(
+ empty_messenger,
+ src_rank=global_rank,
+ group=self.process_group,
+ device=self._default_device,
+ )
def _partition_parameters(self) -> List[List[Dict]]:
r"""
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 3ec6ce3..db95ab7 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -1,4 +1,3 @@
-import collections
import copy
import itertools
import math
@@ -12,13 +11,13 @@
from contextlib import contextmanager, suppress
from datetime import timedelta
from functools import reduce
-from typing import Any, Callable, NamedTuple, Union
+from typing import Union, NamedTuple, Callable, Any
import torch
import torch.cuda
import torch.distributed as dist
-import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
+import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
import torch.nn as nn
@@ -26,43 +25,41 @@
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
from torch.cuda.amp import GradScaler, autocast
-from torch.distributed.algorithms.ddp_comm_hooks import (
- default_hooks as default,
-)
+from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
from torch.distributed.algorithms.ddp_comm_hooks import (
quantization as quantization_hooks,
)
from torch.distributed.distributed_c10d import (
+ get_world_size,
+ _get_default_group,
AllreduceOptions,
GroupMember,
- _get_default_group,
- get_world_size,
)
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
from torch.testing._internal.common_distributed import (
- TEST_SKIPS,
MultiProcessTestCase,
- captured_output,
- cleanup_temp_dir,
+ TEST_SKIPS,
initialize_temp_directories,
- nccl_skip_if_lt_x_gpu,
- require_n_gpus_for_nccl_backend,
- requires_nccl_version,
+ cleanup_temp_dir,
simple_sparse_reduce_tests,
- skip_if_lt_x_gpu,
- skip_if_no_gpu,
skip_if_rocm,
skip_if_small_worldsize,
- verify_ddp_error_logged,
- with_dist_debug_levels,
+ skip_if_lt_x_gpu,
+ nccl_skip_if_lt_x_gpu,
+ skip_if_no_gpu,
+ require_n_gpus_for_nccl_backend,
+ requires_nccl_version,
+ captured_output,
with_nccl_blocking_wait,
+ with_dist_debug_levels,
+ verify_ddp_error_logged,
)
from torch.testing._internal.common_utils import (
- FILE_SCHEMA,
- IS_FBCODE,
IS_MACOS,
IS_WINDOWS,
+ FILE_SCHEMA,
+ IS_FBCODE,
NO_MULTIPROCESSING_SPAWN,
)
from torch.utils.data.distributed import DistributedSampler
@@ -5870,83 +5867,15 @@
single_obj_list = [objects[0]]
if self.rank != src_rank:
self.assertNotEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
- dist.broadcast_object_list(single_obj_list, src=src_rank)
+ dist.broadcast_object_list(single_obj_list, src=0)
self.assertEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
# Multiple input objects test
if self.rank != src_rank:
self.assertNotEqual(objects, COLLECTIVES_OBJECT_TEST_LIST)
- dist.broadcast_object_list(objects, src=src_rank)
+ dist.broadcast_object_list(objects, src=0)
self.assertEqual(objects, COLLECTIVES_OBJECT_TEST_LIST)
- @require_backend({"nccl", "gloo"})
- @require_n_gpus_for_nccl_backend(
- int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
- )
- def test_broadcast_object_list_map_location(self):
- # Test the `map_location` argument
- backend = os.environ["BACKEND"]
- if backend == "nccl":
- torch.cuda.set_device(self.rank)
-
- def _check_tensor_map_location(obj, device):
- # Checks that all tensors contained in `obj` are on `device`
- # Does not account for tensors stored as fields of a class
- if isinstance(obj, torch.Tensor):
- return obj.device == device
- if isinstance(obj, (list, tuple)):
- elems_match = [_check_tensor_map_location(e, device) for e in obj]
- return all(elems_match)
- if isinstance(obj, collections.abc.Mapping):
- values_match = [_check_tensor_map_location(v, device) for _, v in obj.items()]
- return all(values_match)
- return False
-
- def _copy_to_device(obj, device):
- # Copies `obj` to `device`
- if isinstance(obj, torch.Tensor):
- return obj.to(device)
- if isinstance(obj, (list, tuple)):
- elems = [_copy_to_device(elem, device=device) for elem in obj]
- return elems if isinstance(obj, list) else tuple(elems)
- if isinstance(obj, collections.abc.Mapping):
- return {k: _copy_to_device(v, device=device) for k, v in obj.items()}
- return obj
-
- src_rank = 0
- device = torch.device(f"cuda:{src_rank}") if backend == "nccl" else torch.device("cpu")
- source_objects = [
- torch.ones(1, device=device),
- {"key": torch.ones(1, device=device)},
- [torch.ones(1, device=device)],
- ]
- objects = source_objects if self.rank == src_rank else [None for _ in source_objects]
-
- # `map_location` as CPU test
- map_location = torch.device("cpu")
- if self.rank != src_rank:
- self.assertNotEqual(objects, source_objects)
- dist.broadcast_object_list(objects, src=src_rank, map_location=map_location)
- if self.rank != src_rank:
- self.assertTrue(_check_tensor_map_location(objects, map_location))
- self.assertEqual(_copy_to_device(source_objects, map_location), objects)
- else:
- self.assertEqual(source_objects, objects)
-
- # `map_location` as GPU test
- if not torch.cuda.is_available():
- return
- map_location = torch.device("cuda:0")
- objects = source_objects if self.rank == src_rank else [None for _ in source_objects]
- if self.rank != src_rank:
- self.assertNotEqual(objects, source_objects)
- dist.broadcast_object_list(objects, src=src_rank, map_location=map_location)
- if self.rank != src_rank:
- self.assertTrue(_check_tensor_map_location(objects, map_location))
- self.assertEqual(_copy_to_device(source_objects, map_location), objects)
- else:
- self.assertEqual(source_objects, objects)
-
def _test_ddp_ignore_params_arg(self, static_graph=False):
class TestModel(nn.Module):
def __init__(self, rank):