| from collections import OrderedDict, namedtuple | 
 | import functools | 
 | import itertools | 
 |  | 
 | import torch | 
 | from ..backends.thnn import backend as thnn_backend | 
 | from ..parameter import Parameter | 
 | import torch.utils.hooks as hooks | 
 |  | 
 |  | 
 | class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): | 
 |     def __repr__(self): | 
 |         if not self.missing_keys and not self.unexpected_keys: | 
 |             return '<All keys matched successfully>' | 
 |         return super(_IncompatibleKeys, self).__repr__() | 
 |  | 
 |     __str__ = __repr__ | 
 |  | 
 |  | 
 | def _addindent(s_, numSpaces): | 
 |     s = s_.split('\n') | 
 |     # don't do anything for single-line stuff | 
 |     if len(s) == 1: | 
 |         return s_ | 
 |     first = s.pop(0) | 
 |     s = [(numSpaces * ' ') + line for line in s] | 
 |     s = '\n'.join(s) | 
 |     s = first + '\n' + s | 
 |     return s | 
 |  | 
 |  | 
 | class Module(object): | 
 |     r"""Base class for all neural network modules. | 
 |  | 
 |     Your models should also subclass this class. | 
 |  | 
 |     Modules can also contain other Modules, allowing to nest them in | 
 |     a tree structure. You can assign the submodules as regular attributes:: | 
 |  | 
 |         import torch.nn as nn | 
 |         import torch.nn.functional as F | 
 |  | 
 |         class Model(nn.Module): | 
 |             def __init__(self): | 
 |                 super(Model, self).__init__() | 
 |                 self.conv1 = nn.Conv2d(1, 20, 5) | 
 |                 self.conv2 = nn.Conv2d(20, 20, 5) | 
 |  | 
 |             def forward(self, x): | 
 |                 x = F.relu(self.conv1(x)) | 
 |                 return F.relu(self.conv2(x)) | 
 |  | 
 |     Submodules assigned in this way will be registered, and will have their | 
 |     parameters converted too when you call :meth:`to`, etc. | 
 |     """ | 
 |  | 
 |     dump_patches = False | 
 |  | 
 |     r"""This allows better BC support for :meth:`load_state_dict`. In | 
 |     :meth:`state_dict`, the version number will be saved as in the attribute | 
 |     `_metadata` of the returned state dict, and thus pickled. `_metadata` is a | 
 |     dictionary with keys that follow the naming convention of state dict. See | 
 |     ``_load_from_state_dict`` on how to use this information in loading. | 
 |  | 
 |     If new parameters/buffers are added/removed from a module, this number shall | 
 |     be bumped, and the module's `_load_from_state_dict` method can compare the | 
 |     version number and do appropriate changes if the state dict is from before | 
 |     the change.""" | 
 |     _version = 1 | 
 |  | 
 |     def __init__(self): | 
 |         self.__construct() | 
 |         # initialize self.training separately from the rest of the internal | 
 |         # state, as it is managed differently by nn.Module and ScriptModule | 
 |         self.training = True | 
 |  | 
 |     def __construct(self): | 
 |         """ | 
 |         Initializes internal Module state, shared by both nn.Module and ScriptModule. | 
 |         """ | 
 |         torch._C._log_api_usage_once("python.nn_module") | 
 |         self._backend = thnn_backend | 
 |         self._parameters = OrderedDict() | 
 |         self._buffers = OrderedDict() | 
 |         self._backward_hooks = OrderedDict() | 
 |         self._forward_hooks = OrderedDict() | 
 |         self._forward_pre_hooks = OrderedDict() | 
 |         self._state_dict_hooks = OrderedDict() | 
 |         self._load_state_dict_pre_hooks = OrderedDict() | 
 |         self._modules = OrderedDict() | 
 |  | 
 |     def forward(self, *input): | 
 |         r"""Defines the computation performed at every call. | 
 |  | 
 |         Should be overridden by all subclasses. | 
 |  | 
 |         .. note:: | 
 |             Although the recipe for forward pass needs to be defined within | 
 |             this function, one should call the :class:`Module` instance afterwards | 
 |             instead of this since the former takes care of running the | 
 |             registered hooks while the latter silently ignores them. | 
 |         """ | 
 |         raise NotImplementedError | 
 |  | 
 |     def register_buffer(self, name, tensor): | 
 |         r"""Adds a persistent buffer to the module. | 
 |  | 
 |         This is typically used to register a buffer that should not to be | 
 |         considered a model parameter. For example, BatchNorm's ``running_mean`` | 
 |         is not a parameter, but is part of the persistent state. | 
 |  | 
 |         Buffers can be accessed as attributes using given names. | 
 |  | 
 |         Args: | 
 |             name (string): name of the buffer. The buffer can be accessed | 
 |                 from this module using the given name | 
 |             tensor (Tensor): buffer to be registered. | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> self.register_buffer('running_mean', torch.zeros(num_features)) | 
 |  | 
 |         """ | 
 |         if '_buffers' not in self.__dict__: | 
 |             raise AttributeError( | 
 |                 "cannot assign buffer before Module.__init__() call") | 
 |         elif not isinstance(name, torch._six.string_classes): | 
 |             raise TypeError("buffer name should be a string. " | 
 |                             "Got {}".format(torch.typename(name))) | 
 |         elif '.' in name: | 
 |             raise KeyError("buffer name can't contain \".\"") | 
 |         elif name == '': | 
 |             raise KeyError("buffer name can't be empty string \"\"") | 
 |         elif hasattr(self, name) and name not in self._buffers: | 
 |             raise KeyError("attribute '{}' already exists".format(name)) | 
 |         elif tensor is not None and not isinstance(tensor, torch.Tensor): | 
 |             raise TypeError("cannot assign '{}' object to buffer '{}' " | 
 |                             "(torch Tensor or None required)" | 
 |                             .format(torch.typename(tensor), name)) | 
 |         else: | 
 |             self._buffers[name] = tensor | 
 |  | 
 |     def register_parameter(self, name, param): | 
 |         r"""Adds a parameter to the module. | 
 |  | 
 |         The parameter can be accessed as an attribute using given name. | 
 |  | 
 |         Args: | 
 |             name (string): name of the parameter. The parameter can be accessed | 
 |                 from this module using the given name | 
 |             param (Parameter): parameter to be added to the module. | 
 |         """ | 
 |         if '_parameters' not in self.__dict__: | 
 |             raise AttributeError( | 
 |                 "cannot assign parameter before Module.__init__() call") | 
 |  | 
 |         elif not isinstance(name, torch._six.string_classes): | 
 |             raise TypeError("parameter name should be a string. " | 
 |                             "Got {}".format(torch.typename(name))) | 
 |         elif '.' in name: | 
 |             raise KeyError("parameter name can't contain \".\"") | 
 |         elif name == '': | 
 |             raise KeyError("parameter name can't be empty string \"\"") | 
 |         elif hasattr(self, name) and name not in self._parameters: | 
 |             raise KeyError("attribute '{}' already exists".format(name)) | 
 |  | 
 |         if param is None: | 
 |             self._parameters[name] = None | 
 |         elif not isinstance(param, Parameter): | 
 |             raise TypeError("cannot assign '{}' object to parameter '{}' " | 
 |                             "(torch.nn.Parameter or None required)" | 
 |                             .format(torch.typename(param), name)) | 
 |         elif param.grad_fn: | 
 |             raise ValueError( | 
 |                 "Cannot assign non-leaf Tensor to parameter '{0}'. Model " | 
 |                 "parameters must be created explicitly. To express '{0}' " | 
 |                 "as a function of another Tensor, compute the value in " | 
 |                 "the forward() method.".format(name)) | 
 |         else: | 
 |             self._parameters[name] = param | 
 |  | 
 |     def add_module(self, name, module): | 
 |         r"""Adds a child module to the current module. | 
 |  | 
 |         The module can be accessed as an attribute using the given name. | 
 |  | 
 |         Args: | 
 |             name (string): name of the child module. The child module can be | 
 |                 accessed from this module using the given name | 
 |             module (Module): child module to be added to the module. | 
 |         """ | 
 |         if not isinstance(module, Module) and module is not None: | 
 |             raise TypeError("{} is not a Module subclass".format( | 
 |                 torch.typename(module))) | 
 |         elif not isinstance(name, torch._six.string_classes): | 
 |             raise TypeError("module name should be a string. Got {}".format( | 
 |                 torch.typename(name))) | 
 |         elif hasattr(self, name) and name not in self._modules: | 
 |             raise KeyError("attribute '{}' already exists".format(name)) | 
 |         elif '.' in name: | 
 |             raise KeyError("module name can't contain \".\"") | 
 |         elif name == '': | 
 |             raise KeyError("module name can't be empty string \"\"") | 
 |         self._modules[name] = module | 
 |  | 
 |     def _apply(self, fn): | 
 |         for module in self.children(): | 
 |             module._apply(fn) | 
 |  | 
 |         def compute_should_use_set_data(tensor, tensor_applied): | 
 |             if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): | 
 |                 # If the new tensor has compatible tensor type as the existing tensor, | 
 |                 # the current behavior is to change the tensor in-place using `.data =`, | 
 |                 # and the future behavior is to overwrite the existing tensor. However, | 
 |                 # changing the current behavior is a BC-breaking change, and we want it | 
 |                 # to happen in future releases. So for now we introduce the | 
 |                 # `torch.__future__.get_overwrite_module_params_on_conversion()` | 
 |                 # global flag to let the user control whether they want the future | 
 |                 # behavior of overwriting the existing tensor or not. | 
 |                 return not torch.__future__.get_overwrite_module_params_on_conversion() | 
 |             else: | 
 |                 return False | 
 |  | 
 |         for key, param in self._parameters.items(): | 
 |             if param is not None: | 
 |                 # Tensors stored in modules are graph leaves, and we don't want to | 
 |                 # track autograd history of `param_applied`, so we have to use | 
 |                 # `with torch.no_grad():` | 
 |                 with torch.no_grad(): | 
 |                     param_applied = fn(param) | 
 |                 should_use_set_data = compute_should_use_set_data(param, param_applied) | 
 |                 if should_use_set_data: | 
 |                     param.data = param_applied | 
 |                 else: | 
 |                     assert isinstance(param, Parameter) | 
 |                     assert param.is_leaf | 
 |                     self._parameters[key] = Parameter(param_applied, param.requires_grad) | 
 |  | 
 |                 if param.grad is not None: | 
 |                     with torch.no_grad(): | 
 |                         grad_applied = fn(param.grad) | 
 |                     should_use_set_data = compute_should_use_set_data(param.grad, grad_applied) | 
 |                     if should_use_set_data: | 
 |                         param.grad.data = grad_applied | 
 |                     else: | 
 |                         assert param.grad.is_leaf | 
 |                         self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad) | 
 |  | 
 |         for key, buf in self._buffers.items(): | 
 |             if buf is not None: | 
 |                 self._buffers[key] = fn(buf) | 
 |  | 
 |         return self | 
 |  | 
 |     def apply(self, fn): | 
 |         r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) | 
 |         as well as self. Typical use includes initializing the parameters of a model | 
 |         (see also :ref:`torch-nn-init`). | 
 |  | 
 |         Args: | 
 |             fn (:class:`Module` -> None): function to be applied to each submodule | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> def init_weights(m): | 
 |             >>>     print(m) | 
 |             >>>     if type(m) == nn.Linear: | 
 |             >>>         m.weight.data.fill_(1.0) | 
 |             >>>         print(m.weight) | 
 |             >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) | 
 |             >>> net.apply(init_weights) | 
 |             Linear(in_features=2, out_features=2, bias=True) | 
 |             Parameter containing: | 
 |             tensor([[ 1.,  1.], | 
 |                     [ 1.,  1.]]) | 
 |             Linear(in_features=2, out_features=2, bias=True) | 
 |             Parameter containing: | 
 |             tensor([[ 1.,  1.], | 
 |                     [ 1.,  1.]]) | 
 |             Sequential( | 
 |               (0): Linear(in_features=2, out_features=2, bias=True) | 
 |               (1): Linear(in_features=2, out_features=2, bias=True) | 
 |             ) | 
 |             Sequential( | 
 |               (0): Linear(in_features=2, out_features=2, bias=True) | 
 |               (1): Linear(in_features=2, out_features=2, bias=True) | 
 |             ) | 
 |         """ | 
 |         for module in self.children(): | 
 |             module.apply(fn) | 
 |         fn(self) | 
 |         return self | 
 |  | 
 |     def cuda(self, device=None): | 
 |         r"""Moves all model parameters and buffers to the GPU. | 
 |  | 
 |         This also makes associated parameters and buffers different objects. So | 
 |         it should be called before constructing optimizer if the module will | 
 |         live on GPU while being optimized. | 
 |  | 
 |         Arguments: | 
 |             device (int, optional): if specified, all parameters will be | 
 |                 copied to that device | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self._apply(lambda t: t.cuda(device)) | 
 |  | 
 |     def cpu(self): | 
 |         r"""Moves all model parameters and buffers to the CPU. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self._apply(lambda t: t.cpu()) | 
 |  | 
 |     def type(self, dst_type): | 
 |         r"""Casts all parameters and buffers to :attr:`dst_type`. | 
 |  | 
 |         Arguments: | 
 |             dst_type (type or string): the desired type | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self._apply(lambda t: t.type(dst_type)) | 
 |  | 
 |     def float(self): | 
 |         r"""Casts all floating point parameters and buffers to float datatype. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self._apply(lambda t: t.float() if t.is_floating_point() else t) | 
 |  | 
 |     def double(self): | 
 |         r"""Casts all floating point parameters and buffers to ``double`` datatype. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self._apply(lambda t: t.double() if t.is_floating_point() else t) | 
 |  | 
 |     def half(self): | 
 |         r"""Casts all floating point parameters and buffers to ``half`` datatype. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self._apply(lambda t: t.half() if t.is_floating_point() else t) | 
 |  | 
 |     def to(self, *args, **kwargs): | 
 |         r"""Moves and/or casts the parameters and buffers. | 
 |  | 
 |         This can be called as | 
 |  | 
 |         .. function:: to(device=None, dtype=None, non_blocking=False) | 
 |  | 
 |         .. function:: to(dtype, non_blocking=False) | 
 |  | 
 |         .. function:: to(tensor, non_blocking=False) | 
 |  | 
 |         Its signature is similar to :meth:`torch.Tensor.to`, but only accepts | 
 |         floating point desired :attr:`dtype` s. In addition, this method will | 
 |         only cast the floating point parameters and buffers to :attr:`dtype` | 
 |         (if given). The integral parameters and buffers will be moved | 
 |         :attr:`device`, if that is given, but with dtypes unchanged. When | 
 |         :attr:`non_blocking` is set, it tries to convert/move asynchronously | 
 |         with respect to the host if possible, e.g., moving CPU Tensors with | 
 |         pinned memory to CUDA devices. | 
 |  | 
 |         See below for examples. | 
 |  | 
 |         .. note:: | 
 |             This method modifies the module in-place. | 
 |  | 
 |         Args: | 
 |             device (:class:`torch.device`): the desired device of the parameters | 
 |                 and buffers in this module | 
 |             dtype (:class:`torch.dtype`): the desired floating point type of | 
 |                 the floating point parameters and buffers in this module | 
 |             tensor (torch.Tensor): Tensor whose dtype and device are the desired | 
 |                 dtype and device for all parameters and buffers in this module | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> linear = nn.Linear(2, 2) | 
 |             >>> linear.weight | 
 |             Parameter containing: | 
 |             tensor([[ 0.1913, -0.3420], | 
 |                     [-0.5113, -0.2325]]) | 
 |             >>> linear.to(torch.double) | 
 |             Linear(in_features=2, out_features=2, bias=True) | 
 |             >>> linear.weight | 
 |             Parameter containing: | 
 |             tensor([[ 0.1913, -0.3420], | 
 |                     [-0.5113, -0.2325]], dtype=torch.float64) | 
 |             >>> gpu1 = torch.device("cuda:1") | 
 |             >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) | 
 |             Linear(in_features=2, out_features=2, bias=True) | 
 |             >>> linear.weight | 
 |             Parameter containing: | 
 |             tensor([[ 0.1914, -0.3420], | 
 |                     [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') | 
 |             >>> cpu = torch.device("cpu") | 
 |             >>> linear.to(cpu) | 
 |             Linear(in_features=2, out_features=2, bias=True) | 
 |             >>> linear.weight | 
 |             Parameter containing: | 
 |             tensor([[ 0.1914, -0.3420], | 
 |                     [-0.5112, -0.2324]], dtype=torch.float16) | 
 |  | 
 |         """ | 
 |  | 
 |         device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs) | 
 |  | 
 |         if dtype is not None: | 
 |             if not dtype.is_floating_point: | 
 |                 raise TypeError('nn.Module.to only accepts floating point ' | 
 |                                 'dtypes, but got desired dtype={}'.format(dtype)) | 
 |  | 
 |         def convert(t): | 
 |             return t.to(device, dtype if t.is_floating_point() else None, non_blocking) | 
 |  | 
 |         return self._apply(convert) | 
 |  | 
 |     def register_backward_hook(self, hook): | 
 |         r"""Registers a backward hook on the module. | 
 |  | 
 |         The hook will be called every time the gradients with respect to module | 
 |         inputs are computed. The hook should have the following signature:: | 
 |  | 
 |             hook(module, grad_input, grad_output) -> Tensor or None | 
 |  | 
 |         The :attr:`grad_input` and :attr:`grad_output` may be tuples if the | 
 |         module has multiple inputs or outputs. The hook should not modify its | 
 |         arguments, but it can optionally return a new gradient with respect to | 
 |         input that will be used in place of :attr:`grad_input` in subsequent | 
 |         computations. | 
 |  | 
 |         Returns: | 
 |             :class:`torch.utils.hooks.RemovableHandle`: | 
 |                 a handle that can be used to remove the added hook by calling | 
 |                 ``handle.remove()`` | 
 |  | 
 |         .. warning :: | 
 |  | 
 |             The current implementation will not have the presented behavior | 
 |             for complex :class:`Module` that perform many operations. | 
 |             In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only | 
 |             contain the gradients for a subset of the inputs and outputs. | 
 |             For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` | 
 |             directly on a specific input or output to get the required gradients. | 
 |  | 
 |         """ | 
 |         handle = hooks.RemovableHandle(self._backward_hooks) | 
 |         self._backward_hooks[handle.id] = hook | 
 |         return handle | 
 |  | 
 |     def register_forward_pre_hook(self, hook): | 
 |         r"""Registers a forward pre-hook on the module. | 
 |  | 
 |         The hook will be called every time before :func:`forward` is invoked. | 
 |         It should have the following signature:: | 
 |  | 
 |             hook(module, input) -> None or modified input | 
 |  | 
 |         The hook can modify the input. User can either return a tuple or a | 
 |         single modified value in the hook. We will wrap the value into a tuple | 
 |         if a single value is returned(unless that value is already a tuple). | 
 |  | 
 |         Returns: | 
 |             :class:`torch.utils.hooks.RemovableHandle`: | 
 |                 a handle that can be used to remove the added hook by calling | 
 |                 ``handle.remove()`` | 
 |         """ | 
 |         handle = hooks.RemovableHandle(self._forward_pre_hooks) | 
 |         self._forward_pre_hooks[handle.id] = hook | 
 |         return handle | 
 |  | 
 |     def register_forward_hook(self, hook): | 
 |         r"""Registers a forward hook on the module. | 
 |  | 
 |         The hook will be called every time after :func:`forward` has computed an output. | 
 |         It should have the following signature:: | 
 |  | 
 |             hook(module, input, output) -> None or modified output | 
 |  | 
 |         The hook can modify the output. It can modify the input inplace but | 
 |         it will not have effect on forward since this is called after | 
 |         :func:`forward` is called. | 
 |  | 
 |         Returns: | 
 |             :class:`torch.utils.hooks.RemovableHandle`: | 
 |                 a handle that can be used to remove the added hook by calling | 
 |                 ``handle.remove()`` | 
 |         """ | 
 |         handle = hooks.RemovableHandle(self._forward_hooks) | 
 |         self._forward_hooks[handle.id] = hook | 
 |         return handle | 
 |  | 
 |     def _tracing_name(self, tracing_state): | 
 |         if not tracing_state._traced_module_stack: | 
 |             return None | 
 |         module = tracing_state._traced_module_stack[-1] | 
 |         for name, child in module.named_children(): | 
 |             if child is self: | 
 |                 return name | 
 |         return None | 
 |  | 
 |     def _slow_forward(self, *input, **kwargs): | 
 |         tracing_state = torch._C._get_tracing_state() | 
 |         if not tracing_state: | 
 |             return self.forward(*input, **kwargs) | 
 |         if not hasattr(tracing_state, '_traced_module_stack'): | 
 |             tracing_state._traced_module_stack = [] | 
 |         name = self._tracing_name(tracing_state) | 
 |         if name: | 
 |             tracing_state.push_scope('%s[%s]' % (self._get_name(), name)) | 
 |         else: | 
 |             tracing_state.push_scope(self._get_name()) | 
 |         tracing_state._traced_module_stack.append(self) | 
 |         try: | 
 |             result = self.forward(*input, **kwargs) | 
 |         finally: | 
 |             tracing_state.pop_scope() | 
 |             tracing_state._traced_module_stack.pop() | 
 |         return result | 
 |  | 
 |     def __call__(self, *input, **kwargs): | 
 |         for hook in self._forward_pre_hooks.values(): | 
 |             result = hook(self, input) | 
 |             if result is not None: | 
 |                 if not isinstance(result, tuple): | 
 |                     result = (result,) | 
 |                 input = result | 
 |         if torch._C._get_tracing_state(): | 
 |             result = self._slow_forward(*input, **kwargs) | 
 |         else: | 
 |             result = self.forward(*input, **kwargs) | 
 |         for hook in self._forward_hooks.values(): | 
 |             hook_result = hook(self, input, result) | 
 |             if hook_result is not None: | 
 |                 result = hook_result | 
 |         if len(self._backward_hooks) > 0: | 
 |             var = result | 
 |             while not isinstance(var, torch.Tensor): | 
 |                 if isinstance(var, dict): | 
 |                     var = next((v for v in var.values() if isinstance(v, torch.Tensor))) | 
 |                 else: | 
 |                     var = var[0] | 
 |             grad_fn = var.grad_fn | 
 |             if grad_fn is not None: | 
 |                 for hook in self._backward_hooks.values(): | 
 |                     wrapper = functools.partial(hook, self) | 
 |                     functools.update_wrapper(wrapper, hook) | 
 |                     grad_fn.register_hook(wrapper) | 
 |         return result | 
 |  | 
 |     def __setstate__(self, state): | 
 |         self.__dict__.update(state) | 
 |         # Support loading old checkpoints that don't have the following attrs: | 
 |         if '_forward_pre_hooks' not in self.__dict__: | 
 |             self._forward_pre_hooks = OrderedDict() | 
 |         if '_state_dict_hooks' not in self.__dict__: | 
 |             self._state_dict_hooks = OrderedDict() | 
 |         if '_load_state_dict_pre_hooks' not in self.__dict__: | 
 |             self._load_state_dict_pre_hooks = OrderedDict() | 
 |  | 
 |     def __getattr__(self, name): | 
 |         if '_parameters' in self.__dict__: | 
 |             _parameters = self.__dict__['_parameters'] | 
 |             if name in _parameters: | 
 |                 return _parameters[name] | 
 |         if '_buffers' in self.__dict__: | 
 |             _buffers = self.__dict__['_buffers'] | 
 |             if name in _buffers: | 
 |                 return _buffers[name] | 
 |         if '_modules' in self.__dict__: | 
 |             modules = self.__dict__['_modules'] | 
 |             if name in modules: | 
 |                 return modules[name] | 
 |         raise AttributeError("'{}' object has no attribute '{}'".format( | 
 |             type(self).__name__, name)) | 
 |  | 
 |     def __setattr__(self, name, value): | 
 |         def remove_from(*dicts): | 
 |             for d in dicts: | 
 |                 if name in d: | 
 |                     del d[name] | 
 |  | 
 |         params = self.__dict__.get('_parameters') | 
 |         if isinstance(value, Parameter): | 
 |             if params is None: | 
 |                 raise AttributeError( | 
 |                     "cannot assign parameters before Module.__init__() call") | 
 |             remove_from(self.__dict__, self._buffers, self._modules) | 
 |             self.register_parameter(name, value) | 
 |         elif params is not None and name in params: | 
 |             if value is not None: | 
 |                 raise TypeError("cannot assign '{}' as parameter '{}' " | 
 |                                 "(torch.nn.Parameter or None expected)" | 
 |                                 .format(torch.typename(value), name)) | 
 |             self.register_parameter(name, value) | 
 |         else: | 
 |             modules = self.__dict__.get('_modules') | 
 |             if isinstance(value, Module): | 
 |                 if modules is None: | 
 |                     raise AttributeError( | 
 |                         "cannot assign module before Module.__init__() call") | 
 |                 remove_from(self.__dict__, self._parameters, self._buffers) | 
 |                 modules[name] = value | 
 |             elif modules is not None and name in modules: | 
 |                 if value is not None: | 
 |                     raise TypeError("cannot assign '{}' as child module '{}' " | 
 |                                     "(torch.nn.Module or None expected)" | 
 |                                     .format(torch.typename(value), name)) | 
 |                 modules[name] = value | 
 |             else: | 
 |                 buffers = self.__dict__.get('_buffers') | 
 |                 if buffers is not None and name in buffers: | 
 |                     if value is not None and not isinstance(value, torch.Tensor): | 
 |                         raise TypeError("cannot assign '{}' as buffer '{}' " | 
 |                                         "(torch.Tensor or None expected)" | 
 |                                         .format(torch.typename(value), name)) | 
 |                     buffers[name] = value | 
 |                 else: | 
 |                     object.__setattr__(self, name, value) | 
 |  | 
 |     def __delattr__(self, name): | 
 |         if name in self._parameters: | 
 |             del self._parameters[name] | 
 |         elif name in self._buffers: | 
 |             del self._buffers[name] | 
 |         elif name in self._modules: | 
 |             del self._modules[name] | 
 |         else: | 
 |             object.__delattr__(self, name) | 
 |  | 
 |     def _register_state_dict_hook(self, hook): | 
 |         r"""These hooks will be called with arguments: `self`, `state_dict`, | 
 |         `prefix`, `local_metadata`, after the `state_dict` of `self` is set. | 
 |         Note that only parameters and buffers of `self` or its children are | 
 |         guaranteed to exist in `state_dict`. The hooks may modify `state_dict` | 
 |         inplace or return a new one. | 
 |         """ | 
 |         handle = hooks.RemovableHandle(self._state_dict_hooks) | 
 |         self._state_dict_hooks[handle.id] = hook | 
 |         return handle | 
 |  | 
 |     def _save_to_state_dict(self, destination, prefix, keep_vars): | 
 |         r"""Saves module state to `destination` dictionary, containing a state | 
 |         of the module, but not its descendants. This is called on every | 
 |         submodule in :meth:`~torch.nn.Module.state_dict`. | 
 |  | 
 |         In rare cases, subclasses can achieve class-specific behavior by | 
 |         overriding this method with custom logic. | 
 |  | 
 |         Arguments: | 
 |             destination (dict): a dict where state will be stored | 
 |             prefix (str): the prefix for parameters and buffers used in this | 
 |                 module | 
 |         """ | 
 |         for name, param in self._parameters.items(): | 
 |             if param is not None: | 
 |                 destination[prefix + name] = param if keep_vars else param.data | 
 |         for name, buf in self._buffers.items(): | 
 |             if buf is not None: | 
 |                 destination[prefix + name] = buf if keep_vars else buf.data | 
 |  | 
 |     def state_dict(self, destination=None, prefix='', keep_vars=False): | 
 |         r"""Returns a dictionary containing a whole state of the module. | 
 |  | 
 |         Both parameters and persistent buffers (e.g. running averages) are | 
 |         included. Keys are corresponding parameter and buffer names. | 
 |  | 
 |         Returns: | 
 |             dict: | 
 |                 a dictionary containing a whole state of the module | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> module.state_dict().keys() | 
 |             ['bias', 'weight'] | 
 |  | 
 |         """ | 
 |         if destination is None: | 
 |             destination = OrderedDict() | 
 |             destination._metadata = OrderedDict() | 
 |         destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) | 
 |         self._save_to_state_dict(destination, prefix, keep_vars) | 
 |         for name, module in self._modules.items(): | 
 |             if module is not None: | 
 |                 module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) | 
 |         for hook in self._state_dict_hooks.values(): | 
 |             hook_result = hook(self, destination, prefix, local_metadata) | 
 |             if hook_result is not None: | 
 |                 destination = hook_result | 
 |         return destination | 
 |  | 
 |     def _register_load_state_dict_pre_hook(self, hook): | 
 |         r"""These hooks will be called with arguments: `state_dict`, `prefix`, | 
 |         `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, | 
 |         `error_msgs`, before loading `state_dict` into `self`. These arguments | 
 |         are exactly the same as those of `_load_from_state_dict`. | 
 |         """ | 
 |         handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) | 
 |         self._load_state_dict_pre_hooks[handle.id] = hook | 
 |         return handle | 
 |  | 
 |     def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | 
 |                               missing_keys, unexpected_keys, error_msgs): | 
 |         r"""Copies parameters and buffers from :attr:`state_dict` into only | 
 |         this module, but not its descendants. This is called on every submodule | 
 |         in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this | 
 |         module in input :attr:`state_dict` is provided as :attr:`local_metadata`. | 
 |         For state dicts without metadata, :attr:`local_metadata` is empty. | 
 |         Subclasses can achieve class-specific backward compatible loading using | 
 |         the version number at `local_metadata.get("version", None)`. | 
 |  | 
 |         .. note:: | 
 |             :attr:`state_dict` is not the same object as the input | 
 |             :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So | 
 |             it can be modified. | 
 |  | 
 |         Arguments: | 
 |             state_dict (dict): a dict containing parameters and | 
 |                 persistent buffers. | 
 |             prefix (str): the prefix for parameters and buffers used in this | 
 |                 module | 
 |             local_metadata (dict): a dict containing the metadata for this module. | 
 |                 See | 
 |             strict (bool): whether to strictly enforce that the keys in | 
 |                 :attr:`state_dict` with :attr:`prefix` match the names of | 
 |                 parameters and buffers in this module | 
 |             missing_keys (list of str): if ``strict=True``, add missing keys to | 
 |                 this list | 
 |             unexpected_keys (list of str): if ``strict=True``, add unexpected | 
 |                 keys to this list | 
 |             error_msgs (list of str): error messages should be added to this | 
 |                 list, and will be reported together in | 
 |                 :meth:`~torch.nn.Module.load_state_dict` | 
 |         """ | 
 |         for hook in self._load_state_dict_pre_hooks.values(): | 
 |             hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | 
 |  | 
 |         local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) | 
 |         local_state = {k: v.data for k, v in local_name_params if v is not None} | 
 |  | 
 |         for name, param in local_state.items(): | 
 |             key = prefix + name | 
 |             if key in state_dict: | 
 |                 input_param = state_dict[key] | 
 |  | 
 |                 # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ | 
 |                 if len(param.shape) == 0 and len(input_param.shape) == 1: | 
 |                     input_param = input_param[0] | 
 |  | 
 |                 if input_param.shape != param.shape: | 
 |                     # local shape should match the one in checkpoint | 
 |                     error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' | 
 |                                       'the shape in current model is {}.' | 
 |                                       .format(key, input_param.shape, param.shape)) | 
 |                     continue | 
 |  | 
 |                 if isinstance(input_param, Parameter): | 
 |                     # backwards compatibility for serialized parameters | 
 |                     input_param = input_param.data | 
 |                 try: | 
 |                     param.copy_(input_param) | 
 |                 except Exception: | 
 |                     error_msgs.append('While copying the parameter named "{}", ' | 
 |                                       'whose dimensions in the model are {} and ' | 
 |                                       'whose dimensions in the checkpoint are {}.' | 
 |                                       .format(key, param.size(), input_param.size())) | 
 |             elif strict: | 
 |                 missing_keys.append(key) | 
 |  | 
 |         if strict: | 
 |             for key in state_dict.keys(): | 
 |                 if key.startswith(prefix): | 
 |                     input_name = key[len(prefix):] | 
 |                     input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child | 
 |                     if input_name not in self._modules and input_name not in local_state: | 
 |                         unexpected_keys.append(key) | 
 |  | 
 |     def load_state_dict(self, state_dict, strict=True): | 
 |         r"""Copies parameters and buffers from :attr:`state_dict` into | 
 |         this module and its descendants. If :attr:`strict` is ``True``, then | 
 |         the keys of :attr:`state_dict` must exactly match the keys returned | 
 |         by this module's :meth:`~torch.nn.Module.state_dict` function. | 
 |  | 
 |         Arguments: | 
 |             state_dict (dict): a dict containing parameters and | 
 |                 persistent buffers. | 
 |             strict (bool, optional): whether to strictly enforce that the keys | 
 |                 in :attr:`state_dict` match the keys returned by this module's | 
 |                 :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` | 
 |  | 
 |         Returns: | 
 |             ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: | 
 |                 * **missing_keys** is a list of str containing the missing keys | 
 |                 * **unexpected_keys** is a list of str containing the unexpected keys | 
 |         """ | 
 |         missing_keys = [] | 
 |         unexpected_keys = [] | 
 |         error_msgs = [] | 
 |  | 
 |         # copy state_dict so _load_from_state_dict can modify it | 
 |         metadata = getattr(state_dict, '_metadata', None) | 
 |         state_dict = state_dict.copy() | 
 |         if metadata is not None: | 
 |             state_dict._metadata = metadata | 
 |  | 
 |         def load(module, prefix=''): | 
 |             local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | 
 |             module._load_from_state_dict( | 
 |                 state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | 
 |             for name, child in module._modules.items(): | 
 |                 if child is not None: | 
 |                     load(child, prefix + name + '.') | 
 |  | 
 |         load(self) | 
 |         load = None  # break load->load reference cycle | 
 |  | 
 |         if strict: | 
 |             if len(unexpected_keys) > 0: | 
 |                 error_msgs.insert( | 
 |                     0, 'Unexpected key(s) in state_dict: {}. '.format( | 
 |                         ', '.join('"{}"'.format(k) for k in unexpected_keys))) | 
 |             if len(missing_keys) > 0: | 
 |                 error_msgs.insert( | 
 |                     0, 'Missing key(s) in state_dict: {}. '.format( | 
 |                         ', '.join('"{}"'.format(k) for k in missing_keys))) | 
 |  | 
 |         if len(error_msgs) > 0: | 
 |             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( | 
 |                                self.__class__.__name__, "\n\t".join(error_msgs))) | 
 |         return _IncompatibleKeys(missing_keys, unexpected_keys) | 
 |  | 
 |     def _named_members(self, get_members_fn, prefix='', recurse=True): | 
 |         r"""Helper method for yielding various names + members of modules.""" | 
 |         memo = set() | 
 |         modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] | 
 |         for module_prefix, module in modules: | 
 |             members = get_members_fn(module) | 
 |             for k, v in members: | 
 |                 if v is None or v in memo: | 
 |                     continue | 
 |                 memo.add(v) | 
 |                 name = module_prefix + ('.' if module_prefix else '') + k | 
 |                 yield name, v | 
 |  | 
 |     def parameters(self, recurse=True): | 
 |         r"""Returns an iterator over module parameters. | 
 |  | 
 |         This is typically passed to an optimizer. | 
 |  | 
 |         Args: | 
 |             recurse (bool): if True, then yields parameters of this module | 
 |                 and all submodules. Otherwise, yields only parameters that | 
 |                 are direct members of this module. | 
 |  | 
 |         Yields: | 
 |             Parameter: module parameter | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> for param in model.parameters(): | 
 |             >>>     print(type(param.data), param.size()) | 
 |             <class 'torch.FloatTensor'> (20L,) | 
 |             <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) | 
 |  | 
 |         """ | 
 |         for name, param in self.named_parameters(recurse=recurse): | 
 |             yield param | 
 |  | 
 |     def named_parameters(self, prefix='', recurse=True): | 
 |         r"""Returns an iterator over module parameters, yielding both the | 
 |         name of the parameter as well as the parameter itself. | 
 |  | 
 |         Args: | 
 |             prefix (str): prefix to prepend to all parameter names. | 
 |             recurse (bool): if True, then yields parameters of this module | 
 |                 and all submodules. Otherwise, yields only parameters that | 
 |                 are direct members of this module. | 
 |  | 
 |         Yields: | 
 |             (string, Parameter): Tuple containing the name and parameter | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> for name, param in self.named_parameters(): | 
 |             >>>    if name in ['bias']: | 
 |             >>>        print(param.size()) | 
 |  | 
 |         """ | 
 |         gen = self._named_members( | 
 |             lambda module: module._parameters.items(), | 
 |             prefix=prefix, recurse=recurse) | 
 |         for elem in gen: | 
 |             yield elem | 
 |  | 
 |     def buffers(self, recurse=True): | 
 |         r"""Returns an iterator over module buffers. | 
 |  | 
 |         Args: | 
 |             recurse (bool): if True, then yields buffers of this module | 
 |                 and all submodules. Otherwise, yields only buffers that | 
 |                 are direct members of this module. | 
 |  | 
 |         Yields: | 
 |             torch.Tensor: module buffer | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> for buf in model.buffers(): | 
 |             >>>     print(type(buf.data), buf.size()) | 
 |             <class 'torch.FloatTensor'> (20L,) | 
 |             <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) | 
 |  | 
 |         """ | 
 |         for name, buf in self.named_buffers(recurse=recurse): | 
 |             yield buf | 
 |  | 
 |     def named_buffers(self, prefix='', recurse=True): | 
 |         r"""Returns an iterator over module buffers, yielding both the | 
 |         name of the buffer as well as the buffer itself. | 
 |  | 
 |         Args: | 
 |             prefix (str): prefix to prepend to all buffer names. | 
 |             recurse (bool): if True, then yields buffers of this module | 
 |                 and all submodules. Otherwise, yields only buffers that | 
 |                 are direct members of this module. | 
 |  | 
 |         Yields: | 
 |             (string, torch.Tensor): Tuple containing the name and buffer | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> for name, buf in self.named_buffers(): | 
 |             >>>    if name in ['running_var']: | 
 |             >>>        print(buf.size()) | 
 |  | 
 |         """ | 
 |         gen = self._named_members( | 
 |             lambda module: module._buffers.items(), | 
 |             prefix=prefix, recurse=recurse) | 
 |         for elem in gen: | 
 |             yield elem | 
 |  | 
 |     def children(self): | 
 |         r"""Returns an iterator over immediate children modules. | 
 |  | 
 |         Yields: | 
 |             Module: a child module | 
 |         """ | 
 |         for name, module in self.named_children(): | 
 |             yield module | 
 |  | 
 |     def named_children(self): | 
 |         r"""Returns an iterator over immediate children modules, yielding both | 
 |         the name of the module as well as the module itself. | 
 |  | 
 |         Yields: | 
 |             (string, Module): Tuple containing a name and child module | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> for name, module in model.named_children(): | 
 |             >>>     if name in ['conv4', 'conv5']: | 
 |             >>>         print(module) | 
 |  | 
 |         """ | 
 |         memo = set() | 
 |         for name, module in self._modules.items(): | 
 |             if module is not None and module not in memo: | 
 |                 memo.add(module) | 
 |                 yield name, module | 
 |  | 
 |     def modules(self): | 
 |         r"""Returns an iterator over all modules in the network. | 
 |  | 
 |         Yields: | 
 |             Module: a module in the network | 
 |  | 
 |         Note: | 
 |             Duplicate modules are returned only once. In the following | 
 |             example, ``l`` will be returned only once. | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> l = nn.Linear(2, 2) | 
 |             >>> net = nn.Sequential(l, l) | 
 |             >>> for idx, m in enumerate(net.modules()): | 
 |                     print(idx, '->', m) | 
 |  | 
 |             0 -> Sequential( | 
 |               (0): Linear(in_features=2, out_features=2, bias=True) | 
 |               (1): Linear(in_features=2, out_features=2, bias=True) | 
 |             ) | 
 |             1 -> Linear(in_features=2, out_features=2, bias=True) | 
 |  | 
 |         """ | 
 |         for name, module in self.named_modules(): | 
 |             yield module | 
 |  | 
 |     def named_modules(self, memo=None, prefix=''): | 
 |         r"""Returns an iterator over all modules in the network, yielding | 
 |         both the name of the module as well as the module itself. | 
 |  | 
 |         Yields: | 
 |             (string, Module): Tuple of name and module | 
 |  | 
 |         Note: | 
 |             Duplicate modules are returned only once. In the following | 
 |             example, ``l`` will be returned only once. | 
 |  | 
 |         Example:: | 
 |  | 
 |             >>> l = nn.Linear(2, 2) | 
 |             >>> net = nn.Sequential(l, l) | 
 |             >>> for idx, m in enumerate(net.named_modules()): | 
 |                     print(idx, '->', m) | 
 |  | 
 |             0 -> ('', Sequential( | 
 |               (0): Linear(in_features=2, out_features=2, bias=True) | 
 |               (1): Linear(in_features=2, out_features=2, bias=True) | 
 |             )) | 
 |             1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) | 
 |  | 
 |         """ | 
 |  | 
 |         if memo is None: | 
 |             memo = set() | 
 |         if self not in memo: | 
 |             memo.add(self) | 
 |             yield prefix, self | 
 |             for name, module in self._modules.items(): | 
 |                 if module is None: | 
 |                     continue | 
 |                 submodule_prefix = prefix + ('.' if prefix else '') + name | 
 |                 for m in module.named_modules(memo, submodule_prefix): | 
 |                     yield m | 
 |  | 
 |     def train(self, mode=True): | 
 |         r"""Sets the module in training mode. | 
 |  | 
 |         This has any effect only on certain modules. See documentations of | 
 |         particular modules for details of their behaviors in training/evaluation | 
 |         mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, | 
 |         etc. | 
 |  | 
 |         Args: | 
 |             mode (bool): whether to set training mode (``True``) or evaluation | 
 |                          mode (``False``). Default: ``True``. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         self.training = mode | 
 |         for module in self.children(): | 
 |             module.train(mode) | 
 |         return self | 
 |  | 
 |     def eval(self): | 
 |         r"""Sets the module in evaluation mode. | 
 |  | 
 |         This has any effect only on certain modules. See documentations of | 
 |         particular modules for details of their behaviors in training/evaluation | 
 |         mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, | 
 |         etc. | 
 |  | 
 |         This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         return self.train(False) | 
 |  | 
 |     def requires_grad_(self, requires_grad=True): | 
 |         r"""Change if autograd should record operations on parameters in this | 
 |         module. | 
 |  | 
 |         This method sets the parameters' :attr:`requires_grad` attributes | 
 |         in-place. | 
 |  | 
 |         This method is helpful for freezing part of the module for finetuning | 
 |         or training parts of a model individually (e.g., GAN training). | 
 |  | 
 |         Args: | 
 |             requires_grad (bool): whether autograd should record operations on | 
 |                                   parameters in this module. Default: ``True``. | 
 |  | 
 |         Returns: | 
 |             Module: self | 
 |         """ | 
 |         for p in self.parameters(): | 
 |             p.requires_grad_(requires_grad) | 
 |         return self | 
 |  | 
 |     def zero_grad(self): | 
 |         r"""Sets gradients of all model parameters to zero.""" | 
 |         for p in self.parameters(): | 
 |             if p.grad is not None: | 
 |                 p.grad.detach_() | 
 |                 p.grad.zero_() | 
 |  | 
 |     def share_memory(self): | 
 |         return self._apply(lambda t: t.share_memory_()) | 
 |  | 
 |     def _get_name(self): | 
 |         return self.__class__.__name__ | 
 |  | 
 |     def extra_repr(self): | 
 |         r"""Set the extra representation of the module | 
 |  | 
 |         To print customized extra information, you should reimplement | 
 |         this method in your own modules. Both single-line and multi-line | 
 |         strings are acceptable. | 
 |         """ | 
 |         return '' | 
 |  | 
 |     def __repr__(self): | 
 |         # We treat the extra repr like the sub-module, one item per line | 
 |         extra_lines = [] | 
 |         extra_repr = self.extra_repr() | 
 |         # empty string will be split into list [''] | 
 |         if extra_repr: | 
 |             extra_lines = extra_repr.split('\n') | 
 |         child_lines = [] | 
 |         for key, module in self._modules.items(): | 
 |             mod_str = repr(module) | 
 |             mod_str = _addindent(mod_str, 2) | 
 |             child_lines.append('(' + key + '): ' + mod_str) | 
 |         lines = extra_lines + child_lines | 
 |  | 
 |         main_str = self._get_name() + '(' | 
 |         if lines: | 
 |             # simple one-liner info, which most builtin Modules will use | 
 |             if len(extra_lines) == 1 and not child_lines: | 
 |                 main_str += extra_lines[0] | 
 |             else: | 
 |                 main_str += '\n  ' + '\n  '.join(lines) + '\n' | 
 |  | 
 |         main_str += ')' | 
 |         return main_str | 
 |  | 
 |     def __dir__(self): | 
 |         module_attrs = dir(self.__class__) | 
 |         attrs = list(self.__dict__.keys()) | 
 |         parameters = list(self._parameters.keys()) | 
 |         modules = list(self._modules.keys()) | 
 |         buffers = list(self._buffers.keys()) | 
 |         keys = module_attrs + attrs + parameters + modules + buffers | 
 |  | 
 |         # Eliminate attrs that are not legal Python variable names | 
 |         keys = [key for key in keys if not key[0].isdigit()] | 
 |  | 
 |         return sorted(keys) |