Enable resetting of batchnorm running moments and cumulative ("simple") moving average (#5766)
diff --git a/test/expect/TestJit.test_batchnorm.expect b/test/expect/TestJit.test_batchnorm.expect
index 9616ff5..442ec52 100644
--- a/test/expect/TestJit.test_batchnorm.expect
+++ b/test/expect/TestJit.test_batchnorm.expect
@@ -2,7 +2,8 @@
%1 : Double(2)
%2 : Double(2)
%3 : Double(2)
- %4 : Double(2)) {
- %5 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
- return (%5);
+ %4 : Double(2)
+ %5 : Long(1)) {
+ %6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
+ return (%6);
}
diff --git a/test/test_nn.py b/test/test_nn.py
index 32c79ac..87299e2 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1781,6 +1781,13 @@
def test_batchnorm_eval_cuda(self):
self._test_batchnorm_eval(torch.cuda.FloatTensor)
+ def test_batchnorm_simple_average(self):
+ self._test_batchnorm_simple_average()
+
+ @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
+ def test_batchnorm_simple_average_cuda(self):
+ self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
+
def test_MaxPool1d_indices(self):
self._test_maxpool_indices(1)
@@ -1919,6 +1926,7 @@
for i, replica in enumerate(replicas):
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
+ self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_parallel_apply(self):
@@ -2235,7 +2243,7 @@
net.add_module('empty', None)
state_dict = net.state_dict()
- self.assertEqual(len(state_dict), 9)
+ self.assertEqual(len(state_dict), 10)
self.assertIn('linear1.weight', state_dict)
self.assertIn('linear1.bias', state_dict)
self.assertIn('linear2.weight', state_dict)
@@ -2247,6 +2255,7 @@
self.assertIn('bn.bias', state_dict)
self.assertIn('bn.running_var', state_dict)
self.assertIn('bn.running_mean', state_dict)
+ self.assertIn('bn.num_batches_tracked', state_dict)
self.assertFalse(any(map(lambda k: k.startswith('empty'), state_dict.keys())))
for k, v in state_dict.items():
param = net
@@ -3693,17 +3702,21 @@
# training pass
old_running_mean = module.running_mean.clone()
old_running_var = module.running_var.clone()
+ old_num_batches_tracked = module.num_batches_tracked.clone()
module(data)
self.assertNotEqual(old_running_mean, module.running_mean)
self.assertNotEqual(old_running_var, module.running_var)
+ self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)
# eval pass
module.eval()
old_running_mean = module.running_mean.clone()
old_running_var = module.running_var.clone()
+ old_num_batches_tracked = module.num_batches_tracked.clone()
module(data)
self.assertEqual(old_running_mean, module.running_mean)
self.assertEqual(old_running_var, module.running_var)
+ self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)
def test_batchnorm_update_stats(self):
self._test_batchnorm_update_stats()
@@ -3792,6 +3805,48 @@
self.assertEqual(res1, res2)
self.assertEqual(grad1, grad2)
+ def _test_batchnorm_simple_average(self, test_type=torch.FloatTensor):
+ module = nn.BatchNorm1d(3, momentum=None).type(test_type)
+ zeros = torch.zeros(3).type(test_type)
+ ones = torch.ones(3).type(test_type)
+ self.assertEqual(module.running_mean, zeros)
+ self.assertEqual(module.running_var, ones)
+
+ data1 = torch.rand(4, 3).type(test_type)
+ data2 = torch.rand(4, 3).type(test_type)
+
+ # 1st pass
+ res1 = module(data1)
+ running_mean1 = module.running_mean.clone()
+ running_var1 = module.running_var.clone()
+ self.assertNotEqual(running_mean1, zeros)
+ self.assertNotEqual(running_var1, ones)
+
+ # reset stats
+ module.reset_running_stats()
+ self.assertEqual(module.running_mean, zeros)
+ self.assertEqual(module.running_var, ones)
+
+ # 2nd pass
+ res2 = module(data2)
+ running_mean2 = module.running_mean.clone()
+ running_var2 = module.running_var.clone()
+ self.assertNotEqual(running_mean2, zeros)
+ self.assertNotEqual(running_var2, ones)
+
+ # reset stats
+ module.reset_running_stats()
+ self.assertEqual(module.running_mean, zeros)
+ self.assertEqual(module.running_var, ones)
+
+ # 3rd (combined) pass
+ res3 = module(data1)
+ res4 = module(data2)
+ self.assertEqual(res3, res1)
+ self.assertEqual(res4, res2)
+ self.assertAlmostEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
+ self.assertAlmostEqual(module.running_var, (running_var1 + running_var2) / 2)
+
def test_pairwise_distance(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
@@ -5479,6 +5534,14 @@
),
dict(
module_name='BatchNorm1d',
+ constructor_args=(10, 1e-3, None),
+ input_size=(4, 10),
+ cudnn=True,
+ check_eval=True,
+ desc='affine_simple_average',
+ ),
+ dict(
+ module_name='BatchNorm1d',
constructor_args=(10, 1e-3, 0.3, False),
input_size=(4, 10),
cudnn=True,
@@ -5510,6 +5573,14 @@
),
dict(
module_name='BatchNorm2d',
+ constructor_args=(3, 1e-3, None),
+ input_size=(2, 3, 6, 6),
+ cudnn=True,
+ check_eval=True,
+ desc='2d_simple_average',
+ ),
+ dict(
+ module_name='BatchNorm2d',
constructor_args=(3, 1e-3, 0.8),
input_size=(2, 3, 6, 6),
cudnn=True,
@@ -5541,6 +5612,14 @@
),
dict(
module_name='BatchNorm3d',
+ constructor_args=(3, 1e-3, None),
+ input_size=(2, 3, 4, 4, 4),
+ cudnn=True,
+ check_eval=True,
+ desc='3d_simple_average',
+ ),
+ dict(
+ module_name='BatchNorm3d',
constructor_args=(3, 1e-3, 0.7),
input_size=(2, 3, 4, 4, 4),
cudnn=True,
diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py
index afac23f..c270f05 100644
--- a/torch/nn/modules/batchnorm.py
+++ b/torch/nn/modules/batchnorm.py
@@ -25,15 +25,21 @@
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
+ self.register_buffer('num_batches_tracked', torch.LongTensor([0]))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
+ self.register_parameter('num_batches_tracked', None)
self.reset_parameters()
- def reset_parameters(self):
+ def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def reset_parameters(self):
+ self.reset_running_stats()
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
@@ -44,9 +50,19 @@
def forward(self, input):
self._check_input_dim(input)
+ exponential_average_factor = 0.0
+
+ if self.training and self.track_running_stats:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item())
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
- self.training or not self.track_running_stats, self.momentum, self.eps)
+ self.training or not self.track_running_stats,
+ exponential_average_factor, self.eps)
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
@@ -93,7 +109,8 @@
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
@@ -162,7 +179,8 @@
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
@@ -232,7 +250,8 @@
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
+ computation. Can be set to ``None`` for cumulative moving average
+ (i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this