Implement indexing methods for sparse tensors (#24937)
Summary:
Resolves https://github.com/pytorch/pytorch/issues/7416 .
This PR implements the following indexing methods for sparse tensors:
- [x] `select`
- [x] `index_select`
Note that this PR also modifies [gen.py](https://github.com/pytorch/pytorch/pull/24937/files#diff-76aa8cb3d0fad99c5f761d08cbcb4d19) that is not directly required to resolve the original issue but to work around a CI build issue reported in issue https://github.com/pytorch/pytorch/issues/24931 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24937
Differential Revision: D17163796
Pulled By: ezyang
fbshipit-source-id: 06613301ec456d9ed3491b9ce48e804048600f09
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index 8ebd53c..826279b 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -4195,6 +4195,9 @@
case Backend::CPU:
return CPUType::index_select(const_cast<Tensor&>(*this), dim, index);
break;
+ case Backend::SparseCPU:
+ return SparseCPUType::index_select(const_cast<Tensor&>(*this), dim, index);
+ break;
default:
AT_ERROR("index_select not implemented for ", at::toString(tensorTypeIdToBackend(type_id())));
}
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 40c1fed..22b595b 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -369,6 +369,17 @@
results[0].append(x)
else:
results[1].append(x)
+ import difflib
+ import sys
+ d = difflib.Differ()
+ sys.stdout.write('-' * 80 + '\n')
+ sys.stdout.write('x={}, a={}, b={}\n'.format(x, a, b))
+ for i, line in enumerate(list(d.compare(ax.splitlines(), bx.splitlines()))):
+ if line[:2] != ' ':
+ sys.stdout.write('{:5d}: {}\n'.format(i, line))
+ sys.stdout.write('-' * 80 + '\n')
+ sys.stdout.write(ax)
+ sys.stdout.write('-' * 80 + '\n')
except OSError:
results[2].append(x)
return results
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 37c631c..ce8678a 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -456,6 +456,39 @@
return self.reshape(other.sizes());
}
+static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) {
+ int64_t sparse_dim = self.sparse_dim();
+ int64_t dense_dim = self.dense_dim();
+ TORCH_INTERNAL_ASSERT(dim >= 0 && dim < sparse_dim + dense_dim);
+
+ auto indices = self._indices();
+ auto values = self._values();
+ auto new_sizes = self.sizes().vec();
+ new_sizes.erase(new_sizes.begin() + dim);
+
+ if (dim < sparse_dim) {
+ auto nzIndices = (indices[dim] == index).nonzero().view(-1);
+ auto new_values = values.index_select(0, nzIndices);
+ if (sparse_dim == 1) {
+ // return dense part:
+ if (new_values.size(0) == 1) {
+ return new_values[0];
+ } else {
+ return new_values.sum(0);
+ }
+ } else {
+ auto dimIndices = (arange(0, sparse_dim, self.device()) != dim).nonzero().view(-1);
+ auto new_indices = indices.index_select(1, nzIndices).index_select(0, dimIndices);
+ return _sparse_coo_tensor_with_dims_and_tensors(
+ sparse_dim - 1, dense_dim, new_sizes, new_indices, new_values, self.options());
+ }
+ } else {
+ auto new_values = values.select(dim - sparse_dim + 1, index);
+ return _sparse_coo_tensor_with_dims_and_tensors(
+ sparse_dim, dense_dim - 1, new_sizes, indices, new_values, self.options());
+ }
+}
+
Tensor select(const Tensor& self, int64_t dim, int64_t index) {
int64_t ndim = self.dim();
if (ndim == 0) {
@@ -476,6 +509,9 @@
if (index < 0) {
index += size;
}
+ if (self.is_sparse()) {
+ return select_sparse(self, dim, index);
+ }
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
auto storage_offset = self.storage_offset() + index * strides[dim];
@@ -494,6 +530,91 @@
}
#endif
+Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) {
+ /*
+ Algorithm:
+ index - a 1-D tensor of indicies with shape (n,)
+ self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
+ indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
+ values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
+ index_select(dim, index) returns a sparse tensor with the follwing data
+ new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
+ new_indices - shape is (sparse_dims, new_nnz)
+ new_values - shape is (new_nnz,) + dense_shape
+
+ if dim < len(sparse_shape):
+ for i, idx in enumerate(index):
+ for j, jdx in enumerate(indices[dim]):
+ if idx == jdx:
+ icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
+ new_indices.add_column(icol)
+ new_values.add_row(values[j])
+ else:
+ new_indices = indices
+ new_values[k] = values[k].index_select(dim - len(sparse_shape), index) for k in range(nnz)
+ */
+ auto ndim = self.dim();
+ if (ndim == 0) {
+ AT_INDEX_ERROR("index_select() cannot be applied to a 0-dim tensor.");
+ }
+ if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
+ AT_INDEX_ERROR("index_select() argument index must be 1-D long-tensor.");
+ }
+ dim = maybe_wrap_dim(dim, ndim);
+ auto size = self.size(dim);
+ auto sparse_dim = self.sparse_dim();
+ auto dense_dim = self.dense_dim();
+ auto indices = self._indices();
+ auto values = self._values();
+ auto nnz = values.size(0);
+ auto new_sizes = self.sizes().vec();
+ new_sizes[dim] = index.size(0);
+
+ if (dim < sparse_dim) {
+
+ auto dim_indices = indices[dim];
+ std::vector<int64_t> zindices;
+ std::vector<int64_t> iindices;
+ int64_t new_nnz = 0;
+ for (int64_t i=0; i < new_sizes[dim]; i++) {
+ auto idx = index[i].item<int64_t>();
+ if (idx < -size || idx >= size) {
+ AT_INDEX_ERROR("index_select(): index contains ", idx, " that is out of range for tensor of size ",
+ self.sizes(), " at dimension ", dim);
+ }
+ if (idx < 0) {
+ idx += size;
+ }
+ for (int64_t j=0; j < nnz; j++) {
+ auto jdx = dim_indices[j].item<int64_t>();
+ if (idx == jdx) {
+ new_nnz++;
+ iindices.push_back(i);
+ zindices.push_back(j);
+ }
+ }
+ }
+ auto zIndices = at::from_blob(zindices.data(), {new_nnz}, at::kLong).to(indices.device());
+ auto new_indices = indices.index_select(1, zIndices);
+ new_indices[dim] = at::from_blob(iindices.data(), {new_nnz}, at::kLong).to(indices.device());
+ auto new_values = values.index_select(0, zIndices);
+ return _sparse_coo_tensor_with_dims_and_tensors(
+ sparse_dim, dense_dim, new_sizes, new_indices, new_values, self.options());
+
+ } else {
+
+ auto vsize = values.sizes().vec();
+ vsize[dim + 1 - sparse_dim] = index.size(0);
+ auto new_values = at::empty(vsize, values.options());
+ for (int64_t k=0; k < nnz; k++) {
+ new_values[k] = values[k].index_select(dim - sparse_dim, index);
+ }
+ return _sparse_coo_tensor_with_dims_and_tensors(
+ sparse_dim, dense_dim, new_sizes, indices, new_values, self.options());
+
+ }
+}
+
Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 58c6695..d42637e 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3847,6 +3847,8 @@
dispatch:
CPU: legacy::cpu::_th_index_select
CUDA: legacy::cuda::_th_index_select
+ SparseCPU: index_select_sparse
+ SparseCUDA: index_select_sparse
- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 4481b94..b69f24f 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -811,6 +811,61 @@
test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range")
test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range")
+ def test_select(self):
+ def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
+ x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
+ if fail_message:
+ with self.assertRaisesRegex(IndexError, fail_message):
+ torch.select(x, select_dim, select_index)
+ else:
+ result = torch.select(x, select_dim, select_index)
+ if result.is_sparse:
+ result = result.to_dense()
+ dense_result = torch.select(x.to_dense(), select_dim, select_index)
+ self.assertEqual(dense_result, result)
+
+
+ sizes = [5, 7, 11, 13, 17]
+ # hybrid sparse/dense, select sparse dim, result is dense
+ for i in range(sizes[0]):
+ test_shape(1, 10, sizes, 0, i)
+ test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*')
+
+ # hybrid sparse/dense, select sparse dim, result is sparse
+ for d in range(3):
+ for i in range(sizes[d]):
+ test_shape(3, 10, sizes, d, i)
+
+ # hybrid sparse/dense, select dense dim, result is sparse
+ for d in range(1, 3):
+ for i in range(sizes[d]):
+ test_shape(1, 10, sizes, d, i)
+
+
+ def test_index_select(self):
+ def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
+ if isinstance(select_index, int):
+ select_index = [select_index]
+ if isinstance(select_index, list):
+ select_index = torch.tensor(select_index, device=self.device, dtype=torch.long)
+ x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
+ if fail_message:
+ with self.assertRaisesRegex(IndexError, fail_message):
+ torch.index_select(x, select_dim, select_index)
+ else:
+ result = torch.index_select(x, select_dim, select_index)
+ if result.is_sparse:
+ result = result.to_dense()
+ dense_result = torch.index_select(x.to_dense(), select_dim, select_index)
+ self.assertEqual(dense_result, result)
+
+ sizes = [5, 7, 11, 13, 17]
+ for d in range(len(sizes)):
+ for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]:
+ test_shape(1, 10, sizes, d, index)
+ test_shape(len(sizes) // 2, 10, sizes, d, index)
+ test_shape(len(sizes), 10, sizes, d, index)
+
@cpu_only
def test_mm(self):
def test_shape(di, dj, dk, nnz):