allow single non-tuple sequence to trigger advanced indexing (#2323)
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 0883009..d60f6e7 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -601,6 +601,7 @@
check_index(x, y, ([[2, 3], slice(None)]))
# advanced indexing, with less dim, or ellipsis
+ check_index(x, y, ([0]))
check_index(x, y, ([0], ))
x = torch.arange(1, 49).view(4, 3, 4)
diff --git a/test/test_torch.py b/test/test_torch.py
index f1354b3..24ad99d 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2658,14 +2658,20 @@
# Case 1: Purely Integer Array Indexing
reference = conv_fn(consec((10,)))
+ self.assertEqual(reference[[0]], consec((1,)))
self.assertEqual(reference[ri([0]), ], consec((1,)))
self.assertEqual(reference[ri([3]), ], consec((1,), 4))
+ self.assertEqual(reference[[2, 3, 4]], consec((3,), 3))
self.assertEqual(reference[ri([2, 3, 4]), ], consec((3,), 3))
self.assertEqual(reference[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5]))
# setting values
- reference[ri([0],), ] = -1
+ reference[[0]] = -2
+ self.assertEqual(reference[[0]], torch.Tensor([-2]))
+ reference[[0]] = -1
self.assertEqual(reference[ri([0]), ], torch.Tensor([-1]))
+ reference[[2, 3, 4]] = 4
+ self.assertEqual(reference[[2, 3, 4]], torch.Tensor([4, 4, 4]))
reference[ri([2, 3, 4]), ] = 3
self.assertEqual(reference[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]))
reference[ri([0, 2, 4]), ] = conv_fn(torch.Tensor([5, 4, 3]))
@@ -2679,8 +2685,10 @@
strided.set_(reference.storage(), storage_offset=0,
size=torch.Size([4]), stride=[2])
+ self.assertEqual(strided[[0]], torch.Tensor([1]))
self.assertEqual(strided[ri([0]), ], torch.Tensor([1]))
self.assertEqual(strided[ri([3]), ], torch.Tensor([7]))
+ self.assertEqual(strided[[1, 2]], torch.Tensor([3, 5]))
self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5]))
self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
torch.Tensor([[5, 3], [1, 7]]))
@@ -2689,8 +2697,10 @@
strided = conv_fn(torch.Tensor())
strided.set_(reference.storage(), storage_offset=4,
size=torch.Size([2]), stride=[4])
+ self.assertEqual(strided[[0]], torch.Tensor([5]))
self.assertEqual(strided[ri([0]), ], torch.Tensor([5]))
self.assertEqual(strided[ri([1]), ], torch.Tensor([9]))
+ self.assertEqual(strided[[0, 1]], torch.Tensor([5, 9]))
self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9]))
self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
torch.Tensor([[5, 9], [9, 5]]))
diff --git a/torch/csrc/generic/Tensor.cpp b/torch/csrc/generic/Tensor.cpp
index 4a0a1fe..bb3bcc5 100644
--- a/torch/csrc/generic/Tensor.cpp
+++ b/torch/csrc/generic/Tensor.cpp
@@ -587,6 +587,18 @@
#ifndef TH_REAL_IS_HALF
+static bool THPTensor_(_checkSingleSequenceTriggersAdvancedIndexing)(PyObject *arg) {
+ if (PySequence_Check(arg) && !PyTuple_Check(arg)) {
+ auto fast = THPObjectPtr(PySequence_Fast(arg, NULL));
+ for (Py_ssize_t i = 0; i < PySequence_Fast_GET_SIZE(fast.get()); ++i) {
+ if (!THPUtils_checkLong(PySequence_Fast_GET_ITEM(fast.get(), i)))
+ return false;
+ }
+ return true;
+ }
+ return false;
+}
+
static bool THPTensor_(_checkBasicIntegerArrayIndexing)(THPTensor *indexed, PyObject *arg) {
long ndim = THTensor_(nDimension)(LIBRARY_STATE indexed->cdata);
@@ -606,9 +618,11 @@
static bool THPTensor_(_checkAdvancedIndexing)(THPTensor *indexed, PyObject *arg) {
// Currently we only support two forms of advanced indexing:
//
- // 1. "Basic Integer Array Indexing" the integer-array indexing strategy
+ // 1. Indexing with a single non-tuple sequence, not nested within a sequence,
+ // that is composed only of integer indexers, e.g. x[[0, 1, 4]]
+ // 2. "Basic Integer Array Indexing" the integer-array indexing strategy
// where we have ndim sequence/LongTensor arguments
- // 2. Combining Advanced Indexing with ":", or "..." , with the limitation that
+ // 3. Combining Advanced Indexing with ":", or "..." , with the limitation that
// the advanced indexing dimensions must be adjacent, i.e.:
//
// x[:, :, [1,2], [3,4], :] --> valid
@@ -616,10 +630,13 @@
// x[[1,2], [3,4], ...] --> valid
// x[:, [1,2], :, [3,4], :] --> not valid
- // Verification, Step #1 -- ndim sequencers
+ // Verification, Step #1 - single non-tuple sequencer
+ if (THPTensor_(_checkSingleSequenceTriggersAdvancedIndexing)(arg)) return true;
+
+ // Verification, Step #2 -- ndim sequencers
if (THPTensor_(_checkBasicIntegerArrayIndexing)(indexed, arg)) return true;
- // Verification, Step #2 -- at least one sequencer, all the rest are
+ // Verification, Step #3 -- at least one sequencer, all the rest are
// ':' and/or a single '...', can be less than ndim indexers, all sequencers
// adjacent
@@ -737,58 +754,65 @@
// If they can be broadcasted, we store each of the broadcasted Tensors in the
// output map, with the dimension of the original tensor as the key.
- // Indexes all indexing Tensors (pre-broadcast) by which dimension they occurred.
- // Because we rely upon the THPIndexTensor constructor to handle sequence -> tensor
- // conversions, we store THPTensors rather than THTensors. We use an ordered map
- // to maintain the order of Tensors via dimension. Because this is limited to
- // ndim(Tensor), it should always be small + fast.
+ // indexingDims stores the indices containing an advanced index sequence, and indexers
+ // stores the corresponding indexing object, such that the indexer at indexers[i] is
+ // associated with the dm at indexingDims[i]. This is pre-broadcast. Because we rely
+ // upon the THPIndexTensor constructor to handle sequence -> tensor conversions, we
+ // store THPTensors rather than THTensors.
std::vector<Py_ssize_t> indexingDims;
std::vector<THPIndexTensor*>indexers;
- // The indexing matches advanced indexing requirements. In the case that
- // the user has an Ellipsis, and/or less dimensions than are in the
- // Tensor being indexed, we "fill in" empty Slices to these dimensions
- // so that the the resulting advanced indexing code still works
-
-
-
- // The top-level indexer should be a sequence, per the check above
- THPObjectPtr fast(PySequence_Fast(index, NULL));
- sequenceLength = PySequence_Fast_GET_SIZE(fast.get());
- int ellipsisOffset = 0;
-
- for (Py_ssize_t i = 0; i < sequenceLength; ++i) {
- PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i);
-
- // If this is an ellipsis, the all subsequent advanced indexing
- // objects "positions" should be shifted, e.g. if we have a 5D Tensor
- // x, and then x[..., [2, 3]], then the "position" of [2, 3] is 4
- if (Py_TYPE(item) == &PyEllipsis_Type) {
- ellipsisOffset = THTensor_(nDimension)(LIBRARY_STATE indexed) - sequenceLength;
- continue;
+ if (THPTensor_(_checkSingleSequenceTriggersAdvancedIndexing)(index)) {
+ // Handle the special case where we only have a single indexer
+ THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs(
+ THPIndexTensorClass, index, 0, NULL);
+ if (!indexer) {
+ PyErr_Format(PyExc_IndexError,
+ "When performing advanced indexing the indexing objects must be LongTensors or "
+ "convertible to LongTensors");
+ return false;
}
+ indexingDims.push_back(0);
+ indexers.push_back(indexer);
+ } else {
+ // The top-level indexer should be a sequence, per the check above
+ THPObjectPtr fast(PySequence_Fast(index, NULL));
+ sequenceLength = PySequence_Fast_GET_SIZE(fast.get());
+ int ellipsisOffset = 0;
- if (!PySlice_Check(item)) {
- PyObject *obj = PySequence_Fast_GET_ITEM(fast.get(), i);
- // Returns NULL upon conversion failure
- THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs(
- THPIndexTensorClass, obj, NULL);
- if (!indexer) {
- PyErr_Format(PyExc_IndexError,
- "When performing advanced indexing the indexing objects must be LongTensors or "
- "convertible to LongTensors. The indexing object at position %d is of type %s "
- "and cannot be converted", i, THPUtils_typename(obj));
+ for (Py_ssize_t i = 0; i < sequenceLength; ++i) {
+ PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i);
- // Clean up Indexers
- for (auto& idx : indexers) {
- THIndexTensor_(free)(LIBRARY_STATE idx->cdata);
- Py_DECREF(idx);
- }
- return false;
+ // If this is an ellipsis, the all subsequent advanced indexing
+ // objects "positions" should be shifted, e.g. if we have a 5D Tensor
+ // x, and then x[..., [2, 3]], then the "position" of [2, 3] is 4
+ if (Py_TYPE(item) == &PyEllipsis_Type) {
+ ellipsisOffset = THTensor_(nDimension)(LIBRARY_STATE indexed) - sequenceLength;
+ continue;
}
- indexingDims.push_back(i + ellipsisOffset);
- indexers.push_back(indexer);
+
+ if (!PySlice_Check(item)) {
+ PyObject *obj = PySequence_Fast_GET_ITEM(fast.get(), i);
+ // Returns NULL upon conversion failure
+ THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs(
+ THPIndexTensorClass, obj, NULL);
+ if (!indexer) {
+ PyErr_Format(PyExc_IndexError,
+ "When performing advanced indexing the indexing objects must be LongTensors or "
+ "convertible to LongTensors. The indexing object at position %d is of type %s "
+ "and cannot be converted", i, THPUtils_typename(obj));
+
+ // Clean up Indexers
+ for (auto& idx : indexers) {
+ THIndexTensor_(free)(LIBRARY_STATE idx->cdata);
+ Py_DECREF(idx);
+ }
+ return false;
+ }
+ indexingDims.push_back(i + ellipsisOffset);
+ indexers.push_back(indexer);
+ }
}
}