Fix bug in zero_grad, when some parameters didn't require grad
diff --git a/test/test_nn.py b/test/test_nn.py
index d852618..6f3bbc7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -341,6 +341,24 @@
         expected_grad = torch.ones(5, 5).mm(module.weight.data) * 2
         self.assertEqual(input.grad.data, expected_grad)
 
+    def test_zero_grad(self):
+        module = nn.Linear(5, 5)
+        for p in module.parameters():
+            p.requires_grad = False
+        module.zero_grad()
+
+        module.weight.requires_grad = True
+        module.weight.grad.data.fill_(1)
+        module.zero_grad()
+        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+
+        module.bias.requires_grad = True
+        module.weight.grad.data.fill_(1)
+        module.bias.grad.data.fill_(1)
+        module.zero_grad()
+        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+        self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
+
     def test_volatile(self):
         module = nn.Conv2d(2, 5, kernel_size=3, padding=1)
         input = torch.randn(1, 2, 10, 10)
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index 9f1e478..8ae422c 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -380,7 +380,8 @@
     def zero_grad(self):
         """Sets gradients of all model parameters to zero."""
         for p in self.parameters():
-            p.grad.data.zero_()
+            if p.requires_grad:
+                p.grad.data.zero_()
 
     def share_memory(self):
         return self._apply(lambda t: t.share_memory_())