Patching EmeddingBag to accept 2D input (#2429)
* Patching EmeddingBag to accept 2D input
* fix for CUDA inputs
* fix lint
diff --git a/test/test_nn.py b/test/test_nn.py
index 6a58d66..9c1f3d9 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -868,29 +868,29 @@
# check a known test example
es = nn.EmbeddingBag(5, 2, mode=mode)
es.weight.data.copy_(torch.arange(1, 11).resize_as_(es.weight.data))
- input = Variable(torch.LongTensor([3, 1, 1, 1, 4]))
- offsets = Variable(torch.LongTensor([0, 2]))
+ input = Variable(torch.LongTensor([3, 1, 1, 1, 4, 0]))
+ offsets = Variable(torch.LongTensor([0, 3]))
grad_output = torch.arange(1, 5).view(2, 2).type(torch.Tensor)
if mode == 'sum':
expected_output = torch.Tensor(
- [[10, 12],
- [15, 18]])
+ [[13, 16],
+ [13, 16]])
expected_grad_weight = torch.Tensor(
- [[0, 0],
- [7, 10],
+ [[3, 4],
+ [5, 8],
[0, 0],
[1, 2],
[3, 4]])
else:
expected_output = torch.Tensor(
- [[10. / 2, 12. / 2],
- [15. / 3, 18. / 3]])
+ [[13. / 3, 16. / 3],
+ [13. / 3, 16. / 3]])
expected_grad_weight = torch.Tensor(
- [[0., 0.],
- [1. / 2 + 3. / 3 + 3. / 3, 2. / 2 + 4. / 3 + 4. / 3],
+ [[3. / 3, 4. / 3],
+ [1. / 3 + 1. / 3 + 3. / 3, 2. / 3 + 2. / 3 + 4. / 3],
[0., 0.],
- [1. / 2, 2. / 2],
+ [1. / 3, 2. / 3],
[3. / 3, 4. / 3]])
if cuda:
@@ -907,6 +907,15 @@
self.assertEqual(output.data, expected_output)
self.assertEqual(es.weight.grad.data, expected_grad_weight)
+ # check same example except as 2D (2 x 3)
+ input = Variable(input.data.view(2, -1))
+ es.zero_grad()
+ output = es(input)
+ output.backward(grad_output)
+
+ self.assertEqual(output.data, expected_output)
+ self.assertEqual(es.weight.grad.data, expected_grad_weight)
+
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
def _test_vs_Embedding(N, D, B, L):
es = nn.EmbeddingBag(N, D, mode=mode)
diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py
index 877ff5b..78d5b7c 100644
--- a/torch/nn/modules/sparse.py
+++ b/torch/nn/modules/sparse.py
@@ -1,5 +1,6 @@
import torch
from torch.nn.parameter import Parameter
+from torch.autograd import Variable
from .module import Module
@@ -198,7 +199,9 @@
" fixed length sequences. However, found "
"offsets of type {}".format(type(offsets)))
else:
- offsets = input.data.new(input.size(0)).fill_(input.size(1))
+ offsets = Variable(torch.arange(0, input.numel(), input.size(1),
+ out=input.data.new().long()))
+ input = input.view(-1)
elif input.dim() != 1:
raise ValueError("input has to be 1D or 2D Tensor,"
" but got Tensor of dimension {}".format(input.dim()))