| from collections import OrderedDict |
| import string |
| import torch |
| import warnings |
| from .module import Module |
| |
| |
| class Container(Module): |
| |
| def __init__(self, **kwargs): |
| super(Container, self).__init__() |
| # DeprecationWarning is ignored by default <sigh> |
| warnings.warn("nn.Container is deprecated. All of it's functionality " |
| "is now implemented in nn.Module. Subclass that instead.") |
| for key, value in kwargs.items(): |
| self.add_module(key, value) |
| |
| |
| class Sequential(Module): |
| """A sequential container. |
| Modules will be added to it in the order they are passed in the constructor. |
| Alternatively, an ordered dict of modules can also be passed in. |
| |
| To make it easier to understand, given is a small example:: |
| |
| # Example of using Sequential |
| model = nn.Sequential( |
| nn.Conv2d(1,20,5), |
| nn.ReLU(), |
| nn.Conv2d(20,64,5), |
| nn.ReLU() |
| ) |
| |
| # Example of using Sequential with OrderedDict |
| model = nn.Sequential(OrderedDict([ |
| ('conv1', nn.Conv2d(1,20,5)), |
| ('relu1', nn.ReLU()), |
| ('conv2', nn.Conv2d(20,64,5)), |
| ('relu2', nn.ReLU()) |
| ])) |
| """ |
| |
| def __init__(self, *args): |
| super(Sequential, self).__init__() |
| if len(args) == 1 and isinstance(args[0], OrderedDict): |
| for key, module in args[0].items(): |
| self.add_module(key, module) |
| else: |
| for idx, module in enumerate(args): |
| self.add_module(str(idx), module) |
| |
| def __getitem__(self, idx): |
| if not (-len(self) <= idx < len(self)): |
| raise IndexError('index {} is out of range'.format(idx)) |
| if idx < 0: |
| idx += len(self) |
| it = iter(self._modules.values()) |
| for i in range(idx): |
| next(it) |
| return next(it) |
| |
| def __len__(self): |
| return len(self._modules) |
| |
| def forward(self, input): |
| for module in self._modules.values(): |
| input = module(input) |
| return input |
| |
| |
| class ModuleList(Module): |
| """Holds submodules in a list. |
| |
| ModuleList can be indexed like a regular Python list, but modules it contains |
| are properly registered, and will be visible by all Module methods. |
| |
| Arguments: |
| modules (list, optional): a list of modules to add |
| |
| Example:: |
| |
| class MyModule(nn.Module): |
| def __init__(self): |
| super(MyModule, self).__init__() |
| self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) |
| |
| def forward(self, x): |
| # ModuleList can act as an iterable, or be indexed using ints |
| for i, l in enumerate(self.linears): |
| x = self.linears[i // 2](x) + l(x) |
| return x |
| """ |
| |
| def __init__(self, modules=None): |
| super(ModuleList, self).__init__() |
| if modules is not None: |
| self += modules |
| |
| def __getitem__(self, idx): |
| if not (-len(self) <= idx < len(self)): |
| raise IndexError('index {} is out of range'.format(idx)) |
| if idx < 0: |
| idx += len(self) |
| return self._modules[str(idx)] |
| |
| def __setitem__(self, idx, module): |
| return setattr(self, str(idx), module) |
| |
| def __len__(self): |
| return len(self._modules) |
| |
| def __iter__(self): |
| return iter(self._modules.values()) |
| |
| def __iadd__(self, modules): |
| return self.extend(modules) |
| |
| def append(self, module): |
| """Appends a given module at the end of the list. |
| |
| Arguments: |
| module (nn.Module): module to append |
| """ |
| self.add_module(str(len(self)), module) |
| return self |
| |
| def extend(self, modules): |
| """Appends modules from a Python list at the end. |
| |
| Arguments: |
| modules (list): list of modules to append |
| """ |
| if not isinstance(modules, list): |
| raise TypeError("ModuleList.extend should be called with a " |
| "list, but got " + type(modules).__name__) |
| offset = len(self) |
| for i, module in enumerate(modules): |
| self.add_module(str(offset + i), module) |
| return self |
| |
| |
| class ParameterList(Module): |
| """Holds parameters in a list. |
| |
| ParameterList can be indexed like a regular Python list, but parameters it contains |
| are properly registered, and will be visible by all Module methods. |
| |
| Arguments: |
| modules (list, optional): a list of :class:`nn.Parameter`` to add |
| |
| Example:: |
| |
| class MyModule(nn.Module): |
| def __init__(self): |
| super(MyModule, self).__init__() |
| self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)]) |
| |
| def forward(self, x): |
| # ModuleList can act as an iterable, or be indexed using ints |
| for i, p in enumerate(self.params): |
| x = self.params[i // 2].mm(x) + p.mm(x) |
| return x |
| """ |
| |
| def __init__(self, parameters=None): |
| super(ParameterList, self).__init__() |
| if parameters is not None: |
| self += parameters |
| |
| def __getitem__(self, idx): |
| if not (-len(self) <= idx < len(self)): |
| raise IndexError('index {} is out of range'.format(idx)) |
| if idx < 0: |
| idx += len(self) |
| return self._parameters[str(idx)] |
| |
| def __setitem__(self, idx, param): |
| return self.register_parameter(str(idx), param) |
| |
| def __len__(self): |
| return len(self._parameters) |
| |
| def __iter__(self): |
| return iter(self._parameters.values()) |
| |
| def __iadd__(self, parameters): |
| return self.extend(parameters) |
| |
| def append(self, parameter): |
| """Appends a given parameter at the end of the list. |
| |
| Arguments: |
| parameter (nn.Parameter): parameter to append |
| """ |
| self.register_parameter(str(len(self)), parameter) |
| return self |
| |
| def extend(self, parameters): |
| """Appends parameters from a Python list at the end. |
| |
| Arguments: |
| parameters (list): list of parameters to append |
| """ |
| if not isinstance(parameters, list): |
| raise TypeError("ParameterList.extend should be called with a " |
| "list, but got " + type(parameters).__name__) |
| offset = len(self) |
| for i, param in enumerate(parameters): |
| self.register_parameter(str(offset + i), param) |
| return self |