| from __future__ import division |
| |
| import torch |
| from torch import Tensor |
| from ._functions import SyncBatchNorm as sync_batch_norm |
| from .module import Module |
| from torch.nn.parameter import Parameter |
| from .. import functional as F |
| from .. import init |
| |
| from typing import Optional, Any |
| |
| |
| class _NormBase(Module): |
| """Common base of _InstanceNorm and _BatchNorm""" |
| _version = 2 |
| __constants__ = ['track_running_stats', 'momentum', 'eps', |
| 'num_features', 'affine'] |
| num_features: int |
| eps: float |
| momentum: float |
| affine: bool |
| track_running_stats: bool |
| # WARNING: weight and bias purposely not defined here. |
| # See https://github.com/pytorch/pytorch/issues/39670 |
| |
| def __init__( |
| self, |
| num_features: int, |
| eps: float = 1e-5, |
| momentum: float = 0.1, |
| affine: bool = True, |
| track_running_stats: bool = True |
| ) -> None: |
| super(_NormBase, self).__init__() |
| self.num_features = num_features |
| self.eps = eps |
| self.momentum = momentum |
| self.affine = affine |
| self.track_running_stats = track_running_stats |
| if self.affine: |
| self.weight = Parameter(torch.Tensor(num_features)) |
| self.bias = Parameter(torch.Tensor(num_features)) |
| else: |
| self.register_parameter('weight', None) |
| self.register_parameter('bias', None) |
| if self.track_running_stats: |
| self.register_buffer('running_mean', torch.zeros(num_features)) |
| self.register_buffer('running_var', torch.ones(num_features)) |
| self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) |
| else: |
| self.register_parameter('running_mean', None) |
| self.register_parameter('running_var', None) |
| self.register_parameter('num_batches_tracked', None) |
| self.reset_parameters() |
| |
| def reset_running_stats(self) -> None: |
| if self.track_running_stats: |
| self.running_mean.zero_() |
| self.running_var.fill_(1) |
| self.num_batches_tracked.zero_() |
| |
| def reset_parameters(self) -> None: |
| self.reset_running_stats() |
| if self.affine: |
| init.ones_(self.weight) |
| init.zeros_(self.bias) |
| |
| def _check_input_dim(self, input): |
| raise NotImplementedError |
| |
| def extra_repr(self): |
| return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ |
| 'track_running_stats={track_running_stats}'.format(**self.__dict__) |
| |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| version = local_metadata.get('version', None) |
| |
| if (version is None or version < 2) and self.track_running_stats: |
| # at version 2: added num_batches_tracked buffer |
| # this should have a default value of 0 |
| num_batches_tracked_key = prefix + 'num_batches_tracked' |
| if num_batches_tracked_key not in state_dict: |
| state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) |
| |
| super(_NormBase, self)._load_from_state_dict( |
| state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs) |
| |
| |
| class _BatchNorm(_NormBase): |
| |
| def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, |
| track_running_stats=True): |
| super(_BatchNorm, self).__init__( |
| num_features, eps, momentum, affine, track_running_stats) |
| |
| def forward(self, input: Tensor) -> Tensor: |
| self._check_input_dim(input) |
| |
| # exponential_average_factor is set to self.momentum |
| # (when it is available) only so that it gets updated |
| # in ONNX graph when this node is exported to ONNX. |
| if self.momentum is None: |
| exponential_average_factor = 0.0 |
| else: |
| exponential_average_factor = self.momentum |
| |
| if self.training and self.track_running_stats: |
| # TODO: if statement only here to tell the jit to skip emitting this when it is None |
| if self.num_batches_tracked is not None: |
| self.num_batches_tracked = self.num_batches_tracked + 1 |
| if self.momentum is None: # use cumulative moving average |
| exponential_average_factor = 1.0 / float(self.num_batches_tracked) |
| else: # use exponential moving average |
| exponential_average_factor = self.momentum |
| |
| r""" |
| Decide whether the mini-batch stats should be used for normalization rather than the buffers. |
| Mini-batch stats are used in training mode, and in eval mode when buffers are None. |
| """ |
| if self.training: |
| bn_training = True |
| else: |
| bn_training = (self.running_mean is None) and (self.running_var is None) |
| |
| r""" |
| Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be |
| passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are |
| used for normalization (i.e. in eval mode when buffers are not None). |
| """ |
| return F.batch_norm( |
| input, |
| # If buffers are not to be tracked, ensure that they won't be updated |
| self.running_mean if not self.training or self.track_running_stats else None, |
| self.running_var if not self.training or self.track_running_stats else None, |
| self.weight, self.bias, bn_training, exponential_average_factor, self.eps) |
| |
| |
| class BatchNorm1d(_BatchNorm): |
| r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D |
| inputs with optional additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing |
| Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors |
| of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set |
| to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated |
| via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done over the `C` dimension, computing statistics |
| on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics, and initializes statistics |
| buffers :attr:`running_mean` and :attr:`running_var` as ``None``. |
| When these buffers are ``None``, this module always uses batch statistics. |
| in both training and eval modes. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C)` or :math:`(N, C, L)` |
| - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.BatchNorm1d(100) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm1d(100, affine=False) |
| >>> input = torch.randn(20, 100) |
| >>> output = m(input) |
| """ |
| |
| def _check_input_dim(self, input): |
| if input.dim() != 2 and input.dim() != 3: |
| raise ValueError('expected 2D or 3D input (got {}D input)' |
| .format(input.dim())) |
| |
| |
| class BatchNorm2d(_BatchNorm): |
| r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs |
| with additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing |
| Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors |
| of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set |
| to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated |
| via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done over the `C` dimension, computing statistics |
| on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, H, W)` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics, and initializes statistics |
| buffers :attr:`running_mean` and :attr:`running_var` as ``None``. |
| When these buffers are ``None``, this module always uses batch statistics. |
| in both training and eval modes. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C, H, W)` |
| - Output: :math:`(N, C, H, W)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.BatchNorm2d(100) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm2d(100, affine=False) |
| >>> input = torch.randn(20, 100, 35, 45) |
| >>> output = m(input) |
| """ |
| |
| def _check_input_dim(self, input): |
| if input.dim() != 4: |
| raise ValueError('expected 4D input (got {}D input)' |
| .format(input.dim())) |
| |
| |
| class BatchNorm3d(_BatchNorm): |
| r"""Applies Batch Normalization over a 5D input (a mini-batch of 3D inputs |
| with additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing |
| Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over |
| the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors |
| of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set |
| to 1 and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated |
| via the biased estimator, equivalent to `torch.var(input, unbiased=False)`. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done over the `C` dimension, computing statistics |
| on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization |
| or Spatio-temporal Batch Normalization. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, D, H, W)` |
| eps: a value added to the denominator for numerical stability. |
| Default: 1e-5 |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics, and initializes statistics |
| buffers :attr:`running_mean` and :attr:`running_var` as ``None``. |
| When these buffers are ``None``, this module always uses batch statistics. |
| in both training and eval modes. Default: ``True`` |
| |
| Shape: |
| - Input: :math:`(N, C, D, H, W)` |
| - Output: :math:`(N, C, D, H, W)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.BatchNorm3d(100) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm3d(100, affine=False) |
| >>> input = torch.randn(20, 100, 35, 45, 10) |
| >>> output = m(input) |
| """ |
| |
| def _check_input_dim(self, input): |
| if input.dim() != 5: |
| raise ValueError('expected 5D input (got {}D input)' |
| .format(input.dim())) |
| |
| |
| class SyncBatchNorm(_BatchNorm): |
| r"""Applies Batch Normalization over a N-Dimensional input (a mini-batch of [N-2]D inputs |
| with additional channel dimension) as described in the paper |
| `Batch Normalization: Accelerating Deep Network Training by Reducing |
| Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ . |
| |
| .. math:: |
| |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
| |
| The mean and standard-deviation are calculated per-dimension over all |
| mini-batches of the same process groups. :math:`\gamma` and :math:`\beta` |
| are learnable parameter vectors of size `C` (where `C` is the input size). |
| By default, the elements of :math:`\gamma` are sampled from |
| :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. |
| The standard-deviation is calculated via the biased estimator, equivalent to |
| `torch.var(input, unbiased=False)`. |
| |
| Also by default, during training this layer keeps running estimates of its |
| computed mean and variance, which are then used for normalization during |
| evaluation. The running estimates are kept with a default :attr:`momentum` |
| of 0.1. |
| |
| If :attr:`track_running_stats` is set to ``False``, this layer then does not |
| keep running estimates, and batch statistics are instead used during |
| evaluation time as well. |
| |
| .. note:: |
| This :attr:`momentum` argument is different from one used in optimizer |
| classes and the conventional notion of momentum. Mathematically, the |
| update rule for running statistics here is |
| :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`, |
| where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the |
| new observed value. |
| |
| Because the Batch Normalization is done for each channel in the ``C`` dimension, computing |
| statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch |
| Normalization or Spatio-temporal Batch Normalization. |
| |
| Currently :class:`SyncBatchNorm` only supports |
| :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use |
| :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert |
| :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping |
| Network with DDP. |
| |
| Args: |
| num_features: :math:`C` from an expected input of size |
| :math:`(N, C, +)` |
| eps: a value added to the denominator for numerical stability. |
| Default: ``1e-5`` |
| momentum: the value used for the running_mean and running_var |
| computation. Can be set to ``None`` for cumulative moving average |
| (i.e. simple average). Default: 0.1 |
| affine: a boolean value that when set to ``True``, this module has |
| learnable affine parameters. Default: ``True`` |
| track_running_stats: a boolean value that when set to ``True``, this |
| module tracks the running mean and variance, and when set to ``False``, |
| this module does not track such statistics, and initializes statistics |
| buffers :attr:`running_mean` and :attr:`running_var` as ``None``. |
| When these buffers are ``None``, this module always uses batch statistics. |
| in both training and eval modes. Default: ``True`` |
| process_group: synchronization of stats happen within each process group |
| individually. Default behavior is synchronization across the whole |
| world |
| |
| Shape: |
| - Input: :math:`(N, C, +)` |
| - Output: :math:`(N, C, +)` (same shape as input) |
| |
| Examples:: |
| |
| >>> # With Learnable Parameters |
| >>> m = nn.SyncBatchNorm(100) |
| >>> # creating process group (optional) |
| >>> # process_ids is a list of int identifying rank ids. |
| >>> process_group = torch.distributed.new_group(process_ids) |
| >>> # Without Learnable Parameters |
| >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group) |
| >>> input = torch.randn(20, 100, 35, 45, 10) |
| >>> output = m(input) |
| |
| >>> # network is nn.BatchNorm layer |
| >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group) |
| >>> # only single gpu per process is currently supported |
| >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel( |
| >>> sync_bn_network, |
| >>> device_ids=[args.local_rank], |
| >>> output_device=args.local_rank) |
| """ |
| |
| def __init__( |
| self, |
| num_features: int, |
| eps: float = 1e-5, |
| momentum: float = 0.1, |
| affine: bool = True, |
| track_running_stats: bool = True, |
| process_group: Optional[Any] = None |
| ) -> None: |
| super(SyncBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats) |
| self.process_group = process_group |
| # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used |
| # under supported condition (single GPU per process) |
| self.ddp_gpu_size = None |
| |
| def _check_input_dim(self, input): |
| if input.dim() < 2: |
| raise ValueError('expected at least 2D input (got {}D input)' |
| .format(input.dim())) |
| |
| def _specify_ddp_gpu_num(self, gpu_size): |
| if gpu_size > 1: |
| raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process') |
| self.ddp_gpu_size = gpu_size |
| |
| def forward(self, input: Tensor) -> Tensor: |
| # currently only GPU input is supported |
| if not input.is_cuda: |
| raise ValueError('SyncBatchNorm expected input tensor to be on GPU') |
| |
| self._check_input_dim(input) |
| |
| # exponential_average_factor is set to self.momentum |
| # (when it is available) only so that it gets updated |
| # in ONNX graph when this node is exported to ONNX. |
| if self.momentum is None: |
| exponential_average_factor = 0.0 |
| else: |
| exponential_average_factor = self.momentum |
| |
| if self.training and self.track_running_stats: |
| self.num_batches_tracked = self.num_batches_tracked + 1 |
| if self.momentum is None: # use cumulative moving average |
| exponential_average_factor = 1.0 / self.num_batches_tracked.item() |
| else: # use exponential moving average |
| exponential_average_factor = self.momentum |
| |
| r""" |
| Decide whether the mini-batch stats should be used for normalization rather than the buffers. |
| Mini-batch stats are used in training mode, and in eval mode when buffers are None. |
| """ |
| if self.training: |
| bn_training = True |
| else: |
| bn_training = (self.running_mean is None) and (self.running_var is None) |
| |
| r""" |
| Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be |
| passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are |
| used for normalization (i.e. in eval mode when buffers are not None). |
| """ |
| # If buffers are not to be tracked, ensure that they won't be updated |
| running_mean = self.running_mean if not self.training or self.track_running_stats else None |
| running_var = self.running_var if not self.training or self.track_running_stats else None |
| |
| need_sync = bn_training |
| if need_sync: |
| process_group = torch.distributed.group.WORLD |
| if self.process_group: |
| process_group = self.process_group |
| world_size = torch.distributed.get_world_size(process_group) |
| need_sync = world_size > 1 |
| |
| # fallback to framework BN when synchronization is not necessary |
| if not need_sync: |
| return F.batch_norm( |
| input, running_mean, running_var, self.weight, self.bias, |
| bn_training, exponential_average_factor, self.eps) |
| else: |
| if not self.ddp_gpu_size: |
| raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel') |
| |
| assert bn_training |
| return sync_batch_norm.apply( |
| input, self.weight, self.bias, running_mean, running_var, |
| self.eps, exponential_average_factor, process_group, world_size) |
| |
| @classmethod |
| def convert_sync_batchnorm(cls, module, process_group=None): |
| r"""Helper function to convert all :attr:`BatchNorm*D` layers in the model to |
| :class:`torch.nn.SyncBatchNorm` layers. |
| |
| Args: |
| module (nn.Module): module containing one or more attr:`BatchNorm*D` layers |
| process_group (optional): process group to scope synchronization, |
| default is the whole world |
| |
| Returns: |
| The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm` |
| layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer, |
| a new :class:`torch.nn.SyncBatchNorm` layer object will be returned |
| instead. |
| |
| Example:: |
| |
| >>> # Network with nn.BatchNorm layer |
| >>> module = torch.nn.Sequential( |
| >>> torch.nn.Linear(20, 100), |
| >>> torch.nn.BatchNorm1d(100), |
| >>> ).cuda() |
| >>> # creating process group (optional) |
| >>> # process_ids is a list of int identifying rank ids. |
| >>> process_group = torch.distributed.new_group(process_ids) |
| >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group) |
| |
| """ |
| module_output = module |
| if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): |
| module_output = torch.nn.SyncBatchNorm(module.num_features, |
| module.eps, module.momentum, |
| module.affine, |
| module.track_running_stats, |
| process_group) |
| if module.affine: |
| with torch.no_grad(): |
| module_output.weight = module.weight |
| module_output.bias = module.bias |
| module_output.running_mean = module.running_mean |
| module_output.running_var = module.running_var |
| module_output.num_batches_tracked = module.num_batches_tracked |
| for name, child in module.named_children(): |
| module_output.add_module(name, cls.convert_sync_batchnorm(child, process_group)) |
| del module |
| return module_output |