|  | from typing import Optional, Any | 
|  |  | 
|  | import torch | 
|  | from torch import Tensor | 
|  | from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer | 
|  |  | 
|  | from .. import functional as F | 
|  | from .. import init | 
|  | from ._functions import SyncBatchNorm as sync_batch_norm | 
|  | from .lazy import LazyModuleMixin | 
|  | from .module import Module | 
|  |  | 
|  | __all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d', | 
|  | 'LazyBatchNorm3d', 'SyncBatchNorm'] | 
|  |  | 
|  |  | 
|  | class _NormBase(Module): | 
|  | """Common base of _InstanceNorm and _BatchNorm""" | 
|  |  | 
|  | _version = 2 | 
|  | __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"] | 
|  | num_features: int | 
|  | eps: float | 
|  | momentum: float | 
|  | affine: bool | 
|  | track_running_stats: bool | 
|  | # WARNING: weight and bias purposely not defined here. | 
|  | # See https://github.com/pytorch/pytorch/issues/39670 | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | num_features: int, | 
|  | eps: float = 1e-5, | 
|  | momentum: float = 0.1, | 
|  | affine: bool = True, | 
|  | track_running_stats: bool = True, | 
|  | device=None, | 
|  | dtype=None | 
|  | ) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__() | 
|  | self.num_features = num_features | 
|  | self.eps = eps | 
|  | self.momentum = momentum | 
|  | self.affine = affine | 
|  | self.track_running_stats = track_running_stats | 
|  | if self.affine: | 
|  | self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) | 
|  | self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) | 
|  | else: | 
|  | self.register_parameter("weight", None) | 
|  | self.register_parameter("bias", None) | 
|  | if self.track_running_stats: | 
|  | self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) | 
|  | self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) | 
|  | self.running_mean: Optional[Tensor] | 
|  | self.running_var: Optional[Tensor] | 
|  | self.register_buffer('num_batches_tracked', | 
|  | torch.tensor(0, dtype=torch.long, | 
|  | **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})) | 
|  | self.num_batches_tracked: Optional[Tensor] | 
|  | else: | 
|  | self.register_buffer("running_mean", None) | 
|  | self.register_buffer("running_var", None) | 
|  | self.register_buffer("num_batches_tracked", None) | 
|  | self.reset_parameters() | 
|  |  | 
|  | def reset_running_stats(self) -> None: | 
|  | if self.track_running_stats: | 
|  | # running_mean/running_var/num_batches... are registered at runtime depending | 
|  | # if self.track_running_stats is on | 
|  | self.running_mean.zero_()  # type: ignore[union-attr] | 
|  | self.running_var.fill_(1)  # type: ignore[union-attr] | 
|  | self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator] | 
|  |  | 
|  | def reset_parameters(self) -> None: | 
|  | self.reset_running_stats() | 
|  | if self.affine: | 
|  | init.ones_(self.weight) | 
|  | init.zeros_(self.bias) | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | raise NotImplementedError | 
|  |  | 
|  | def extra_repr(self): | 
|  | return ( | 
|  | "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " | 
|  | "track_running_stats={track_running_stats}".format(**self.__dict__) | 
|  | ) | 
|  |  | 
|  | def _load_from_state_dict( | 
|  | self, | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ): | 
|  | version = local_metadata.get("version", None) | 
|  |  | 
|  | if (version is None or version < 2) and self.track_running_stats: | 
|  | # at version 2: added num_batches_tracked buffer | 
|  | #               this should have a default value of 0 | 
|  | num_batches_tracked_key = prefix + "num_batches_tracked" | 
|  | if num_batches_tracked_key not in state_dict: | 
|  | state_dict[num_batches_tracked_key] = ( | 
|  | self.num_batches_tracked | 
|  | if self.num_batches_tracked is not None | 
|  | else torch.tensor(0, dtype=torch.long) | 
|  | ) | 
|  |  | 
|  | super()._load_from_state_dict( | 
|  | state_dict, | 
|  | prefix, | 
|  | local_metadata, | 
|  | strict, | 
|  | missing_keys, | 
|  | unexpected_keys, | 
|  | error_msgs, | 
|  | ) | 
|  |  | 
|  |  | 
|  | class _BatchNorm(_NormBase): | 
|  | def __init__( | 
|  | self, | 
|  | num_features: int, | 
|  | eps: float = 1e-5, | 
|  | momentum: float = 0.1, | 
|  | affine: bool = True, | 
|  | track_running_stats: bool = True, | 
|  | device=None, | 
|  | dtype=None | 
|  | ) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__( | 
|  | num_features, eps, momentum, affine, track_running_stats, **factory_kwargs | 
|  | ) | 
|  |  | 
|  | def forward(self, input: Tensor) -> Tensor: | 
|  | self._check_input_dim(input) | 
|  |  | 
|  | # exponential_average_factor is set to self.momentum | 
|  | # (when it is available) only so that it gets updated | 
|  | # in ONNX graph when this node is exported to ONNX. | 
|  | if self.momentum is None: | 
|  | exponential_average_factor = 0.0 | 
|  | else: | 
|  | exponential_average_factor = self.momentum | 
|  |  | 
|  | if self.training and self.track_running_stats: | 
|  | # TODO: if statement only here to tell the jit to skip emitting this when it is None | 
|  | if self.num_batches_tracked is not None:  # type: ignore[has-type] | 
|  | self.num_batches_tracked.add_(1)  # type: ignore[has-type] | 
|  | if self.momentum is None:  # use cumulative moving average | 
|  | exponential_average_factor = 1.0 / float(self.num_batches_tracked) | 
|  | else:  # use exponential moving average | 
|  | exponential_average_factor = self.momentum | 
|  |  | 
|  | r""" | 
|  | Decide whether the mini-batch stats should be used for normalization rather than the buffers. | 
|  | Mini-batch stats are used in training mode, and in eval mode when buffers are None. | 
|  | """ | 
|  | if self.training: | 
|  | bn_training = True | 
|  | else: | 
|  | bn_training = (self.running_mean is None) and (self.running_var is None) | 
|  |  | 
|  | r""" | 
|  | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be | 
|  | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are | 
|  | used for normalization (i.e. in eval mode when buffers are not None). | 
|  | """ | 
|  | return F.batch_norm( | 
|  | input, | 
|  | # If buffers are not to be tracked, ensure that they won't be updated | 
|  | self.running_mean | 
|  | if not self.training or self.track_running_stats | 
|  | else None, | 
|  | self.running_var if not self.training or self.track_running_stats else None, | 
|  | self.weight, | 
|  | self.bias, | 
|  | bn_training, | 
|  | exponential_average_factor, | 
|  | self.eps, | 
|  | ) | 
|  |  | 
|  |  | 
|  | class _LazyNormBase(LazyModuleMixin, _NormBase): | 
|  |  | 
|  | weight: UninitializedParameter  # type: ignore[assignment] | 
|  | bias: UninitializedParameter  # type: ignore[assignment] | 
|  |  | 
|  | def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, | 
|  | device=None, dtype=None) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__( | 
|  | # affine and track_running_stats are hardcoded to False to | 
|  | # avoid creating tensors that will soon be overwritten. | 
|  | 0, | 
|  | eps, | 
|  | momentum, | 
|  | False, | 
|  | False, | 
|  | **factory_kwargs, | 
|  | ) | 
|  | self.affine = affine | 
|  | self.track_running_stats = track_running_stats | 
|  | if self.affine: | 
|  | self.weight = UninitializedParameter(**factory_kwargs) | 
|  | self.bias = UninitializedParameter(**factory_kwargs) | 
|  | if self.track_running_stats: | 
|  | self.running_mean = UninitializedBuffer(**factory_kwargs) | 
|  | self.running_var = UninitializedBuffer(**factory_kwargs) | 
|  | self.num_batches_tracked = torch.tensor( | 
|  | 0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) | 
|  |  | 
|  | def reset_parameters(self) -> None: | 
|  | if not self.has_uninitialized_params() and self.num_features != 0: | 
|  | super().reset_parameters() | 
|  |  | 
|  | def initialize_parameters(self, input) -> None:  # type: ignore[override] | 
|  | if self.has_uninitialized_params(): | 
|  | self.num_features = input.shape[1] | 
|  | if self.affine: | 
|  | assert isinstance(self.weight, UninitializedParameter) | 
|  | assert isinstance(self.bias, UninitializedParameter) | 
|  | self.weight.materialize((self.num_features,)) | 
|  | self.bias.materialize((self.num_features,)) | 
|  | if self.track_running_stats: | 
|  | self.running_mean.materialize((self.num_features,))  # type:ignore[union-attr] | 
|  | self.running_var.materialize((self.num_features,))  # type:ignore[union-attr] | 
|  | self.reset_parameters() | 
|  |  | 
|  |  | 
|  | class BatchNorm1d(_BatchNorm): | 
|  | r"""Applies Batch Normalization over a 2D or 3D input as described in the paper | 
|  | `Batch Normalization: Accelerating Deep Network Training by Reducing | 
|  | Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | 
|  |  | 
|  | The mean and standard-deviation are calculated per-dimension over | 
|  | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors | 
|  | of size `C` (where `C` is the number of features or channels of the input). By default, the | 
|  | elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. | 
|  | At train time in the forward pass, the standard-deviation is calculated via the biased estimator, | 
|  | equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the | 
|  | moving average of the standard-deviation is calculated via the unbiased  estimator, equivalent to | 
|  | ``torch.var(input, unbiased=True)``. | 
|  |  | 
|  | Also by default, during training this layer keeps running estimates of its | 
|  | computed mean and variance, which are then used for normalization during | 
|  | evaluation. The running estimates are kept with a default :attr:`momentum` | 
|  | of 0.1. | 
|  |  | 
|  | If :attr:`track_running_stats` is set to ``False``, this layer then does not | 
|  | keep running estimates, and batch statistics are instead used during | 
|  | evaluation time as well. | 
|  |  | 
|  | .. note:: | 
|  | This :attr:`momentum` argument is different from one used in optimizer | 
|  | classes and the conventional notion of momentum. Mathematically, the | 
|  | update rule for running statistics here is | 
|  | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | 
|  | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | 
|  | new observed value. | 
|  |  | 
|  | Because the Batch Normalization is done over the `C` dimension, computing statistics | 
|  | on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. | 
|  |  | 
|  | Args: | 
|  | num_features: number of features or channels :math:`C` of the input | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: 1e-5 | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  |  | 
|  | Shape: | 
|  | - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, | 
|  | :math:`C` is the number of features or channels, and :math:`L` is the sequence length | 
|  | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> # With Learnable Parameters | 
|  | >>> m = nn.BatchNorm1d(100) | 
|  | >>> # Without Learnable Parameters | 
|  | >>> m = nn.BatchNorm1d(100, affine=False) | 
|  | >>> input = torch.randn(20, 100) | 
|  | >>> output = m(input) | 
|  | """ | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() != 2 and input.dim() != 3: | 
|  | raise ValueError( | 
|  | f"expected 2D or 3D input (got {input.dim()}D input)" | 
|  | ) | 
|  |  | 
|  |  | 
|  | class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): | 
|  | r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization of | 
|  | the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred | 
|  | from the ``input.size(1)``. | 
|  | The attributes that will be lazily initialized are `weight`, `bias`, | 
|  | `running_mean` and `running_var`. | 
|  |  | 
|  | Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation | 
|  | on lazy modules and their limitations. | 
|  |  | 
|  | Args: | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: 1e-5 | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  | """ | 
|  |  | 
|  | cls_to_become = BatchNorm1d  # type: ignore[assignment] | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() != 2 and input.dim() != 3: | 
|  | raise ValueError( | 
|  | f"expected 2D or 3D input (got {input.dim()}D input)" | 
|  | ) | 
|  |  | 
|  |  | 
|  | class BatchNorm2d(_BatchNorm): | 
|  | r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs | 
|  | with additional channel dimension) as described in the paper | 
|  | `Batch Normalization: Accelerating Deep Network Training by Reducing | 
|  | Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | 
|  |  | 
|  | The mean and standard-deviation are calculated per-dimension over | 
|  | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors | 
|  | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set | 
|  | to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the | 
|  | standard-deviation is calculated via the biased estimator, equivalent to | 
|  | ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the | 
|  | standard-deviation is calculated via the unbiased  estimator, equivalent to | 
|  | ``torch.var(input, unbiased=True)``. | 
|  |  | 
|  | Also by default, during training this layer keeps running estimates of its | 
|  | computed mean and variance, which are then used for normalization during | 
|  | evaluation. The running estimates are kept with a default :attr:`momentum` | 
|  | of 0.1. | 
|  |  | 
|  | If :attr:`track_running_stats` is set to ``False``, this layer then does not | 
|  | keep running estimates, and batch statistics are instead used during | 
|  | evaluation time as well. | 
|  |  | 
|  | .. note:: | 
|  | This :attr:`momentum` argument is different from one used in optimizer | 
|  | classes and the conventional notion of momentum. Mathematically, the | 
|  | update rule for running statistics here is | 
|  | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | 
|  | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | 
|  | new observed value. | 
|  |  | 
|  | Because the Batch Normalization is done over the `C` dimension, computing statistics | 
|  | on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. | 
|  |  | 
|  | Args: | 
|  | num_features: :math:`C` from an expected input of size | 
|  | :math:`(N, C, H, W)` | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: 1e-5 | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  |  | 
|  | Shape: | 
|  | - Input: :math:`(N, C, H, W)` | 
|  | - Output: :math:`(N, C, H, W)` (same shape as input) | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> # With Learnable Parameters | 
|  | >>> m = nn.BatchNorm2d(100) | 
|  | >>> # Without Learnable Parameters | 
|  | >>> m = nn.BatchNorm2d(100, affine=False) | 
|  | >>> input = torch.randn(20, 100, 35, 45) | 
|  | >>> output = m(input) | 
|  | """ | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() != 4: | 
|  | raise ValueError(f"expected 4D input (got {input.dim()}D input)") | 
|  |  | 
|  |  | 
|  | class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): | 
|  | r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization of | 
|  | the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred | 
|  | from the ``input.size(1)``. | 
|  | The attributes that will be lazily initialized are `weight`, `bias`, | 
|  | `running_mean` and `running_var`. | 
|  |  | 
|  | Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation | 
|  | on lazy modules and their limitations. | 
|  |  | 
|  | Args: | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: 1e-5 | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  | """ | 
|  |  | 
|  | cls_to_become = BatchNorm2d  # type: ignore[assignment] | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() != 4: | 
|  | raise ValueError(f"expected 4D input (got {input.dim()}D input)") | 
|  |  | 
|  |  | 
|  | class BatchNorm3d(_BatchNorm): | 
|  | r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs | 
|  | with additional channel dimension) as described in the paper | 
|  | `Batch Normalization: Accelerating Deep Network Training by Reducing | 
|  | Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | 
|  |  | 
|  | The mean and standard-deviation are calculated per-dimension over | 
|  | the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors | 
|  | of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set | 
|  | to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the | 
|  | standard-deviation is calculated via the biased estimator, equivalent to | 
|  | ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the | 
|  | standard-deviation is calculated via the unbiased  estimator, equivalent to | 
|  | ``torch.var(input, unbiased=True)``. | 
|  |  | 
|  | Also by default, during training this layer keeps running estimates of its | 
|  | computed mean and variance, which are then used for normalization during | 
|  | evaluation. The running estimates are kept with a default :attr:`momentum` | 
|  | of 0.1. | 
|  |  | 
|  | If :attr:`track_running_stats` is set to ``False``, this layer then does not | 
|  | keep running estimates, and batch statistics are instead used during | 
|  | evaluation time as well. | 
|  |  | 
|  | .. note:: | 
|  | This :attr:`momentum` argument is different from one used in optimizer | 
|  | classes and the conventional notion of momentum. Mathematically, the | 
|  | update rule for running statistics here is | 
|  | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | 
|  | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | 
|  | new observed value. | 
|  |  | 
|  | Because the Batch Normalization is done over the `C` dimension, computing statistics | 
|  | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization | 
|  | or Spatio-temporal Batch Normalization. | 
|  |  | 
|  | Args: | 
|  | num_features: :math:`C` from an expected input of size | 
|  | :math:`(N, C, D, H, W)` | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: 1e-5 | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  |  | 
|  | Shape: | 
|  | - Input: :math:`(N, C, D, H, W)` | 
|  | - Output: :math:`(N, C, D, H, W)` (same shape as input) | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> # With Learnable Parameters | 
|  | >>> m = nn.BatchNorm3d(100) | 
|  | >>> # Without Learnable Parameters | 
|  | >>> m = nn.BatchNorm3d(100, affine=False) | 
|  | >>> input = torch.randn(20, 100, 35, 45, 10) | 
|  | >>> output = m(input) | 
|  | """ | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() != 5: | 
|  | raise ValueError(f"expected 5D input (got {input.dim()}D input)") | 
|  |  | 
|  |  | 
|  | class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): | 
|  | r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization of | 
|  | the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred | 
|  | from the ``input.size(1)``. | 
|  | The attributes that will be lazily initialized are `weight`, `bias`, | 
|  | `running_mean` and `running_var`. | 
|  |  | 
|  | Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation | 
|  | on lazy modules and their limitations. | 
|  |  | 
|  | Args: | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: 1e-5 | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  | """ | 
|  |  | 
|  | cls_to_become = BatchNorm3d  # type: ignore[assignment] | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() != 5: | 
|  | raise ValueError(f"expected 5D input (got {input.dim()}D input)") | 
|  |  | 
|  |  | 
|  | class SyncBatchNorm(_BatchNorm): | 
|  | r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs | 
|  | with additional channel dimension) as described in the paper | 
|  | `Batch Normalization: Accelerating Deep Network Training by Reducing | 
|  | Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta | 
|  |  | 
|  | The mean and standard-deviation are calculated per-dimension over all | 
|  | mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` | 
|  | are learnable parameter vectors of size `C` (where `C` is the input size). | 
|  | By default, the elements of :math:`\gamma` are sampled from | 
|  | :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. | 
|  | The standard-deviation is calculated via the biased estimator, equivalent to | 
|  | `torch.var(input, unbiased=False)`. | 
|  |  | 
|  | Also by default, during training this layer keeps running estimates of its | 
|  | computed mean and variance, which are then used for normalization during | 
|  | evaluation. The running estimates are kept with a default :attr:`momentum` | 
|  | of 0.1. | 
|  |  | 
|  | If :attr:`track_running_stats` is set to ``False``, this layer then does not | 
|  | keep running estimates, and batch statistics are instead used during | 
|  | evaluation time as well. | 
|  |  | 
|  | .. note:: | 
|  | This :attr:`momentum` argument is different from one used in optimizer | 
|  | classes and the conventional notion of momentum. Mathematically, the | 
|  | update rule for running statistics here is | 
|  | :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, | 
|  | where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the | 
|  | new observed value. | 
|  |  | 
|  | Because the Batch Normalization is done for each channel in the ``C`` dimension, computing | 
|  | statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch | 
|  | Normalization or Spatio-temporal Batch Normalization. | 
|  |  | 
|  | Currently :class:`SyncBatchNorm` only supports | 
|  | :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use | 
|  | :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert | 
|  | :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping | 
|  | Network with DDP. | 
|  |  | 
|  | Args: | 
|  | num_features: :math:`C` from an expected input of size | 
|  | :math:`(N, C, +)` | 
|  | eps: a value added to the denominator for numerical stability. | 
|  | Default: ``1e-5`` | 
|  | momentum: the value used for the running_mean and running_var | 
|  | computation. Can be set to ``None`` for cumulative moving average | 
|  | (i.e. simple average). Default: 0.1 | 
|  | affine: a boolean value that when set to ``True``, this module has | 
|  | learnable affine parameters. Default: ``True`` | 
|  | track_running_stats: a boolean value that when set to ``True``, this | 
|  | module tracks the running mean and variance, and when set to ``False``, | 
|  | this module does not track such statistics, and initializes statistics | 
|  | buffers :attr:`running_mean` and :attr:`running_var` as ``None``. | 
|  | When these buffers are ``None``, this module always uses batch statistics. | 
|  | in both training and eval modes. Default: ``True`` | 
|  | process_group: synchronization of stats happen within each process group | 
|  | individually. Default behavior is synchronization across the whole | 
|  | world | 
|  |  | 
|  | Shape: | 
|  | - Input: :math:`(N, C, +)` | 
|  | - Output: :math:`(N, C, +)` (same shape as input) | 
|  |  | 
|  | .. note:: | 
|  | Synchronization of batchnorm statistics occurs only while training, i.e. | 
|  | synchronization is disabled when ``model.eval()`` is set or if | 
|  | ``self.training`` is otherwise ``False``. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> # xdoctest: +SKIP | 
|  | >>> # With Learnable Parameters | 
|  | >>> m = nn.SyncBatchNorm(100) | 
|  | >>> # creating process group (optional) | 
|  | >>> # ranks is a list of int identifying rank ids. | 
|  | >>> ranks = list(range(8)) | 
|  | >>> r1, r2 = ranks[:4], ranks[4:] | 
|  | >>> # Note: every rank calls into new_group for every | 
|  | >>> # process group created, even if that rank is not | 
|  | >>> # part of the group. | 
|  | >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] | 
|  | >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] | 
|  | >>> # Without Learnable Parameters | 
|  | >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) | 
|  | >>> input = torch.randn(20, 100, 35, 45, 10) | 
|  | >>> output = m(input) | 
|  |  | 
|  | >>> # network is nn.BatchNorm layer | 
|  | >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) | 
|  | >>> # only single gpu per process is currently supported | 
|  | >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( | 
|  | >>>                         sync_bn_network, | 
|  | >>>                         device_ids=[args.local_rank], | 
|  | >>>                         output_device=args.local_rank) | 
|  | """ | 
|  |  | 
|  | def __init__( | 
|  | self, | 
|  | num_features: int, | 
|  | eps: float = 1e-5, | 
|  | momentum: float = 0.1, | 
|  | affine: bool = True, | 
|  | track_running_stats: bool = True, | 
|  | process_group: Optional[Any] = None, | 
|  | device=None, | 
|  | dtype=None | 
|  | ) -> None: | 
|  | factory_kwargs = {'device': device, 'dtype': dtype} | 
|  | super().__init__( | 
|  | num_features, eps, momentum, affine, track_running_stats, **factory_kwargs | 
|  | ) | 
|  | self.process_group = process_group | 
|  |  | 
|  | def _check_input_dim(self, input): | 
|  | if input.dim() < 2: | 
|  | raise ValueError( | 
|  | f"expected at least 2D input (got {input.dim()}D input)" | 
|  | ) | 
|  |  | 
|  | def _check_non_zero_input_channels(self, input): | 
|  | if input.size(1) == 0: | 
|  | raise ValueError( | 
|  | "SyncBatchNorm number of input channels should be non-zero" | 
|  | ) | 
|  |  | 
|  | def forward(self, input: Tensor) -> Tensor: | 
|  | self._check_input_dim(input) | 
|  | self._check_non_zero_input_channels(input) | 
|  |  | 
|  | # exponential_average_factor is set to self.momentum | 
|  | # (when it is available) only so that it gets updated | 
|  | # in ONNX graph when this node is exported to ONNX. | 
|  | if self.momentum is None: | 
|  | exponential_average_factor = 0.0 | 
|  | else: | 
|  | exponential_average_factor = self.momentum | 
|  |  | 
|  | if self.training and self.track_running_stats: | 
|  | assert self.num_batches_tracked is not None | 
|  | self.num_batches_tracked.add_(1) | 
|  | if self.momentum is None:  # use cumulative moving average | 
|  | exponential_average_factor = 1.0 / self.num_batches_tracked.item() | 
|  | else:  # use exponential moving average | 
|  | exponential_average_factor = self.momentum | 
|  |  | 
|  | r""" | 
|  | Decide whether the mini-batch stats should be used for normalization rather than the buffers. | 
|  | Mini-batch stats are used in training mode, and in eval mode when buffers are None. | 
|  | """ | 
|  | if self.training: | 
|  | bn_training = True | 
|  | else: | 
|  | bn_training = (self.running_mean is None) and (self.running_var is None) | 
|  |  | 
|  | r""" | 
|  | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be | 
|  | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are | 
|  | used for normalization (i.e. in eval mode when buffers are not None). | 
|  | """ | 
|  | # If buffers are not to be tracked, ensure that they won't be updated | 
|  | running_mean = ( | 
|  | self.running_mean if not self.training or self.track_running_stats else None | 
|  | ) | 
|  | running_var = ( | 
|  | self.running_var if not self.training or self.track_running_stats else None | 
|  | ) | 
|  |  | 
|  | # Don't sync batchnorm stats in inference mode (model.eval()). | 
|  | need_sync = (bn_training and self.training and | 
|  | torch.distributed.is_available() and torch.distributed.is_initialized()) | 
|  | if need_sync: | 
|  | # currently only GPU/PrivateUse1 input is supported | 
|  | if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]: | 
|  | raise ValueError("SyncBatchNorm expected input tensor to be on GPU or " | 
|  | f"{torch._C._get_privateuse1_backend_name()}") | 
|  |  | 
|  | process_group = torch.distributed.group.WORLD | 
|  | if self.process_group: | 
|  | process_group = self.process_group | 
|  | world_size = torch.distributed.get_world_size(process_group) | 
|  | need_sync = world_size > 1 | 
|  |  | 
|  | # fallback to framework BN when synchronization is not necessary | 
|  | if not need_sync: | 
|  | return F.batch_norm( | 
|  | input, | 
|  | running_mean, | 
|  | running_var, | 
|  | self.weight, | 
|  | self.bias, | 
|  | bn_training, | 
|  | exponential_average_factor, | 
|  | self.eps, | 
|  | ) | 
|  | else: | 
|  | assert bn_training | 
|  | return sync_batch_norm.apply( | 
|  | input, | 
|  | self.weight, | 
|  | self.bias, | 
|  | running_mean, | 
|  | running_var, | 
|  | self.eps, | 
|  | exponential_average_factor, | 
|  | process_group, | 
|  | world_size, | 
|  | ) | 
|  |  | 
|  | @classmethod | 
|  | def convert_sync_batchnorm(cls, module, process_group=None): | 
|  | r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to | 
|  | :class:`torch.nn.SyncBatchNorm` layers. | 
|  |  | 
|  | Args: | 
|  | module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers | 
|  | process_group (optional): process group to scope synchronization, | 
|  | default is the whole world | 
|  |  | 
|  | Returns: | 
|  | The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` | 
|  | layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, | 
|  | a new :class:`torch.nn.SyncBatchNorm` layer object will be returned | 
|  | instead. | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> # Network with nn.BatchNorm layer | 
|  | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) | 
|  | >>> module = torch.nn.Sequential( | 
|  | >>>            torch.nn.Linear(20, 100), | 
|  | >>>            torch.nn.BatchNorm1d(100), | 
|  | >>>          ).cuda() | 
|  | >>> # creating process group (optional) | 
|  | >>> # ranks is a list of int identifying rank ids. | 
|  | >>> ranks = list(range(8)) | 
|  | >>> r1, r2 = ranks[:4], ranks[4:] | 
|  | >>> # Note: every rank calls into new_group for every | 
|  | >>> # process group created, even if that rank is not | 
|  | >>> # part of the group. | 
|  | >>> # xdoctest: +SKIP("distributed") | 
|  | >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]] | 
|  | >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1] | 
|  | >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) | 
|  |  | 
|  | """ | 
|  | module_output = module | 
|  | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): | 
|  | module_output = torch.nn.SyncBatchNorm( | 
|  | module.num_features, | 
|  | module.eps, | 
|  | module.momentum, | 
|  | module.affine, | 
|  | module.track_running_stats, | 
|  | process_group, | 
|  | ) | 
|  | if module.affine: | 
|  | with torch.no_grad(): | 
|  | module_output.weight = module.weight | 
|  | module_output.bias = module.bias | 
|  | module_output.running_mean = module.running_mean | 
|  | module_output.running_var = module.running_var | 
|  | module_output.num_batches_tracked = module.num_batches_tracked | 
|  | module_output.training = module.training | 
|  | if hasattr(module, "qconfig"): | 
|  | module_output.qconfig = module.qconfig | 
|  | for name, child in module.named_children(): | 
|  | module_output.add_module( | 
|  | name, cls.convert_sync_batchnorm(child, process_group) | 
|  | ) | 
|  | del module | 
|  | return module_output |