Clarify track_running_stats docs; Make SyncBatchNorm track_running_stats behavior consistent (#44445)

Summary:
context: https://github.com/pytorch/pytorch/pull/38084

Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44445

Reviewed By: colesbury

Differential Revision: D23634216

Pulled By: mrshenli

fbshipit-source-id: d1242c694dec0e7794651f8031327625eb9989ee
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index ae316ab..d13bcef 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -116,7 +116,8 @@
                 else:  # use exponential moving average
                     exponential_average_factor = self.momentum
 
-        """ Decide whether the mini-batch stats should be used for normalization rather than the buffers.
+        r""" 
+        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
         Mini-batch stats are used in training mode, and in eval mode when buffers are None.
         """
         if self.training:
@@ -124,7 +125,8 @@
         else:
             bn_training = (self.running_mean is None) and (self.running_var is None)
 
-        """Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
+        r"""
+        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
         passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
         used for normalization (i.e. in eval mode when buffers are not None).
         """
@@ -184,8 +186,10 @@
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this
             module tracks the running mean and variance, and when set to ``False``,
-            this module does not track such statistics and uses batch statistics instead
-            in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+            this module does not track such statistics, and initializes statistics
+            buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 
+            When these buffers are ``None``, this module always uses batch statistics.
+            in both training and eval modes. Default: ``True``
 
     Shape:
         - Input: :math:`(N, C)` or :math:`(N, C, L)`
@@ -255,8 +259,10 @@
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this
             module tracks the running mean and variance, and when set to ``False``,
-            this module does not track such statistics and uses batch statistics instead
-            in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+            this module does not track such statistics, and initializes statistics
+            buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 
+            When these buffers are ``None``, this module always uses batch statistics.
+            in both training and eval modes. Default: ``True``
 
     Shape:
         - Input: :math:`(N, C, H, W)`
@@ -327,8 +333,10 @@
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this
             module tracks the running mean and variance, and when set to ``False``,
-            this module does not track such statistics and uses batch statistics instead
-            in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+            this module does not track such statistics, and initializes statistics
+            buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 
+            When these buffers are ``None``, this module always uses batch statistics.
+            in both training and eval modes. Default: ``True``
 
     Shape:
         - Input: :math:`(N, C, D, H, W)`
@@ -407,8 +415,10 @@
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this
             module tracks the running mean and variance, and when set to ``False``,
-            this module does not track such statistics and uses batch statistics instead
-            in both training and eval modes if the running mean and variance are ``None``. Default: ``True``
+            this module does not track such statistics, and initializes statistics
+            buffers :attr:`running_mean` and :attr:`running_var` as ``None``. 
+            When these buffers are ``None``, this module always uses batch statistics.
+            in both training and eval modes. Default: ``True``
         process_group: synchronization of stats happen within each process group
             individually. Default behavior is synchronization across the whole
             world
@@ -485,7 +495,25 @@
             else:  # use exponential moving average
                 exponential_average_factor = self.momentum
 
-        need_sync = self.training or not self.track_running_stats
+        r""" 
+        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
+        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
+        """
+        if self.training:
+            bn_training = True
+        else:
+            bn_training = (self.running_mean is None) and (self.running_var is None)
+
+        r"""
+        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
+        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
+        used for normalization (i.e. in eval mode when buffers are not None).
+        """
+        # If buffers are not to be tracked, ensure that they won't be updated
+        running_mean = self.running_mean if not self.training or self.track_running_stats else None
+        running_var = self.running_var if not self.training or self.track_running_stats else None
+
+        need_sync = bn_training
         if need_sync:
             process_group = torch.distributed.group.WORLD
             if self.process_group:
@@ -496,15 +524,15 @@
         # fallback to framework BN when synchronization is not necessary
         if not need_sync:
             return F.batch_norm(
-                input, self.running_mean, self.running_var, self.weight, self.bias,
-                self.training or not self.track_running_stats,
-                exponential_average_factor, self.eps)
+                input, running_mean, running_var, self.weight, self.bias,
+                bn_training, exponential_average_factor, self.eps)
         else:
             if not self.ddp_gpu_size:
                 raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
 
+            assert bn_training
             return sync_batch_norm.apply(
-                input, self.weight, self.bias, self.running_mean, self.running_var,
+                input, self.weight, self.bias, running_mean, running_var,
                 self.eps, exponential_average_factor, process_group, world_size)
 
     @classmethod