Fix the doc of PostLocalSGDState (#72792)
Summary:
The first arg of `PostLocalSGDState` ctor, `process_group`, cannot be empty. Here to simplify the usage, does not even create a subgroup explicitly.
See the example in unit test: https://github.com/pytorch/pytorch/blob/4feef6c97092cfde7d57a97d8390a79551e92369/torch/testing/_internal/distributed/distributed_test.py#L4260
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72792
Reviewed By: samdow
Differential Revision: D34213221
Pulled By: rohan-varma
fbshipit-source-id: 078343f3ee138e175bf835897f190032eb970662
(cherry picked from commit bf90af704fb371eef799a951007cc5d41dbe07a1)
diff --git a/torch/distributed/algorithms/model_averaging/averagers.py b/torch/distributed/algorithms/model_averaging/averagers.py
index a084ab7..cb67057 100644
--- a/torch/distributed/algorithms/model_averaging/averagers.py
+++ b/torch/distributed/algorithms/model_averaging/averagers.py
@@ -60,8 +60,7 @@
>>> module, device_ids=[rank], output_device=rank
>>> )
>>> # Register a post-localSGD communication hook.
- >>> subgroup, subgroups = dist.new_subgroups()
- >>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
+ >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py
index 1a80bab..f242934 100644
--- a/torch/distributed/optim/post_localSGD_optimizer.py
+++ b/torch/distributed/optim/post_localSGD_optimizer.py
@@ -26,8 +26,7 @@
>>> )
>>>
>>> # Register a post-localSGD communication hook.
- >>> subgroup, subgroups = dist.new_subgroups()
- >>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100)
+ >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.