Enable simple indexing into CSR tensor, add torch.select for CSR

This PR implements `torch.select` for CSR tensors. Currently, it's not possible to select rows or columns for batched CSR. The non-batched case works fine by converting to COO and calling select. Initially, I implemented raw manipulations of indices but converting to COO is only slightly slower and more readable.

This PR also enables indexing into batched CSR tensor with `[x, y, z]`. Assigning is disabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76228
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f8ad9b0..840d623 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3872,6 +3872,7 @@
   device_guard: False
   dispatch:
     CompositeExplicitAutograd: select
+    SparseCsrCPU, SparseCsrCUDA: select_sparse_csr
 
 - func: select_backward(Tensor grad_output, int[] input_sizes, int dim, int index) -> Tensor
   variants: function
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 92ec4b1..a7d6665 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -15,8 +15,10 @@
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
 #else
+#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
 #include <ATen/ops/_nnz_native.h>
 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
+#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
 #include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
 #include <ATen/ops/_validate_sparse_csr_tensor_args_native.h>
 #include <ATen/ops/clone_native.h>
@@ -28,6 +30,7 @@
 #include <ATen/ops/empty_native.h>
 #include <ATen/ops/resize_as_sparse_native.h>
 #include <ATen/ops/resize_native.h>
+#include <ATen/ops/select_native.h>
 #include <ATen/ops/sparse_csr_tensor_native.h>
 #include <ATen/ops/values_native.h>
 #endif
@@ -468,5 +471,43 @@
   }
 }
 
+Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
+  TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
+  TORCH_CHECK_INDEX(self.dim() != 0, "select() cannot be applied to a 0-dim tensor.");
+  dim = maybe_wrap_dim(dim, self.dim());
+  auto size = self.size(dim);
+  if (index < -size || index >= size) {
+    TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
+                   self.sizes(), " at dimension ", dim);
+  }
+  if (index < 0) {
+    index += size;
+  }
+
+  TORCH_INTERNAL_ASSERT(dim >= 0 && dim < self.dim());
+
+  auto new_sizes = DimVector(self.sizes());
+  new_sizes.erase(new_sizes.begin() + dim);
+  auto options = self.options();
+
+  // Selecting batch dimension
+  if (dim < self.dim() - 2) {
+    return at::native::_sparse_csr_tensor_unsafe(
+        self.crow_indices().select(dim, index),
+        self.col_indices().select(dim, index),
+        self.values().select(dim, index),
+        new_sizes,
+        optTypeMetaToScalarType(options.dtype_opt()),
+        options.layout_opt(),
+        options.device_opt(),
+        options.pinned_memory_opt());
+  } else {
+    TORCH_CHECK(self.dim() == 2, "select(): selecting rows or columns is not implemented for batched sparse CSR tensors.")
+    // Converting to COO and calling select is slighly slower than operating on the CSR indices directly
+    // for constructing a COO vector, however current version is more readable and easier to understand.
+    return self.to_sparse().select(dim, index);
+  }
+}
+
 } // namespace native
 } // namespace at
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index c7b4e2b..70d3771 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -255,6 +255,50 @@
             self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
             self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())
 
+    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+    def test_sparse_csr_select(self, device, dtype):
+        batch_shape = (2, 3)
+        crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
+        col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
+        values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
+        sparse = torch.sparse_csr_tensor(crow_indices,
+                                         col_indices,
+                                         values,
+                                         size=(*batch_shape, 2, 10),
+                                         dtype=dtype,
+                                         device=device)
+
+        # select from batch dimensions
+        sparse_selected12 = sparse.select(1, 2)
+        expected_sparse_selected12 = torch.sparse_csr_tensor(crow_indices.select(1, 2).contiguous(),
+                                                             col_indices.select(1, 2).contiguous(),
+                                                             values.select(1, 2).contiguous(),
+                                                             size=(2, 2, 10),
+                                                             dtype=dtype,
+                                                             device=device)
+        self.assertEqual(expected_sparse_selected12, sparse_selected12)
+
+        # select from rows or columns
+        sparse_non_batched = sparse[0, 0]
+        for selects_args in [(0, 0), (1, 1)]:
+            sparse_selected = sparse_non_batched.select(*selects_args)
+            dense_selected = sparse_non_batched.to_dense().select(*selects_args)
+            self.assertEqual(dense_selected, sparse_selected)
+
+        # index a single element
+        self.assertEqual(sparse[0, 0, 0, 0], sparse.to_dense()[0, 0, 0, 0])
+
+        # selecting from rows or columns for batched CSR is not yet implemented
+        with self.assertRaisesRegex(RuntimeError, "selecting rows or columns is not implemented for batched"):
+            sparse.select(-2, 0)
+
+        with self.assertRaisesRegex(RuntimeError, "selecting rows or columns is not implemented for batched"):
+            sparse.select(-1, 0)
+
+        # assigning to sparse trhough indexing is disabled
+        with self.assertRaisesRegex(TypeError, "Cannot assign to a sparse tensor"):
+            sparse[0, 0, 0, 0] = 99.0
+
     @skipMeta
     @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
     def test_empty(self, device, dtype):
diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp
index 6b7b7b6..27016f4 100644
--- a/torch/csrc/autograd/python_variable_indexing.cpp
+++ b/torch/csrc/autograd/python_variable_indexing.cpp
@@ -368,7 +368,7 @@
   }
 
   const auto& self_ = THPVariable_Unpack(self);
-  if (self_.is_sparse())
+  if (self_.is_sparse() || self_.is_sparse_csr())
   {
     throw TypeError("Cannot assign to a sparse tensor");
   }