blob: 988ee3dc7f182f51170492fdd2762a9e4fdddb74 [file] [log] [blame]
from itertools import chain
from collections import OrderedDict
import torch
from ..backends.thnn import backend as thnn_backend
from ..parameter import Parameter
from torch.autograd import Variable
class Module(object):
"""Base class for all Modules defined in the nn package.
Even the Container class derives from it.
"""
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self.training = True
for name, param in self._parameters.items():
if not isinstance(param, Parameter):
if isinstance(param, Variable):
raise TypeError("can't use a Variable as a module "
"parameter. Convert it to torch.nn.Parameter first.")
if param is not None:
param = Parameter(param)
self._parameters[name] = param
def forward(self, *input):
"""Defines the computation performed at every call.
Should be overriden by all subclasses.
"""
raise NotImplementedError
def register_buffer(self, name, tensor):
"""Adds a persistent buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the persistent state.
Buffers can be accessed as attributes using given names.
Example:
self.register_buffer('running_mean', torch.zeros(num_features))
"""
self._buffers[name] = tensor
def register_parameter(self, name, param):
"""Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or None required)"
.format(torch.typename(param), name))
elif param.creator:
raise ValueError(
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another variable, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
def _apply(self, fn):
for param in self._parameters.values():
if param is not None:
# Variables stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param.grad is not None:
param._grad = fn(param._grad)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def apply(self, fn):
fn(self)
return self
def cuda(self, device_id=None):
"""Moves all model parameters and buffers to the GPU.
Arguments:
device_id (int, optional): if specified, all parameters will be
copied to that device
"""
return self._apply(lambda t: t.cuda(device_id))
def cpu(self, device_id=None):
"""Moves all model parameters and buffers to the CPU."""
return self._apply(lambda t: t.cpu())
def type(self, dst_type):
return self._apply(lambda t: t.type(dst_type))
def float(self):
"""Casts all parameters and buffers to float datatype."""
return self._apply(lambda t: t.float())
def double(self):
"""Casts all parameters and buffers to double datatype."""
return self._apply(lambda t: t.double())
def half(self):
"""Casts all parameters and buffers to half datatype."""
return self._apply(lambda t: t.half())
def register_backward_hook(self, name, hook):
"""Registers a backward hook on the module, under a given name.
The hook will be called every time the gradient w.r.t. module inputs
is computed. The callable should accept two arguments - gradient w.r.t.
the input and gradient w.r.t. the output, where both arguments can be
tuples if the module had multiple inputs or outputs.
The hook should never modify its arguments in-place, but it can
optionally return a new gradient w.r.t. the input, that will be used
in subsequent computation.
"""
assert name not in self._backward_hooks, \
"Trying to register a second backward hook with name {}".format(name)
self._backward_hooks[name] = lambda gi, go: hook(self, gi, go)
def remove_backward_hook(self, name):
"""Removes a backward hook with a given name.
If no such hook exists, a RuntimeError is raised.
"""
assert name in self._backward_hooks, \
"Trying to remove an inexistent backward hook with name {}".format(name)
del self._backward_hooks[name]
def register_forward_hook(self, name, hook):
"""Registers a forward hook on the module, under a given name.
The hook will be called every time :func:`forward` computes an output.
The callable should accept two arguments - module's input and output.
Both should not be modified by the hook.
"""
assert name not in self._forward_hooks, \
"Trying to register a second forward hook with name {}".format(name)
self._forward_hooks[name] = hook
def remove_forward_hook(self, name):
"""Removes a forward hook with a given name.
If no such hook exists, a RuntimeError is raised.
"""
assert name in self._forward_hooks, \
"Trying to remove an inexistent forward hook with name {}".format(name)
del self._forward_hooks[name]
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for name, hook in self._forward_hooks.items():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError("forward hooks should never return any "
"values, but '{}' didn't return None".format(name))
var = result
while not isinstance(var, Variable):
var = var[0]
creator = var.creator
if creator is not None:
creator._backward_hooks = self._backward_hooks
return result
def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter) or (params and name in params):
self.register_parameter(name, value)
else:
object.__setattr__(self, name, value)
def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
else:
object.__delattr__(self, name)
def state_dict(self, destination=None, prefix=''):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Example:
>>> print(module.state_dict().keys())
['bias', 'weight']
"""
if destination is None:
destination = OrderedDict()
for name, param in chain(self._buffers.items(), self._parameters.items()):
if param is not None:
destination[prefix + name] = param
return destination
def load_state_dict(self, state_dict, prefix=''):
"""Replaces module parameters using values from a given state_dict.
This will load all values from the state dict (including such that
weren't registered before loading).
Arguments:
state_dict (dict): A dict containing loaded parameters and
persistent buffers.
"""
for name, param in self._parameters.items():
new_param = state_dict.get(prefix + name, param)
if not isinstance(new_param, Parameter) and new_param is not None:
raise TypeError(
"expected torch.autograd.Parameter for key '{}' (got {})"
.format(prefix + name, torch.typename(new_param)))
self._parameters[name] = new_param
for name, buf in self._buffers.items():
self._buffers[name] = state_dict.get(prefix + name, buf)
def parameters(self, memo=None):
"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Example:
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
if memo is None:
memo = set()
for p in self._parameters.values():
if p is not None and p not in memo:
memo.add(p)
yield p
def children(self):
"""Returns an iterator over children modules."""
if False:
yield
def modules(self, memo=None):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield self
def train(self):
"""Sets the module in training mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = True
return self
def eval(self):
"""Sets the module in evaluation mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = False
return self
def zero_grad(self):
"""Sets gradients of all model parameters to zero."""
for p in self.parameters():
p.grad.zero_()
def share_memory(self):
return self._apply(lambda t: t.share_memory_())