[Gradient Compression] Make GradBucket class public (#53099)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53099
Publish GradBucket APIs for publishing DDP communication hooks.
s/_GradBucket/GradBucket
ghstack-source-id: 123030921
Test Plan: waitforbuildbot
Reviewed By: rohan-varma
Differential Revision: D26721121
fbshipit-source-id: ee5f68e33095b9965b51937b86cdeb331fd2419a
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index 7d450ca..e8c48ba 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -360,7 +360,7 @@
client_store = dist.TCPStore(addr, port, world_size, timeout=timedelta(seconds=10))
self.assertEqual("value".encode(), client_store.get("key"))
client_store.set(f"new_key{index}", f"new_value{index}")
- self.assertEqual(f"next_value{index}".encode(),
+ self.assertEqual(f"next_value{index}".encode(),
client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
except Exception:
messages.put('Caught exception: \n{}exiting process with exit code: {}'
@@ -3057,7 +3057,7 @@
"""
def allreduce_hook(
- process_group: object, bucket: dist._GradBucket
+ process_group: object, bucket: dist.GradBucket
) -> torch._C.Future:
tensors = [t / self.world_size for t in bucket.get_tensors()]
return process_group.allreduce(tensors).get_future()
@@ -3077,7 +3077,7 @@
"""
def allreduce_with_then_hook(
- process_group: object, bucket: dist._GradBucket
+ process_group: object, bucket: dist.GradBucket
) -> torch.futures.Future:
fut = process_group.allreduce(bucket.get_tensors()).get_future()
@@ -3727,7 +3727,7 @@
[self.assertEqual(p.grad, expected_grad) for p in model.parameters()]
def _simple_hook(
- self, state: object, bucket: dist._GradBucket
+ self, state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
fut = torch.futures.Future()
fut.set_result([torch.ones_like(t) for t in bucket.get_tensors()])
@@ -3782,7 +3782,7 @@
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
- def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
+ def allreduce_hook(state: object, bucket: dist.GradBucket) -> torch._C.Future:
tensors = [t / self.world_size for t in bucket.get_tensors()]
return process_group.allreduce(tensors).get_future()
@@ -3930,7 +3930,7 @@
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
def allreduce_with_then_hook(
- state: object, bucket: dist._GradBucket
+ state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
tensors = [t / self.world_size for t in bucket.get_tensors()]
fut = process_group.allreduce(tensors).get_future()
@@ -3972,7 +3972,7 @@
model.register_comm_hook(state=None, hook=1)
with self.assertRaisesRegex(
- ValueError, "bucket annotation should be dist._GradBucket."
+ ValueError, "bucket annotation should be dist.GradBucket."
):
def comm_hook(state: object, bucket: int) -> torch.futures.Future:
@@ -3999,7 +3999,7 @@
"Communication hook: return annotation should be torch.futures.Future or torch._C.Future.",
):
- def comm_hook(state: object, bucket: dist._GradBucket) -> int:
+ def comm_hook(state: object, bucket: dist.GradBucket) -> int:
return torch.futures.Future()
model.register_comm_hook(state=None, hook=comm_hook)
@@ -4009,7 +4009,7 @@
"callback must return a torch.futures.Future or torch._C.Future object, but got",
):
- def comm_hook(state: object, bucket: dist._GradBucket):
+ def comm_hook(state: object, bucket: dist.GradBucket):
return 1
model.register_comm_hook(state=None, hook=comm_hook)
@@ -4067,7 +4067,7 @@
# "get_future" API does not support gloo backend, see GH Issue #42048.
# Instead, we wait for an allreduce work, and write its result to a Future.
def allreduce_hook_gloo(
- state: object, bucket: dist._GradBucket
+ state: object, bucket: dist.GradBucket
) -> torch.futures.Future:
# Prepare allreduced grad bucket tensors by running an async work.
work = process_group.allreduce(bucket.get_tensors())
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 57f3f6a..14516e2 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -15,7 +15,7 @@
def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ...
-class _GradBucket:
+class GradBucket:
def __init__(self, tensors: List[Tensor]): ...
def get_index(self) -> int: ...
def get_tensors(self) -> List[Tensor]: ...
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index ed016f9..504805f 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -183,7 +183,7 @@
py::arg("reducer"),
py::arg("comm_hook_type"));
- shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
+ shared_ptr_class_<::c10d::GradBucket>(module, "GradBucket")
.def(
py::init<
size_t,
@@ -1231,7 +1231,7 @@
``get_future` API to retrieve a Future associated with the completion of
``allreduce`` work.
- >>> def allreduce(state: object, bucket: dist._GradBucket): -> torch._C.Future
+ >>> def allreduce(state: object, bucket: dist.GradBucket): -> torch._C.Future
>>> tensors = [t / process_group.world_size for t in bucket.get_tensors()]
>>> work = process_group.allreduce(tensors)
>>> return work.get_future()
diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py
index 4288520..3604f06 100644
--- a/torch/distributed/__init__.py
+++ b/torch/distributed/__init__.py
@@ -63,8 +63,8 @@
Reducer,
Logger,
BuiltinCommHookType,
+ GradBucket,
_DEFAULT_FIRST_BUCKET_BYTES,
- _GradBucket,
_register_comm_hook,
_register_builtin_comm_hook,
_broadcast_coalesced,
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
index 15a3aca..6a93be5 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -17,7 +17,7 @@
def allreduce_hook(
- process_group: dist.ProcessGroup, bucket: dist._GradBucket
+ process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
@@ -35,7 +35,7 @@
def fp16_compress_hook(
- process_group: dist.ProcessGroup, bucket: dist._GradBucket
+ process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future:
"""
This DDP communication hook implements a simple gradient compression
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
index 16baadd..ed80b4a 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -172,7 +172,7 @@
def powerSGD_hook(
- state: PowerSGDState, bucket: dist._GradBucket
+ state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future:
r"""
This DDP communication hook implements PowerSGD gradient compression
@@ -217,7 +217,7 @@
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
and ``min_compression_rate``.
- bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+ bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.
@@ -440,7 +440,7 @@
def batched_powerSGD_hook(
- state: PowerSGDState, bucket: dist._GradBucket
+ state: PowerSGDState, bucket: dist.GradBucket
) -> torch.futures.Future:
r"""
This DDP communication hook implements a simplified PowerSGD gradient compression
@@ -484,7 +484,7 @@
Args:
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
- bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+ bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
Note that since DDP comm hook only supports single process single device mode at this time,
only exactly one tensor is stored in this bucket.
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
index dda2615..3ebd25c 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
@@ -43,7 +43,7 @@
def quantization_pertensor_hook(
- process_group: dist.ProcessGroup, bucket: dist._GradBucket
+ process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future:
"""
Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
@@ -116,7 +116,7 @@
def quantization_perchannel_hook(
- process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512
+ process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
) -> torch.futures.Future:
"""
Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index c240e04..5e912ea 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1065,7 +1065,7 @@
It is locally stored by each worker
and shared by all the gradient tensors on the worker.
hook (callable): Averages gradient tensors across workers and defined as:
- ``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``:
+ ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future``:
This function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
@@ -1107,7 +1107,7 @@
Example::
Below is an example of a noop hook that returns the same tensors.
- >>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
+ >>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future
>>> fut = torch.futures.Future()
>>> fut.set_result(bucket.get_tensors())
>>> return fut
@@ -1118,7 +1118,7 @@
Below is an example of a Parallel SGD algorithm where gradients are encoded before
allreduce, and then decoded after allreduce.
- >>> def encode_and_decode(state: object, bucket: dist._GradBucket): -> torch.futures.Future
+ >>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future
>>> tensors = [t / process_group.world_size for t in bucket.get_tensors()]
>>> encoded_tensors = encode(tensors) # encode gradients
>>> fut = process_group.allreduce(encoded_tensors).get_future()
@@ -1270,10 +1270,10 @@
sig = inspect.signature(hook)
if (
sig.parameters["bucket"].annotation != inspect._empty
- and sig.parameters["bucket"].annotation != dist._GradBucket
+ and sig.parameters["bucket"].annotation != dist.GradBucket
):
raise ValueError(
- "Communication hook: bucket annotation should be dist._GradBucket."
+ "Communication hook: bucket annotation should be dist.GradBucket."
)
if sig.return_annotation != inspect._empty and (