[DDP Comm Hook] Do not expose hook_then_optimizer as a public method (#62532)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62532
This method is not stable at this time, so avoid releasing it when DDP communication hook feature is released as a stable feature.
ghstack-source-id: 134787831
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_ddp_hook_with_optimizer_parity
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_hook_then_optimizer_nccl
Reviewed By: rohan-varma
Differential Revision: D30031222
fbshipit-source-id: e03a8e13fee5116a5ddd724eb76316ee98f2a676
diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py
index f611e5b..cc60920 100644
--- a/test/distributed/test_c10d_nccl.py
+++ b/test/distributed/test_c10d_nccl.py
@@ -1607,12 +1607,12 @@
sgd_lr = 1e-2
sgd_momentum = 0.9
sgd_weight_decay = 0.01
- opt_hook_state = default.OptimizerHookState(
+ opt_hook_state = default._OptimizerHookState(
_FunctionalSGD, sgd_lr, momentum=sgd_momentum, weight_decay=sgd_weight_decay
)
gpu_model = self._gpu_model_with_ddp_comm_hook(
process_group,
- default.hook_then_optimizer(hook, opt_hook_state),
+ default._hook_then_optimizer(hook, opt_hook_state),
gradient_as_bucket_view,
hook_state,
)
diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
index e2021af..50667ff 100644
--- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
+++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py
@@ -70,7 +70,7 @@
return fut.then(decompress)
-class OptimizerHookState(object):
+class _OptimizerHookState(object):
"""
Holds state for running optimizer in-line after DDP communication hook.
Currently contains only optimizer class which must have a method `step_param`.
@@ -93,11 +93,18 @@
)
-def hook_then_optimizer(
+# TODO: Add an example to use such a wrapper.
+def _hook_then_optimizer(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
- optimizer_state: OptimizerHookState,
+ optimizer_state: _OptimizerHookState,
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
- """Runs optimizer in a functional fashion after DDP communication hook."""
+ r"""
+ Runs optimizer in a functional fashion after DDP communication hook.
+
+ .. warning ::
+ This API is experimental adn subject to change.
+ """
+
def hook_then_optimizer_wrapper(
hook_state, bucket: dist.GradBucket
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 2ff88ea..f73e0a2 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -3865,7 +3865,7 @@
# Register hook that runs allreduce + functional SGD step.
allreduce_hook = default.allreduce_hook
- opt_hook_state = default.OptimizerHookState(
+ opt_hook_state = default._OptimizerHookState(
_FunctionalSGD,
sgd_lr,
momentum=sgd_momentum,
@@ -3873,7 +3873,7 @@
)
ddp_model_with_optimizer_hook.register_comm_hook(
None,
- default.hook_then_optimizer(allreduce_hook, opt_hook_state),
+ default._hook_then_optimizer(allreduce_hook, opt_hook_state),
)
# Create DDP model with no hook that does optimizer after
# backward.