Fix bug in zero_grad, when some parameters didn't require grad
diff --git a/test/test_nn.py b/test/test_nn.py
index d852618..6f3bbc7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -341,6 +341,24 @@
expected_grad = torch.ones(5, 5).mm(module.weight.data) * 2
self.assertEqual(input.grad.data, expected_grad)
+ def test_zero_grad(self):
+ module = nn.Linear(5, 5)
+ for p in module.parameters():
+ p.requires_grad = False
+ module.zero_grad()
+
+ module.weight.requires_grad = True
+ module.weight.grad.data.fill_(1)
+ module.zero_grad()
+ self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+
+ module.bias.requires_grad = True
+ module.weight.grad.data.fill_(1)
+ module.bias.grad.data.fill_(1)
+ module.zero_grad()
+ self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+ self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
+
def test_volatile(self):
module = nn.Conv2d(2, 5, kernel_size=3, padding=1)
input = torch.randn(1, 2, 10, 10)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 9f1e478..8ae422c 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -380,7 +380,8 @@
def zero_grad(self):
"""Sets gradients of all model parameters to zero."""
for p in self.parameters():
- p.grad.data.zero_()
+ if p.requires_grad:
+ p.grad.data.zero_()
def share_memory(self):
return self._apply(lambda t: t.share_memory_())