blob: 20846265d7e8d09c761c4180947950b6f43613d9 [file] [log] [blame]
from itertools import chain
from collections import OrderedDict
import functools
import torch
from ..backends.thnn import backend as thnn_backend
from ..parameter import Parameter
from torch.autograd import Variable
import torch.utils.hooks as hooks
def _addindent(s_, numSpaces):
s = s_.split('\n')
# dont do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
class Module(object):
"""Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call .cuda(), etc.
"""
dump_patches = False
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
for name, param in self._parameters.items():
if not isinstance(param, Parameter):
if isinstance(param, Variable):
raise TypeError("can't use a Variable as a module "
"parameter. Convert it to torch.nn.Parameter first.")
if param is not None:
param = Parameter(param)
self._parameters[name] = param
def forward(self, *input):
"""Defines the computation performed at every call.
Should be overriden by all subclasses.
"""
raise NotImplementedError
def register_buffer(self, name, tensor):
"""Adds a persistent buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the persistent state.
Buffers can be accessed as attributes using given names.
Example:
>>> self.register_buffer('running_mean', torch.zeros(num_features))
"""
self._buffers[name] = tensor
def register_parameter(self, name, param):
"""Adds a parameter to the module.
The parameter can be accessed as an attribute using given name.
"""
if '_parameters' not in self.__dict__:
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or None required)"
.format(torch.typename(param), name))
elif param.creator:
raise ValueError(
"Cannot assign non-leaf Variable to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another variable, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
def add_module(self, name, module):
if hasattr(self, name):
raise KeyError("attribute already exists '{}'".format(name))
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
torch.typename(module)))
self._modules[name] = module
def _apply(self, fn):
for module in self.children():
module._apply(fn)
for param in self._parameters.values():
if param is not None:
# Variables stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data)
if param.grad is not None:
param._grad.data = fn(param._grad.data)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def apply(self, fn):
for module in self.children():
module.apply(fn)
fn(self)
return self
def cuda(self, device_id=None):
"""Moves all model parameters and buffers to the GPU.
Arguments:
device_id (int, optional): if specified, all parameters will be
copied to that device
"""
return self._apply(lambda t: t.cuda(device_id))
def cpu(self, device_id=None):
"""Moves all model parameters and buffers to the CPU."""
return self._apply(lambda t: t.cpu())
def type(self, dst_type):
return self._apply(lambda t: t.type(dst_type))
def float(self):
"""Casts all parameters and buffers to float datatype."""
return self._apply(lambda t: t.float())
def double(self):
"""Casts all parameters and buffers to double datatype."""
return self._apply(lambda t: t.double())
def half(self):
"""Casts all parameters and buffers to half datatype."""
return self._apply(lambda t: t.half())
def register_backward_hook(self, hook):
"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
hook(module, grad_input, grad_output) -> Tensor or None
The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
module has multiple inputs or outputs. The hook should not modify its
arguments, but it can optionally return a new gradient with respect to
input that will be used in place of :attr:`grad_input` in subsequent
computations.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
"""
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[id(handle)] = hook
return handle
def register_forward_hook(self, hook):
"""Registers a forward hook on the module.
The hook will be called every time :func:`forward` computes an output.
It should have the following signature::
hook(module, input, output) -> None
The hook should not modify the input or output.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
"""
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[id(handle)] = hook
return handle
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
var = result
while not isinstance(var, Variable):
var = var[0]
creator = var.creator
if creator is not None and len(self._backward_hooks) > 0:
if creator._backward_hooks is None:
creator._backward_hooks = OrderedDict()
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
creator._backward_hooks[id(wrapper)] = wrapper
return result
def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
return object.__getattribute__(self, name)
def __setattr__(self, name, value):
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
else:
object.__setattr__(self, name, value)
def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
elif name in self._modules:
del self._modules[name]
else:
object.__delattr__(self, name)
def state_dict(self, destination=None, prefix=''):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Example:
>>> module.state_dict().keys()
['bias', 'weight']
"""
if destination is None:
destination = OrderedDict()
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = param.data
for name, buf in self._buffers.items():
if buf is not None:
destination[prefix + name] = buf
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.')
return destination
def load_state_dict(self, state_dict):
"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. The keys of :attr:`state_dict` must
exactly match the keys returned by this module's :func:`state_dict()`
fuction.
Arguments:
state_dict (dict): A dict containing parameters and
persistent buffers.
"""
own_state = self.state_dict()
for name, param in state_dict.items():
if name not in own_state:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
def parameters(self, memo=None):
"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Example:
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
if memo is None:
memo = set()
for p in self._parameters.values():
if p is not None and p not in memo:
memo.add(p)
yield p
for module in self.children():
for p in module.parameters(memo):
yield p
def children(self):
"""Returns an iterator over children modules."""
memo = set()
for module in self._modules.values():
if module is not None and module not in memo:
memo.add(module)
yield module
def modules(self, memo=None):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield self
for module in self.children():
for m in module.modules(memo):
yield m
def train(self):
"""Sets the module in training mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = True
for module in self.children():
module.train()
return self
def eval(self):
"""Sets the module in evaluation mode.
This has any effect only on modules such as Dropout or BatchNorm.
"""
self.training = False
for module in self.children():
module.eval()
return self
def zero_grad(self):
"""Sets gradients of all model parameters to zero."""
for p in self.parameters():
p.grad.data.zero_()
def share_memory(self):
return self._apply(lambda t: t.share_memory_())
def __repr__(self):
tmpstr = self.__class__.__name__ + ' (\n'
for key, module in self._modules.items():
modstr = module.__repr__()
modstr = _addindent(modstr, 2)
tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n'
tmpstr = tmpstr + ')'
return tmpstr