Improve utils.checkpoint docs (#6526)
* improve util.checkpoint docs
* change volatile to no_grad, and add more explanation
* address comments
diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py
index 2b115ea..dcd2607 100644
--- a/torch/utils/checkpoint.py
+++ b/torch/utils/checkpoint.py
@@ -39,59 +39,71 @@
return (None,) + tuple(inp.grad for inp in detached_inputs)
-def checkpoint(run_function, *args):
+def checkpoint(function, *args):
r"""Checkpoint a model or part of the model
- Checkpoint works by trading compute for memory. It can be applied on any
- part of the model. In the forward pass, the model activations are not
- stored. The forward pass save the inputs tuple and the run_function
- parameter. In the backwards pass, the saved inputs and run_function is
- retreived, and the forward pass is done on the model again (non-volatile
- this time) since we need to get the activations values for calculating the
- gradient and then the gradients are calculated.
+ Checkpointing works by trading compute for memory. Rather than storing all
+ intermediate activations of the entire computation graph for computing
+ backward, the checkpointed part does **not** save intermediate activations,
+ and instead recomputes them in backward pass. It can be applied on any part
+ of a model.
+
+ Specifically, in the forward pass, :attr:`function` will run in
+ :func:`torch.no_grad` manner, i.e., not storing the intermediate
+ activations. Instead, the forward pass saves the inputs tuple and the
+ :attr:`function` parameter. In the backwards pass, the saved inputs and
+ :attr:`function` is retreived, and the forward pass is computed on
+ :attr:`function` again, now tracking the intermediate activations, and then
+ the gradients are calculated using these activation values.
.. warning::
+ Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
+ with :func:`torch.autograd.backward`.
- checkpointing doesn't work with torch.autograd.grad(), but only with
- torch.autograd.backward()
+ .. warning::
+ If :attr:`function` invocation during backward does anything different
+ than the one during forward, e.g., due to some global variable, the
+ checkpointed version won't be equivalent, and unfortunately it can't be
+ detected.
Args:
- run_function: describes what to run in the forward pass of the model or
+ function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
- passed as the tuple. For example, in LSTM, user passes (activation,
- hidden), run_function should correctly use first input as activation
- and second input as hidden
- args: tuple containing inputs to the run_function
+ passed as the tuple. For example, in LSTM, if user passes
+ ``(activation, hidden)``, :attr:`function` should correctly use the
+ first input as ``activation`` and the second input as ``hidden``
+ args: tuple containing inputs to the :attr:`function`
Returns:
- Output of running the run_function on *args
+ Output of running :attr`function` on *:attr:`args`
"""
- return CheckpointFunction.apply(run_function, *args)
+ return CheckpointFunction.apply(function, *args)
def checkpoint_sequential(functions, segments, *inputs):
- r"""A helper function for checkpointing Sequential based models.
+ r"""A helper function for checkpointing sequential models.
- For models that are constructed using Sequential, they normally are built
- using various modules/functions. For such models, given a list of modules/functions
- it executes in order (sequentially), we can divide the model in various
- segments and checkpoint the segments. All segments except the last will be
- run in volatile manner i.e. the model activations are not stored. The inputs
- of each checkpointed segment will be saved for re-running the segment in the
- backward pass.
+ Sequential models execute a list of modules/functions in order
+ (sequentially). Therefore, we can divide such a model in various segments
+ and checkpoint each segment. All segments except the last will run in
+ :func:`torch.no_grad` manner, i.e., not storing the intermediate
+ activations. The inputs of each checkpointed segment will be saved for
+ re-running the segment in the backward pass.
+
+ See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
-
- checkpointing doesn't work with torch.autograd.grad(), but only with
- torch.autograd.backward()
+ Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
+ with :func:`torch.autograd.backward`.
Args:
- functions: A sequential or the list of modules or functions (comprising the model) to run in order.
+ functions: A :class:`torch.nn.Sequential` or the list of modules or
+ functions (comprising the model) to run sequentially.
segments: Number of chunks to create in the model
- inputs: tuple of Tensors that are inputs to run_function
+ inputs: tuple of Tensors that are inputs to :attr:`functions`
Returns:
- Output of running the modules/functions on *inputs
+ Output of running :attr:`functions` sequentially on *:attr:`inputs`
Example:
>>> model = nn.Sequential(...)