Revert D29701479: [pytorch][PR] Remove `_broadcast_object()` from `ZeroRedundancyOptimizer`

Test Plan: revert-hammer

Differential Revision:
D29701479 (https://github.com/pytorch/pytorch/commit/9b5d9b404927d4438d449a36749fbe62273fdde4)

Original commit changeset: c8d5f9057b32

fbshipit-source-id: 35ab1f399513fb9d1c4e73b1fa906e559d2a6994
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index 4e1864a..574314a 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -19,6 +19,7 @@
     sys.exit(0)
 from torch.distributed.algorithms.join import _Join, _Joinable, _JoinHook
 from torch.distributed.optim import ZeroRedundancyOptimizer
+from torch.distributed.optim.zero_redundancy_optimizer import _broadcast_object
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.optim import SGD
 from torch.testing._internal import common_distributed, common_utils
@@ -457,15 +458,12 @@
         else:
             optimizer_state_dict = {}
 
-        optimizer_state_dict_list = [optimizer_state_dict]
-        with torch.cuda.device(self.device):
-            dist.broadcast_object_list(
-                optimizer_state_dict_list,
-                src=RECIPIENT_RANK,
-                group=dist.group.WORLD,
-                map_location=self.device
-            )
-        optimizer_state_dict = optimizer_state_dict_list[0]
+        optimizer_state_dict = _broadcast_object(
+            optimizer_state_dict,
+            src_rank=RECIPIENT_RANK,
+            group=dist.group.WORLD,
+            device=self.device,
+        )
 
         # Load the optimizer state dict, check that no exception is raised
         optimizer.load_state_dict(optimizer_state_dict)
@@ -709,14 +707,12 @@
         # Broadcast the saved gradients and parameters to all of the other
         # ranks (which joined early)
         grads_and_params = [grads_at_each_iter, params_at_each_iter]
-        dist.broadcast_object_list(
-            grads_and_params,
-            src=world_size - 1,
-            group=dist.group.WORLD,
-            map_location=device
-        )
+        grads_and_params = _broadcast_object(grads_and_params, src_rank=world_size - 1, group=dist.group.WORLD, device=device)
         grads_at_each_iter = grads_and_params[0]
         params_at_each_iter = grads_and_params[1]
+        # TODO: Replace this `_broadcast_object` with `broadcast_object_list`
+        # once the latter supports loading to the destination device instead
+        # of the source device
 
         # A process must still set the remaining gradients after joining, so we
         # define a join hook to do this before the ZeRO join hook
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 688e368..ec10ac4 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1494,16 +1494,16 @@
 
 def _object_to_tensor(obj):
     f = io.BytesIO()
-    torch.save(obj, f)
+    _pickler(f).dump(obj)
     byte_storage = torch.ByteStorage.from_buffer(f.getvalue())  # type: ignore[attr-defined]
     byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
     local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long)
     return byte_tensor, local_size
 
 
-def _tensor_to_object(tensor, tensor_size, map_location=None):
+def _tensor_to_object(tensor, tensor_size):
     buf = tensor.numpy().tobytes()[:tensor_size]
-    return torch.load(io.BytesIO(buf), map_location=map_location)
+    return _unpickler(io.BytesIO(buf)).load()
 
 
 def all_gather_object(object_list, obj, group=None):
@@ -1611,9 +1611,8 @@
             collective and will contain the output. Must be ``None`` on non-dst
             ranks. (default is ``None``)
         dst (int, optional): Destination rank. (default is 0)
-        group: (ProcessGroup, optional): The process group to work on. If
-            ``None``, the default process group will be used. Default is
-            ``None``.
+        group: (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used. Default is ``None``.
 
     Returns:
         None. On the ``dst`` rank, ``object_gather_list`` will contain the
@@ -1701,7 +1700,7 @@
         object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
 
 
-def broadcast_object_list(object_list, src=0, group=None, device=None, map_location=None):
+def broadcast_object_list(object_list, src=0, group=None, device=None):
     """
     Broadcasts picklable objects in ``object_list`` to the whole group. Similar
     to :func:`broadcast`, but Python objects can be passed in.
@@ -1715,13 +1714,9 @@
         src (int): Source rank from which to broadcast ``object_list``.
         group: (ProcessGroup, optional): The process group to work on. If None,
             the default process group will be used. Default is ``None``.
-        device (torch.device, optional): If not ``None``, the objects are
+        device (``torch.device``, optional): If not None, the objects are
             serialized and converted to tensors which are moved to the
             ``device`` before broadcasting. Default is ``None``.
-        map_location (torch.device, optional): The device to load tensors
-            contained in the received objects; this argument does not affect
-            the source rank. If ``None``, the tensors are loaded to the device
-            they were on when passed into this function. Default is ``None``.
 
     Returns:
         ``None``. If rank is part of the group, ``object_list`` will contain the
@@ -1816,8 +1811,7 @@
             if obj_view.device != torch.device("cpu"):
                 obj_view = obj_view.cpu()
             offset += obj_size
-            # Deserialize contained tensors directly to `map_location`
-            object_list[i] = _tensor_to_object(obj_view, obj_size, map_location)
+            object_list[i] = _tensor_to_object(obj_view, obj_size)
 
 
 def scatter_object_list(
diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py
index 398068d..e55dc7e 100644
--- a/torch/distributed/optim/zero_redundancy_optimizer.py
+++ b/torch/distributed/optim/zero_redundancy_optimizer.py
@@ -4,8 +4,8 @@
 # LICENSE file in the root directory of this source tree.
 
 import collections
-import contextlib
 import copy
+import io
 import logging
 from itertools import chain
 from typing import Any, Callable, Dict, List, Optional, Type
@@ -55,6 +55,46 @@
     return param.requires_grad
 
 
+def _broadcast_object(
+    obj: Any, src_rank: int,
+    group: object = dist.group.WORLD,
+    device: torch.device = torch.device("cpu")
+) -> Any:
+    r"""
+    Broadcasts an object to the given group, sending the object if called from
+    the source rank and receiving the object otherwise.
+
+    Arguments:
+        obj: object to broadcast; only used if called on the source rank.
+        src_rank (int): source rank.
+        group (``ProcessGroup``, optional): group used for the broadcast
+            (default: ``dist.group.WORLD``).
+        device (``torch.device``, optional): device to send from or receive
+            to (default: ``torch.device("cpu")``).
+
+    Returns:
+        The broadcasted object.
+    """
+    if dist.get_rank() == src_rank:
+        # Send the object
+        buffer = io.BytesIO()
+        torch.save(obj, buffer)
+        data = bytearray(buffer.getbuffer())
+        length_tensor = torch.LongTensor([len(data)]).to(device)
+        data_send_tensor = torch.ByteTensor(data).to(device)
+        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
+        dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
+    else:
+        # Receive the object
+        length_tensor = torch.LongTensor([0]).to(device)
+        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
+        data_recv_tensor = torch.empty([int(length_tensor.item())], dtype=torch.uint8, device=device)
+        dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
+        buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
+        obj = torch.load(buffer, map_location=device)
+    return obj
+
+
 def _get_global_rank(group: Any, rank: int) -> int:
     r"""
     Returns the global rank for the given group and rank.
@@ -259,56 +299,51 @@
         self._sync_param_groups(self.param_groups, self.optim.param_groups)
 
         # Pull the sharded state from all ranks and store them in rank order
+        empty_messenger = torch.tensor([0], dtype=torch.uint8, device=self._default_device)
+
         # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
         # due to compatibility issues with NCCL backend; a possible follow-up
         # is to move all sharded state management to RPC RRef
         self._all_state_dicts = []
-        # Set `map_location` to CPU to save one GPU -> CPU transfer
-        map_location = torch.device("cpu")
-        is_gpu_device = self._default_device.type == "cuda"
-        # TODO: Once the minimum supported version is Python 3.7, replace
-        # `contextlib.suppress()` with `contextlib.nullcontext()`
-        with torch.cuda.device(self._default_device) if is_gpu_device else contextlib.suppress():
-            for rank in range(self.world_size):
-                global_rank = _get_global_rank(self.process_group, rank)
-                if self.rank == to:
-                    # Consolidate all local `state_dict`s on this rank, storing on
-                    # CPU to save GPU memory
-                    if rank == self.rank:
-                        # Directly append own optimizer state
-                        self._all_state_dicts.append(
-                            _recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"))
-                        )
-                    else:
-                        # Receive the optimizer state from the source rank
-                        local_state_dict_list = [None]
-                        dist.broadcast_object_list(
-                            local_state_dict_list,
-                            src=global_rank,
-                            group=self.process_group,
-                            map_location=map_location
-                        )
-                        local_state_dict = local_state_dict_list[0]
-                        self._all_state_dicts.append(
-                            _recursive_copy_to_device(local_state_dict, non_blocking=True, device=torch.device("cpu"))
-                        )
+        for rank in range(self.world_size):
+            global_rank = _get_global_rank(self.process_group, rank)
+            if self.rank == to:
+                # Consolidate all local `state_dict`s on this rank, storing on
+                # CPU to save GPU memory
+                if rank == self.rank:
+                    # Directly append own optimizer state
+                    self._all_state_dicts.append(
+                        _recursive_copy_to_device(self.optim.state_dict(), non_blocking=True, device=torch.device("cpu"),)
+                    )
                 else:
-                    if rank == self.rank:
-                        # Send the optimizer state to the target rank
-                        dist.broadcast_object_list(
-                            [self.optim.state_dict()],
-                            src=self.global_rank,
-                            group=self.process_group
-                        )
-                    elif rank != to:
-                        # Discard the received object; `broadcast()` is used for
-                        # compatibility reasons
-                        dist.broadcast_object_list(
-                            [None],
-                            src=global_rank,
-                            group=self.process_group,
-                            map_location=map_location
-                        )
+                    # Receive the optimizer state from the source rank
+                    local_state_dict = _broadcast_object(
+                        empty_messenger,
+                        src_rank=global_rank,
+                        group=self.process_group,
+                        device=self._default_device,
+                    )
+                    self._all_state_dicts.append(
+                        _recursive_copy_to_device(local_state_dict, non_blocking=True, device=torch.device("cpu"))
+                    )
+            else:
+                if rank == self.rank:
+                    # Send the optimizer state to the target rank
+                    _ = _broadcast_object(
+                        self.optim.state_dict(),
+                        src_rank=self.global_rank,
+                        group=self.process_group,
+                        device=self._default_device,
+                    )
+                elif rank != to:
+                    # Discard the received object; `broadcast()` is used for
+                    # compatibility reasons
+                    _ = _broadcast_object(
+                        empty_messenger,
+                        src_rank=global_rank,
+                        group=self.process_group,
+                        device=self._default_device,
+                    )
 
     def _partition_parameters(self) -> List[List[Dict]]:
         r"""
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 3ec6ce3..db95ab7 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -1,4 +1,3 @@
-import collections
 import copy
 import itertools
 import math
@@ -12,13 +11,13 @@
 from contextlib import contextmanager, suppress
 from datetime import timedelta
 from functools import reduce
-from typing import Any, Callable, NamedTuple, Union
+from typing import Union, NamedTuple, Callable, Any
 
 import torch
 import torch.cuda
 import torch.distributed as dist
-import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
 import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
+import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
 import torch.distributed.algorithms.model_averaging.averagers as averagers
 import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
 import torch.nn as nn
@@ -26,43 +25,41 @@
 from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
 from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
 from torch.cuda.amp import GradScaler, autocast
-from torch.distributed.algorithms.ddp_comm_hooks import (
-    default_hooks as default,
-)
+from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default
 from torch.distributed.algorithms.ddp_comm_hooks import (
     quantization as quantization_hooks,
 )
 from torch.distributed.distributed_c10d import (
+    get_world_size,
+    _get_default_group,
     AllreduceOptions,
     GroupMember,
-    _get_default_group,
-    get_world_size,
 )
 from torch.nn.parallel import DistributedDataParallel
 from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
 from torch.testing._internal.common_distributed import (
-    TEST_SKIPS,
     MultiProcessTestCase,
-    captured_output,
-    cleanup_temp_dir,
+    TEST_SKIPS,
     initialize_temp_directories,
-    nccl_skip_if_lt_x_gpu,
-    require_n_gpus_for_nccl_backend,
-    requires_nccl_version,
+    cleanup_temp_dir,
     simple_sparse_reduce_tests,
-    skip_if_lt_x_gpu,
-    skip_if_no_gpu,
     skip_if_rocm,
     skip_if_small_worldsize,
-    verify_ddp_error_logged,
-    with_dist_debug_levels,
+    skip_if_lt_x_gpu,
+    nccl_skip_if_lt_x_gpu,
+    skip_if_no_gpu,
+    require_n_gpus_for_nccl_backend,
+    requires_nccl_version,
+    captured_output,
     with_nccl_blocking_wait,
+    with_dist_debug_levels,
+    verify_ddp_error_logged,
 )
 from torch.testing._internal.common_utils import (
-    FILE_SCHEMA,
-    IS_FBCODE,
     IS_MACOS,
     IS_WINDOWS,
+    FILE_SCHEMA,
+    IS_FBCODE,
     NO_MULTIPROCESSING_SPAWN,
 )
 from torch.utils.data.distributed import DistributedSampler
@@ -5870,83 +5867,15 @@
             single_obj_list = [objects[0]]
             if self.rank != src_rank:
                 self.assertNotEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
-            dist.broadcast_object_list(single_obj_list, src=src_rank)
+            dist.broadcast_object_list(single_obj_list, src=0)
             self.assertEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
 
             # Multiple input objects test
             if self.rank != src_rank:
                 self.assertNotEqual(objects, COLLECTIVES_OBJECT_TEST_LIST)
-            dist.broadcast_object_list(objects, src=src_rank)
+            dist.broadcast_object_list(objects, src=0)
             self.assertEqual(objects, COLLECTIVES_OBJECT_TEST_LIST)
 
-        @require_backend({"nccl", "gloo"})
-        @require_n_gpus_for_nccl_backend(
-            int(os.environ["WORLD_SIZE"]), os.environ["BACKEND"]
-        )
-        def test_broadcast_object_list_map_location(self):
-            # Test the `map_location` argument
-            backend = os.environ["BACKEND"]
-            if backend == "nccl":
-                torch.cuda.set_device(self.rank)
-
-            def _check_tensor_map_location(obj, device):
-                # Checks that all tensors contained in `obj` are on `device`
-                # Does not account for tensors stored as fields of a class
-                if isinstance(obj, torch.Tensor):
-                    return obj.device == device
-                if isinstance(obj, (list, tuple)):
-                    elems_match = [_check_tensor_map_location(e, device) for e in obj]
-                    return all(elems_match)
-                if isinstance(obj, collections.abc.Mapping):
-                    values_match = [_check_tensor_map_location(v, device) for _, v in obj.items()]
-                    return all(values_match)
-                return False
-
-            def _copy_to_device(obj, device):
-                # Copies `obj` to `device`
-                if isinstance(obj, torch.Tensor):
-                    return obj.to(device)
-                if isinstance(obj, (list, tuple)):
-                    elems = [_copy_to_device(elem, device=device) for elem in obj]
-                    return elems if isinstance(obj, list) else tuple(elems)
-                if isinstance(obj, collections.abc.Mapping):
-                    return {k: _copy_to_device(v, device=device) for k, v in obj.items()}
-                return obj
-
-            src_rank = 0
-            device = torch.device(f"cuda:{src_rank}") if backend == "nccl" else torch.device("cpu")
-            source_objects = [
-                torch.ones(1, device=device),
-                {"key": torch.ones(1, device=device)},
-                [torch.ones(1, device=device)],
-            ]
-            objects = source_objects if self.rank == src_rank else [None for _ in source_objects]
-
-            # `map_location` as CPU test
-            map_location = torch.device("cpu")
-            if self.rank != src_rank:
-                self.assertNotEqual(objects, source_objects)
-            dist.broadcast_object_list(objects, src=src_rank, map_location=map_location)
-            if self.rank != src_rank:
-                self.assertTrue(_check_tensor_map_location(objects, map_location))
-                self.assertEqual(_copy_to_device(source_objects, map_location), objects)
-            else:
-                self.assertEqual(source_objects, objects)
-
-            # `map_location` as GPU test
-            if not torch.cuda.is_available():
-                return
-            map_location = torch.device("cuda:0")
-            objects = source_objects if self.rank == src_rank else [None for _ in source_objects]
-            if self.rank != src_rank:
-                self.assertNotEqual(objects, source_objects)
-            dist.broadcast_object_list(objects, src=src_rank, map_location=map_location)
-            if self.rank != src_rank:
-                self.assertTrue(_check_tensor_map_location(objects, map_location))
-                self.assertEqual(_copy_to_device(source_objects, map_location), objects)
-            else:
-                self.assertEqual(source_objects, objects)
-
         def _test_ddp_ignore_params_arg(self, static_graph=False):
             class TestModel(nn.Module):
                 def __init__(self, rank):