| from collections import OrderedDict, namedtuple |
| import itertools |
| import warnings |
| import functools |
| |
| import torch |
| from ..parameter import Parameter |
| import torch.utils.hooks as hooks |
| |
| from torch import Tensor, device, dtype |
| from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List |
| from ...utils.hooks import RemovableHandle |
| |
| _grad_t = Union[Tuple[Tensor, ...], Tensor] |
| # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use |
| # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be |
| # the type of the subclass, not the looser type of `Module`. |
| T = TypeVar('T', bound='Module') |
| |
| class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): |
| def __repr__(self): |
| if not self.missing_keys and not self.unexpected_keys: |
| return '<All keys matched successfully>' |
| return super(_IncompatibleKeys, self).__repr__() |
| |
| __str__ = __repr__ |
| |
| |
| class ModuleAttributeError(AttributeError): |
| """ When `__getattr__` raises AttributeError inside a property, |
| AttributeError is raised with the property name instead of the |
| attribute that initially raised AttributeError, making the error |
| message uninformative. Using `ModuleAttributeError` instead |
| fixes this issue.""" |
| |
| |
| def _addindent(s_, numSpaces): |
| s = s_.split('\n') |
| # don't do anything for single-line stuff |
| if len(s) == 1: |
| return s_ |
| first = s.pop(0) |
| s = [(numSpaces * ' ') + line for line in s] |
| s = '\n'.join(s) |
| s = first + '\n' + s |
| return s |
| |
| |
| r"""This tracks hooks common to all modules that are executed before/after |
| calling forward and backward. This is global state used for debugging/profiling |
| purposes""" |
| _global_backward_hooks: Dict[int, Callable] = OrderedDict() |
| _global_is_full_backward_hook: Optional[bool] = None |
| _global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() |
| _global_forward_hooks: Dict[int, Callable] = OrderedDict() |
| |
| |
| def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: |
| r"""Registers a forward pre-hook common to all modules. |
| |
| .. warning :: |
| |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| The hook will be called every time before :func:`forward` is invoked. |
| It should have the following signature:: |
| |
| hook(module, input) -> None or modified input |
| |
| The input contains only the positional arguments given to the module. |
| Keyword arguments won't be passed to the hooks and only to the ``forward``. |
| The hook can modify the input. User can either return a tuple or a |
| single modified value in the hook. We will wrap the value into a tuple |
| if a single value is returned(unless that value is already a tuple). |
| |
| This hook has precedence over the specific module hooks registered with |
| ``register_forward_pre_hook``. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(_global_forward_pre_hooks) |
| _global_forward_pre_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle: |
| r"""Registers a global forward hook for all the modules |
| |
| .. warning :: |
| |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| The hook will be called every time after :func:`forward` has computed an output. |
| It should have the following signature:: |
| |
| hook(module, input, output) -> None or modified output |
| |
| The input contains only the positional arguments given to the module. |
| Keyword arguments won't be passed to the hooks and only to the ``forward``. |
| The hook can modify the output. It can modify the input inplace but |
| it will not have effect on forward since this is called after |
| :func:`forward` is called. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| This hook will be executed before specific module hooks registered with |
| ``register_forward_hook``. |
| """ |
| handle = hooks.RemovableHandle(_global_forward_hooks) |
| _global_forward_hooks[handle.id] = hook |
| return handle |
| |
| def register_module_backward_hook( |
| hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]] |
| ) -> RemovableHandle: |
| r"""Registers a backward hook common to all the modules. |
| |
| This function is deprecated in favor of :meth:`nn.module.register_module_full_backward_hook` |
| and the behavior of this function will change in future versions. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| global _global_is_full_backward_hook |
| if _global_is_full_backward_hook is True: |
| raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " |
| "global Module hook. Please use only one of them.") |
| |
| _global_is_full_backward_hook = False |
| |
| handle = hooks.RemovableHandle(_global_backward_hooks) |
| _global_backward_hooks[handle.id] = hook |
| return handle |
| |
| def register_module_full_backward_hook( |
| hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]] |
| ) -> RemovableHandle: |
| r"""Registers a backward hook common to all the modules. |
| |
| .. warning :: |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| The current implementation will not have the presented behavior |
| for complex :class:`Module` that perform many operations. |
| In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only |
| contain the gradients for a subset of the inputs and outputs. |
| For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` |
| directly on a specific input or output to get the required gradients. |
| |
| The hook will be called every time the gradients with respect to module |
| inputs are computed. The hook should have the following signature:: |
| |
| hook(module, grad_input, grad_output) -> Tensor or None |
| |
| The :attr:`grad_input` and :attr:`grad_output` are tuples. The hook should |
| not modify its arguments, but it can optionally return a new gradient with |
| respect to the input that will be used in place of :attr:`grad_input` in |
| subsequent computations. :attr:`grad_input` will only correspond to the inputs given |
| as positional arguments and all kwarg arguments will not appear in the hook. Entries |
| in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor |
| arguments. |
| |
| Global hooks are called before hooks registered with `register_backward_hook` |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| global _global_is_full_backward_hook |
| if _global_is_full_backward_hook is False: |
| raise RuntimeError("Cannot use both regular backward hooks and full backward hooks as a " |
| "global Module hook. Please use only one of them.") |
| |
| _global_is_full_backward_hook = True |
| |
| handle = hooks.RemovableHandle(_global_backward_hooks) |
| _global_backward_hooks[handle.id] = hook |
| return handle |
| |
| |
| # Trick mypy into not applying contravariance rules to inputs by defining |
| # forward as a value, rather than a function. See also |
| # https://github.com/python/mypy/issues/8795 |
| def _forward_unimplemented(self, *input: Any) -> None: |
| r"""Defines the computation performed at every call. |
| |
| Should be overridden by all subclasses. |
| |
| .. note:: |
| Although the recipe for forward pass needs to be defined within |
| this function, one should call the :class:`Module` instance afterwards |
| instead of this since the former takes care of running the |
| registered hooks while the latter silently ignores them. |
| """ |
| raise NotImplementedError |
| |
| |
| class Module: |
| r"""Base class for all neural network modules. |
| |
| Your models should also subclass this class. |
| |
| Modules can also contain other Modules, allowing to nest them in |
| a tree structure. You can assign the submodules as regular attributes:: |
| |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| class Model(nn.Module): |
| def __init__(self): |
| super(Model, self).__init__() |
| self.conv1 = nn.Conv2d(1, 20, 5) |
| self.conv2 = nn.Conv2d(20, 20, 5) |
| |
| def forward(self, x): |
| x = F.relu(self.conv1(x)) |
| return F.relu(self.conv2(x)) |
| |
| Submodules assigned in this way will be registered, and will have their |
| parameters converted too when you call :meth:`to`, etc. |
| |
| :ivar training: Boolean represents whether this module is in training or |
| evaluation mode. |
| :vartype training: bool |
| """ |
| |
| dump_patches: bool = False |
| |
| r"""This allows better BC support for :meth:`load_state_dict`. In |
| :meth:`state_dict`, the version number will be saved as in the attribute |
| `_metadata` of the returned state dict, and thus pickled. `_metadata` is a |
| dictionary with keys that follow the naming convention of state dict. See |
| ``_load_from_state_dict`` on how to use this information in loading. |
| |
| If new parameters/buffers are added/removed from a module, this number shall |
| be bumped, and the module's `_load_from_state_dict` method can compare the |
| version number and do appropriate changes if the state dict is from before |
| the change.""" |
| _version: int = 1 |
| |
| training: bool |
| _is_full_backward_hook: Optional[bool] |
| |
| def __init__(self): |
| """ |
| Initializes internal Module state, shared by both nn.Module and ScriptModule. |
| """ |
| torch._C._log_api_usage_once("python.nn_module") |
| |
| self.training = True |
| self._parameters = OrderedDict() |
| self._buffers = OrderedDict() |
| self._non_persistent_buffers_set = set() |
| self._backward_hooks = OrderedDict() |
| self._is_full_backward_hook = None |
| self._forward_hooks = OrderedDict() |
| self._forward_pre_hooks = OrderedDict() |
| self._state_dict_hooks = OrderedDict() |
| self._load_state_dict_pre_hooks = OrderedDict() |
| self._modules = OrderedDict() |
| |
| forward: Callable[..., Any] = _forward_unimplemented |
| |
| def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: |
| r"""Adds a buffer to the module. |
| |
| This is typically used to register a buffer that should not to be |
| considered a model parameter. For example, BatchNorm's ``running_mean`` |
| is not a parameter, but is part of the module's state. Buffers, by |
| default, are persistent and will be saved alongside parameters. This |
| behavior can be changed by setting :attr:`persistent` to ``False``. The |
| only difference between a persistent buffer and a non-persistent buffer |
| is that the latter will not be a part of this module's |
| :attr:`state_dict`. |
| |
| Buffers can be accessed as attributes using given names. |
| |
| Args: |
| name (string): name of the buffer. The buffer can be accessed |
| from this module using the given name |
| tensor (Tensor): buffer to be registered. |
| persistent (bool): whether the buffer is part of this module's |
| :attr:`state_dict`. |
| |
| Example:: |
| |
| >>> self.register_buffer('running_mean', torch.zeros(num_features)) |
| |
| """ |
| if persistent is False and isinstance(self, torch.jit.ScriptModule): |
| raise RuntimeError("ScriptModule does not support non-persistent buffers") |
| |
| if '_buffers' not in self.__dict__: |
| raise AttributeError( |
| "cannot assign buffer before Module.__init__() call") |
| elif not isinstance(name, torch._six.string_classes): |
| raise TypeError("buffer name should be a string. " |
| "Got {}".format(torch.typename(name))) |
| elif '.' in name: |
| raise KeyError("buffer name can't contain \".\"") |
| elif name == '': |
| raise KeyError("buffer name can't be empty string \"\"") |
| elif hasattr(self, name) and name not in self._buffers: |
| raise KeyError("attribute '{}' already exists".format(name)) |
| elif tensor is not None and not isinstance(tensor, torch.Tensor): |
| raise TypeError("cannot assign '{}' object to buffer '{}' " |
| "(torch Tensor or None required)" |
| .format(torch.typename(tensor), name)) |
| else: |
| self._buffers[name] = tensor |
| if persistent: |
| self._non_persistent_buffers_set.discard(name) |
| else: |
| self._non_persistent_buffers_set.add(name) |
| |
| def register_parameter(self, name: str, param: Optional[Parameter]) -> None: |
| r"""Adds a parameter to the module. |
| |
| The parameter can be accessed as an attribute using given name. |
| |
| Args: |
| name (string): name of the parameter. The parameter can be accessed |
| from this module using the given name |
| param (Parameter): parameter to be added to the module. |
| """ |
| if '_parameters' not in self.__dict__: |
| raise AttributeError( |
| "cannot assign parameter before Module.__init__() call") |
| |
| elif not isinstance(name, torch._six.string_classes): |
| raise TypeError("parameter name should be a string. " |
| "Got {}".format(torch.typename(name))) |
| elif '.' in name: |
| raise KeyError("parameter name can't contain \".\"") |
| elif name == '': |
| raise KeyError("parameter name can't be empty string \"\"") |
| elif hasattr(self, name) and name not in self._parameters: |
| raise KeyError("attribute '{}' already exists".format(name)) |
| |
| if param is None: |
| self._parameters[name] = None |
| elif not isinstance(param, Parameter): |
| raise TypeError("cannot assign '{}' object to parameter '{}' " |
| "(torch.nn.Parameter or None required)" |
| .format(torch.typename(param), name)) |
| elif param.grad_fn: |
| raise ValueError( |
| "Cannot assign non-leaf Tensor to parameter '{0}'. Model " |
| "parameters must be created explicitly. To express '{0}' " |
| "as a function of another Tensor, compute the value in " |
| "the forward() method.".format(name)) |
| else: |
| self._parameters[name] = param |
| |
| def add_module(self, name: str, module: Optional['Module']) -> None: |
| r"""Adds a child module to the current module. |
| |
| The module can be accessed as an attribute using the given name. |
| |
| Args: |
| name (string): name of the child module. The child module can be |
| accessed from this module using the given name |
| module (Module): child module to be added to the module. |
| """ |
| if not isinstance(module, Module) and module is not None: |
| raise TypeError("{} is not a Module subclass".format( |
| torch.typename(module))) |
| elif not isinstance(name, torch._six.string_classes): |
| raise TypeError("module name should be a string. Got {}".format( |
| torch.typename(name))) |
| elif hasattr(self, name) and name not in self._modules: |
| raise KeyError("attribute '{}' already exists".format(name)) |
| elif '.' in name: |
| raise KeyError("module name can't contain \".\", got: {}".format(name)) |
| elif name == '': |
| raise KeyError("module name can't be empty string \"\"") |
| self._modules[name] = module |
| |
| def _apply(self, fn): |
| for module in self.children(): |
| module._apply(fn) |
| |
| def compute_should_use_set_data(tensor, tensor_applied): |
| if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): |
| # If the new tensor has compatible tensor type as the existing tensor, |
| # the current behavior is to change the tensor in-place using `.data =`, |
| # and the future behavior is to overwrite the existing tensor. However, |
| # changing the current behavior is a BC-breaking change, and we want it |
| # to happen in future releases. So for now we introduce the |
| # `torch.__future__.get_overwrite_module_params_on_conversion()` |
| # global flag to let the user control whether they want the future |
| # behavior of overwriting the existing tensor or not. |
| return not torch.__future__.get_overwrite_module_params_on_conversion() |
| else: |
| return False |
| |
| for key, param in self._parameters.items(): |
| if param is not None: |
| # Tensors stored in modules are graph leaves, and we don't want to |
| # track autograd history of `param_applied`, so we have to use |
| # `with torch.no_grad():` |
| with torch.no_grad(): |
| param_applied = fn(param) |
| should_use_set_data = compute_should_use_set_data(param, param_applied) |
| if should_use_set_data: |
| param.data = param_applied |
| else: |
| assert isinstance(param, Parameter) |
| assert param.is_leaf |
| self._parameters[key] = Parameter(param_applied, param.requires_grad) |
| |
| if param.grad is not None: |
| with torch.no_grad(): |
| grad_applied = fn(param.grad) |
| should_use_set_data = compute_should_use_set_data(param.grad, grad_applied) |
| if should_use_set_data: |
| param.grad.data = grad_applied |
| else: |
| assert param.grad.is_leaf |
| self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad) |
| |
| for key, buf in self._buffers.items(): |
| if buf is not None: |
| self._buffers[key] = fn(buf) |
| |
| return self |
| |
| def apply(self: T, fn: Callable[['Module'], None]) -> T: |
| r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) |
| as well as self. Typical use includes initializing the parameters of a model |
| (see also :ref:`nn-init-doc`). |
| |
| Args: |
| fn (:class:`Module` -> None): function to be applied to each submodule |
| |
| Returns: |
| Module: self |
| |
| Example:: |
| |
| >>> @torch.no_grad() |
| >>> def init_weights(m): |
| >>> print(m) |
| >>> if type(m) == nn.Linear: |
| >>> m.weight.fill_(1.0) |
| >>> print(m.weight) |
| >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) |
| >>> net.apply(init_weights) |
| Linear(in_features=2, out_features=2, bias=True) |
| Parameter containing: |
| tensor([[ 1., 1.], |
| [ 1., 1.]]) |
| Linear(in_features=2, out_features=2, bias=True) |
| Parameter containing: |
| tensor([[ 1., 1.], |
| [ 1., 1.]]) |
| Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| """ |
| for module in self.children(): |
| module.apply(fn) |
| fn(self) |
| return self |
| |
| def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: |
| r"""Moves all model parameters and buffers to the GPU. |
| |
| This also makes associated parameters and buffers different objects. So |
| it should be called before constructing optimizer if the module will |
| live on GPU while being optimized. |
| |
| Args: |
| device (int, optional): if specified, all parameters will be |
| copied to that device |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.cuda(device)) |
| |
| def cpu(self: T) -> T: |
| r"""Moves all model parameters and buffers to the CPU. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.cpu()) |
| |
| def type(self: T, dst_type: Union[dtype, str]) -> T: |
| r"""Casts all parameters and buffers to :attr:`dst_type`. |
| |
| Args: |
| dst_type (type or string): the desired type |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.type(dst_type)) |
| |
| def float(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to float datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.float() if t.is_floating_point() else t) |
| |
| def double(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``double`` datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.double() if t.is_floating_point() else t) |
| |
| def half(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``half`` datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.half() if t.is_floating_point() else t) |
| |
| def bfloat16(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) |
| |
| @overload |
| def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., |
| non_blocking: bool = ...) -> T: |
| ... |
| |
| @overload |
| def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: |
| ... |
| |
| @overload |
| def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: |
| ... |
| |
| def to(self, *args, **kwargs): |
| r"""Moves and/or casts the parameters and buffers. |
| |
| This can be called as |
| |
| .. function:: to(device=None, dtype=None, non_blocking=False) |
| |
| .. function:: to(dtype, non_blocking=False) |
| |
| .. function:: to(tensor, non_blocking=False) |
| |
| .. function:: to(memory_format=torch.channels_last) |
| |
| Its signature is similar to :meth:`torch.Tensor.to`, but only accepts |
| floating point or complex :attr:`dtype`s. In addition, this method will |
| only cast the floating point or complex parameters and buffers to :attr:`dtype` |
| (if given). The integral parameters and buffers will be moved |
| :attr:`device`, if that is given, but with dtypes unchanged. When |
| :attr:`non_blocking` is set, it tries to convert/move asynchronously |
| with respect to the host if possible, e.g., moving CPU Tensors with |
| pinned memory to CUDA devices. |
| |
| See below for examples. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Args: |
| device (:class:`torch.device`): the desired device of the parameters |
| and buffers in this module |
| dtype (:class:`torch.dtype`): the desired floating point or complex dtype of |
| the parameters and buffers in this module |
| tensor (torch.Tensor): Tensor whose dtype and device are the desired |
| dtype and device for all parameters and buffers in this module |
| memory_format (:class:`torch.memory_format`): the desired memory |
| format for 4D parameters and buffers in this module (keyword |
| only argument) |
| |
| Returns: |
| Module: self |
| |
| Examples:: |
| |
| >>> linear = nn.Linear(2, 2) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1913, -0.3420], |
| [-0.5113, -0.2325]]) |
| >>> linear.to(torch.double) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1913, -0.3420], |
| [-0.5113, -0.2325]], dtype=torch.float64) |
| >>> gpu1 = torch.device("cuda:1") |
| >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1914, -0.3420], |
| [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') |
| >>> cpu = torch.device("cpu") |
| >>> linear.to(cpu) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1914, -0.3420], |
| [-0.5112, -0.2324]], dtype=torch.float16) |
| |
| >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.3741+0.j, 0.2382+0.j], |
| [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) |
| >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) |
| tensor([[0.6122+0.j, 0.1150+0.j], |
| [0.6122+0.j, 0.1150+0.j], |
| [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) |
| |
| """ |
| |
| device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) |
| |
| if dtype is not None: |
| if not (dtype.is_floating_point or dtype.is_complex): |
| raise TypeError('nn.Module.to only accepts floating point or complex ' |
| 'dtypes, but got desired dtype={}'.format(dtype)) |
| if dtype.is_complex: |
| warnings.warn( |
| "Complex modules are a new feature under active development whose design may change, " |
| "and some modules might not work as expected when using complex tensors as parameters or buffers. " |
| "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.md " |
| "if a complex module does not work as expected.") |
| |
| def convert(t): |
| if convert_to_format is not None and t.dim() == 4: |
| return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, |
| non_blocking, memory_format=convert_to_format) |
| return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) |
| |
| return self._apply(convert) |
| |
| def register_backward_hook( |
| self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]] |
| ) -> RemovableHandle: |
| r"""Registers a backward hook on the module. |
| |
| This function is deprecated in favor of :meth:`nn.Module.register_full_backward_hook` and |
| the behavior of this function will change in future versions. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| if self._is_full_backward_hook is True: |
| raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " |
| "single Module. Please use only one of them.") |
| |
| self._is_full_backward_hook = False |
| |
| handle = hooks.RemovableHandle(self._backward_hooks) |
| self._backward_hooks[handle.id] = hook |
| return handle |
| |
| def register_full_backward_hook( |
| self, hook: Callable[['Module', _grad_t, _grad_t], Union[None, Tensor]] |
| ) -> RemovableHandle: |
| r"""Registers a backward hook on the module. |
| |
| The hook will be called every time the gradients with respect to module |
| inputs are computed. The hook should have the following signature:: |
| |
| hook(module, grad_input, grad_output) -> tuple(Tensor) or None |
| |
| The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients |
| with respect to the inputs and outputs respectively. The hook should |
| not modify its arguments, but it can optionally return a new gradient with |
| respect to the input that will be used in place of :attr:`grad_input` in |
| subsequent computations. :attr:`grad_input` will only correspond to the inputs given |
| as positional arguments and all kwarg arguments are ignored. Entries |
| in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor |
| arguments. |
| |
| .. warning :: |
| Modifying inputs or outputs inplace is not allowed when using backward hooks and |
| will raise an error. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| if self._is_full_backward_hook is False: |
| raise RuntimeError("Cannot use both regular backward hooks and full backward hooks on a " |
| "single Module. Please use only one of them.") |
| |
| self._is_full_backward_hook = True |
| |
| handle = hooks.RemovableHandle(self._backward_hooks) |
| self._backward_hooks[handle.id] = hook |
| return handle |
| |
| def _get_backward_hooks(self): |
| r"""Returns the backward hooks for use in the call function. |
| It returns two lists, one with the full backward hooks and one with the non-full |
| backward hooks. |
| """ |
| full_backward_hooks: List[Callable] = [] |
| if (_global_is_full_backward_hook is True): |
| full_backward_hooks += _global_backward_hooks.values() |
| if (self._is_full_backward_hook is True): |
| full_backward_hooks += self._backward_hooks.values() |
| |
| non_full_backward_hooks: List[Callable] = [] |
| if (_global_is_full_backward_hook is False): |
| non_full_backward_hooks += _global_backward_hooks.values() |
| if (self._is_full_backward_hook is False): |
| non_full_backward_hooks += self._backward_hooks.values() |
| |
| return full_backward_hooks, non_full_backward_hooks |
| |
| def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): |
| if not isinstance(result, torch.Tensor): |
| if not (isinstance(result, tuple) and all([isinstance(r, torch.Tensor) for r in result])): |
| warnings.warn("Using non-full backward hooks on a Module that does not return a " |
| "single Tensor or a tuple of Tensors is deprecated and will be removed " |
| "in future versions. This hook will be missing some of the grad_output. " |
| "Please use register_full_backward_hook to get the documented behavior.") |
| return |
| else: |
| result = (result,) |
| |
| if not isinstance(inputs, torch.Tensor): |
| if not (isinstance(inputs, tuple) and all([isinstance(i, torch.Tensor) for i in inputs])): |
| warnings.warn("Using non-full backward hooks on a Module that does not take as input a " |
| "single Tensor or a tuple of Tensors is deprecated and will be removed " |
| "in future versions. This hook will be missing some of the grad_input. " |
| "Please use register_full_backward_hook to get the documented behavior.") |
| return |
| else: |
| inputs = (inputs,) |
| |
| # At this point we are sure that inputs and result are tuple of Tensors |
| out_grad_fn = set([r.grad_fn for r in result if r.grad_fn is not None]) |
| if len(out_grad_fn) == 0 or (len(out_grad_fn) == 1 and grad_fn not in out_grad_fn): |
| warnings.warn("Using a non-full backward hook when outputs are nested in python data structure " |
| "is deprecated and will be removed in future versions. This hook will be missing " |
| "some grad_output.") |
| elif len(out_grad_fn) > 1: |
| warnings.warn("Using a non-full backward hook when outputs are generated by different autograd Nodes " |
| "is deprecated and will be removed in future versions. This hook will be missing " |
| "some grad_output. Please use register_full_backward_hook to get the documented behavior.") |
| else: |
| # At this point the grad_ouput part of the hook will most likely be correct |
| inputs_grad_fn = set([i.grad_fn for i in inputs if i.grad_fn is not None]) |
| |
| next_functions = set([n[0] for n in grad_fn.next_functions]) |
| |
| if inputs_grad_fn != next_functions: |
| warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes " |
| "is deprecated and will be removed in future versions. This hook will be missing " |
| "some grad_input. Please use register_full_backward_hook to get the documented " |
| "behavior.") |
| |
| def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle: |
| r"""Registers a forward pre-hook on the module. |
| |
| The hook will be called every time before :func:`forward` is invoked. |
| It should have the following signature:: |
| |
| hook(module, input) -> None or modified input |
| |
| The input contains only the positional arguments given to the module. |
| Keyword arguments won't be passed to the hooks and only to the ``forward``. |
| The hook can modify the input. User can either return a tuple or a |
| single modified value in the hook. We will wrap the value into a tuple |
| if a single value is returned(unless that value is already a tuple). |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._forward_pre_hooks) |
| self._forward_pre_hooks[handle.id] = hook |
| return handle |
| |
| def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: |
| r"""Registers a forward hook on the module. |
| |
| The hook will be called every time after :func:`forward` has computed an output. |
| It should have the following signature:: |
| |
| hook(module, input, output) -> None or modified output |
| |
| The input contains only the positional arguments given to the module. |
| Keyword arguments won't be passed to the hooks and only to the ``forward``. |
| The hook can modify the output. It can modify the input inplace but |
| it will not have effect on forward since this is called after |
| :func:`forward` is called. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._forward_hooks) |
| self._forward_hooks[handle.id] = hook |
| return handle |
| |
| def _slow_forward(self, *input, **kwargs): |
| tracing_state = torch._C._get_tracing_state() |
| if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): |
| return self.forward(*input, **kwargs) |
| recording_scopes = torch.jit._trace._trace_module_map is not None |
| if recording_scopes: |
| # type ignore was added because at this point one knows that |
| # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] |
| name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore |
| if name: |
| tracing_state.push_scope(name) |
| else: |
| recording_scopes = False |
| try: |
| result = self.forward(*input, **kwargs) |
| finally: |
| if recording_scopes: |
| tracing_state.pop_scope() |
| return result |
| |
| def _call_impl(self, *input, **kwargs): |
| # Do not call functions when jit is used |
| full_backward_hooks, non_full_backward_hooks = [], [] |
| if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0: |
| full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() |
| |
| for hook in itertools.chain( |
| _global_forward_pre_hooks.values(), |
| self._forward_pre_hooks.values()): |
| result = hook(self, input) |
| if result is not None: |
| if not isinstance(result, tuple): |
| result = (result,) |
| input = result |
| |
| bw_hook = None |
| if len(full_backward_hooks) > 0: |
| bw_hook = hooks.BackwardHook(self, full_backward_hooks) |
| input = bw_hook.setup_input_hook(input) |
| |
| if torch._C._get_tracing_state(): |
| result = self._slow_forward(*input, **kwargs) |
| else: |
| result = self.forward(*input, **kwargs) |
| for hook in itertools.chain( |
| _global_forward_hooks.values(), |
| self._forward_hooks.values()): |
| hook_result = hook(self, input, result) |
| if hook_result is not None: |
| result = hook_result |
| |
| if bw_hook: |
| result = bw_hook.setup_output_hook(result) |
| |
| # Handle the non-full backward hooks |
| if len(non_full_backward_hooks) > 0: |
| var = result |
| while not isinstance(var, torch.Tensor): |
| if isinstance(var, dict): |
| var = next((v for v in var.values() if isinstance(v, torch.Tensor))) |
| else: |
| var = var[0] |
| grad_fn = var.grad_fn |
| if grad_fn is not None: |
| for hook in non_full_backward_hooks: |
| wrapper = functools.partial(hook, self) |
| functools.update_wrapper(wrapper, hook) |
| grad_fn.register_hook(wrapper) |
| self._maybe_warn_non_full_backward_hook(input, result, grad_fn) |
| |
| return result |
| |
| __call__ : Callable[..., Any] = _call_impl |
| |
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| # Support loading old checkpoints that don't have the following attrs: |
| if '_forward_pre_hooks' not in self.__dict__: |
| self._forward_pre_hooks = OrderedDict() |
| if '_state_dict_hooks' not in self.__dict__: |
| self._state_dict_hooks = OrderedDict() |
| if '_load_state_dict_pre_hooks' not in self.__dict__: |
| self._load_state_dict_pre_hooks = OrderedDict() |
| if '_non_persistent_buffers_set' not in self.__dict__: |
| self._non_persistent_buffers_set = set() |
| if '_is_full_backward_hook' not in self.__dict__: |
| self._is_full_backward_hook = None |
| |
| def __getattr__(self, name: str) -> Union[Tensor, 'Module']: |
| if '_parameters' in self.__dict__: |
| _parameters = self.__dict__['_parameters'] |
| if name in _parameters: |
| return _parameters[name] |
| if '_buffers' in self.__dict__: |
| _buffers = self.__dict__['_buffers'] |
| if name in _buffers: |
| return _buffers[name] |
| if '_modules' in self.__dict__: |
| modules = self.__dict__['_modules'] |
| if name in modules: |
| return modules[name] |
| raise ModuleAttributeError("'{}' object has no attribute '{}'".format( |
| type(self).__name__, name)) |
| |
| def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: |
| def remove_from(*dicts_or_sets): |
| for d in dicts_or_sets: |
| if name in d: |
| if isinstance(d, dict): |
| del d[name] |
| else: |
| d.discard(name) |
| |
| params = self.__dict__.get('_parameters') |
| if isinstance(value, Parameter): |
| if params is None: |
| raise AttributeError( |
| "cannot assign parameters before Module.__init__() call") |
| remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) |
| self.register_parameter(name, value) |
| elif params is not None and name in params: |
| if value is not None: |
| raise TypeError("cannot assign '{}' as parameter '{}' " |
| "(torch.nn.Parameter or None expected)" |
| .format(torch.typename(value), name)) |
| self.register_parameter(name, value) |
| else: |
| modules = self.__dict__.get('_modules') |
| if isinstance(value, Module): |
| if modules is None: |
| raise AttributeError( |
| "cannot assign module before Module.__init__() call") |
| remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) |
| modules[name] = value |
| elif modules is not None and name in modules: |
| if value is not None: |
| raise TypeError("cannot assign '{}' as child module '{}' " |
| "(torch.nn.Module or None expected)" |
| .format(torch.typename(value), name)) |
| modules[name] = value |
| else: |
| buffers = self.__dict__.get('_buffers') |
| if buffers is not None and name in buffers: |
| if value is not None and not isinstance(value, torch.Tensor): |
| raise TypeError("cannot assign '{}' as buffer '{}' " |
| "(torch.Tensor or None expected)" |
| .format(torch.typename(value), name)) |
| buffers[name] = value |
| else: |
| object.__setattr__(self, name, value) |
| |
| def __delattr__(self, name): |
| if name in self._parameters: |
| del self._parameters[name] |
| elif name in self._buffers: |
| del self._buffers[name] |
| self._non_persistent_buffers_set.discard(name) |
| elif name in self._modules: |
| del self._modules[name] |
| else: |
| object.__delattr__(self, name) |
| |
| def _register_state_dict_hook(self, hook): |
| r"""These hooks will be called with arguments: `self`, `state_dict`, |
| `prefix`, `local_metadata`, after the `state_dict` of `self` is set. |
| Note that only parameters and buffers of `self` or its children are |
| guaranteed to exist in `state_dict`. The hooks may modify `state_dict` |
| inplace or return a new one. |
| """ |
| handle = hooks.RemovableHandle(self._state_dict_hooks) |
| self._state_dict_hooks[handle.id] = hook |
| return handle |
| |
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| r"""Saves module state to `destination` dictionary, containing a state |
| of the module, but not its descendants. This is called on every |
| submodule in :meth:`~torch.nn.Module.state_dict`. |
| |
| In rare cases, subclasses can achieve class-specific behavior by |
| overriding this method with custom logic. |
| |
| Args: |
| destination (dict): a dict where state will be stored |
| prefix (str): the prefix for parameters and buffers used in this |
| module |
| """ |
| for name, param in self._parameters.items(): |
| if param is not None: |
| destination[prefix + name] = param if keep_vars else param.detach() |
| for name, buf in self._buffers.items(): |
| if buf is not None and name not in self._non_persistent_buffers_set: |
| destination[prefix + name] = buf if keep_vars else buf.detach() |
| |
| # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns |
| # back that same object. But if they pass nothing, an `OrederedDict` is created and returned. |
| T_destination = TypeVar('T_destination', bound=Mapping[str, Tensor]) |
| |
| @overload |
| def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: |
| ... |
| |
| # TODO: annotate with OrderedDict not Dict, but there is a problem: |
| # https://docs.python.org/3/library/typing.html#typing.OrderedDict |
| @overload |
| def state_dict(self, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Tensor]: |
| ... |
| |
| def state_dict(self, destination=None, prefix='', keep_vars=False): |
| r"""Returns a dictionary containing a whole state of the module. |
| |
| Both parameters and persistent buffers (e.g. running averages) are |
| included. Keys are corresponding parameter and buffer names. |
| |
| Returns: |
| dict: |
| a dictionary containing a whole state of the module |
| |
| Example:: |
| |
| >>> module.state_dict().keys() |
| ['bias', 'weight'] |
| |
| """ |
| if destination is None: |
| destination = OrderedDict() |
| destination._metadata = OrderedDict() |
| destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) |
| self._save_to_state_dict(destination, prefix, keep_vars) |
| for name, module in self._modules.items(): |
| if module is not None: |
| module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) |
| for hook in self._state_dict_hooks.values(): |
| hook_result = hook(self, destination, prefix, local_metadata) |
| if hook_result is not None: |
| destination = hook_result |
| return destination |
| |
| def _register_load_state_dict_pre_hook(self, hook): |
| r"""These hooks will be called with arguments: `state_dict`, `prefix`, |
| `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, |
| `error_msgs`, before loading `state_dict` into `self`. These arguments |
| are exactly the same as those of `_load_from_state_dict`. |
| """ |
| handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) |
| self._load_state_dict_pre_hooks[handle.id] = hook |
| return handle |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| r"""Copies parameters and buffers from :attr:`state_dict` into only |
| this module, but not its descendants. This is called on every submodule |
| in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this |
| module in input :attr:`state_dict` is provided as :attr:`local_metadata`. |
| For state dicts without metadata, :attr:`local_metadata` is empty. |
| Subclasses can achieve class-specific backward compatible loading using |
| the version number at `local_metadata.get("version", None)`. |
| |
| .. note:: |
| :attr:`state_dict` is not the same object as the input |
| :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So |
| it can be modified. |
| |
| Args: |
| state_dict (dict): a dict containing parameters and |
| persistent buffers. |
| prefix (str): the prefix for parameters and buffers used in this |
| module |
| local_metadata (dict): a dict containing the metadata for this module. |
| See |
| strict (bool): whether to strictly enforce that the keys in |
| :attr:`state_dict` with :attr:`prefix` match the names of |
| parameters and buffers in this module |
| missing_keys (list of str): if ``strict=True``, add missing keys to |
| this list |
| unexpected_keys (list of str): if ``strict=True``, add unexpected |
| keys to this list |
| error_msgs (list of str): error messages should be added to this |
| list, and will be reported together in |
| :meth:`~torch.nn.Module.load_state_dict` |
| """ |
| for hook in self._load_state_dict_pre_hooks.values(): |
| hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) |
| |
| persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} |
| local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) |
| local_state = {k: v for k, v in local_name_params if v is not None} |
| |
| for name, param in local_state.items(): |
| key = prefix + name |
| if key in state_dict: |
| input_param = state_dict[key] |
| # This is used to avoid copying uninitialized parameters into |
| # non-lazy modules, since they dont have the hook to do the checks |
| # in such case, it will error when accessing the .shape attribute. |
| is_param_lazy = isinstance(param, torch.nn.parameter.UninitializedParameter) |
| # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ |
| if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: |
| input_param = input_param[0] |
| |
| if not is_param_lazy and input_param.shape != param.shape: |
| # local shape should match the one in checkpoint |
| error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' |
| 'the shape in current model is {}.' |
| .format(key, input_param.shape, param.shape)) |
| continue |
| try: |
| with torch.no_grad(): |
| param.copy_(input_param) |
| except Exception as ex: |
| error_msgs.append('While copying the parameter named "{}", ' |
| 'whose dimensions in the model are {} and ' |
| 'whose dimensions in the checkpoint are {}, ' |
| 'an exception occurred : {}.' |
| .format(key, param.size(), input_param.size(), ex.args)) |
| elif strict: |
| missing_keys.append(key) |
| |
| if strict: |
| for key in state_dict.keys(): |
| if key.startswith(prefix): |
| input_name = key[len(prefix):] |
| input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child |
| if input_name not in self._modules and input_name not in local_state: |
| unexpected_keys.append(key) |
| |
| def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', |
| strict: bool = True): |
| r"""Copies parameters and buffers from :attr:`state_dict` into |
| this module and its descendants. If :attr:`strict` is ``True``, then |
| the keys of :attr:`state_dict` must exactly match the keys returned |
| by this module's :meth:`~torch.nn.Module.state_dict` function. |
| |
| Args: |
| state_dict (dict): a dict containing parameters and |
| persistent buffers. |
| strict (bool, optional): whether to strictly enforce that the keys |
| in :attr:`state_dict` match the keys returned by this module's |
| :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` |
| |
| Returns: |
| ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: |
| * **missing_keys** is a list of str containing the missing keys |
| * **unexpected_keys** is a list of str containing the unexpected keys |
| """ |
| missing_keys: List[str] = [] |
| unexpected_keys: List[str] = [] |
| error_msgs: List[str] = [] |
| |
| # copy state_dict so _load_from_state_dict can modify it |
| metadata = getattr(state_dict, '_metadata', None) |
| state_dict = state_dict.copy() |
| if metadata is not None: |
| # mypy isn't aware that "_metadata" exists in state_dict |
| state_dict._metadata = metadata # type: ignore[attr-defined] |
| |
| def load(module, prefix=''): |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| module._load_from_state_dict( |
| state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) |
| for name, child in module._modules.items(): |
| if child is not None: |
| load(child, prefix + name + '.') |
| |
| load(self) |
| del load |
| |
| if strict: |
| if len(unexpected_keys) > 0: |
| error_msgs.insert( |
| 0, 'Unexpected key(s) in state_dict: {}. '.format( |
| ', '.join('"{}"'.format(k) for k in unexpected_keys))) |
| if len(missing_keys) > 0: |
| error_msgs.insert( |
| 0, 'Missing key(s) in state_dict: {}. '.format( |
| ', '.join('"{}"'.format(k) for k in missing_keys))) |
| |
| if len(error_msgs) > 0: |
| raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( |
| self.__class__.__name__, "\n\t".join(error_msgs))) |
| return _IncompatibleKeys(missing_keys, unexpected_keys) |
| |
| def _named_members(self, get_members_fn, prefix='', recurse=True): |
| r"""Helper method for yielding various names + members of modules.""" |
| memo = set() |
| modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] |
| for module_prefix, module in modules: |
| members = get_members_fn(module) |
| for k, v in members: |
| if v is None or v in memo: |
| continue |
| memo.add(v) |
| name = module_prefix + ('.' if module_prefix else '') + k |
| yield name, v |
| |
| def parameters(self, recurse: bool = True) -> Iterator[Parameter]: |
| r"""Returns an iterator over module parameters. |
| |
| This is typically passed to an optimizer. |
| |
| Args: |
| recurse (bool): if True, then yields parameters of this module |
| and all submodules. Otherwise, yields only parameters that |
| are direct members of this module. |
| |
| Yields: |
| Parameter: module parameter |
| |
| Example:: |
| |
| >>> for param in model.parameters(): |
| >>> print(type(param), param.size()) |
| <class 'torch.Tensor'> (20L,) |
| <class 'torch.Tensor'> (20L, 1L, 5L, 5L) |
| |
| """ |
| for name, param in self.named_parameters(recurse=recurse): |
| yield param |
| |
| def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: |
| r"""Returns an iterator over module parameters, yielding both the |
| name of the parameter as well as the parameter itself. |
| |
| Args: |
| prefix (str): prefix to prepend to all parameter names. |
| recurse (bool): if True, then yields parameters of this module |
| and all submodules. Otherwise, yields only parameters that |
| are direct members of this module. |
| |
| Yields: |
| (string, Parameter): Tuple containing the name and parameter |
| |
| Example:: |
| |
| >>> for name, param in self.named_parameters(): |
| >>> if name in ['bias']: |
| >>> print(param.size()) |
| |
| """ |
| gen = self._named_members( |
| lambda module: module._parameters.items(), |
| prefix=prefix, recurse=recurse) |
| for elem in gen: |
| yield elem |
| |
| def buffers(self, recurse: bool = True) -> Iterator[Tensor]: |
| r"""Returns an iterator over module buffers. |
| |
| Args: |
| recurse (bool): if True, then yields buffers of this module |
| and all submodules. Otherwise, yields only buffers that |
| are direct members of this module. |
| |
| Yields: |
| torch.Tensor: module buffer |
| |
| Example:: |
| |
| >>> for buf in model.buffers(): |
| >>> print(type(buf), buf.size()) |
| <class 'torch.Tensor'> (20L,) |
| <class 'torch.Tensor'> (20L, 1L, 5L, 5L) |
| |
| """ |
| for name, buf in self.named_buffers(recurse=recurse): |
| yield buf |
| |
| def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]: |
| r"""Returns an iterator over module buffers, yielding both the |
| name of the buffer as well as the buffer itself. |
| |
| Args: |
| prefix (str): prefix to prepend to all buffer names. |
| recurse (bool): if True, then yields buffers of this module |
| and all submodules. Otherwise, yields only buffers that |
| are direct members of this module. |
| |
| Yields: |
| (string, torch.Tensor): Tuple containing the name and buffer |
| |
| Example:: |
| |
| >>> for name, buf in self.named_buffers(): |
| >>> if name in ['running_var']: |
| >>> print(buf.size()) |
| |
| """ |
| gen = self._named_members( |
| lambda module: module._buffers.items(), |
| prefix=prefix, recurse=recurse) |
| for elem in gen: |
| yield elem |
| |
| def children(self) -> Iterator['Module']: |
| r"""Returns an iterator over immediate children modules. |
| |
| Yields: |
| Module: a child module |
| """ |
| for name, module in self.named_children(): |
| yield module |
| |
| def named_children(self) -> Iterator[Tuple[str, 'Module']]: |
| r"""Returns an iterator over immediate children modules, yielding both |
| the name of the module as well as the module itself. |
| |
| Yields: |
| (string, Module): Tuple containing a name and child module |
| |
| Example:: |
| |
| >>> for name, module in model.named_children(): |
| >>> if name in ['conv4', 'conv5']: |
| >>> print(module) |
| |
| """ |
| memo = set() |
| for name, module in self._modules.items(): |
| if module is not None and module not in memo: |
| memo.add(module) |
| yield name, module |
| |
| def modules(self) -> Iterator['Module']: |
| r"""Returns an iterator over all modules in the network. |
| |
| Yields: |
| Module: a module in the network |
| |
| Note: |
| Duplicate modules are returned only once. In the following |
| example, ``l`` will be returned only once. |
| |
| Example:: |
| |
| >>> l = nn.Linear(2, 2) |
| >>> net = nn.Sequential(l, l) |
| >>> for idx, m in enumerate(net.modules()): |
| print(idx, '->', m) |
| |
| 0 -> Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| 1 -> Linear(in_features=2, out_features=2, bias=True) |
| |
| """ |
| for name, module in self.named_modules(): |
| yield module |
| |
| def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''): |
| r"""Returns an iterator over all modules in the network, yielding |
| both the name of the module as well as the module itself. |
| |
| Yields: |
| (string, Module): Tuple of name and module |
| |
| Note: |
| Duplicate modules are returned only once. In the following |
| example, ``l`` will be returned only once. |
| |
| Example:: |
| |
| >>> l = nn.Linear(2, 2) |
| >>> net = nn.Sequential(l, l) |
| >>> for idx, m in enumerate(net.named_modules()): |
| print(idx, '->', m) |
| |
| 0 -> ('', Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| )) |
| 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) |
| |
| """ |
| |
| if memo is None: |
| memo = set() |
| if self not in memo: |
| memo.add(self) |
| yield prefix, self |
| for name, module in self._modules.items(): |
| if module is None: |
| continue |
| submodule_prefix = prefix + ('.' if prefix else '') + name |
| for m in module.named_modules(memo, submodule_prefix): |
| yield m |
| |
| def train(self: T, mode: bool = True) -> T: |
| r"""Sets the module in training mode. |
| |
| This has any effect only on certain modules. See documentations of |
| particular modules for details of their behaviors in training/evaluation |
| mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, |
| etc. |
| |
| Args: |
| mode (bool): whether to set training mode (``True``) or evaluation |
| mode (``False``). Default: ``True``. |
| |
| Returns: |
| Module: self |
| """ |
| self.training = mode |
| for module in self.children(): |
| module.train(mode) |
| return self |
| |
| def eval(self: T) -> T: |
| r"""Sets the module in evaluation mode. |
| |
| This has any effect only on certain modules. See documentations of |
| particular modules for details of their behaviors in training/evaluation |
| mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, |
| etc. |
| |
| This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. |
| |
| Returns: |
| Module: self |
| """ |
| return self.train(False) |
| |
| def requires_grad_(self: T, requires_grad: bool = True) -> T: |
| r"""Change if autograd should record operations on parameters in this |
| module. |
| |
| This method sets the parameters' :attr:`requires_grad` attributes |
| in-place. |
| |
| This method is helpful for freezing part of the module for finetuning |
| or training parts of a model individually (e.g., GAN training). |
| |
| Args: |
| requires_grad (bool): whether autograd should record operations on |
| parameters in this module. Default: ``True``. |
| |
| Returns: |
| Module: self |
| """ |
| for p in self.parameters(): |
| p.requires_grad_(requires_grad) |
| return self |
| |
| def zero_grad(self, set_to_none: bool = False) -> None: |
| r"""Sets gradients of all model parameters to zero. See similar function |
| under :class:`torch.optim.Optimizer` for more context. |
| |
| Args: |
| set_to_none (bool): instead of setting to zero, set the grads to None. |
| See :meth:`torch.optim.Optimizer.zero_grad` for details. |
| """ |
| if getattr(self, '_is_replica', False): |
| warnings.warn( |
| "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " |
| "The parameters are copied (in a differentiable manner) from the original module. " |
| "This means they are not leaf nodes in autograd and so don't accumulate gradients. " |
| "If you need gradients in your forward method, consider using autograd.grad instead.") |
| |
| for p in self.parameters(): |
| if p.grad is not None: |
| if set_to_none: |
| p.grad = None |
| else: |
| if p.grad.grad_fn is not None: |
| p.grad.detach_() |
| else: |
| p.grad.requires_grad_(False) |
| p.grad.zero_() |
| |
| def share_memory(self: T) -> T: |
| return self._apply(lambda t: t.share_memory_()) |
| |
| def _get_name(self): |
| return self.__class__.__name__ |
| |
| def extra_repr(self) -> str: |
| r"""Set the extra representation of the module |
| |
| To print customized extra information, you should re-implement |
| this method in your own modules. Both single-line and multi-line |
| strings are acceptable. |
| """ |
| return '' |
| |
| def __repr__(self): |
| # We treat the extra repr like the sub-module, one item per line |
| extra_lines = [] |
| extra_repr = self.extra_repr() |
| # empty string will be split into list [''] |
| if extra_repr: |
| extra_lines = extra_repr.split('\n') |
| child_lines = [] |
| for key, module in self._modules.items(): |
| mod_str = repr(module) |
| mod_str = _addindent(mod_str, 2) |
| child_lines.append('(' + key + '): ' + mod_str) |
| lines = extra_lines + child_lines |
| |
| main_str = self._get_name() + '(' |
| if lines: |
| # simple one-liner info, which most builtin Modules will use |
| if len(extra_lines) == 1 and not child_lines: |
| main_str += extra_lines[0] |
| else: |
| main_str += '\n ' + '\n '.join(lines) + '\n' |
| |
| main_str += ')' |
| return main_str |
| |
| def __dir__(self): |
| module_attrs = dir(self.__class__) |
| attrs = list(self.__dict__.keys()) |
| parameters = list(self._parameters.keys()) |
| modules = list(self._modules.keys()) |
| buffers = list(self._buffers.keys()) |
| keys = module_attrs + attrs + parameters + modules + buffers |
| |
| # Eliminate attrs that are not legal Python variable names |
| keys = [key for key in keys if not key[0].isdigit()] |
| |
| return sorted(keys) |
| |
| def _replicate_for_data_parallel(self): |
| replica = self.__new__(type(self)) |
| replica.__dict__ = self.__dict__.copy() |
| |
| # replicas do not have parameters themselves, the replicas reference the original |
| # module. |
| replica._parameters = OrderedDict() |
| replica._buffers = replica._buffers.copy() |
| replica._modules = replica._modules.copy() |
| replica._is_replica = True |
| |
| return replica |