Fix nn docs
diff --git a/docs/source/nn.rst b/docs/source/nn.rst
index e3f2c4d..c812a4f 100644
--- a/docs/source/nn.rst
+++ b/docs/source/nn.rst
@@ -2,4 +2,8 @@
===================================
.. automodule:: torch.nn
+.. currentmodule:: torch.nn
+.. autoclass:: Module
+ :members:
.. autoclass:: Container
+ :members:
diff --git a/torch/nn/modules/__init__.py b/torch/nn/modules/__init__.py
index 80d4f98..f85aa3d 100644
--- a/torch/nn/modules/__init__.py
+++ b/torch/nn/modules/__init__.py
@@ -1,3 +1,4 @@
+from .module import Module
from .linear import Linear
from .conv import Conv1d, Conv2d, ConvTranspose2d, Conv3d, ConvTranspose3d
from .activation import Threshold, ReLU, Hardtanh, ReLU6, Sigmoid, Tanh, \
@@ -17,4 +18,4 @@
from .sparse import Embedding
from .rnn import RNNBase, RNN, LSTM, GRU, \
RNNCell, LSTMCell, GRUCell
-from .pixelshuffle import PixelShuffle
\ No newline at end of file
+from .pixelshuffle import PixelShuffle
diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py
index b5aceec..7903f49 100644
--- a/torch/nn/modules/container.py
+++ b/torch/nn/modules/container.py
@@ -22,90 +22,30 @@
your operations.
To make it easier to understand, given is a small example.
- ```
- # Example of using Container
- class Net(nn.Container):
- def __init__(self):
- super(Net, self).__init__(
- conv1 = nn.Conv2d(1, 20, 5),
- relu = nn.ReLU()
- )
- def forward(self, input):
- output = self.relu(self.conv1(x))
- return output
- model = Net()
- ```
+
+ ::
+
+ # Example of using Container
+ class Net(nn.Container):
+ def __init__(self):
+ super(Net, self).__init__(
+ conv1 = nn.Conv2d(1, 20, 5),
+ relu = nn.ReLU()
+ )
+ def forward(self, input):
+ output = self.relu(self.conv1(x))
+ return output
+ model = Net()
One can also add new modules to a container after construction.
You can do this with the add_module function
- or by assigning them as Container attributes.
+ or by assigning them as Container attributes::
- ```python
- # one can add modules to the container after construction
- model.add_module('pool1', nn.MaxPool2d(2, 2))
+ # one can add modules to the container after construction
+ model.add_module('pool1', nn.MaxPool2d(2, 2))
- # one can also set modules as attributes of the container
- model.conv1 = nn.Conv2d(12, 24, 3)
- ```
-
- The container has some important additional methods:
-
- **`[generator] parameters()`**
-
- returns a generator over all learnable parameters in the container instance.
- This can typically be passed to the optimizer API
-
- ```python
- # .parameters()
- >>> for param in model.parameters():
- >>> print(type(param.data), param.size())
- <class 'torch.FloatTensor'> (20L,)
- <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
- ```
-
- **`[dict] state_dict()`**
-
- returns a dictionary containing a whole state of the Container (both
- parameters and persistent buffers e.g. running averages)
- For example: ['conv1.weight' : Parameter(torch.FloatTensor(20x1x5x5)),
- 'conv1.bias' : Parameter(torch.FloatTensor(20)),
- ]
-
- ```python
- # .state_dict()
- >>> pdict = model.state_dict()
- >>> print(pdict.keys())
- ['conv1.bias', 'conv1.weight']
- ```
-
-
- **`load_state_dict(dict)`**
-
- Given a state dict, updates the parameters of self if respective keys are
- present in the dict. Excessive or non-matching parameter names are ignored.
- For example, the input dict has an entry 'conv44.weight', but
- if the container does not have a module named 'conv44', then this entry is ignored.
-
- **`children()`**
-
- Returns a generator over all the children modules of self
-
- **`train()`**
-
- Sets the Container (and all it's child modules) to training mode (for modules such as batchnorm, dropout etc.)
-
- **`eval()`**
-
- Sets the Container (and all it's child modules) to evaluate mode (for modules such as batchnorm, dropout etc.)
-
- **`apply(closure)`**
-
- Applies the given closure to each parameter of the container.
-
-
- **__Note: Apart from these, the container will define the base functions that it has derived from nn.Module __**
-
-
+ # one can also set modules as attributes of the container
+ model.conv1 = nn.Conv2d(12, 24, 3)
"""
dump_patches = False
@@ -154,6 +94,16 @@
Module.__delattr__(self, name)
def state_dict(self, destination=None, prefix=''):
+ """Returns a dictionary containing a whole state of the model.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are computed using a natural Python's indexing syntax
+ (e.g. 'subcontainer.module.weight'), excluding ``self``.
+
+ Example:
+ >>> print(model.state_dict().keys())
+ ['conv1.bias', 'conv1.weight']
+ """
result = super(Container, self).state_dict(destination, prefix)
for name, module in self._modules.items():
if module is not None:
@@ -161,12 +111,35 @@
return result
def load_state_dict(self, state_dict, prefix=''):
+ """Replaces model parameters using values from a given state_dict.
+
+ Copies all state_dict entries, where keys match any of the submodules.
+ For example, if the state_dict has an entry ``'conv44.weight'``, but
+ if the container does not have any submodule named ``'conv44'``, then
+ such entry will be ignored. However, once a module is found, this will
+ load all values from the state dict (including such that weren't
+ registered before loading).
+
+ Arguments:
+ state_dict (dict): A dict containing loaded parameters and
+ persistent buffers.
+ """
super(Container, self).load_state_dict(state_dict)
for name, module in self._modules.items():
if module is not None:
module.load_state_dict(state_dict, prefix + name + '.')
def parameters(self, memo=None):
+ """Returns an iterator over model parameters (including submodules).
+
+ This is typically passed to an optimizer.
+
+ Example:
+ >>> for param in model.parameters():
+ >>> print(type(param.data), param.size())
+ <class 'torch.FloatTensor'> (20L,)
+ <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
+ """
if memo is None:
memo = set()
for p in super(Container, self).parameters(memo):
@@ -176,6 +149,7 @@
yield p
def children(self):
+ """Returns an iterator over children modules."""
memo = set()
for module in self._modules.values():
if module is not None and module not in memo:
@@ -193,12 +167,20 @@
yield m
def train(self):
+ """Sets the module (including children) in training mode.
+
+ This has any effect only on modules such as Dropout or BatchNorm.
+ """
super(Container, self).train()
for module in self.children():
module.train()
return self
def eval(self):
+ """Sets the module (including children) in evaluation mode.
+
+ This has any effect only on modules such as Dropout or BatchNorm.
+ """
super(Container, self).eval()
for module in self.children():
module.eval()
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 23d2240..988ee3d 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -8,119 +8,9 @@
class Module(object):
- """This is the base class for all Modules defined in the nn package.
- Even the Container class derives from this class.
+ """Base class for all Modules defined in the nn package.
- An nn.Module has the following interface:
-
- **Constructor:**
- nn.Module()
-
- **forward(...)**
-
- This is the function that one defines when subclassing to create
- their own modules.
- It takes in inputs and returns outputs.
-
- **__call__(...)**
-
- This calls the forward function, as well as the hooks
-
- **register_parameter(name, param)**
-
- Adds a parameter to the module. The parameter can be accessed as an
- attribute of the module by its name.
-
- **register_buffer(name, tensor)**
-
- This is typically used to register a buffer that is not a Parameter.
- For example, in BatchNorm, the running_mean is a buffer, so one would
- register it in the constructor of BatchNorm with:
-
- `self.register_buffer('running_mean', torch.zeros(num_features))`
-
- The registered buffers can simply be accessed as class members
- when needed.
-
- **cpu()**
-
- Recursively moves all it's parameters and buffers to the CPU
-
- **cuda(device_id=None)**
- Recursively moves all it's parameters and buffers to the CUDA memory.
- If device_id is given, moves it to GPU number device_id
-
- **float()**
- Typecasts the parameters and buffers to float
-
- **double()**
- Typecasts the parameters and buffers to double
-
- **register_forward_hook(name, hook)**
-
- This will register a user-defined closure on the module.
- Whenever the module finishes it's forward operation,
- the user closure is called.
- The signature of the closure is `def closure(input, output)`
-
- **register_backward_hook(name, hook)**
-
- This will register a user-defined closure on the module.
- Whenever the module finishes it's backward operation,
- the user closure is called.
- The signature of the closure is `def closure(gradOutput, gradInput)`
-
- **remove_forward_hook(name)**
-
- Removes a registered forward hook with the given name
-
- **remove_backward_hook(name)**
-
- Removes a registered backward hook with the given name
-
- **`[generator] parameters()`**
-
- returns a generator over all learnable parameters in the container instance.
- This can typically be passed to the optimizer API
-
- ```python
- # .parameters()
- >>> for param in model.parameters():
- >>> print(type(param.data), param.size())
- <class 'torch.FloatTensor'> (20L,)
- <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
- ```
-
- **`[dict] state_dict()`**
-
- returns a dictionary of learnable parameters of the Module.
- For example: ['weight' : Parameter(torch.FloatTensor(20x1x5x5)),
- 'bias' : Parameter(torch.FloatTensor(20)),
- ]
-
- ```python
- # .state_dict()
- >>> pdict = model.state_dict()
- >>> print(pdict.keys())
- ['bias', 'weight']
- ```
-
- **`load_state_dict(dict)`**
-
- Given a parameter dict, sets the parameters of self to be the given dict.
-
- **`train()`**
-
- Sets the Container to training mode (for modules such as batchnorm, dropout etc.)
-
- **`eval()`**
-
- Sets the Container to evaluate mode (for modules such as batchnorm, dropout etc.)
-
- **`zero_grad()`**
-
- Zeroes the gradients of each Parameter of the module
-
+ Even the Container class derives from it.
"""
def __init__(self):
self._backend = thnn_backend
@@ -139,12 +29,31 @@
self._parameters[name] = param
def forward(self, *input):
+ """Defines the computation performed at every call.
+
+ Should be overriden by all subclasses.
+ """
raise NotImplementedError
def register_buffer(self, name, tensor):
+ """Adds a persistent buffer to the module.
+
+ This is typically used to register a buffer that should not to be
+ considered a model parameter. For example, BatchNorm's ``running_mean``
+ is not a parameter, but is part of the persistent state.
+
+ Buffers can be accessed as attributes using given names.
+
+ Example:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ """
self._buffers[name] = tensor
def register_parameter(self, name, param):
+ """Adds a parameter to the module.
+
+ The parameter can be accessed as an attribute using given name.
+ """
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
@@ -182,36 +91,73 @@
return self
def cuda(self, device_id=None):
+ """Moves all model parameters and buffers to the GPU.
+
+ Arguments:
+ device_id (int, optional): if specified, all parameters will be
+ copied to that device
+ """
return self._apply(lambda t: t.cuda(device_id))
def cpu(self, device_id=None):
+ """Moves all model parameters and buffers to the CPU."""
return self._apply(lambda t: t.cpu())
def type(self, dst_type):
return self._apply(lambda t: t.type(dst_type))
def float(self):
+ """Casts all parameters and buffers to float datatype."""
return self._apply(lambda t: t.float())
def double(self):
+ """Casts all parameters and buffers to double datatype."""
return self._apply(lambda t: t.double())
+ def half(self):
+ """Casts all parameters and buffers to half datatype."""
+ return self._apply(lambda t: t.half())
+
def register_backward_hook(self, name, hook):
+ """Registers a backward hook on the module, under a given name.
+
+ The hook will be called every time the gradient w.r.t. module inputs
+ is computed. The callable should accept two arguments - gradient w.r.t.
+ the input and gradient w.r.t. the output, where both arguments can be
+ tuples if the module had multiple inputs or outputs.
+ The hook should never modify its arguments in-place, but it can
+ optionally return a new gradient w.r.t. the input, that will be used
+ in subsequent computation.
+ """
assert name not in self._backward_hooks, \
"Trying to register a second backward hook with name {}".format(name)
self._backward_hooks[name] = lambda gi, go: hook(self, gi, go)
def remove_backward_hook(self, name):
+ """Removes a backward hook with a given name.
+
+ If no such hook exists, a RuntimeError is raised.
+ """
assert name in self._backward_hooks, \
"Trying to remove an inexistent backward hook with name {}".format(name)
del self._backward_hooks[name]
def register_forward_hook(self, name, hook):
+ """Registers a forward hook on the module, under a given name.
+
+ The hook will be called every time :func:`forward` computes an output.
+ The callable should accept two arguments - module's input and output.
+ Both should not be modified by the hook.
+ """
assert name not in self._forward_hooks, \
"Trying to register a second forward hook with name {}".format(name)
self._forward_hooks[name] = hook
def remove_forward_hook(self, name):
+ """Removes a forward hook with a given name.
+
+ If no such hook exists, a RuntimeError is raised.
+ """
assert name in self._forward_hooks, \
"Trying to remove an inexistent forward hook with name {}".format(name)
del self._forward_hooks[name]
@@ -256,6 +202,15 @@
object.__delattr__(self, name)
def state_dict(self, destination=None, prefix=''):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ Example:
+ >>> print(module.state_dict().keys())
+ ['bias', 'weight']
+ """
if destination is None:
destination = OrderedDict()
for name, param in chain(self._buffers.items(), self._parameters.items()):
@@ -264,6 +219,15 @@
return destination
def load_state_dict(self, state_dict, prefix=''):
+ """Replaces module parameters using values from a given state_dict.
+
+ This will load all values from the state dict (including such that
+ weren't registered before loading).
+
+ Arguments:
+ state_dict (dict): A dict containing loaded parameters and
+ persistent buffers.
+ """
for name, param in self._parameters.items():
new_param = state_dict.get(prefix + name, param)
if not isinstance(new_param, Parameter) and new_param is not None:
@@ -275,6 +239,16 @@
self._buffers[name] = state_dict.get(prefix + name, buf)
def parameters(self, memo=None):
+ """Returns an iterator over module parameters.
+
+ This is typically passed to an optimizer.
+
+ Example:
+ >>> for param in model.parameters():
+ >>> print(type(param.data), param.size())
+ <class 'torch.FloatTensor'> (20L,)
+ <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
+ """
if memo is None:
memo = set()
for p in self._parameters.values():
@@ -283,6 +257,7 @@
yield p
def children(self):
+ """Returns an iterator over children modules."""
if False:
yield
@@ -294,14 +269,23 @@
yield self
def train(self):
+ """Sets the module in training mode.
+
+ This has any effect only on modules such as Dropout or BatchNorm.
+ """
self.training = True
return self
def eval(self):
+ """Sets the module in evaluation mode.
+
+ This has any effect only on modules such as Dropout or BatchNorm.
+ """
self.training = False
return self
def zero_grad(self):
+ """Sets gradients of all model parameters to zero."""
for p in self.parameters():
p.grad.zero_()