Add documentation page for pipeline parallelism. (#50791)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50791

Add a dedicated pipeline parallelism doc page explaining the APIs and
the overall value of the module.
ghstack-source-id: 120257168

Test Plan:
1) View locally
2) waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D25967981

fbshipit-source-id: b607b788703173a5fa4e3526471140506171632b
diff --git a/docs/source/_static/img/pipeline_parallelism/no_pipe.png b/docs/source/_static/img/pipeline_parallelism/no_pipe.png
new file mode 100644
index 0000000..4b2b795
--- /dev/null
+++ b/docs/source/_static/img/pipeline_parallelism/no_pipe.png
Binary files differ
diff --git a/docs/source/_static/img/pipeline_parallelism/pipe.png b/docs/source/_static/img/pipeline_parallelism/pipe.png
new file mode 100644
index 0000000..084b455
--- /dev/null
+++ b/docs/source/_static/img/pipeline_parallelism/pipe.png
Binary files differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
index a334bff..a105b6d 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -71,6 +71,7 @@
    onnx
    optim
    complex_numbers
+   pipeline
    quantization
    rpc
    torch.random <random>
diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst
new file mode 100644
index 0000000..6f52e62
--- /dev/null
+++ b/docs/source/pipeline.rst
@@ -0,0 +1,71 @@
+.. _pipeline-parallelism:
+
+Pipeline Parallelism
+====================
+
+Pipeline parallelism was original introduced in the 
+`Gpipe <https://arxiv.org/abs/1811.06965>`__  paper and is an efficient 
+technique to train large models on multiple GPUs. 
+
+.. warning ::
+     Pipeline Parallelism is experimental and subject to change.
+
+Model Parallelism using multiple GPUs
+-------------------------------------
+
+Typically for large models which don't fit on a single GPU, model parallelism 
+is employed where certain parts of the model are placed on different GPUs. 
+Although, if this is done naively for sequential models, the training process 
+suffers from GPU under utilization since only one GPU is active at one time as 
+shown in the figure below:
+
+.. figure:: _static/img/pipeline_parallelism/no_pipe.png
+
+   The figure represents a model with 4 layers placed on 4 different GPUs 
+   (vertical axis). The horizontal axis represents training this model through 
+   time demonstrating that only 1 GPU is utilized at a time 
+   (`image source <https://arxiv.org/abs/1811.06965>`__).
+
+Pipelined Execution
+-------------------
+
+To alleviate this problem, pipeline parallelism splits the input minibatch into 
+multiple microbatches and pipelines the execution of these microbatches across 
+multiple GPUs. This is outlined in the figure below:
+
+.. figure:: _static/img/pipeline_parallelism/pipe.png
+
+   The figure represents a model with 4 layers placed on 4 different GPUs 
+   (vertical axis). The horizontal axis represents training this model through 
+   time demonstrating that the GPUs are utilized much more efficiently. 
+   However, there still exists a bubble (as demonstrated in the figure) where 
+   certain GPUs are not utilized.
+   (`image source <https://arxiv.org/abs/1811.06965>`__).
+
+Pipe APIs in PyTorch
+--------------------
+.. autoclass:: torch.distributed.pipeline.sync.Pipe
+   :members: forward
+
+Skip connections
+^^^^^^^^^^^^^^^^
+
+Certain models like ResNeXt are not completely sequential and have skip 
+connections between layers. Naively implementing as part of pipeling 
+parallelism would imply that we need to copy outputs for certain layers through 
+multiple GPUs till we eventually reach the GPU where the layer for the skip 
+connection resides. To avoid this copy overhead, we provide APIs below to stash 
+and pop Tensors in different layers of the model.
+
+.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable
+.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash
+.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop
+.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables
+
+Acknowledgements
+----------------
+
+The implementation for pipeline parallelism is based on `fairscale's pipe implementation <https://github.com/facebookresearch/fairscale/tree/master/fairscale/nn/pipe>`__ and 
+`torchgpipe <https://github.com/kakaobrain/torchgpipe>`__. We would like to 
+thank both teams for their contributions and guidance towards bringing pipeline 
+parallelism into PyTorch.
diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py
index e0b0dae..6f6bcd7 100644
--- a/torch/distributed/pipeline/sync/skip/skippable.py
+++ b/torch/distributed/pipeline/sync/skip/skippable.py
@@ -242,7 +242,7 @@
     """The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
     connections. Decorated modules are called "skippable". This functionality
     works perfectly fine even when the module is not wrapped by
-    :class:`~torchpipe.Pipe`.
+    :class:`~torch.distributed.pipeline.sync.Pipe`.
 
     Each skip tensor is managed by its name. Before manipulating skip tensors,
     a skippable module must statically declare the names for skip tensors by
@@ -282,23 +282,10 @@
                 return input + carol
 
     Every skip tensor must be associated with exactly one pair of `stash` and
-    `pop`. :class:`~torchpipe.Pipe` checks this restriction automatically
-    when wrapping a module. You can also check the restriction by
-    :func:`~torchpipe.skip.verify_skippables` without
-    :class:`~torchpipe.Pipe`.
-
-    .. note::
-
-        :func:`@skippable <skippable>` changes the type of the wrapped class.
-        But currently (mypy v0.740), mypy could not understand class decorators
-        yet (`#3135 <https://github.com/python/mypy/issues/3135>`_).
-
-        There are two workarounds:
-
-        1. Naively ignore type errors by ``# type: ignore``.
-        2. Use ``skippable()()`` as a function instead of a decorator.
-
-    .. seealso:: :ref:`Long Skip Connections`
+    `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this
+    restriction automatically when wrapping a module. You can also check the
+    restriction by :func:`verify_skippables`
+    without :class:`~torch.distributed.pipeline.sync.Pipe`.
 
     """
     stashable_names = frozenset(stash)