Add torch.utils.data docs and improve notes (#460)
* Add torch.utils.data docs and improve notes
diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst
index 1464d96..59dfe22 100644
--- a/docs/source/notes/cuda.rst
+++ b/docs/source/notes/cuda.rst
@@ -11,7 +11,7 @@
Cross-GPU operations are not allowed by default, with the only exception of
:meth:`~torch.Tensor.copy_`. Unless you enable peer-to-peer memory accesses
-any attempts to launch ops on tensors spread accross different devices will
+any attempts to launch ops on tensors spread across different devices will
raise an error.
Below you can find a small example showcasing this::
@@ -35,3 +35,26 @@
# even within a context, you can give a GPU id to the .cuda call
c = torch.randn(2).cuda(2)
# c.get_device() == 2
+
+Best practices
+--------------
+
+Use pinned memory buffers
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. warning:
+
+ This is an advanced tip. You overuse of pinned memory can cause serious
+ problems if you'll be running low on RAM, and you should be aware that
+ pinning is often an expensive operation.
+
+Host to GPU copies are much faster when they originate from pinned (page-locked)
+memory. CPU tensors and storages expose a :meth:`~torch.Tensor.pin_memory`
+method, that returns a copy of the object, with data put in a pinned region.
+
+Also, once you pin a tensor or storage, you can use asynchronous GPU copies.
+Just pass an additional ``async=True`` argument to a :meth:`~torch.Tensor.cuda`
+call. This can be used to overlap data transfers with computation.
+
+You can make the :class:`~torch.utils.data.DataLoader` return batches placed in
+pinned memory by passing ``pinned=True`` to its constructor.
diff --git a/docs/source/notes/multiprocessing.rst b/docs/source/notes/multiprocessing.rst
index b1dba03..94a563d 100644
--- a/docs/source/notes/multiprocessing.rst
+++ b/docs/source/notes/multiprocessing.rst
@@ -34,8 +34,28 @@
apply to shared CPU memory.
-Best practices
---------------
+Best practices and tips
+-----------------------
+
+Avoiding and fighting deadlocks
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+There are a lot of things that can go wrong when a new process is spawned, with
+the most common cause of deadlocks being background threads. If there's any
+thread that holds a lock or imports a module, and ``fork`` is called, it's very
+likely that the subprocess will be in a corrupted state and will deadlock or
+fail in a different way. Note that even if you don't, Python built in
+libraries do - no need to look further than :mod:`python:multiprocessing`.
+:class:`python:multiprocessing.Queue` is actually a very complex class, that
+spawns multiple threads used to serialize, send and receive objects, and they
+can cause aforementioned problems too. If you find yourself in such situation
+try using a :class:`~python:multiprocessing.queues.SimpleQueue`, that doesn't
+use any additional threads.
+
+We're trying our best to make it easy for you and ensure these deadlocks don't
+happen but some things are out of our control. If you have any issues you can't
+cope with for a while, try reaching out on forums, and we'll see if it's an
+issue we can fix.
Reuse buffers passed through a Queue
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -47,8 +67,61 @@
processes sending data to a single one, make it send the buffers back - this
is nearly free and will let you avoid a copy when sending next batch.
+Asynchronous multiprocess training (e.g. Hogwild)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Using :mod:`torch.multiprocessing`, it is possible to train a model
+asynchronously, with parameters either shared all the time, or being
+periodically synchronized. In the first case, we recommend sending over the whole
+model object, while in the latter, we advise to only send the
+:meth:`~torch.nn.Module.state_dict`.
+We recommend using :class:`python:multiprocessing.Queue` for passing all kinds
+of PyTorch objects between processes. It is possible to e.g. inherit the tensors
+and storages already in shared memory, when using the ``fork`` start method,
+however it is very bug prone and should be used with care, and only by advanced
+users. Queues, even though they're sometimes a less elegant solution, will work
+properly in all cases.
+.. warning::
+ You should be careful about having global statements, that are not guarded
+ with an ``if __name__ == '__main__'``. If a different start method than
+ ``fork`` is used, they will be executed in all subprocesses.
+Hogwild
+~~~~~~~
+
+A concrete Hogwild implementation can be found in the `examples repository`__,
+but to showcase the overall structure of the code, there's also a minimal
+example below as well::
+
+ import torch.multiprocessing as mp
+ from model import MyModel
+
+ def train(model):
+ # This for loop will break sharing of gradient buffers. It's not
+ # necessary but it reduces the contention, and has a small memory cost
+ # (equal to the total size of parameters).
+ for param in model.parameters():
+ param.grad.data = param.grad.data.clone()
+ # Construct data_loader, optimizer, etc.
+ for data, labels in data_loader:
+ optimizer.zero_grad()
+ loss_fn(model(data), labels).backward()
+ optimizer.step() # This will update the shared parameters
+
+ if __name__ == '__main__':
+ num_processes = 4
+ model = MyModel()
+ # NOTE: this is required for the ``fork`` method to work
+ model.share_memory()
+ processes = []
+ for rank in range(num_processes):
+ p = mp.Process(target=train, args=(model,))
+ p.start()
+ processes.append(p)
+ for p in processes:
+ p.join()
+
+.. __: https://github.com/pytorch/examples/tree/master/mnist_hogwild
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 6605bfb..2084626 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -34,7 +34,7 @@
class Model(nn.Module):
def __init__(self):
- super(Net, self).__init__()
+ super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 761da1b..bf254ce 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -216,6 +216,20 @@
"""
Data loader. Combines a dataset and a sampler, and provides
single- or multi-process iterators over the dataset.
+
+ Arguments:
+ dataset (Dataset): dataset from which to load the data.
+ batch_size (int, optional): how many samples per batch to load
+ (default: 1).
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: False).
+ sampler (Sampler, optional): defines the strategy to draw samples from
+ the dataset. If specified, the ``shuffle`` argument is ignored.
+ num_workers (int, optional): how many subprocesses to use for data
+ loading. 0 means that the data will be loaded in the main process
+ (default: 0)
+ collate_fn (callable, optional)
+ pin_memory (bool, optional)
"""
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py
index 69b4c36..bdf8d52 100644
--- a/torch/utils/data/dataset.py
+++ b/torch/utils/data/dataset.py
@@ -1,5 +1,11 @@
class Dataset(object):
+ """An abstract class representing a Dataset.
+
+ All other datasets should subclass it. All subclasses should override
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
+ supporting integer indexing in range from 0 to len(self) exclusive.
+ """
def __getitem__(self, index):
raise NotImplementedError
@@ -9,6 +15,15 @@
class TensorDataset(Dataset):
+ """Dataset wrapping data and target tensors.
+
+ Each sample will be retrieved by indexing both tensors along the first
+ dimension.
+
+ Arguments:
+ data_tensor (Tensor): contains sample data.
+ target_tensor (Tensor): contains sample targets (labels).
+ """
def __init__(self, data_tensor, target_tensor):
assert data_tensor.size(0) == target_tensor.size(0)
diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py
index 53e3ff5..ffbdaa9 100644
--- a/torch/utils/data/sampler.py
+++ b/torch/utils/data/sampler.py
@@ -2,6 +2,12 @@
class Sampler(object):
+ """Base class for all Samplers.
+
+ Every Sampler subclass has to provide an __iter__ method, providing a way
+ to iterate over indices of dataset elements, and a __len__ method that
+ returns the length of the returned iterators.
+ """
def __init__(self, data_source):
pass
@@ -14,6 +20,11 @@
class SequentialSampler(Sampler):
+ """Samples elements sequentially, always in the same order.
+
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ """
def __init__(self, data_source):
self.num_samples = len(data_source)
@@ -26,6 +37,11 @@
class RandomSampler(Sampler):
+ """Samples elements randomly, without replacement.
+
+ Arguments:
+ data_source (Dataset): dataset to sample from
+ """
def __init__(self, data_source):
self.num_samples = len(data_source)