Improved DDP checkpoint documentation (#106985)
Amended the documentation for the specified case.
Fixes #84589
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106985
Approved by: https://github.com/wanchaol, https://github.com/fduwjj
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 24f63e1..5a6facf 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -456,9 +456,12 @@
.. note::
DistributedDataParallel currently offers limited support for gradient
- checkpointing with :meth:`torch.utils.checkpoint`. DDP will work as
- expected when there are no unused parameters in the model and each layer
- is checkpointed at most once (make sure you are not passing
+ checkpointing with :meth:`torch.utils.checkpoint`.
+ If the checkpoint is done with use_reentrant=False (recommended), DDP
+ will work as expected without any limitations.
+ If, however, the checkpoint is done with use_reentrant=True (the default),
+ DDP will work as expected when there are no unused parameters in the model
+ and each layer is checkpointed at most once (make sure you are not passing
`find_unused_parameters=True` to DDP). We currently do not support the
case where a layer is checkpointed multiple times, or when there unused
parameters in the checkpointed model.