Clarify track_running_stats docs; Make SyncBatchNorm track_running_stats behavior consistent (#44445)
Summary:
context: https://github.com/pytorch/pytorch/pull/38084
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44445
Reviewed By: colesbury
Differential Revision: D23634216
Pulled By: mrshenli
fbshipit-source-id: d1242c694dec0e7794651f8031327625eb9989ee
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index ae316ab..d13bcef 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -116,7 +116,8 @@
else: # use exponential moving average
exponential_average_factor = self.momentum
- """ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
+ r"""
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
@@ -124,7 +125,8 @@
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
- """Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
+ r"""
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
@@ -184,8 +186,10 @@
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics and uses batch statistics instead
- in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+ this module does not track such statistics, and initializes statistics
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
+ When these buffers are ``None``, this module always uses batch statistics.
+ in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
@@ -255,8 +259,10 @@
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics and uses batch statistics instead
- in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+ this module does not track such statistics, and initializes statistics
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
+ When these buffers are ``None``, this module always uses batch statistics.
+ in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
@@ -327,8 +333,10 @@
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics and uses batch statistics instead
- in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+ this module does not track such statistics, and initializes statistics
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
+ When these buffers are ``None``, this module always uses batch statistics.
+ in both training and eval modes. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
@@ -407,8 +415,10 @@
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
module tracks the running mean and variance, and when set to ``False``,
- this module does not track such statistics and uses batch statistics instead
- in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+ this module does not track such statistics, and initializes statistics
+ buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
+ When these buffers are ``None``, this module always uses batch statistics.
+ in both training and eval modes. Default: ``True``
process_group: synchronization of stats happen within each process group
individually. Default behavior is synchronization across the whole
world
@@ -485,7 +495,25 @@
else: # use exponential moving average
exponential_average_factor = self.momentum
- need_sync = self.training or not self.track_running_stats
+ r"""
+ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
+ Mini-batch stats are used in training mode, and in eval mode when buffers are None.
+ """
+ if self.training:
+ bn_training = True
+ else:
+ bn_training = (self.running_mean is None) and (self.running_var is None)
+
+ r"""
+ Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
+ passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
+ used for normalization (i.e. in eval mode when buffers are not None).
+ """
+ # If buffers are not to be tracked, ensure that they won't be updated
+ running_mean = self.running_mean if not self.training or self.track_running_stats else None
+ running_var = self.running_var if not self.training or self.track_running_stats else None
+
+ need_sync = bn_training
if need_sync:
process_group = torch.distributed.group.WORLD
if self.process_group:
@@ -496,15 +524,15 @@
# fallback to framework BN when synchronization is not necessary
if not need_sync:
return F.batch_norm(
- input, self.running_mean, self.running_var, self.weight, self.bias,
- self.training or not self.track_running_stats,
- exponential_average_factor, self.eps)
+ input, running_mean, running_var, self.weight, self.bias,
+ bn_training, exponential_average_factor, self.eps)
else:
if not self.ddp_gpu_size:
raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
+ assert bn_training
return sync_batch_norm.apply(
- input, self.weight, self.bias, self.running_mean, self.running_var,
+ input, self.weight, self.bias, running_mean, running_var,
self.eps, exponential_average_factor, process_group, world_size)
@classmethod