| from __future__ import division |
| |
| import torch |
| from .module import Module |
| from torch.nn.parameter import Parameter |
| from .. import functional as F |
| from .. import init |
| from ..._jit_internal import weak_module, weak_script_method |
| |
| |
| # TODO: check contiguous in THNN |
| # TODO: use separate backend functions? |
| @weak_module |
| class _BatchNorm(Module): |
| _version = 2 |
| __constants__ = ['training', 'track_running_stats', 'momentum', 'eps'] |
| |
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, |
| track_running_stats=True): |
| super(_BatchNorm, self).__init__() |
| self.num_features = num_features |
| self.eps = eps |
| self.momentum = momentum |
| self.affine = affine |
| self.track_running_stats = track_running_stats |
| if self.affine: |
| self.weight = Parameter(torch.Tensor(num_features)) |
| self.bias = Parameter(torch.Tensor(num_features)) |
| else: |
| self.register_parameter('weight', None) |
| self.register_parameter('bias', None) |
| if self.track_running_stats: |
| self.register_buffer('running_mean', torch.zeros(num_features)) |
| self.register_buffer('running_var', torch.ones(num_features)) |
| self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) |
| else: |
| self.register_parameter('running_mean', None) |
| self.register_parameter('running_var', None) |
| self.register_parameter('num_batches_tracked', None) |
| self.reset_parameters() |
| |
| def reset_running_stats(self): |
| if self.track_running_stats: |
| self.running_mean.zero_() |
| self.running_var.fill_(1) |
| self.num_batches_tracked.zero_() |
| |
| def reset_parameters(self): |
| self.reset_running_stats() |
| if self.affine: |
| init.uniform_(self.weight) |
| init.zeros_(self.bias) |
| |
| def _check_input_dim(self, input): |
| raise NotImplementedError |
| |
| @weak_script_method |
| def forward(self, input): |
| self._check_input_dim(input) |
| |
| exponential_average_factor = 0.0 |
| |
| if self.training and self.track_running_stats: |
| self.num_batches_tracked += 1 |
| if self.momentum is None: # use cumulative moving average |
| exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
| else: # use exponential moving average |
| exponential_average_factor = self.momentum |
| |
| return F.batch_norm( |
| input, self.running_mean, self.running_var, self.weight, self.bias, |
| self.training or not self.track_running_stats, |
| exponential_average_factor, self.eps) |
| |
| def extra_repr(self): |
| return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ |
| 'track_running_stats={track_running_stats}'.format(**self.__dict__) |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| version = local_metadata.get('version', None) |
| |
| if (version is None or version < 2) and self.track_running_stats: |
| # at version 2: added num_batches_tracked buffer |
| # this should have a default value of 0 |
| num_batches_tracked_key = prefix + 'num_batches_tracked' |
| if num_batches_tracked_key not in state_dict: |
| state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) |
| |
| super(_BatchNorm, self)._load_from_state_dict( |
| state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs) |
| |
| |
| @weak_module |
| class BatchNorm1d(_BatchNorm): |
| r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D |
| inputs with optional additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors |
| of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled |
| from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done over the `C` dimension, computing statistics |
| on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics and always uses batch |
| statistics in both training and eval modes. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C)` or :math:`(N, C, L)` |
| - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.BatchNorm1d(100) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm1d(100, affine=False) |
| >>> input = torch.randn(20, 100) |
| >>> output = m(input) |
| |
| .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: |
| https://arxiv.org/abs/1502.03167 |
| """ |
| |
| def _check_input_dim(self, input): |
| if input.dim() != 2 and input.dim() != 3: |
| raise ValueError('expected 2D or 3D input (got {}D input)' |
| .format(input.dim())) |
| |
| |
| @weak_module |
| class BatchNorm2d(_BatchNorm): |
| r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs |
| with additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors |
| of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled |
| from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done over the `C` dimension, computing statistics |
| on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, H, W)` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics and always uses batch |
| statistics in both training and eval modes. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C, H, W)` |
| - Output: :math:`(N, C, H, W)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.BatchNorm2d(100) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm2d(100, affine=False) |
| >>> input = torch.randn(20, 100, 35, 45) |
| >>> output = m(input) |
| |
| .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: |
| https://arxiv.org/abs/1502.03167 |
| """ |
| |
| def _check_input_dim(self, input): |
| if input.dim() != 4: |
| raise ValueError('expected 4D input (got {}D input)' |
| .format(input.dim())) |
| |
| |
| @weak_module |
| class BatchNorm3d(_BatchNorm): |
| r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs |
| with additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors |
| of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are sampled |
| from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done over the `C` dimension, computing statistics |
| on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization |
| or Spatio-temporal Batch Normalization. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, D, H, W)` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics and always uses batch |
| statistics in both training and eval modes. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C, D, H, W)` |
| - Output: :math:`(N, C, D, H, W)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.BatchNorm3d(100) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm3d(100, affine=False) |
| >>> input = torch.randn(20, 100, 35, 45, 10) |
| >>> output = m(input) |
| |
| .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`: |
| https://arxiv.org/abs/1502.03167 |
| """ |
| |
| def _check_input_dim(self, input): |
| if input.dim() != 5: |
| raise ValueError('expected 5D input (got {}D input)' |
| .format(input.dim())) |