[Gradient Compression] Make GradBucket class public (#53099)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53099

Publish GradBucket APIs for publishing DDP communication hooks.

s/_GradBucket/GradBucket
ghstack-source-id: 123030921

Test Plan: waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D26721121

fbshipit-source-id: ee5f68e33095b9965b51937b86cdeb331fd2419a
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index 7d450ca..e8c48ba 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -360,7 +360,7 @@
             client_store = dist.TCPStore(addr, port, world_size, timeout=timedelta(seconds=10))
             self.assertEqual("value".encode(), client_store.get("key"))
             client_store.set(f"new_key{index}", f"new_value{index}")
-            self.assertEqual(f"next_value{index}".encode(), 
+            self.assertEqual(f"next_value{index}".encode(),
                              client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
         except Exception:
             messages.put('Caught exception: \n{}exiting process with exit code: {}'
@@ -3057,7 +3057,7 @@
         """
 
         def allreduce_hook(
-            process_group: object, bucket: dist._GradBucket
+            process_group: object, bucket: dist.GradBucket
         ) -> torch._C.Future:
             tensors = [t / self.world_size for t in bucket.get_tensors()]
             return process_group.allreduce(tensors).get_future()
@@ -3077,7 +3077,7 @@
         """
 
         def allreduce_with_then_hook(
-            process_group: object, bucket: dist._GradBucket
+            process_group: object, bucket: dist.GradBucket
         ) -> torch.futures.Future:
             fut = process_group.allreduce(bucket.get_tensors()).get_future()
 
@@ -3727,7 +3727,7 @@
         [self.assertEqual(p.grad, expected_grad) for p in model.parameters()]
 
     def _simple_hook(
-        self, state: object, bucket: dist._GradBucket
+        self, state: object, bucket: dist.GradBucket
     ) -> torch.futures.Future:
         fut = torch.futures.Future()
         fut.set_result([torch.ones_like(t) for t in bucket.get_tensors()])
@@ -3782,7 +3782,7 @@
         store = c10d.FileStore(self.file_name, self.world_size)
         process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
 
-        def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
+        def allreduce_hook(state: object, bucket: dist.GradBucket) -> torch._C.Future:
             tensors = [t / self.world_size for t in bucket.get_tensors()]
             return process_group.allreduce(tensors).get_future()
 
@@ -3930,7 +3930,7 @@
         process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
 
         def allreduce_with_then_hook(
-            state: object, bucket: dist._GradBucket
+            state: object, bucket: dist.GradBucket
         ) -> torch.futures.Future:
             tensors = [t / self.world_size for t in bucket.get_tensors()]
             fut = process_group.allreduce(tensors).get_future()
@@ -3972,7 +3972,7 @@
             model.register_comm_hook(state=None, hook=1)
 
         with self.assertRaisesRegex(
-            ValueError, "bucket annotation should be dist._GradBucket."
+            ValueError, "bucket annotation should be dist.GradBucket."
         ):
 
             def comm_hook(state: object, bucket: int) -> torch.futures.Future:
@@ -3999,7 +3999,7 @@
             "Communication hook: return annotation should be torch.futures.Future or torch._C.Future.",
         ):
 
-            def comm_hook(state: object, bucket: dist._GradBucket) -> int:
+            def comm_hook(state: object, bucket: dist.GradBucket) -> int:
                 return torch.futures.Future()
 
             model.register_comm_hook(state=None, hook=comm_hook)
@@ -4009,7 +4009,7 @@
             "callback must return a torch.futures.Future or torch._C.Future object, but got",
         ):
 
-            def comm_hook(state: object, bucket: dist._GradBucket):
+            def comm_hook(state: object, bucket: dist.GradBucket):
                 return 1
 
             model.register_comm_hook(state=None, hook=comm_hook)
@@ -4067,7 +4067,7 @@
         # "get_future" API does not support gloo backend, see GH Issue #42048.
         # Instead, we wait for an allreduce work, and write its result to a Future.
         def allreduce_hook_gloo(
-            state: object, bucket: dist._GradBucket
+            state: object, bucket: dist.GradBucket
         ) -> torch.futures.Future:
             # Prepare allreduced grad bucket tensors by running an async work.
             work = process_group.allreduce(bucket.get_tensors())
diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi
index 57f3f6a..14516e2 100644
--- a/torch/_C/_distributed_c10d.pyi
+++ b/torch/_C/_distributed_c10d.pyi
@@ -15,7 +15,7 @@
 def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
 def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ...
 
-class _GradBucket:
+class GradBucket:
     def __init__(self, tensors: List[Tensor]): ...
     def get_index(self) -> int: ...
     def get_tensors(self) -> List[Tensor]: ...
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index ed016f9..504805f 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -183,7 +183,7 @@
           py::arg("reducer"),
           py::arg("comm_hook_type"));
 
-  shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket")
+  shared_ptr_class_<::c10d::GradBucket>(module, "GradBucket")
       .def(
           py::init<
               size_t,
@@ -1231,7 +1231,7 @@
                 ``get_future` API to retrieve a Future associated with the completion of
                 ``allreduce`` work.
 
-                >>> def allreduce(state: object, bucket: dist._GradBucket): -> torch._C.Future
+                >>> def allreduce(state: object, bucket: dist.GradBucket): -> torch._C.Future
                 >>>     tensors = [t / process_group.world_size for t in bucket.get_tensors()]
                 >>>     work = process_group.allreduce(tensors)
                 >>>     return work.get_future()
diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py
index 4288520..3604f06 100644
--- a/torch/distributed/__init__.py
+++ b/torch/distributed/__init__.py
@@ -63,8 +63,8 @@
         Reducer,
         Logger,
         BuiltinCommHookType,
+        GradBucket,
         _DEFAULT_FIRST_BUCKET_BYTES,
-        _GradBucket,
         _register_comm_hook,
         _register_builtin_comm_hook,
         _broadcast_coalesced,
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
index 15a3aca..6a93be5 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -17,7 +17,7 @@
 
 
 def allreduce_hook(
-    process_group: dist.ProcessGroup, bucket: dist._GradBucket
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
 ) -> torch.futures.Future:
     """
     This DDP communication hook just calls ``allreduce`` using ``GradBucket``
@@ -35,7 +35,7 @@
 
 
 def fp16_compress_hook(
-    process_group: dist.ProcessGroup, bucket: dist._GradBucket
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
 ) -> torch.futures.Future:
     """
     This DDP communication hook implements a simple gradient compression
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
index 16baadd..ed80b4a 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py
@@ -172,7 +172,7 @@
 
 
 def powerSGD_hook(
-    state: PowerSGDState, bucket: dist._GradBucket
+    state: PowerSGDState, bucket: dist.GradBucket
 ) -> torch.futures.Future:
     r"""
     This DDP communication hook implements PowerSGD gradient compression
@@ -217,7 +217,7 @@
         state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
             To tune the compression configs, mainly need to tune ``matrix_approximation_rank``, ``start_powerSGD_iter``
             and ``min_compression_rate``.
-        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
             Note that since DDP comm hook only supports single process single device mode at this time,
             only exactly one tensor is stored in this bucket.
 
@@ -440,7 +440,7 @@
 
 
 def batched_powerSGD_hook(
-    state: PowerSGDState, bucket: dist._GradBucket
+    state: PowerSGDState, bucket: dist.GradBucket
 ) -> torch.futures.Future:
     r"""
     This DDP communication hook implements a simplified PowerSGD gradient compression
@@ -484,7 +484,7 @@
     Args:
         state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
             To tune the compression configs, mainly need to tune ``matrix_approximation_rank`` and ``start_powerSGD_iter``.
-        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
+        bucket (dist.GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
             Note that since DDP comm hook only supports single process single device mode at this time,
             only exactly one tensor is stored in this bucket.
 
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
index dda2615..3ebd25c 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py
@@ -43,7 +43,7 @@
 
 
 def quantization_pertensor_hook(
-    process_group: dist.ProcessGroup, bucket: dist._GradBucket
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket
 ) -> torch.futures.Future:
     """
     Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
@@ -116,7 +116,7 @@
 
 
 def quantization_perchannel_hook(
-    process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512
+    process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512
 ) -> torch.futures.Future:
     """
     Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index c240e04..5e912ea 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1065,7 +1065,7 @@
                             It is locally stored by each worker
                             and shared by all the gradient tensors on the worker.
             hook (callable): Averages gradient tensors across workers and defined as:
-                             ``hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future``:
+                             ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future``:
 
                              This function is called once the bucket is ready. The
                              hook can perform whatever processing is needed and return
@@ -1107,7 +1107,7 @@
         Example::
             Below is an example of a noop hook that returns the same tensors.
 
-            >>> def noop(state: object, bucket: dist._GradBucket): -> torch.futures.Future
+            >>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future
             >>>     fut = torch.futures.Future()
             >>>     fut.set_result(bucket.get_tensors())
             >>>     return fut
@@ -1118,7 +1118,7 @@
             Below is an example of a Parallel SGD algorithm where gradients are encoded before
             allreduce, and then decoded after allreduce.
 
-            >>> def encode_and_decode(state: object, bucket: dist._GradBucket): -> torch.futures.Future
+            >>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future
             >>>     tensors = [t / process_group.world_size for t in bucket.get_tensors()]
             >>>     encoded_tensors = encode(tensors) # encode gradients
             >>>     fut = process_group.allreduce(encoded_tensors).get_future()
@@ -1270,10 +1270,10 @@
         sig = inspect.signature(hook)
         if (
             sig.parameters["bucket"].annotation != inspect._empty
-            and sig.parameters["bucket"].annotation != dist._GradBucket
+            and sig.parameters["bucket"].annotation != dist.GradBucket
         ):
             raise ValueError(
-                "Communication hook: bucket annotation should be dist._GradBucket."
+                "Communication hook: bucket annotation should be dist.GradBucket."
             )
 
         if sig.return_annotation != inspect._empty and (