Fix group-convolution w/o biases on CPU. (#1273)
* Fix group-convolution w/o biases on CPU.
Not having this guard will cause a crash further down in the `cat`
function when it uses the first element in the passed list to create a
new tensor. (And even after that, cat doesn't handle nulls well.)
* Added test for groupconv w/o bias on CPU.
diff --git a/test/test_nn.py b/test/test_nn.py
index 179fdf5..e1830ee 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1309,6 +1309,33 @@
self.assertEqual(m.weight.grad.data,
torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0))
+ # For https://github.com/pytorch/pytorch/pull/1273
+ # Almost identical to the above `test_Conv2d_naive_groups`
+ def test_Conv2d_groups_nobias(self):
+ m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False)
+ i = Variable(torch.randn(2, 4, 6, 6), requires_grad=True)
+ output = m(i)
+ grad_output = torch.randn(2, 4, 4, 4)
+ output.backward(grad_output)
+
+ m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False)
+ m1.weight.data.copy_(m.weight.data[:2])
+ i1 = Variable(i.data[:, :2].contiguous(), requires_grad=True)
+ output1 = m1(i1)
+ output1.backward(grad_output[:, :2].contiguous())
+
+ m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False)
+ m2.weight.data.copy_(m.weight.data[2:])
+ i2 = Variable(i.data[:, 2:].contiguous(), requires_grad=True)
+ output2 = m2(i2)
+ output2.backward(grad_output[:, 2:].contiguous())
+
+ self.assertEqual(output, torch.cat([output1, output2], 1))
+ self.assertEqual(i.grad.data,
+ torch.cat([i1.grad.data, i2.grad.data], 1))
+ self.assertEqual(m.weight.grad.data,
+ torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0))
+
def test_MaxUnpool2d_output_size(self):
m = nn.MaxPool2d(3, stride=2, return_indices=True)
mu = nn.MaxUnpool2d(3, stride=2)
diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp
index 8992f61..ba5c467 100644
--- a/torch/csrc/autograd/functions/convolution.cpp
+++ b/torch/csrc/autograd/functions/convolution.cpp
@@ -292,7 +292,9 @@
columns[g].get(), ones[g].get(), kernel_size, *this);
}
grad_weight = cat(grad_weights, 0);
- grad_bias = cat(grad_biases, 0);
+ if (bias && needs_input_grad(2)) {
+ grad_bias = cat(grad_biases, 0);
+ }
}
}