Sparse CSR: Add tensor.resize_ and tensor.copy_ (#63510)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63510
Sparse CSR matrix resizing behavior:
If we _increase the number of rows_ the number of specified elements in the matrix remains the same -> the size of col_indices, values doesn't change, the size of crow_indices becomes `rows+1`.
If we _decrease the number of rows_ the number of specified elements will be `min(nnz, rows*cols)` -> need to resize `crow_indices` to `rows+1` and set the last element to `min(nnz, rows*cols)`; decrease the size of col_indices and values to `min(nnz, rows*cols)`.
If we _increase the number of columns_ the number of specified elements in the matrix remains the same, the number of rows remains the same -> no need to resize anything, just set new sizes.
We _cannot decrease the number of columns_ because it would require recomputing `crow_indices`.
cc nikitaved pearu cpuhrsch IvanYashchuk
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D31796680
Pulled By: cpuhrsch
fbshipit-source-id: 7d8a9701ce06d30a1841f94bba0a057cacea9401
diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp
index bfd81b7..1e9202a 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.cpp
+++ b/aten/src/ATen/SparseCsrTensorImpl.cpp
@@ -57,6 +57,21 @@
col_indices_(std::move(col_indices)),
values_(std::move(values)) {}
+void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
+ auto rows = size[0];
+ auto cols = size[1];
+ auto old_crow_indices_size = crow_indices_.size(-1);
+ crow_indices_.resize_({rows + 1});
+ if (rows + 1 >= old_crow_indices_size) {
+ crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
+ } else {
+ crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
+ }
+ col_indices_.resize_({std::min<int64_t>(nnz, rows*cols)});
+ values_.resize_({std::min<int64_t>(nnz, rows*cols)});
+ sizes_and_strides_.set_sizes(size);
+}
+
void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
crow_indices_ = at::empty_like(
src.crow_indices(),
diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h
index f776033..de0a6e5 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.h
+++ b/aten/src/ATen/SparseCsrTensorImpl.h
@@ -32,6 +32,7 @@
public:
explicit SparseCsrTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
+ void resize_(int64_t nnz, IntArrayRef size);
void resize_as_sparse_csr_tensor_(const Tensor& src);
void set_member_tensors(
const Tensor& crow_indices,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index dc38c6a..fc5862b 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1236,6 +1236,7 @@
MkldnnCPU: copy_mkldnn_
SparseCPU, SparseCUDA, SparseHIP, SparseXPU: copy_sparse_wrapper_
CompositeExplicitAutograd: copy_
+ SparseCsrCPU, SparseCsrCUDA: copy_sparse_csr_
- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
dispatch: {}
@@ -1823,6 +1824,7 @@
CPU, Meta: resize_
CUDA: resize_cuda_
QuantizedCPU: quantized_resize_cpu_
+ SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_
- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
category_override: factory
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 9892b87..ed50651 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -248,6 +248,41 @@
pin_memory);
}
+const Tensor& resize_sparse_csr_(
+ const Tensor& self,
+ IntArrayRef size,
+ c10::optional<MemoryFormat> optional_memory_format) {
+ check_size_nonnegative(size);
+ TORCH_CHECK(size.size() == 2, "torch.resize_: Only 2D sparse CSR tensors are supported.");
+ TORCH_CHECK(
+ self.size(1) <= size[1],
+ "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
+ "The original number of columns is ",
+ self.size(1),
+ " while the requested new number of columns is ", size[1], ".");
+ get_sparse_csr_impl(self)->resize_(self._nnz(), size);
+ return self;
+}
+
+Tensor& copy_sparse_csr_(Tensor& self, const Tensor& src, bool non_blocking) {
+ TORCH_CHECK(
+ self.sizes() == src.sizes(),
+ "copy_sparse_csr_: only same size tensors are supported.");
+ TORCH_CHECK(
+ self.is_sparse_csr() && src.is_sparse_csr(),
+ "copy_sparse_csr_: copy between different layouts is not supported. Found self type = ",
+ self.toString(),
+ " and src type = ",
+ src.toString());
+ TORCH_CHECK(
+ self._nnz() == src._nnz(),
+ "copy_sparse_csr_: only tensors with the same number of specified elements are supported.");
+ self.crow_indices().copy_(src.crow_indices(), non_blocking);
+ self.col_indices().copy_(src.col_indices(), non_blocking);
+ self.values().copy_(src.values(), non_blocking);
+ return self;
+}
+
// Access members of CSR tensors.
int64_t _nnz_sparse_csr(const SparseCsrTensor& self) {
return get_sparse_csr_impl(self)->nnz();
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 6decd4c..480f6c3 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -136,6 +136,86 @@
with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)
+ @skipMeta
+ @dtypes(*get_all_dtypes())
+ def test_copy(self, device, dtype):
+
+ def run_test(shape, nnz, index_type):
+ a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+ b = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+
+ a.copy_(b)
+
+ self.assertEqual(a.crow_indices(), b.crow_indices())
+ self.assertEqual(a.col_indices(), b.col_indices())
+ self.assertEqual(a.values(), b.values())
+
+ ns = [5, 2, 0]
+ for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
+ run_test(shape, 0, index_dtype)
+ run_test(shape, shape[0] * shape[1], index_dtype)
+
+ @skipMeta
+ @dtypes(*get_all_dtypes())
+ def test_copy_errors(self, device, dtype):
+ for index_dtype in [torch.int32, torch.int64]:
+ shape1 = (2, 3)
+ shape2 = (3, 2)
+ a = self.genSparseCSRTensor(shape1, 0, dtype=dtype, device=device, index_dtype=index_dtype)
+ b = self.genSparseCSRTensor(shape2, 0, dtype=dtype, device=device, index_dtype=index_dtype)
+
+ with self.assertRaisesRegex(RuntimeError, "only same size tensors are supported."):
+ a.copy_(b)
+
+ with self.assertRaisesRegex(RuntimeError, "copy between different layouts is not supported."):
+ a.copy_(torch.empty(a.shape, dtype=dtype, device=device))
+
+ b = self.genSparseCSRTensor(shape1, 1, dtype=dtype, device=device, index_dtype=index_dtype)
+ with self.assertRaisesRegex(RuntimeError, "only tensors with the same number of specified elements are supported."):
+ a.copy_(b)
+
+ @skipMeta
+ @dtypes(*get_all_dtypes())
+ def test_resize(self, device, dtype):
+ for index_dtype in [torch.int32, torch.int64]:
+ shape = (2, 3)
+ nnz = 6
+ a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+
+ new_shape = (4, 5)
+ a.resize_(new_shape)
+
+ self.assertEqual(a.shape, new_shape)
+ # resize to larger shape doesn't add specified elements
+ self.assertEqual(a._nnz(), nnz)
+
+ new_shape = (1, 5)
+ a.resize_(new_shape)
+
+ self.assertEqual(a.shape, new_shape)
+ # resize to smaller shape trims specified elements
+ self.assertEqual(a._nnz(), 5)
+
+ @skipMeta
+ @dtypes(*get_all_dtypes())
+ def test_resize_errors(self, device, dtype):
+ for index_dtype in [torch.int32, torch.int64]:
+ shape = (2, 3)
+ nnz = 6
+ a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+
+ with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
+ new_shape = (4,)
+ a.resize_(new_shape)
+
+ # resizing of columns to smaller size is not implemented
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported.",
+ ):
+ new_shape = (2, 2)
+ a.resize_(new_shape)
+
def test_factory_type_invariants_check(self, device):
with self.assertRaisesRegex(RuntimeError, "both crow_indices and col_indices should have the same type."):
torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64),