[Gradient Compression] Update the docstring of fp16_compress_wrapper (#53955)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53955
Per title
ghstack-source-id: 123852836
Test Plan: N/A
Reviewed By: iseessel
Differential Revision: D27032700
fbshipit-source-id: 6f9bbc028efe6cc9b54f4ec729fea745368efb2e
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
index ba7d6e3..a442aa9 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -73,9 +73,12 @@
This wrapper casts the input gradient tensors of a given DDP communication hook to half-precision
floating point format (``torch.float16``), and casts the resulting tensors of the given hook back to
the input data type, such as ``float32``.
+
+ Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``.
+
Example::
- >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
- >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
+ >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10)
+ >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
"""
def fp16_compress_wrapper_hook(hook_state, bucket: dist.GradBucket) -> torch.futures.Future: