Add Sparse CSC support to torch.empty
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77508
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h
index 7bdde2b..dfc7ff8 100644
--- a/aten/src/ATen/SparseCsrTensorUtils.h
+++ b/aten/src/ATen/SparseCsrTensorUtils.h
@@ -73,6 +73,30 @@
} \
} ()
+#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
+ [&]() { \
+ const auto& the_layout = LAYOUT; \
+ switch (the_layout) { \
+ case kSparseCsr: \
+ case kSparseCsc: \
+ return (ACTION)(); \
+ default: \
+ AT_ERROR(#NAME, " expected sparse compressed (non-block) tensor layout but got ", the_layout); \
+ } \
+ } ()
+
+#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
+ [&]() { \
+ const auto& the_layout = LAYOUT; \
+ switch (the_layout) { \
+ case kSparseBsr: \
+ case kSparseBsc: \
+ return (ACTION)(); \
+ default: \
+ AT_ERROR(#NAME, " expected sparse compressed block tensor layout but got ", the_layout); \
+ } \
+ } ()
+
namespace at {
namespace sparse_csr {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2ab1de6..78a27e6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1944,7 +1944,7 @@
Meta: empty_meta
MkldnnCPU: empty_mkldnn
SparseCPU, SparseCUDA: empty_sparse
- SparseCsrCPU, SparseCsrCUDA: empty_sparse_csr
+ SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
QuantizedCPU, QuantizedCUDA: empty_unknown_quantized
# We do not make new_empty a composite that calls into new_empty_strided, as the strided version
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 3030486..4adb101 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -459,7 +459,7 @@
SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
-Tensor empty_sparse_csr(
+Tensor empty_sparse_compressed(
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
@@ -467,32 +467,34 @@
c10::optional<bool> pin_memory,
c10::optional<MemoryFormat> optional_memory_format) {
check_size_nonnegative(size);
+ TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size ", size);
- TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse CSR matrices are supported, but got size ", size);
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::SparseCsr);
+ // Strided is the default layout for torch.empty.
+ Layout layout_ = layout.value_or(Layout::Strided);
- auto rows = size[size.size() - 2];
+ // torch.empty cannot be used to create blocked tensors because its
+ // API lacks a method to specify the block size.
+ AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(layout_, "empty_sparse_compressed", [&]{});
+
int64_t nnz = 0;
-
- auto crow_indices_size = DimVector(size.slice(0, size.size() - 2));
- crow_indices_size.push_back(rows + 1);
- auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
- col_indices_values_size.push_back(nnz);
+ auto compressed_indices_size = DimVector(size.slice(0, size.size() - 2));
+ auto plain_indices_and_values_size = DimVector(size.slice(0, size.size() - 2));
+ compressed_indices_size.push_back(size[compressedDimension(layout_, size)] + 1);
+ plain_indices_and_values_size.push_back(nnz);
TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
- auto crow_indices = at::empty(crow_indices_size, options);
- auto col_indices = at::empty(col_indices_values_size, options);
- auto values = at::empty(col_indices_values_size, options.dtype(dtype));
+ auto compressed_indices = at::empty(compressed_indices_size, options);
+ auto plain_indices = at::empty(plain_indices_and_values_size, options);
+ auto values = at::empty(plain_indices_and_values_size, options.dtype(dtype));
- return at::native::_sparse_csr_tensor_unsafe(
- crow_indices,
- col_indices,
- values,
- size,
- dtype,
- layout,
- device,
- pin_memory);
+ return at::native::_sparse_compressed_tensor_unsafe(compressed_indices,
+ plain_indices,
+ values,
+ size,
+ dtype,
+ layout,
+ device,
+ pin_memory);
}
const Tensor& resize_sparse_csr_(
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index f65a800..ca7bf2e 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -146,6 +146,12 @@
subtest(torch.sparse_bsc, name='SparseBSC')])
+def sparse_compressed_nonblock_layouts(test_name='layout'):
+ return parametrize(test_name, [
+ subtest(torch.sparse_csr, name='SparseCSR'),
+ subtest(torch.sparse_csc, name='SparseCSC')])
+
+
class TestSparseCompressed(TestCase):
"""Testing sparse compressed (CSR, CSC, BSR, BSC) tensor generic features.
"""
@@ -261,6 +267,51 @@
self.assertEqual(plain_indices, plain_indices_mth(sparse))
self.assertEqual(values, sparse.values())
+ @skipMeta
+ @sparse_compressed_nonblock_layouts()
+ @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
+ def test_empty(self, layout, device, dtype):
+ ns = [5, 2, 0]
+ batch_shapes = [(), (2,), (2, 3)]
+ compressed_indices_mth = {
+ torch.sparse_csr: torch.Tensor.crow_indices,
+ torch.sparse_csc: torch.Tensor.ccol_indices,
+ }[layout]
+ plain_indices_mth = {
+ torch.sparse_csr: torch.Tensor.col_indices,
+ torch.sparse_csc: torch.Tensor.row_indices,
+ }[layout]
+ compressed_dim = {
+ torch.sparse_csr: -2,
+ torch.sparse_csc: -1,
+ }[layout]
+ for m, n, b in itertools.product(ns, ns, batch_shapes):
+ shape = (*b, m, n)
+ result = torch.empty(shape, dtype=dtype, device=device, layout=layout)
+ self.assertEqual(result.shape, shape)
+ self.assertEqual(result.dtype, dtype)
+ self.assertEqual(result.device, torch.device(device))
+ self.assertEqual(result.layout, layout)
+ self.assertEqual(compressed_indices_mth(result).shape, (*b, shape[compressed_dim] + 1,))
+ self.assertEqual(plain_indices_mth(result).shape, (*b, 0,))
+ self.assertEqual(result.values().shape, (*b, 0,))
+ self.assertEqual(result._nnz(), 0)
+ self.assertEqual(compressed_indices_mth(result).device, torch.device(device))
+ self.assertEqual(plain_indices_mth(result).device, torch.device(device))
+ self.assertEqual(result.values().device, torch.device(device))
+ self.assertEqual(compressed_indices_mth(result).dtype, torch.int64)
+ self.assertEqual(plain_indices_mth(result).dtype, torch.int64)
+ self.assertEqual(result.values().dtype, dtype)
+
+ @skipMeta
+ @sparse_compressed_nonblock_layouts()
+ @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
+ def test_empty_errors(self, layout, device, dtype):
+ with self.assertRaisesRegex(RuntimeError,
+ "torch.empty: Only batched sparse compressed \\(non-block\\) tensors are supported"
+ ", but got size"):
+ torch.empty((5,), dtype=dtype, device=device, layout=layout)
+
class TestSparseCSR(TestCase):
@@ -334,35 +385,6 @@
sparse[0, 0, 0, 0] = 99.0
@skipMeta
- @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
- def test_empty(self, device, dtype):
- ns = [5, 2, 0]
- batch_shapes = [(), (2,), (2, 3)]
- for m, n, b in itertools.product(ns, ns, batch_shapes):
- shape = (*b, m, n)
- result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
- self.assertEqual(result.shape, shape)
- self.assertEqual(result.dtype, dtype)
- self.assertEqual(result.device, torch.device(device))
- self.assertEqual(result.layout, torch.sparse_csr)
- self.assertEqual(result.crow_indices().shape, (*b, shape[-2] + 1,))
- self.assertEqual(result.col_indices().shape, (*b, 0,))
- self.assertEqual(result.values().shape, (*b, 0,))
- self.assertEqual(result._nnz(), 0)
- self.assertEqual(result.crow_indices().device, torch.device(device))
- self.assertEqual(result.col_indices().device, torch.device(device))
- self.assertEqual(result.values().device, torch.device(device))
- self.assertEqual(result.crow_indices().dtype, torch.int64)
- self.assertEqual(result.col_indices().dtype, torch.int64)
- self.assertEqual(result.values().dtype, dtype)
-
- @skipMeta
- @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
- def test_empty_errors(self, device, dtype):
- with self.assertRaisesRegex(RuntimeError, "torch.empty: Only batched sparse CSR matrices are supported, but got size"):
- torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)
-
- @skipMeta
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
def test_clone(self, device, dtype):
from operator import mul