Implement indexing methods for sparse tensors (#24937)

Summary:
Resolves https://github.com/pytorch/pytorch/issues/7416 .

This PR implements the following indexing methods for sparse tensors:
-  [x] `select`
-  [x] `index_select`

Note that this PR also modifies [gen.py](https://github.com/pytorch/pytorch/pull/24937/files#diff-76aa8cb3d0fad99c5f761d08cbcb4d19) that is not directly required to resolve the original issue but to work around a CI build issue reported in issue https://github.com/pytorch/pytorch/issues/24931 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24937

Differential Revision: D17163796

Pulled By: ezyang

fbshipit-source-id: 06613301ec456d9ed3491b9ce48e804048600f09
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index 8ebd53c..826279b 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -4195,6 +4195,9 @@
         case Backend::CPU:
             return CPUType::index_select(const_cast<Tensor&>(*this), dim, index);
             break;
+        case Backend::SparseCPU:
+            return SparseCPUType::index_select(const_cast<Tensor&>(*this), dim, index);
+            break;
         default:
             AT_ERROR("index_select not implemented for ", at::toString(tensorTypeIdToBackend(type_id())));
     }
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 40c1fed..22b595b 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -369,6 +369,17 @@
                 results[0].append(x)
             else:
                 results[1].append(x)
+                import difflib
+                import sys
+                d = difflib.Differ()
+                sys.stdout.write('-' * 80 + '\n')
+                sys.stdout.write('x={}, a={}, b={}\n'.format(x, a, b))
+                for i, line in enumerate(list(d.compare(ax.splitlines(), bx.splitlines()))):
+                    if line[:2] != '  ':
+                        sys.stdout.write('{:5d}: {}\n'.format(i, line))
+                sys.stdout.write('-' * 80 + '\n')
+                sys.stdout.write(ax)
+                sys.stdout.write('-' * 80 + '\n')
         except OSError:
             results[2].append(x)
     return results
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 37c631c..ce8678a 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -456,6 +456,39 @@
   return self.reshape(other.sizes());
 }
 
+static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) {
+  int64_t sparse_dim = self.sparse_dim();
+  int64_t dense_dim = self.dense_dim();
+  TORCH_INTERNAL_ASSERT(dim >= 0 && dim < sparse_dim + dense_dim);
+
+  auto indices = self._indices();
+  auto values = self._values();
+  auto new_sizes = self.sizes().vec();
+  new_sizes.erase(new_sizes.begin() + dim);
+
+  if (dim < sparse_dim) {
+    auto nzIndices = (indices[dim] == index).nonzero().view(-1);
+    auto new_values = values.index_select(0, nzIndices);
+    if (sparse_dim == 1) {
+      // return dense part:
+      if (new_values.size(0) == 1) {
+        return new_values[0];
+      } else {
+        return new_values.sum(0);
+      }
+    } else {
+      auto dimIndices = (arange(0, sparse_dim, self.device()) != dim).nonzero().view(-1);
+      auto new_indices = indices.index_select(1, nzIndices).index_select(0, dimIndices);
+      return _sparse_coo_tensor_with_dims_and_tensors(
+            sparse_dim - 1, dense_dim, new_sizes, new_indices, new_values, self.options());
+    }
+  } else {
+    auto new_values = values.select(dim - sparse_dim + 1, index);
+    return _sparse_coo_tensor_with_dims_and_tensors(
+         sparse_dim, dense_dim - 1, new_sizes, indices, new_values, self.options());
+  }
+}
+
 Tensor select(const Tensor& self, int64_t dim, int64_t index) {
   int64_t ndim = self.dim();
   if (ndim == 0) {
@@ -476,6 +509,9 @@
   if (index < 0) {
     index += size;
   }
+  if (self.is_sparse()) {
+    return select_sparse(self, dim, index);
+  }
   auto sizes = self.sizes().vec();
   auto strides = self.strides().vec();
   auto storage_offset = self.storage_offset() + index * strides[dim];
@@ -494,6 +530,91 @@
 }
 #endif
 
+Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) {
+  /*
+    Algorithm:
+    index - a 1-D tensor of indicies with shape (n,)
+    self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
+      indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
+      values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
+    index_select(dim, index) returns a sparse tensor with the follwing data
+      new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
+      new_indices - shape is (sparse_dims, new_nnz)
+      new_values - shape is (new_nnz,) + dense_shape
+
+      if dim < len(sparse_shape):
+          for i, idx in enumerate(index):
+              for j, jdx in enumerate(indices[dim]):
+                  if idx == jdx:
+                      icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
+                      new_indices.add_column(icol)
+                      new_values.add_row(values[j])
+      else:
+          new_indices = indices
+          new_values[k] = values[k].index_select(dim - len(sparse_shape), index) for k in range(nnz)
+    */
+  auto ndim = self.dim();
+  if (ndim == 0) {
+    AT_INDEX_ERROR("index_select() cannot be applied to a 0-dim tensor.");
+  }
+  if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
+    AT_INDEX_ERROR("index_select() argument index must be 1-D long-tensor.");
+  }
+  dim = maybe_wrap_dim(dim, ndim);
+  auto size = self.size(dim);
+  auto sparse_dim = self.sparse_dim();
+  auto dense_dim = self.dense_dim();
+  auto indices = self._indices();
+  auto values = self._values();
+  auto nnz = values.size(0);
+  auto new_sizes = self.sizes().vec();
+  new_sizes[dim] = index.size(0);
+
+  if (dim < sparse_dim) {
+
+    auto dim_indices = indices[dim];
+    std::vector<int64_t> zindices;
+    std::vector<int64_t> iindices;
+    int64_t new_nnz = 0;
+    for (int64_t i=0; i < new_sizes[dim]; i++) {
+      auto idx = index[i].item<int64_t>();
+      if (idx < -size || idx >= size) {
+        AT_INDEX_ERROR("index_select(): index contains ", idx, " that is out of range for tensor of size ",
+                   self.sizes(), " at dimension ", dim);
+      }
+      if (idx < 0) {
+        idx += size;
+      }
+      for (int64_t j=0; j < nnz; j++) {
+        auto jdx = dim_indices[j].item<int64_t>();
+        if (idx == jdx) {
+          new_nnz++;
+          iindices.push_back(i);
+          zindices.push_back(j);
+        }
+      }
+    }
+    auto zIndices = at::from_blob(zindices.data(), {new_nnz}, at::kLong).to(indices.device());
+    auto new_indices = indices.index_select(1, zIndices);
+    new_indices[dim] = at::from_blob(iindices.data(), {new_nnz}, at::kLong).to(indices.device());
+    auto new_values = values.index_select(0, zIndices);
+    return _sparse_coo_tensor_with_dims_and_tensors(
+        sparse_dim, dense_dim, new_sizes, new_indices, new_values, self.options());
+
+  } else {
+
+    auto vsize = values.sizes().vec();
+    vsize[dim + 1 - sparse_dim] = index.size(0);
+    auto new_values = at::empty(vsize, values.options());
+    for (int64_t k=0; k < nnz; k++) {
+      new_values[k] = values[k].index_select(dim - sparse_dim, index);
+    }
+    return _sparse_coo_tensor_with_dims_and_tensors(
+        sparse_dim, dense_dim, new_sizes, indices, new_values, self.options());
+
+  }
+}
+
 Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
   int64_t ndim = self.dim();
   if (ndim == 0) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 58c6695..d42637e 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3847,6 +3847,8 @@
   dispatch:
     CPU: legacy::cpu::_th_index_select
     CUDA: legacy::cuda::_th_index_select
+    SparseCPU: index_select_sparse
+    SparseCUDA: index_select_sparse
 
 - func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 4481b94..b69f24f 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -811,6 +811,61 @@
         test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range")
         test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range")
 
+    def test_select(self):
+        def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
+            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
+            if fail_message:
+                with self.assertRaisesRegex(IndexError, fail_message):
+                    torch.select(x, select_dim, select_index)
+            else:
+                result = torch.select(x, select_dim, select_index)
+                if result.is_sparse:
+                    result = result.to_dense()
+                dense_result = torch.select(x.to_dense(), select_dim, select_index)
+                self.assertEqual(dense_result, result)
+
+
+        sizes = [5, 7, 11, 13, 17]
+        # hybrid sparse/dense, select sparse dim, result is dense
+        for i in range(sizes[0]):
+            test_shape(1, 10, sizes, 0, i)
+        test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*')
+
+        # hybrid sparse/dense, select sparse dim, result is sparse
+        for d in range(3):
+            for i in range(sizes[d]):
+                test_shape(3, 10, sizes, d, i)
+
+        # hybrid sparse/dense, select dense dim, result is sparse
+        for d in range(1, 3):
+            for i in range(sizes[d]):
+                test_shape(1, 10, sizes, d, i)
+
+
+    def test_index_select(self):
+        def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
+            if isinstance(select_index, int):
+                select_index = [select_index]
+            if isinstance(select_index, list):
+                select_index = torch.tensor(select_index, device=self.device, dtype=torch.long)
+            x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
+            if fail_message:
+                with self.assertRaisesRegex(IndexError, fail_message):
+                    torch.index_select(x, select_dim, select_index)
+            else:
+                result = torch.index_select(x, select_dim, select_index)
+                if result.is_sparse:
+                    result = result.to_dense()
+                dense_result = torch.index_select(x.to_dense(), select_dim, select_index)
+                self.assertEqual(dense_result, result)
+
+        sizes = [5, 7, 11, 13, 17]
+        for d in range(len(sizes)):
+            for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]:
+                test_shape(1, 10, sizes, d, index)
+                test_shape(len(sizes) // 2, 10, sizes, d, index)
+                test_shape(len(sizes), 10, sizes, d, index)
+
     @cpu_only
     def test_mm(self):
         def test_shape(di, dj, dk, nnz):