blob: 05247dbcb9970a67a4f30a4da21680f67d0f2a7f [file] [log] [blame]
from collections import OrderedDict
import torch
from torch.autograd import Variable
from .module import Module
class Container(Module):
"""This is the base container class for all neural networks you would define.
You will subclass your container from this class.
In the constructor you define the modules that you would want to use,
and in the __call__ function you use the constructed modules in
your operations.
To make it easier to understand, given is a small example.
```
# Example of using Container
class Net(nn.Container):
def __init__(self):
super(Net, self).__init__(
conv1 = nn.Conv2d(1, 20, 5),
relu = nn.ReLU()
)
def __call__(self, input):
output = self.relu(self.conv1(x))
return output
model = Net()
```
One can also add new modules to a container after construction.
You can do this with the add_module function.
```
# one can add modules to the container after construction
model.add_module('pool1', nn.MaxPool2d(2, 2))
```
The container has one additional method `parameters()` which
returns the list of learnable parameters in the container instance.
"""
def __init__(self, **kwargs):
super(Container, self).__init__()
self.modules = OrderedDict()
for key, value in kwargs.items():
self.add_module(key, value)
def add_module(self, name, module):
if hasattr(self, name):
raise KeyError("attribute already exists '{}'".format(name))
if not isinstance(module, Module) and module is not None:
raise ValueError("{} is not a Module subclass".format(
torch.typename(module)))
self.modules[name] = module
def __getattr__(self, name):
if 'modules' in self.__dict__:
modules = self.__dict__['modules']
if name in modules:
return modules[name]
return Module.__getattr__(self, name)
def parameters(self, memo=None):
if memo is None:
memo = set()
super(Container, self).parameters(memo)
for module in self.modules.values():
for p in module.parameters(memo):
yield p
def _apply(self, fn):
for module in self.modules.values():
module._apply(fn)
return super(Container, self)._apply(fn)
class Sequential(Container):
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
idx = 0
for module in args:
self.add_module(str(idx), module)
idx += 1
def __getitem__(self, idx):
if idx >= len(self.modules):
raise IndexError('index {} is out of range'.format(idx))
it = self.modules.values()
for i in range(idx-1):
it.next()
return it.next()
def forward(self, input):
for module in self.modules.values():
input = module(input)
return input