| from collections import OrderedDict |
| import functools |
| import itertools |
| |
| import torch |
| from ..backends.thnn import backend as thnn_backend |
| from ..parameter import Parameter |
| import torch.utils.hooks as hooks |
| |
| |
| def _addindent(s_, numSpaces): |
| s = s_.split('\n') |
| # don't do anything for single-line stuff |
| if len(s) == 1: |
| return s_ |
| first = s.pop(0) |
| s = [(numSpaces * ' ') + line for line in s] |
| s = '\n'.join(s) |
| s = first + '\n' + s |
| return s |
| |
| |
| def _if_float_tensor(fn): |
| ''' |
| Calls `fn` on a value `t` only if `t` is a float tensor, or not a tensor (in |
| which case it's a module, as part of a recursive call to apply()). |
| ''' |
| def apply(t): |
| if not isinstance(t, torch.Tensor) or t.is_floating_point(): |
| return fn(t) |
| return t |
| return apply |
| |
| |
| class Module(object): |
| r"""Base class for all neural network modules. |
| |
| Your models should also subclass this class. |
| |
| Modules can also contain other Modules, allowing to nest them in |
| a tree structure. You can assign the submodules as regular attributes:: |
| |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| class Model(nn.Module): |
| def __init__(self): |
| super(Model, self).__init__() |
| self.conv1 = nn.Conv2d(1, 20, 5) |
| self.conv2 = nn.Conv2d(20, 20, 5) |
| |
| def forward(self, x): |
| x = F.relu(self.conv1(x)) |
| return F.relu(self.conv2(x)) |
| |
| Submodules assigned in this way will be registered, and will have their |
| parameters converted too when you call :meth:`to`, etc. |
| """ |
| |
| dump_patches = False |
| |
| r"""This allows better BC support for :meth:`load_state_dict`. In |
| :meth:`state_dict`, the version number will be saved as in the attribute |
| `_metadata` of the returned state dict, and thus pickled. `_metadata` is a |
| dictionary with keys that follow the naming convention of state dict. See |
| ``_load_from_state_dict`` on how to use this information in loading. |
| |
| If new parameters/buffers are added/removed from a module, this number shall |
| be bumped, and the module's `_load_from_state_dict` method can compare the |
| version number and do appropriate changes if the state dict is from before |
| the change.""" |
| _version = 1 |
| |
| def __init__(self): |
| self._backend = thnn_backend |
| self._parameters = OrderedDict() |
| self._buffers = OrderedDict() |
| self._backward_hooks = OrderedDict() |
| self._forward_hooks = OrderedDict() |
| self._forward_pre_hooks = OrderedDict() |
| self._state_dict_hooks = OrderedDict() |
| self._load_state_dict_pre_hooks = OrderedDict() |
| self._modules = OrderedDict() |
| self.training = True |
| |
| def forward(self, *input): |
| r"""Defines the computation performed at every call. |
| |
| Should be overridden by all subclasses. |
| |
| .. note:: |
| Although the recipe for forward pass needs to be defined within |
| this function, one should call the :class:`Module` instance afterwards |
| instead of this since the former takes care of running the |
| registered hooks while the latter silently ignores them. |
| """ |
| raise NotImplementedError |
| |
| def register_buffer(self, name, tensor): |
| r"""Adds a persistent buffer to the module. |
| |
| This is typically used to register a buffer that should not to be |
| considered a model parameter. For example, BatchNorm's ``running_mean`` |
| is not a parameter, but is part of the persistent state. |
| |
| Buffers can be accessed as attributes using given names. |
| |
| Args: |
| name (string): name of the buffer. The buffer can be accessed |
| from this module using the given name |
| tensor (Tensor): buffer to be registered. |
| |
| Example:: |
| |
| >>> self.register_buffer('running_mean', torch.zeros(num_features)) |
| |
| """ |
| if not isinstance(name, torch._six.string_classes): |
| raise TypeError("buffer name should be a string. " |
| "Got {}".format(torch.typename(name))) |
| elif '.' in name: |
| raise KeyError("buffer name can't contain \".\"") |
| elif name == '': |
| raise KeyError("buffer name can't be empty string \"\"") |
| elif hasattr(self, name) and name not in self._buffers: |
| raise KeyError("attribute '{}' already exists".format(name)) |
| elif tensor is not None and not isinstance(tensor, torch.Tensor): |
| raise TypeError("cannot assign '{}' object to buffer '{}' " |
| "(torch Tensor or None required)" |
| .format(torch.typename(tensor), name)) |
| else: |
| self._buffers[name] = tensor |
| |
| def register_parameter(self, name, param): |
| r"""Adds a parameter to the module. |
| |
| The parameter can be accessed as an attribute using given name. |
| |
| Args: |
| name (string): name of the parameter. The parameter can be accessed |
| from this module using the given name |
| parameter (Parameter): parameter to be added to the module. |
| """ |
| if '_parameters' not in self.__dict__: |
| raise AttributeError( |
| "cannot assign parameter before Module.__init__() call") |
| |
| elif not isinstance(name, torch._six.string_classes): |
| raise TypeError("parameter name should be a string. " |
| "Got {}".format(torch.typename(name))) |
| elif '.' in name: |
| raise KeyError("parameter name can't contain \".\"") |
| elif name == '': |
| raise KeyError("parameter name can't be empty string \"\"") |
| elif hasattr(self, name) and name not in self._parameters: |
| raise KeyError("attribute '{}' already exists".format(name)) |
| |
| if param is None: |
| self._parameters[name] = None |
| elif not isinstance(param, Parameter): |
| raise TypeError("cannot assign '{}' object to parameter '{}' " |
| "(torch.nn.Parameter or None required)" |
| .format(torch.typename(param), name)) |
| elif param.grad_fn: |
| raise ValueError( |
| "Cannot assign non-leaf Tensor to parameter '{0}'. Model " |
| "parameters must be created explicitly. To express '{0}' " |
| "as a function of another Tensor, compute the value in " |
| "the forward() method.".format(name)) |
| else: |
| self._parameters[name] = param |
| |
| def add_module(self, name, module): |
| r"""Adds a child module to the current module. |
| |
| The module can be accessed as an attribute using the given name. |
| |
| Args: |
| name (string): name of the child module. The child module can be |
| accessed from this module using the given name |
| parameter (Module): child module to be added to the module. |
| """ |
| if not isinstance(module, Module) and module is not None: |
| raise TypeError("{} is not a Module subclass".format( |
| torch.typename(module))) |
| elif not isinstance(name, torch._six.string_classes): |
| raise TypeError("module name should be a string. Got {}".format( |
| torch.typename(name))) |
| elif hasattr(self, name) and name not in self._modules: |
| raise KeyError("attribute '{}' already exists".format(name)) |
| elif '.' in name: |
| raise KeyError("module name can't contain \".\"") |
| elif name == '': |
| raise KeyError("module name can't be empty string \"\"") |
| self._modules[name] = module |
| |
| def _apply(self, fn): |
| for module in self.children(): |
| fn(module) |
| |
| for param in self._parameters.values(): |
| if param is not None: |
| # Tensors stored in modules are graph leaves, and we don't |
| # want to create copy nodes, so we have to unpack the data. |
| param.data = fn(param.data) |
| if param._grad is not None: |
| param._grad.data = fn(param._grad.data) |
| |
| for key, buf in self._buffers.items(): |
| if buf is not None: |
| self._buffers[key] = fn(buf) |
| |
| return self |
| |
| def apply(self, fn): |
| r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) |
| as well as self. Typical use includes initializing the parameters of a model |
| (see also :ref:`torch-nn-init`). |
| |
| Args: |
| fn (:class:`Module` -> None): function to be applied to each submodule |
| |
| Returns: |
| Module: self |
| |
| Example:: |
| |
| >>> def init_weights(m): |
| print(m) |
| if type(m) == nn.Linear: |
| m.weight.data.fill_(1.0) |
| print(m.weight) |
| |
| >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) |
| >>> net.apply(init_weights) |
| Linear(in_features=2, out_features=2, bias=True) |
| Parameter containing: |
| tensor([[ 1., 1.], |
| [ 1., 1.]]) |
| Linear(in_features=2, out_features=2, bias=True) |
| Parameter containing: |
| tensor([[ 1., 1.], |
| [ 1., 1.]]) |
| Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| """ |
| for module in self.children(): |
| module.apply(fn) |
| fn(self) |
| return self |
| |
| def cuda(self, device=None): |
| r"""Moves all model parameters and buffers to the GPU. |
| |
| This also makes associated parameters and buffers different objects. So |
| it should be called before constructing optimizer if the module will |
| live on GPU while being optimized. |
| |
| Arguments: |
| device (int, optional): if specified, all parameters will be |
| copied to that device |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.cuda(device)) |
| |
| def cpu(self): |
| r"""Moves all model parameters and buffers to the CPU. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.cpu()) |
| |
| def type(self, dst_type): |
| r"""Casts all parameters and buffers to :attr:`dst_type`. |
| |
| Arguments: |
| dst_type (type or string): the desired type |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.type(dst_type)) |
| |
| def float(self): |
| r"""Casts all floating point parameters and buffers to float datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(_if_float_tensor(lambda t: t.float())) |
| |
| def double(self): |
| r"""Casts all floating point parameters and buffers to ``double`` datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(_if_float_tensor(lambda t: t.double())) |
| |
| def half(self): |
| r"""Casts all floating point parameters and buffers to ``half`` datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(_if_float_tensor(lambda t: t.half())) |
| |
| def to(self, *args, **kwargs): |
| r"""Moves and/or casts the parameters and buffers. |
| |
| This can be called as |
| |
| .. function:: to(device=None, dtype=None, non_blocking=False) |
| |
| .. function:: to(dtype, non_blocking=False) |
| |
| .. function:: to(tensor, non_blocking=False) |
| |
| Its signature is similar to :meth:`torch.Tensor.to`, but only accepts |
| floating point desired :attr:`dtype` s. In addition, this method will |
| only cast the floating point parameters and buffers to :attr:`dtype` |
| (if given). The integral parameters and buffers will be moved |
| :attr:`device`, if that is given, but with dtypes unchanged. When |
| :attr:`non_blocking` is set, it tries to convert/move asynchronously |
| with respect to the host if possible, e.g., moving CPU Tensors with |
| pinned memory to CUDA devices. |
| |
| See below for examples. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Args: |
| device (:class:`torch.device`): the desired device of the parameters |
| and buffers in this module |
| dtype (:class:`torch.dtype`): the desired floating point type of |
| the floating point parameters and buffers in this module |
| tensor (torch.Tensor): Tensor whose dtype and device are the desired |
| dtype and device for all parameters and buffers in this module |
| |
| Returns: |
| Module: self |
| |
| Example:: |
| |
| >>> linear = nn.Linear(2, 2) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1913, -0.3420], |
| [-0.5113, -0.2325]]) |
| >>> linear.to(torch.double) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1913, -0.3420], |
| [-0.5113, -0.2325]], dtype=torch.float64) |
| >>> gpu1 = torch.device("cuda:1") |
| >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1914, -0.3420], |
| [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') |
| >>> cpu = torch.device("cpu") |
| >>> linear.to(cpu) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1914, -0.3420], |
| [-0.5112, -0.2324]], dtype=torch.float16) |
| |
| """ |
| |
| device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) |
| |
| if dtype is not None: |
| if not dtype.is_floating_point: |
| raise TypeError('nn.Module.to only accepts floating point ' |
| 'dtypes, but got desired dtype={}'.format(dtype)) |
| |
| def convert(t): |
| if isinstance(t, torch.Tensor): |
| return t.to(device, dtype if t.is_floating_point() else None, non_blocking) |
| return t.to(device, dtype, non_blocking) |
| |
| return self._apply(convert) |
| |
| def register_backward_hook(self, hook): |
| r"""Registers a backward hook on the module. |
| |
| The hook will be called every time the gradients with respect to module |
| inputs are computed. The hook should have the following signature:: |
| |
| hook(module, grad_input, grad_output) -> Tensor or None |
| |
| The :attr:`grad_input` and :attr:`grad_output` may be tuples if the |
| module has multiple inputs or outputs. The hook should not modify its |
| arguments, but it can optionally return a new gradient with respect to |
| input that will be used in place of :attr:`grad_input` in subsequent |
| computations. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| .. warning :: |
| |
| The current implementation will not have the presented behavior |
| for complex :class:`Module` that perform many operations. |
| In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only |
| contain the gradients for a subset of the inputs and outputs. |
| For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` |
| directly on a specific input or output to get the required gradients. |
| |
| """ |
| handle = hooks.RemovableHandle(self._backward_hooks) |
| self._backward_hooks[handle.id] = hook |
| return handle |
| |
| def register_forward_pre_hook(self, hook): |
| r"""Registers a forward pre-hook on the module. |
| |
| The hook will be called every time before :func:`forward` is invoked. |
| It should have the following signature:: |
| |
| hook(module, input) -> None |
| |
| The hook should not modify the input. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._forward_pre_hooks) |
| self._forward_pre_hooks[handle.id] = hook |
| return handle |
| |
| def register_forward_hook(self, hook): |
| r"""Registers a forward hook on the module. |
| |
| The hook will be called every time after :func:`forward` has computed an output. |
| It should have the following signature:: |
| |
| hook(module, input, output) -> None |
| |
| The hook should not modify the input or output. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._forward_hooks) |
| self._forward_hooks[handle.id] = hook |
| return handle |
| |
| def _tracing_name(self, tracing_state): |
| if not tracing_state._traced_module_stack: |
| return None |
| module = tracing_state._traced_module_stack[-1] |
| for name, child in module.named_children(): |
| if child is self: |
| return name |
| return None |
| |
| def _slow_forward(self, *input, **kwargs): |
| input_vars = tuple(torch.autograd.function._iter_tensors(input)) |
| tracing_state = torch._C._get_tracing_state() |
| if not tracing_state: |
| return self.forward(*input, **kwargs) |
| if not hasattr(tracing_state, '_traced_module_stack'): |
| tracing_state._traced_module_stack = [] |
| name = self._tracing_name(tracing_state) |
| if name: |
| tracing_state.push_scope('%s[%s]' % (self._get_name(), name)) |
| else: |
| tracing_state.push_scope(self._get_name()) |
| tracing_state._traced_module_stack.append(self) |
| try: |
| result = self.forward(*input, **kwargs) |
| finally: |
| tracing_state.pop_scope() |
| tracing_state._traced_module_stack.pop() |
| return result |
| |
| def __call__(self, *input, **kwargs): |
| for hook in self._forward_pre_hooks.values(): |
| hook(self, input) |
| if torch._C._get_tracing_state(): |
| result = self._slow_forward(*input, **kwargs) |
| else: |
| result = self.forward(*input, **kwargs) |
| for hook in self._forward_hooks.values(): |
| hook_result = hook(self, input, result) |
| if hook_result is not None: |
| raise RuntimeError( |
| "forward hooks should never return any values, but '{}'" |
| "didn't return None".format(hook)) |
| if len(self._backward_hooks) > 0: |
| var = result |
| while not isinstance(var, torch.Tensor): |
| if isinstance(var, dict): |
| var = next((v for v in var.values() if isinstance(v, torch.Tensor))) |
| else: |
| var = var[0] |
| grad_fn = var.grad_fn |
| if grad_fn is not None: |
| for hook in self._backward_hooks.values(): |
| wrapper = functools.partial(hook, self) |
| functools.update_wrapper(wrapper, hook) |
| grad_fn.register_hook(wrapper) |
| return result |
| |
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| # Support loading old checkpoints that don't have the following attrs: |
| if '_forward_pre_hooks' not in self.__dict__: |
| self._forward_pre_hooks = OrderedDict() |
| if '_state_dict_hooks' not in self.__dict__: |
| self._state_dict_hooks = OrderedDict() |
| if '_load_state_dict_pre_hooks' not in self.__dict__: |
| self._load_state_dict_pre_hooks = OrderedDict() |
| |
| def __getattr__(self, name): |
| if '_parameters' in self.__dict__: |
| _parameters = self.__dict__['_parameters'] |
| if name in _parameters: |
| return _parameters[name] |
| if '_buffers' in self.__dict__: |
| _buffers = self.__dict__['_buffers'] |
| if name in _buffers: |
| return _buffers[name] |
| if '_modules' in self.__dict__: |
| modules = self.__dict__['_modules'] |
| if name in modules: |
| return modules[name] |
| raise AttributeError("'{}' object has no attribute '{}'".format( |
| type(self).__name__, name)) |
| |
| def __setattr__(self, name, value): |
| def remove_from(*dicts): |
| for d in dicts: |
| if name in d: |
| del d[name] |
| |
| params = self.__dict__.get('_parameters') |
| if isinstance(value, Parameter): |
| if params is None: |
| raise AttributeError( |
| "cannot assign parameters before Module.__init__() call") |
| remove_from(self.__dict__, self._buffers, self._modules) |
| self.register_parameter(name, value) |
| elif params is not None and name in params: |
| if value is not None: |
| raise TypeError("cannot assign '{}' as parameter '{}' " |
| "(torch.nn.Parameter or None expected)" |
| .format(torch.typename(value), name)) |
| self.register_parameter(name, value) |
| else: |
| modules = self.__dict__.get('_modules') |
| if isinstance(value, Module): |
| if modules is None: |
| raise AttributeError( |
| "cannot assign module before Module.__init__() call") |
| remove_from(self.__dict__, self._parameters, self._buffers) |
| modules[name] = value |
| elif modules is not None and name in modules: |
| if value is not None: |
| raise TypeError("cannot assign '{}' as child module '{}' " |
| "(torch.nn.Module or None expected)" |
| .format(torch.typename(value), name)) |
| modules[name] = value |
| else: |
| buffers = self.__dict__.get('_buffers') |
| if buffers is not None and name in buffers: |
| if value is not None and not isinstance(value, torch.Tensor): |
| raise TypeError("cannot assign '{}' as buffer '{}' " |
| "(torch.Tensor or None expected)" |
| .format(torch.typename(value), name)) |
| buffers[name] = value |
| else: |
| object.__setattr__(self, name, value) |
| |
| def __delattr__(self, name): |
| if name in self._parameters: |
| del self._parameters[name] |
| elif name in self._buffers: |
| del self._buffers[name] |
| elif name in self._modules: |
| del self._modules[name] |
| else: |
| object.__delattr__(self, name) |
| |
| def _register_state_dict_hook(self, hook): |
| r"""These hooks will be called with arguments: `self`, `state_dict`, |
| `prefix`, `local_metadata`, after the `state_dict` of `self` is set. |
| Note that only parameters and buffers of `self` or its children are |
| guaranteed to exist in `state_dict`. The hooks may modify `state_dict` |
| inplace or return a new one. |
| """ |
| handle = hooks.RemovableHandle(self._state_dict_hooks) |
| self._state_dict_hooks[handle.id] = hook |
| return handle |
| |
| def state_dict(self, destination=None, prefix='', keep_vars=False): |
| r"""Returns a dictionary containing a whole state of the module. |
| |
| Both parameters and persistent buffers (e.g. running averages) are |
| included. Keys are corresponding parameter and buffer names. |
| |
| Returns: |
| dict: |
| a dictionary containing a whole state of the module |
| |
| Example:: |
| |
| >>> module.state_dict().keys() |
| ['bias', 'weight'] |
| |
| """ |
| if destination is None: |
| destination = OrderedDict() |
| destination._metadata = OrderedDict() |
| destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) |
| for name, param in self._parameters.items(): |
| if param is not None: |
| destination[prefix + name] = param if keep_vars else param.data |
| for name, buf in self._buffers.items(): |
| if buf is not None: |
| destination[prefix + name] = buf if keep_vars else buf.data |
| for name, module in self._modules.items(): |
| if module is not None: |
| module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) |
| for hook in self._state_dict_hooks.values(): |
| hook_result = hook(self, destination, prefix, local_metadata) |
| if hook_result is not None: |
| destination = hook_result |
| return destination |
| |
| def _register_load_state_dict_pre_hook(self, hook): |
| r"""These hooks will be called with arguments: `state_dict`, `prefix`, |
| `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, |
| `error_msgs`, before loading `state_dict` into `self`. These arguments |
| are exactly the same as those of `_load_from_state_dict`. |
| """ |
| handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) |
| self._load_state_dict_pre_hooks[handle.id] = hook |
| return handle |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| r"""Copies parameters and buffers from :attr:`state_dict` into only |
| this module, but not its descendants. This is called on every submodule |
| in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this |
| module in input :attr:`state_dict` is provided as :attr`local_metadata`. |
| For state dicts without metadata, :attr`local_metadata` is empty. |
| Subclasses can achieve class-specific backward compatible loading using |
| the version number at `local_metadata.get("version", None)`. |
| |
| .. note:: |
| :attr:`state_dict` is not the same object as the input |
| :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So |
| it can be modified. |
| |
| Arguments: |
| state_dict (dict): a dict containing parameters and |
| persistent buffers. |
| prefix (str): the prefix for parameters and buffers used in this |
| module |
| local_metadata (dict): a dict containing the metadata for this moodule. |
| See |
| strict (bool): whether to strictly enforce that the keys in |
| :attr:`state_dict` with :attr:`prefix` match the names of |
| parameters and buffers in this module |
| missing_keys (list of str): if ``strict=False``, add missing keys to |
| this list |
| unexpected_keys (list of str): if ``strict=False``, add unexpected |
| keys to this list |
| error_msgs (list of str): error messages should be added to this |
| list, and will be reported together in |
| :meth:`~torch.nn.Module.load_state_dict` |
| """ |
| for hook in self._load_state_dict_pre_hooks.values(): |
| hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
| |
| local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) |
| local_state = {k: v.data for k, v in local_name_params if v is not None} |
| |
| for name, param in local_state.items(): |
| key = prefix + name |
| if key in state_dict: |
| input_param = state_dict[key] |
| |
| # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ |
| if len(param.shape) == 0 and len(input_param.shape) == 1: |
| input_param = input_param[0] |
| |
| if input_param.shape != param.shape: |
| # local shape should match the one in checkpoint |
| error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' |
| 'the shape in current model is {}.' |
| .format(key, input_param.shape, param.shape)) |
| continue |
| |
| if isinstance(input_param, Parameter): |
| # backwards compatibility for serialized parameters |
| input_param = input_param.data |
| try: |
| param.copy_(input_param) |
| except Exception: |
| error_msgs.append('While copying the parameter named "{}", ' |
| 'whose dimensions in the model are {} and ' |
| 'whose dimensions in the checkpoint are {}.' |
| .format(key, param.size(), input_param.size())) |
| elif strict: |
| missing_keys.append(key) |
| |
| if strict: |
| for key, input_param in state_dict.items(): |
| if key.startswith(prefix): |
| input_name = key[len(prefix):] |
| input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child |
| if input_name not in self._modules and input_name not in local_state: |
| unexpected_keys.append(key) |
| |
| def load_state_dict(self, state_dict, strict=True): |
| r"""Copies parameters and buffers from :attr:`state_dict` into |
| this module and its descendants. If :attr:`strict` is ``True``, then |
| the keys of :attr:`state_dict` must exactly match the keys returned |
| by this module's :meth:`~torch.nn.Module.state_dict` function. |
| |
| Arguments: |
| state_dict (dict): a dict containing parameters and |
| persistent buffers. |
| strict (bool, optional): whether to strictly enforce that the keys |
| in :attr:`state_dict` match the keys returned by this module's |
| :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` |
| """ |
| missing_keys = [] |
| unexpected_keys = [] |
| error_msgs = [] |
| |
| # copy state_dict so _load_from_state_dict can modify it |
| metadata = getattr(state_dict, '_metadata', None) |
| state_dict = state_dict.copy() |
| if metadata is not None: |
| state_dict._metadata = metadata |
| |
| def load(module, prefix=''): |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| module._load_from_state_dict( |
| state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
| for name, child in module._modules.items(): |
| if child is not None: |
| load(child, prefix + name + '.') |
| |
| load(self) |
| |
| if strict: |
| error_msg = '' |
| if len(unexpected_keys) > 0: |
| error_msgs.insert( |
| 0, 'Unexpected key(s) in state_dict: {}. '.format( |
| ', '.join('"{}"'.format(k) for k in unexpected_keys))) |
| if len(missing_keys) > 0: |
| error_msgs.insert( |
| 0, 'Missing key(s) in state_dict: {}. '.format( |
| ', '.join('"{}"'.format(k) for k in missing_keys))) |
| |
| if len(error_msgs) > 0: |
| raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( |
| self.__class__.__name__, "\n\t".join(error_msgs))) |
| |
| def _named_members(self, get_members_fn, prefix='', recurse=True): |
| r"""Helper method for yielding various names + members of modules.""" |
| memo = set() |
| modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] |
| for module_prefix, module in modules: |
| members = get_members_fn(module) |
| for k, v in members: |
| if v is None or v in memo: |
| continue |
| memo.add(v) |
| name = module_prefix + ('.' if module_prefix else '') + k |
| yield name, v |
| |
| def parameters(self, recurse=True): |
| r"""Returns an iterator over module parameters. |
| |
| This is typically passed to an optimizer. |
| |
| Args: |
| recurse (bool): if True, then yields parameters of this module |
| and all submodules. Otherwise, yields only parameters that |
| are direct members of this module. |
| |
| Yields: |
| Parameter: module parameter |
| |
| Example:: |
| |
| >>> for param in model.parameters(): |
| >>> print(type(param.data), param.size()) |
| <class 'torch.FloatTensor'> (20L,) |
| <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) |
| |
| """ |
| for name, param in self.named_parameters(recurse=recurse): |
| yield param |
| |
| def named_parameters(self, prefix='', recurse=True): |
| r"""Returns an iterator over module parameters, yielding both the |
| name of the parameter as well as the parameter itself. |
| |
| Args: |
| prefix (str): prefix to prepend to all parameter names. |
| recurse (bool): if True, then yields parameters of this module |
| and all submodules. Otherwise, yields only parameters that |
| are direct members of this module. |
| |
| Yields: |
| (string, Parameter): Tuple containing the name and parameter |
| |
| Example:: |
| |
| >>> for name, param in self.named_parameters(): |
| >>> if name in ['bias']: |
| >>> print(param.size()) |
| |
| """ |
| gen = self._named_members( |
| lambda module: module._parameters.items(), |
| prefix=prefix, recurse=recurse) |
| for elem in gen: |
| yield elem |
| |
| def buffers(self, recurse=True): |
| r"""Returns an iterator over module buffers. |
| |
| Args: |
| recurse (bool): if True, then yields buffers of this module |
| and all submodules. Otherwise, yields only buffers that |
| are direct members of this module. |
| |
| Yields: |
| torch.Tensor: module buffer |
| |
| Example:: |
| |
| >>> for buf in model.buffers(): |
| >>> print(type(buf.data), buf.size()) |
| <class 'torch.FloatTensor'> (20L,) |
| <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) |
| |
| """ |
| for name, buf in self.named_buffers(recurse=recurse): |
| yield buf |
| |
| def named_buffers(self, prefix='', recurse=True): |
| r"""Returns an iterator over module buffers, yielding both the |
| name of the buffer as well as the buffer itself. |
| |
| Args: |
| prefix (str): prefix to prepend to all buffer names. |
| recurse (bool): if True, then yields buffers of this module |
| and all submodules. Otherwise, yields only buffers that |
| are direct members of this module. |
| |
| Yields: |
| (string, torch.Tensor): Tuple containing the name and buffer |
| |
| Example:: |
| |
| >>> for name, buf in self.named_buffers(): |
| >>> if name in ['running_var']: |
| >>> print(buf.size()) |
| |
| """ |
| gen = self._named_members( |
| lambda module: module._buffers.items(), |
| prefix=prefix, recurse=recurse) |
| for elem in gen: |
| yield elem |
| |
| def children(self): |
| r"""Returns an iterator over immediate children modules. |
| |
| Yields: |
| Module: a child module |
| """ |
| for name, module in self.named_children(): |
| yield module |
| |
| def named_children(self): |
| r"""Returns an iterator over immediate children modules, yielding both |
| the name of the module as well as the module itself. |
| |
| Yields: |
| (string, Module): Tuple containing a name and child module |
| |
| Example:: |
| |
| >>> for name, module in model.named_children(): |
| >>> if name in ['conv4', 'conv5']: |
| >>> print(module) |
| |
| """ |
| memo = set() |
| for name, module in self._modules.items(): |
| if module is not None and module not in memo: |
| memo.add(module) |
| yield name, module |
| |
| def modules(self): |
| r"""Returns an iterator over all modules in the network. |
| |
| Yields: |
| Module: a module in the network |
| |
| Note: |
| Duplicate modules are returned only once. In the following |
| example, ``l`` will be returned only once. |
| |
| Example:: |
| |
| >>> l = nn.Linear(2, 2) |
| >>> net = nn.Sequential(l, l) |
| >>> for idx, m in enumerate(net.modules()): |
| print(idx, '->', m) |
| |
| 0 -> Sequential ( |
| (0): Linear (2 -> 2) |
| (1): Linear (2 -> 2) |
| ) |
| 1 -> Linear (2 -> 2) |
| |
| """ |
| for name, module in self.named_modules(): |
| yield module |
| |
| def named_modules(self, memo=None, prefix=''): |
| r"""Returns an iterator over all modules in the network, yielding |
| both the name of the module as well as the module itself. |
| |
| Yields: |
| (string, Module): Tuple of name and module |
| |
| Note: |
| Duplicate modules are returned only once. In the following |
| example, ``l`` will be returned only once. |
| |
| Example:: |
| |
| >>> l = nn.Linear(2, 2) |
| >>> net = nn.Sequential(l, l) |
| >>> for idx, m in enumerate(net.named_modules()): |
| print(idx, '->', m) |
| |
| 0 -> ('', Sequential ( |
| (0): Linear (2 -> 2) |
| (1): Linear (2 -> 2) |
| )) |
| 1 -> ('0', Linear (2 -> 2)) |
| |
| """ |
| |
| if memo is None: |
| memo = set() |
| if self not in memo: |
| memo.add(self) |
| yield prefix, self |
| for name, module in self._modules.items(): |
| if module is None: |
| continue |
| submodule_prefix = prefix + ('.' if prefix else '') + name |
| for m in module.named_modules(memo, submodule_prefix): |
| yield m |
| |
| def train(self, mode=True): |
| r"""Sets the module in training mode. |
| |
| This has any effect only on certain modules. See documentations of |
| particular modules for details of their behaviors in training/evaluation |
| mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, |
| etc. |
| |
| Returns: |
| Module: self |
| """ |
| self.training = mode |
| for module in self.children(): |
| module.train(mode) |
| return self |
| |
| def eval(self): |
| r"""Sets the module in evaluation mode. |
| |
| This has any effect only on certain modules. See documentations of |
| particular modules for details of their behaviors in training/evaluation |
| mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, |
| etc. |
| """ |
| return self.train(False) |
| |
| def zero_grad(self): |
| r"""Sets gradients of all model parameters to zero.""" |
| for p in self.parameters(): |
| if p.grad is not None: |
| p.grad.detach_() |
| p.grad.zero_() |
| |
| def share_memory(self): |
| return self._apply(lambda t: t.share_memory_()) |
| |
| def _get_name(self): |
| return self.__class__.__name__ |
| |
| def extra_repr(self): |
| r"""Set the extra representation of the module |
| |
| To print customized extra information, you should reimplement |
| this method in your own modules. Both single-line and multi-line |
| strings are acceptable. |
| """ |
| return '' |
| |
| def __repr__(self): |
| # We treat the extra repr like the sub-module, one item per line |
| extra_lines = [] |
| extra_repr = self.extra_repr() |
| # empty string will be split into list [''] |
| if extra_repr: |
| extra_lines = extra_repr.split('\n') |
| child_lines = [] |
| for key, module in self._modules.items(): |
| mod_str = repr(module) |
| mod_str = _addindent(mod_str, 2) |
| child_lines.append('(' + key + '): ' + mod_str) |
| lines = extra_lines + child_lines |
| |
| main_str = self._get_name() + '(' |
| if lines: |
| # simple one-liner info, which most builtin Modules will use |
| if len(extra_lines) == 1 and not child_lines: |
| main_str += extra_lines[0] |
| else: |
| main_str += '\n ' + '\n '.join(lines) + '\n' |
| |
| main_str += ')' |
| return main_str |
| |
| def __dir__(self): |
| module_attrs = dir(self.__class__) |
| attrs = list(self.__dict__.keys()) |
| parameters = list(self._parameters.keys()) |
| modules = list(self._modules.keys()) |
| buffers = list(self._buffers.keys()) |
| keys = module_attrs + attrs + parameters + modules + buffers |
| |
| # Eliminate attrs that are not legal Python variable names |
| keys = [key for key in keys if not key[0].isdigit()] |
| |
| return sorted(keys) |