Add documentation page for pipeline parallelism. (#50791)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50791
Add a dedicated pipeline parallelism doc page explaining the APIs and
the overall value of the module.
ghstack-source-id: 120257168
Test Plan:
1) View locally
2) waitforbuildbot
Reviewed By: rohan-varma
Differential Revision: D25967981
fbshipit-source-id: b607b788703173a5fa4e3526471140506171632b
diff --git a/docs/source/_static/img/pipeline_parallelism/no_pipe.png b/docs/source/_static/img/pipeline_parallelism/no_pipe.png
new file mode 100644
index 0000000..4b2b795
--- /dev/null
+++ b/docs/source/_static/img/pipeline_parallelism/no_pipe.png
Binary files differ
diff --git a/docs/source/_static/img/pipeline_parallelism/pipe.png b/docs/source/_static/img/pipeline_parallelism/pipe.png
new file mode 100644
index 0000000..084b455
--- /dev/null
+++ b/docs/source/_static/img/pipeline_parallelism/pipe.png
Binary files differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
index a334bff..a105b6d 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -71,6 +71,7 @@
onnx
optim
complex_numbers
+ pipeline
quantization
rpc
torch.random <random>
diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst
new file mode 100644
index 0000000..6f52e62
--- /dev/null
+++ b/docs/source/pipeline.rst
@@ -0,0 +1,71 @@
+.. _pipeline-parallelism:
+
+Pipeline Parallelism
+====================
+
+Pipeline parallelism was original introduced in the
+`Gpipe <https://arxiv.org/abs/1811.06965>`__ paper and is an efficient
+technique to train large models on multiple GPUs.
+
+.. warning ::
+ Pipeline Parallelism is experimental and subject to change.
+
+Model Parallelism using multiple GPUs
+-------------------------------------
+
+Typically for large models which don't fit on a single GPU, model parallelism
+is employed where certain parts of the model are placed on different GPUs.
+Although, if this is done naively for sequential models, the training process
+suffers from GPU under utilization since only one GPU is active at one time as
+shown in the figure below:
+
+.. figure:: _static/img/pipeline_parallelism/no_pipe.png
+
+ The figure represents a model with 4 layers placed on 4 different GPUs
+ (vertical axis). The horizontal axis represents training this model through
+ time demonstrating that only 1 GPU is utilized at a time
+ (`image source <https://arxiv.org/abs/1811.06965>`__).
+
+Pipelined Execution
+-------------------
+
+To alleviate this problem, pipeline parallelism splits the input minibatch into
+multiple microbatches and pipelines the execution of these microbatches across
+multiple GPUs. This is outlined in the figure below:
+
+.. figure:: _static/img/pipeline_parallelism/pipe.png
+
+ The figure represents a model with 4 layers placed on 4 different GPUs
+ (vertical axis). The horizontal axis represents training this model through
+ time demonstrating that the GPUs are utilized much more efficiently.
+ However, there still exists a bubble (as demonstrated in the figure) where
+ certain GPUs are not utilized.
+ (`image source <https://arxiv.org/abs/1811.06965>`__).
+
+Pipe APIs in PyTorch
+--------------------
+.. autoclass:: torch.distributed.pipeline.sync.Pipe
+ :members: forward
+
+Skip connections
+^^^^^^^^^^^^^^^^
+
+Certain models like ResNeXt are not completely sequential and have skip
+connections between layers. Naively implementing as part of pipeling
+parallelism would imply that we need to copy outputs for certain layers through
+multiple GPUs till we eventually reach the GPU where the layer for the skip
+connection resides. To avoid this copy overhead, we provide APIs below to stash
+and pop Tensors in different layers of the model.
+
+.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable
+.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash
+.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop
+.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables
+
+Acknowledgements
+----------------
+
+The implementation for pipeline parallelism is based on `fairscale's pipe implementation <https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/pipe>`__ and
+`torchgpipe <https://github.com/kakaobrain/torchgpipe>`__. We would like to
+thank both teams for their contributions and guidance towards bringing pipeline
+parallelism into PyTorch.
diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py
index e0b0dae..6f6bcd7 100644
--- a/torch/distributed/pipeline/sync/skip/skippable.py
+++ b/torch/distributed/pipeline/sync/skip/skippable.py
@@ -242,7 +242,7 @@
"""The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
connections. Decorated modules are called "skippable". This functionality
works perfectly fine even when the module is not wrapped by
- :class:`~torchpipe.Pipe`.
+ :class:`~torch.distributed.pipeline.sync.Pipe`.
Each skip tensor is managed by its name. Before manipulating skip tensors,
a skippable module must statically declare the names for skip tensors by
@@ -282,23 +282,10 @@
return input + carol
Every skip tensor must be associated with exactly one pair of `stash` and
- `pop`. :class:`~torchpipe.Pipe` checks this restriction automatically
- when wrapping a module. You can also check the restriction by
- :func:`~torchpipe.skip.verify_skippables` without
- :class:`~torchpipe.Pipe`.
-
- .. note::
-
- :func:`@skippable <skippable>` changes the type of the wrapped class.
- But currently (mypy v0.740), mypy could not understand class decorators
- yet (`#3135 <https://github.com/python/mypy/issues/3135>`_).
-
- There are two workarounds:
-
- 1. Naively ignore type errors by ``# type: ignore``.
- 2. Use ``skippable()()`` as a function instead of a decorator.
-
- .. seealso:: :ref:`Long Skip Connections`
+ `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this
+ restriction automatically when wrapping a module. You can also check the
+ restriction by :func:`verify_skippables`
+ without :class:`~torch.distributed.pipeline.sync.Pipe`.
"""
stashable_names = frozenset(stash)