Symbolic implementation of Index supporting tuple of slices. (#3294)

diff --git a/torch/_six.py b/torch/_six.py
index 8390bc5..0642e88 100644
--- a/torch/_six.py
+++ b/torch/_six.py
@@ -32,6 +32,12 @@
     string_classes = (str, bytes)
 
 
+if PY2:
+    int_classes = (int, long)
+else:
+    int_classes = int
+
+
 def with_metaclass(meta, *bases):
     """Create a base class with a metaclass."""
     # This requires a bit of explanation: the basic idea is to make a dummy
diff --git a/torch/autograd/_functions/tensor.py b/torch/autograd/_functions/tensor.py
index 5d2cd7e..ec9de92 100644
--- a/torch/autograd/_functions/tensor.py
+++ b/torch/autograd/_functions/tensor.py
@@ -1,5 +1,6 @@
 from functools import reduce
 import torch
+from torch._six import int_classes
 from torch._utils import _accumulate
 
 from ..function import Function, InplaceFunction, once_differentiable, traceable
@@ -24,13 +25,59 @@
         # We should only expect index as an integer in this case.
         # We use "Slice" to get the index-th element in i,
         # Then we reduce the dimension using "Reshape".
-        if not isinstance(index, int):
-            raise ValueError('Right now, only int-type index is suppported.')
-        axes = g.constant(0, [1], "int")
-        starts = g.constant(index, [1], "long")
-        ends = g.constant(index + 1, [1], "long")
-        slice = g.op("Slice", i, axes, starts, ends)
-        return g.op("Squeeze", slice, axes_i=[0])
+        if isinstance(index, int_classes):
+            axes = g.constant(0, [1], "int")
+            starts = g.constant(index, [1], "long")
+            ends = g.constant(index + 1, [1], "long")
+            slice_node = g.op("Slice", i, axes, starts, ends)
+            return g.op("Squeeze", slice_node, axes_i=[0])
+        elif isinstance(index, tuple):
+            dims = i.type().sizes()
+            axes_ten = torch.IntTensor([idx for idx in range(len(index))])
+            axes = g.op("Constant", value_t=axes_ten)
+            starts_list = []
+            ends_list = []
+            squeeze_indices = []
+
+            # Given an index, size of dimension, a list, and a default fill val,
+            # fill in based on these conditions:
+            # 1) not specified (None) - fill with fillval (e.g. 0 or size)
+            # 2) negative index - calculate corresponding positive index and append
+            # 3) positive index - append to list
+            # 4) integer - keep only that integer and squeeze it at the end
+            def append_index(index, dim, append_list, fillval):
+                if index is None:
+                    append_list.append(fillval)
+                else:
+                    addend = (dim if index < 0 else 0)
+                    append_list.append(index + addend)
+
+            for idx in range(len(index)):
+                if isinstance(index[idx], int_classes):
+                    starts_list.append(index[idx])
+                    ends_list.append(index[idx] + 1)
+                    squeeze_indices.append(idx)
+                    continue
+
+                # Start index
+                append_index(index[idx].start, dims[idx], starts_list, 0)
+                # End index
+                append_index(index[idx].stop, dims[idx], ends_list, dims[idx])
+
+                if index[idx].step is not None:
+                    raise ValueError("Strided slice is not supported at this time")
+
+            starts_ten = torch.LongTensor(starts_list)
+            starts = g.op("Constant", value_t=starts_ten)
+            ends_ten = torch.LongTensor(ends_list)
+            ends = g.op("Constant", value_t=ends_ten)
+            slice_node = g.op("Slice", i, axes, starts, ends)
+            if squeeze_indices:
+                return g.op('Squeeze', slice_node, axes_i=squeeze_indices)
+            else:
+                return slice_node
+        else:
+            raise ValueError('Unsupported index type {}'.format(type(index)))
 
     @staticmethod
     def forward(ctx, i, index):