Enable resetting of batchnorm running moments and cumulative ("simple") moving average (#5766)

diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect
index 9616ff5..442ec52 100644
--- a/test/expect/TestJit.test_batchnorm.expect
+++ b/test/expect/TestJit.test_batchnorm.expect
@@ -2,7 +2,8 @@
       %1 : Double(2)
       %2 : Double(2)
       %3 : Double(2)
-      %4 : Double(2)) {
-  %5 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
-  return (%5);
+      %4 : Double(2)
+      %5 : Long(1)) {
+  %6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
+  return (%6);
 }
diff --git a/test/test_nn.py b/test/test_nn.py
index 32c79ac..87299e2 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1781,6 +1781,13 @@
     def test_batchnorm_eval_cuda(self):
         self._test_batchnorm_eval(torch.cuda.FloatTensor)
 
+    def test_batchnorm_simple_average(self):
+        self._test_batchnorm_simple_average()
+
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+    def test_batchnorm_simple_average_cuda(self):
+        self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
+
     def test_MaxPool1d_indices(self):
         self._test_maxpool_indices(1)
 
@@ -1919,6 +1926,7 @@
         for i, replica in enumerate(replicas):
             self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
             self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
+            self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
 
     @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
     def test_parallel_apply(self):
@@ -2235,7 +2243,7 @@
         net.add_module('empty', None)
 
         state_dict = net.state_dict()
-        self.assertEqual(len(state_dict), 9)
+        self.assertEqual(len(state_dict), 10)
         self.assertIn('linear1.weight', state_dict)
         self.assertIn('linear1.bias', state_dict)
         self.assertIn('linear2.weight', state_dict)
@@ -2247,6 +2255,7 @@
         self.assertIn('bn.bias', state_dict)
         self.assertIn('bn.running_var', state_dict)
         self.assertIn('bn.running_mean', state_dict)
+        self.assertIn('bn.num_batches_tracked', state_dict)
         self.assertFalse(any(map(lambda k: k.startswith('empty'), state_dict.keys())))
         for k, v in state_dict.items():
             param = net
@@ -3693,17 +3702,21 @@
         # training pass
         old_running_mean = module.running_mean.clone()
         old_running_var = module.running_var.clone()
+        old_num_batches_tracked = module.num_batches_tracked.clone()
         module(data)
         self.assertNotEqual(old_running_mean, module.running_mean)
         self.assertNotEqual(old_running_var, module.running_var)
+        self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)
 
         # eval pass
         module.eval()
         old_running_mean = module.running_mean.clone()
         old_running_var = module.running_var.clone()
+        old_num_batches_tracked = module.num_batches_tracked.clone()
         module(data)
         self.assertEqual(old_running_mean, module.running_mean)
         self.assertEqual(old_running_var, module.running_var)
+        self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)
 
     def test_batchnorm_update_stats(self):
         self._test_batchnorm_update_stats()
@@ -3792,6 +3805,48 @@
         self.assertEqual(res1, res2)
         self.assertEqual(grad1, grad2)
 
+    def _test_batchnorm_simple_average(self, test_type=torch.FloatTensor):
+        module = nn.BatchNorm1d(3, momentum=None).type(test_type)
+        zeros = torch.zeros(3).type(test_type)
+        ones = torch.ones(3).type(test_type)
+        self.assertEqual(module.running_mean, zeros)
+        self.assertEqual(module.running_var, ones)
+
+        data1 = torch.rand(4, 3).type(test_type)
+        data2 = torch.rand(4, 3).type(test_type)
+
+        # 1st pass
+        res1 = module(data1)
+        running_mean1 = module.running_mean.clone()
+        running_var1 = module.running_var.clone()
+        self.assertNotEqual(running_mean1, zeros)
+        self.assertNotEqual(running_var1, ones)
+
+        # reset stats
+        module.reset_running_stats()
+        self.assertEqual(module.running_mean, zeros)
+        self.assertEqual(module.running_var, ones)
+
+        # 2nd pass
+        res2 = module(data2)
+        running_mean2 = module.running_mean.clone()
+        running_var2 = module.running_var.clone()
+        self.assertNotEqual(running_mean2, zeros)
+        self.assertNotEqual(running_var2, ones)
+
+        # reset stats
+        module.reset_running_stats()
+        self.assertEqual(module.running_mean, zeros)
+        self.assertEqual(module.running_var, ones)
+
+        # 3rd (combined) pass
+        res3 = module(data1)
+        res4 = module(data2)
+        self.assertEqual(res3, res1)
+        self.assertEqual(res4, res2)
+        self.assertAlmostEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
+        self.assertAlmostEqual(module.running_var, (running_var1 + running_var2) / 2)
+
     def test_pairwise_distance(self):
         input1 = Variable(torch.randn(4, 4), requires_grad=True)
         input2 = Variable(torch.randn(4, 4), requires_grad=True)
@@ -5479,6 +5534,14 @@
     ),
     dict(
         module_name='BatchNorm1d',
+        constructor_args=(10, 1e-3, None),
+        input_size=(4, 10),
+        cudnn=True,
+        check_eval=True,
+        desc='affine_simple_average',
+    ),
+    dict(
+        module_name='BatchNorm1d',
         constructor_args=(10, 1e-3, 0.3, False),
         input_size=(4, 10),
         cudnn=True,
@@ -5510,6 +5573,14 @@
     ),
     dict(
         module_name='BatchNorm2d',
+        constructor_args=(3, 1e-3, None),
+        input_size=(2, 3, 6, 6),
+        cudnn=True,
+        check_eval=True,
+        desc='2d_simple_average',
+    ),
+    dict(
+        module_name='BatchNorm2d',
         constructor_args=(3, 1e-3, 0.8),
         input_size=(2, 3, 6, 6),
         cudnn=True,
@@ -5541,6 +5612,14 @@
     ),
     dict(
         module_name='BatchNorm3d',
+        constructor_args=(3, 1e-3, None),
+        input_size=(2, 3, 4, 4, 4),
+        cudnn=True,
+        check_eval=True,
+        desc='3d_simple_average',
+    ),
+    dict(
+        module_name='BatchNorm3d',
         constructor_args=(3, 1e-3, 0.7),
         input_size=(2, 3, 4, 4, 4),
         cudnn=True,
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index afac23f..c270f05 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -25,15 +25,21 @@
         if self.track_running_stats:
             self.register_buffer('running_mean', torch.zeros(num_features))
             self.register_buffer('running_var', torch.ones(num_features))
+            self.register_buffer('num_batches_tracked', torch.LongTensor([0]))
         else:
             self.register_parameter('running_mean', None)
             self.register_parameter('running_var', None)
+            self.register_parameter('num_batches_tracked', None)
         self.reset_parameters()
 
-    def reset_parameters(self):
+    def reset_running_stats(self):
         if self.track_running_stats:
             self.running_mean.zero_()
             self.running_var.fill_(1)
+            self.num_batches_tracked.zero_()
+
+    def reset_parameters(self):
+        self.reset_running_stats()
         if self.affine:
             self.weight.data.uniform_()
             self.bias.data.zero_()
@@ -44,9 +50,19 @@
     def forward(self, input):
         self._check_input_dim(input)
 
+        exponential_average_factor = 0.0
+
+        if self.training and self.track_running_stats:
+            self.num_batches_tracked += 1
+            if self.momentum is None:  # use cumulative moving average
+                exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item())
+            else:  # use exponential moving average
+                exponential_average_factor = self.momentum
+
         return F.batch_norm(
             input, self.running_mean, self.running_var, self.weight, self.bias,
-            self.training or not self.track_running_stats, self.momentum, self.eps)
+            self.training or not self.track_running_stats,
+            exponential_average_factor, self.eps)
 
     def __repr__(self):
         return ('{name}({num_features}, eps={eps}, momentum={momentum},'
@@ -93,7 +109,8 @@
         eps: a value added to the denominator for numerical stability.
             Default: 1e-5
         momentum: the value used for the running_mean and running_var
-            computation. Default: 0.1
+            computation. Can be set to ``None`` for cumulative moving average
+            (i.e. simple average). Default: 0.1
         affine: a boolean value that when set to ``True``, this module has
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this
@@ -162,7 +179,8 @@
         eps: a value added to the denominator for numerical stability.
             Default: 1e-5
         momentum: the value used for the running_mean and running_var
-            computation. Default: 0.1
+            computation. Can be set to ``None`` for cumulative moving average
+            (i.e. simple average). Default: 0.1
         affine: a boolean value that when set to ``True``, this module has
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this
@@ -232,7 +250,8 @@
         eps: a value added to the denominator for numerical stability.
             Default: 1e-5
         momentum: the value used for the running_mean and running_var
-            computation. Default: 0.1
+            computation. Can be set to ``None`` for cumulative moving average
+            (i.e. simple average). Default: 0.1
         affine: a boolean value that when set to ``True``, this module has
             learnable affine parameters. Default: ``True``
         track_running_stats: a boolean value that when set to ``True``, this