relax index dim check
diff --git a/aten/src/TH/generic/THTensorMath.c b/aten/src/TH/generic/THTensorMath.c
index ce56ee0..51aed3f 100644
--- a/aten/src/TH/generic/THTensorMath.c
+++ b/aten/src/TH/generic/THTensorMath.c
@@ -245,9 +245,9 @@
int64_t *index_data;
real *tensor_data, *src_data;
- THArgCheck(index->nDimension == 1, 3, "Index is supposed to be a vector");
- THArgCheck(dim < src->nDimension, 4,"Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
- THArgCheck(src->nDimension > 0,2,"Source tensor is empty");
+ THArgCheck(index->nDimension <= 1, 3, "Index is supposed to be an empty tensor or a vector");
+ THArgCheck(dim < src->nDimension, 4, "Indexing dim %d is out of bounds of tensor", dim + TH_INDEX_BASE);
+ THArgCheck(src->nDimension > 0, 2, "Source tensor is empty");
numel = THLongTensor_nElement(index);
diff --git a/aten/src/THC/generic/THCTensorIndex.cu b/aten/src/THC/generic/THCTensorIndex.cu
index 69e4671..523657f 100644
--- a/aten/src/THC/generic/THCTensorIndex.cu
+++ b/aten/src/THC/generic/THCTensorIndex.cu
@@ -437,7 +437,7 @@
void THCTensor_(indexSelect_long)(THCState *state, THCTensor *dst, THCTensor *src, int dim, THLongTensor *indices)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, dst, src));
- THArgCheck(indices->nDimension == 1, 3, "Index is supposed to be a vector");
+ THArgCheck(indices->nDimension <= 1, 3, "Index is supposed to be an empty tensor or a vector");
THCudaLongTensor *indices_ = THCudaLongTensor_newWithSize1d(state, indices->size[0]);
THCudaLongTensor_copyLong(state, indices_, indices);
@@ -463,12 +463,22 @@
int srcDims = THCTensor_(nDimension)(state, src);
cudaStream_t stream = THCState_getCurrentStream(state);
- THArgCheck(THCudaLongTensor_nDimension(state, indices) == 1, 3,
- "expecting vector of indices");
+ THArgCheck(THCudaLongTensor_nDimension(state, indices) <= 1, 3,
+ "Index is supposed to be an empty tensor or a vector");
THArgCheck(dim < srcDims, 4, "Indexing dim is out of bounds");
THArgCheck(srcDims > 0, 2, "Source tensor is empty");
- THLongStorage *newSize = THCTensor_(newSizeOf)(state, src);
+ THLongStorage *newSize;
+
+ if (numIndices == 0) {
+ newSize = THCTensor_(newSizeOf)(state, src);
+ THLongStorage_set(newSize, 0, numIndices);
+ THCTensor_(resize)(state, dst, newSize, NULL);
+ THLongStorage_free(newSize);
+ return;
+ }
+
+ newSize = THCTensor_(newSizeOf)(state, src);
THLongStorage_set(newSize, dim, numIndices);
THCTensor_(resize)(state, dst, newSize, NULL);
THLongStorage_free(newSize);
diff --git a/test/test_torch.py b/test/test_torch.py
index e030a90..7d15e2c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2729,6 +2729,10 @@
def test_index(self):
reference = self._consecutive((3, 3, 3))
+
+ # empty tensor indexing
+ self.assertEqual(reference[torch.LongTensor()], reference.new())
+
self.assertEqual(reference[0], self._consecutive((3, 3)), 0)
self.assertEqual(reference[1], self._consecutive((3, 3), 10), 0)
self.assertEqual(reference[2], self._consecutive((3, 3), 19), 0)
diff --git a/torch/csrc/generic/Tensor.cpp b/torch/csrc/generic/Tensor.cpp
index 98a731e..099db36 100644
--- a/torch/csrc/generic/Tensor.cpp
+++ b/torch/csrc/generic/Tensor.cpp
@@ -1486,15 +1486,15 @@
// TH will also throw an error, but its a Runtime Error that is less interpretable
// than doing it at this layer
- if (THIndexTensor_(nDimension)(LIBRARY_STATE index_t) != 1) {
+ if (THIndexTensor_(nDimension)(LIBRARY_STATE index_t) > 1) {
PyErr_Format(PyExc_IndexError, "Indexing a Tensor with a "
#ifndef THC_GENERIC_FILE
"torch.LongTensor "
#else
"torch.cuda.LongTensor "
#endif
- "triggers index_select semantics, and thus we expect a vector, but the indexing "
- "Tensor passed has %lld dimensions",
+ "triggers index_select semantics, and thus we expect an empty tensor or a vector, "
+ "but the indexing Tensor passed has %lld dimensions",
(long long) THIndexTensor_(nDimension)(LIBRARY_STATE index_t));
throw python_error();
}