Symbolic implementation of Index supporting tuple of slices. (#3294)
diff --git a/torch/_six.py b/torch/_six.py
index 8390bc5..0642e88 100644
--- a/torch/_six.py
+++ b/torch/_six.py
@@ -32,6 +32,12 @@
string_classes = (str, bytes)
+if PY2:
+ int_classes = (int, long)
+else:
+ int_classes = int
+
+
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass."""
# This requires a bit of explanation: the basic idea is to make a dummy
diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py
index 5d2cd7e..ec9de92 100644
--- a/torch/autograd/_functions/tensor.py
+++ b/torch/autograd/_functions/tensor.py
@@ -1,5 +1,6 @@
from functools import reduce
import torch
+from torch._six import int_classes
from torch._utils import _accumulate
from ..function import Function, InplaceFunction, once_differentiable, traceable
@@ -24,13 +25,59 @@
# We should only expect index as an integer in this case.
# We use "Slice" to get the index-th element in i,
# Then we reduce the dimension using "Reshape".
- if not isinstance(index, int):
- raise ValueError('Right now, only int-type index is suppported.')
- axes = g.constant(0, [1], "int")
- starts = g.constant(index, [1], "long")
- ends = g.constant(index + 1, [1], "long")
- slice = g.op("Slice", i, axes, starts, ends)
- return g.op("Squeeze", slice, axes_i=[0])
+ if isinstance(index, int_classes):
+ axes = g.constant(0, [1], "int")
+ starts = g.constant(index, [1], "long")
+ ends = g.constant(index + 1, [1], "long")
+ slice_node = g.op("Slice", i, axes, starts, ends)
+ return g.op("Squeeze", slice_node, axes_i=[0])
+ elif isinstance(index, tuple):
+ dims = i.type().sizes()
+ axes_ten = torch.IntTensor([idx for idx in range(len(index))])
+ axes = g.op("Constant", value_t=axes_ten)
+ starts_list = []
+ ends_list = []
+ squeeze_indices = []
+
+ # Given an index, size of dimension, a list, and a default fill val,
+ # fill in based on these conditions:
+ # 1) not specified (None) - fill with fillval (e.g. 0 or size)
+ # 2) negative index - calculate corresponding positive index and append
+ # 3) positive index - append to list
+ # 4) integer - keep only that integer and squeeze it at the end
+ def append_index(index, dim, append_list, fillval):
+ if index is None:
+ append_list.append(fillval)
+ else:
+ addend = (dim if index < 0 else 0)
+ append_list.append(index + addend)
+
+ for idx in range(len(index)):
+ if isinstance(index[idx], int_classes):
+ starts_list.append(index[idx])
+ ends_list.append(index[idx] + 1)
+ squeeze_indices.append(idx)
+ continue
+
+ # Start index
+ append_index(index[idx].start, dims[idx], starts_list, 0)
+ # End index
+ append_index(index[idx].stop, dims[idx], ends_list, dims[idx])
+
+ if index[idx].step is not None:
+ raise ValueError("Strided slice is not supported at this time")
+
+ starts_ten = torch.LongTensor(starts_list)
+ starts = g.op("Constant", value_t=starts_ten)
+ ends_ten = torch.LongTensor(ends_list)
+ ends = g.op("Constant", value_t=ends_ten)
+ slice_node = g.op("Slice", i, axes, starts, ends)
+ if squeeze_indices:
+ return g.op('Squeeze', slice_node, axes_i=squeeze_indices)
+ else:
+ return slice_node
+ else:
+ raise ValueError('Unsupported index type {}'.format(type(index)))
@staticmethod
def forward(ctx, i, index):