Sparse CSR: Add tensor.resize_ and tensor.copy_ (#63510)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63510

Sparse CSR matrix resizing behavior:
If we _increase the number of rows_ the number of specified elements in the matrix remains the same -> the size of col_indices, values doesn't change, the size of crow_indices becomes `rows+1`.
If we _decrease the number of rows_ the number of specified elements will be `min(nnz, rows*cols)` -> need to resize `crow_indices` to `rows+1` and set the last element to `min(nnz, rows*cols)`; decrease the size of col_indices and values to `min(nnz, rows*cols)`.
If we _increase the number of columns_ the number of specified elements in the matrix remains the same, the number of rows remains the same -> no need to resize anything, just set new sizes.
We _cannot decrease the number of columns_ because it would require recomputing `crow_indices`.

cc nikitaved pearu cpuhrsch IvanYashchuk

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D31796680

Pulled By: cpuhrsch

fbshipit-source-id: 7d8a9701ce06d30a1841f94bba0a057cacea9401
diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp
index bfd81b7..1e9202a 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.cpp
+++ b/aten/src/ATen/SparseCsrTensorImpl.cpp
@@ -57,6 +57,21 @@
       col_indices_(std::move(col_indices)),
       values_(std::move(values)) {}
 
+void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
+  auto rows = size[0];
+  auto cols = size[1];
+  auto old_crow_indices_size = crow_indices_.size(-1);
+  crow_indices_.resize_({rows + 1});
+  if (rows + 1 >= old_crow_indices_size) {
+    crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
+  } else {
+    crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
+  }
+  col_indices_.resize_({std::min<int64_t>(nnz, rows*cols)});
+  values_.resize_({std::min<int64_t>(nnz, rows*cols)});
+  sizes_and_strides_.set_sizes(size);
+}
+
 void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
   crow_indices_ = at::empty_like(
       src.crow_indices(),
diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h
index f776033..de0a6e5 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.h
+++ b/aten/src/ATen/SparseCsrTensorImpl.h
@@ -32,6 +32,7 @@
  public:
   explicit SparseCsrTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
 
+  void resize_(int64_t nnz, IntArrayRef size);
   void resize_as_sparse_csr_tensor_(const Tensor& src);
   void set_member_tensors(
       const Tensor& crow_indices,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index dc38c6a..fc5862b 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1236,6 +1236,7 @@
     MkldnnCPU: copy_mkldnn_
     SparseCPU, SparseCUDA, SparseHIP, SparseXPU: copy_sparse_wrapper_
     CompositeExplicitAutograd: copy_
+    SparseCsrCPU, SparseCsrCUDA: copy_sparse_csr_
 
 - func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
   dispatch: {}
@@ -1823,6 +1824,7 @@
     CPU, Meta: resize_
     CUDA: resize_cuda_
     QuantizedCPU: quantized_resize_cpu_
+    SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_
 
 - func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
   category_override: factory
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 9892b87..ed50651 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -248,6 +248,41 @@
       pin_memory);
 }
 
+const Tensor& resize_sparse_csr_(
+    const Tensor& self,
+    IntArrayRef size,
+    c10::optional<MemoryFormat> optional_memory_format) {
+  check_size_nonnegative(size);
+  TORCH_CHECK(size.size() == 2, "torch.resize_: Only 2D sparse CSR tensors are supported.");
+  TORCH_CHECK(
+      self.size(1) <= size[1],
+      "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
+      "The original number of columns is ",
+      self.size(1),
+      " while the requested new number of columns is ", size[1], ".");
+  get_sparse_csr_impl(self)->resize_(self._nnz(), size);
+  return self;
+}
+
+Tensor& copy_sparse_csr_(Tensor& self, const Tensor& src, bool non_blocking) {
+  TORCH_CHECK(
+      self.sizes() == src.sizes(),
+      "copy_sparse_csr_: only same size tensors are supported.");
+  TORCH_CHECK(
+      self.is_sparse_csr() && src.is_sparse_csr(),
+      "copy_sparse_csr_: copy between different layouts is not supported. Found self type = ",
+      self.toString(),
+      " and src type = ",
+      src.toString());
+  TORCH_CHECK(
+      self._nnz() == src._nnz(),
+      "copy_sparse_csr_: only tensors with the same number of specified elements are supported.");
+  self.crow_indices().copy_(src.crow_indices(), non_blocking);
+  self.col_indices().copy_(src.col_indices(), non_blocking);
+  self.values().copy_(src.values(), non_blocking);
+  return self;
+}
+
 // Access members of CSR tensors.
 int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
   return get_sparse_csr_impl(self)->nnz();
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 6decd4c..480f6c3 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -136,6 +136,86 @@
         with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
             torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)
 
+    @skipMeta
+    @dtypes(*get_all_dtypes())
+    def test_copy(self, device, dtype):
+
+        def run_test(shape, nnz, index_type):
+            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+            b = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+
+            a.copy_(b)
+
+            self.assertEqual(a.crow_indices(), b.crow_indices())
+            self.assertEqual(a.col_indices(), b.col_indices())
+            self.assertEqual(a.values(), b.values())
+
+        ns = [5, 2, 0]
+        for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
+            run_test(shape, 0, index_dtype)
+            run_test(shape, shape[0] * shape[1], index_dtype)
+
+    @skipMeta
+    @dtypes(*get_all_dtypes())
+    def test_copy_errors(self, device, dtype):
+        for index_dtype in [torch.int32, torch.int64]:
+            shape1 = (2, 3)
+            shape2 = (3, 2)
+            a = self.genSparseCSRTensor(shape1, 0, dtype=dtype, device=device, index_dtype=index_dtype)
+            b = self.genSparseCSRTensor(shape2, 0, dtype=dtype, device=device, index_dtype=index_dtype)
+
+            with self.assertRaisesRegex(RuntimeError, "only same size tensors are supported."):
+                a.copy_(b)
+
+            with self.assertRaisesRegex(RuntimeError, "copy between different layouts is not supported."):
+                a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
+
+            b = self.genSparseCSRTensor(shape1, 1, dtype=dtype, device=device, index_dtype=index_dtype)
+            with self.assertRaisesRegex(RuntimeError, "only tensors with the same number of specified elements are supported."):
+                a.copy_(b)
+
+    @skipMeta
+    @dtypes(*get_all_dtypes())
+    def test_resize(self, device, dtype):
+        for index_dtype in [torch.int32, torch.int64]:
+            shape = (2, 3)
+            nnz = 6
+            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+
+            new_shape = (4, 5)
+            a.resize_(new_shape)
+
+            self.assertEqual(a.shape, new_shape)
+            # resize to larger shape doesn't add specified elements
+            self.assertEqual(a._nnz(), nnz)
+
+            new_shape = (1, 5)
+            a.resize_(new_shape)
+
+            self.assertEqual(a.shape, new_shape)
+            # resize to smaller shape trims specified elements
+            self.assertEqual(a._nnz(), 5)
+
+    @skipMeta
+    @dtypes(*get_all_dtypes())
+    def test_resize_errors(self, device, dtype):
+        for index_dtype in [torch.int32, torch.int64]:
+            shape = (2, 3)
+            nnz = 6
+            a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+
+            with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
+                new_shape = (4,)
+                a.resize_(new_shape)
+
+            # resizing of columns to smaller size is not implemented
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.",
+            ):
+                new_shape = (2, 2)
+                a.resize_(new_shape)
+
     def test_factory_type_invariants_check(self, device):
         with self.assertRaisesRegex(RuntimeError, "both crow_indices and col_indices should have the same type."):
             torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64),