Add Sparse CSC support to torch.empty

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77508

Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h
index 7bdde2b..dfc7ff8 100644
--- a/aten/src/ATen/SparseCsrTensorUtils.h
+++ b/aten/src/ATen/SparseCsrTensorUtils.h
@@ -73,6 +73,30 @@
     }                                                                   \
   } ()
 
+#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
+  [&]() {                                                               \
+    const auto& the_layout = LAYOUT;                                    \
+    switch (the_layout) {                                               \
+    case kSparseCsr:                                                    \
+    case kSparseCsc:                                                    \
+      return (ACTION)();                                                \
+    default:                                                            \
+      AT_ERROR(#NAME, " expected sparse compressed (non-block) tensor layout but got ", the_layout); \
+    }                                                                   \
+  } ()
+
+#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
+  [&]() {                                                               \
+    const auto& the_layout = LAYOUT;                                    \
+    switch (the_layout) {                                               \
+    case kSparseBsr:                                                    \
+    case kSparseBsc:                                                    \
+      return (ACTION)();                                                \
+    default:                                                            \
+      AT_ERROR(#NAME, " expected sparse compressed block tensor layout but got ", the_layout); \
+    }                                                                   \
+  } ()
+
 namespace at {
 namespace sparse_csr {
 
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2ab1de6..78a27e6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1944,7 +1944,7 @@
     Meta: empty_meta
     MkldnnCPU: empty_mkldnn
     SparseCPU, SparseCUDA: empty_sparse
-    SparseCsrCPU, SparseCsrCUDA: empty_sparse_csr
+    SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
     QuantizedCPU, QuantizedCUDA: empty_unknown_quantized
 
 # We do not make new_empty a composite that calls into new_empty_strided, as the strided version
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 3030486..4adb101 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -459,7 +459,7 @@
 SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
 SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
 
-Tensor empty_sparse_csr(
+Tensor empty_sparse_compressed(
     IntArrayRef size,
     c10::optional<ScalarType> dtype,
     c10::optional<Layout> layout,
@@ -467,32 +467,34 @@
     c10::optional<bool> pin_memory,
     c10::optional<MemoryFormat> optional_memory_format) {
   check_size_nonnegative(size);
+  TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse compressed (non-block) tensors are supported, but got size ", size);
 
-  TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse CSR matrices are supported, but got size ", size);
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::SparseCsr);
+  // Strided is the default layout for torch.empty.
+  Layout layout_ = layout.value_or(Layout::Strided);
 
-  auto rows = size[size.size() - 2];
+  // torch.empty cannot be used to create blocked tensors because its
+  // API lacks a method to specify the block size.
+  AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(layout_, "empty_sparse_compressed", [&]{});
+
   int64_t nnz = 0;
-
-  auto crow_indices_size = DimVector(size.slice(0, size.size() - 2));
-  crow_indices_size.push_back(rows + 1);
-  auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
-  col_indices_values_size.push_back(nnz);
+  auto compressed_indices_size = DimVector(size.slice(0, size.size() - 2));
+  auto plain_indices_and_values_size = DimVector(size.slice(0, size.size() - 2));
+  compressed_indices_size.push_back(size[compressedDimension(layout_, size)] + 1);
+  plain_indices_and_values_size.push_back(nnz);
 
   TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
-  auto crow_indices = at::empty(crow_indices_size, options);
-  auto col_indices = at::empty(col_indices_values_size, options);
-  auto values = at::empty(col_indices_values_size, options.dtype(dtype));
+  auto compressed_indices = at::empty(compressed_indices_size, options);
+  auto plain_indices = at::empty(plain_indices_and_values_size, options);
+  auto values = at::empty(plain_indices_and_values_size, options.dtype(dtype));
 
-  return at::native::_sparse_csr_tensor_unsafe(
-      crow_indices,
-      col_indices,
-      values,
-      size,
-      dtype,
-      layout,
-      device,
-      pin_memory);
+  return at::native::_sparse_compressed_tensor_unsafe(compressed_indices,
+                                                      plain_indices,
+                                                      values,
+                                                      size,
+                                                      dtype,
+                                                      layout,
+                                                      device,
+                                                      pin_memory);
 }
 
 const Tensor& resize_sparse_csr_(
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index f65a800..ca7bf2e 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -146,6 +146,12 @@
         subtest(torch.sparse_bsc, name='SparseBSC')])
 
 
+def sparse_compressed_nonblock_layouts(test_name='layout'):
+    return parametrize(test_name, [
+        subtest(torch.sparse_csr, name='SparseCSR'),
+        subtest(torch.sparse_csc, name='SparseCSC')])
+
+
 class TestSparseCompressed(TestCase):
     """Testing sparse compressed (CSR, CSC, BSR, BSC) tensor generic features.
     """
@@ -261,6 +267,51 @@
                 self.assertEqual(plain_indices, plain_indices_mth(sparse))
                 self.assertEqual(values, sparse.values())
 
+    @skipMeta
+    @sparse_compressed_nonblock_layouts()
+    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
+    def test_empty(self, layout, device, dtype):
+        ns = [5, 2, 0]
+        batch_shapes = [(), (2,), (2, 3)]
+        compressed_indices_mth = {
+            torch.sparse_csr: torch.Tensor.crow_indices,
+            torch.sparse_csc: torch.Tensor.ccol_indices,
+        }[layout]
+        plain_indices_mth = {
+            torch.sparse_csr: torch.Tensor.col_indices,
+            torch.sparse_csc: torch.Tensor.row_indices,
+        }[layout]
+        compressed_dim = {
+            torch.sparse_csr: -2,
+            torch.sparse_csc: -1,
+        }[layout]
+        for m, n, b in itertools.product(ns, ns, batch_shapes):
+            shape = (*b, m, n)
+            result = torch.empty(shape, dtype=dtype, device=device, layout=layout)
+            self.assertEqual(result.shape, shape)
+            self.assertEqual(result.dtype, dtype)
+            self.assertEqual(result.device, torch.device(device))
+            self.assertEqual(result.layout, layout)
+            self.assertEqual(compressed_indices_mth(result).shape, (*b, shape[compressed_dim] + 1,))
+            self.assertEqual(plain_indices_mth(result).shape, (*b, 0,))
+            self.assertEqual(result.values().shape, (*b, 0,))
+            self.assertEqual(result._nnz(), 0)
+            self.assertEqual(compressed_indices_mth(result).device, torch.device(device))
+            self.assertEqual(plain_indices_mth(result).device, torch.device(device))
+            self.assertEqual(result.values().device, torch.device(device))
+            self.assertEqual(compressed_indices_mth(result).dtype, torch.int64)
+            self.assertEqual(plain_indices_mth(result).dtype, torch.int64)
+            self.assertEqual(result.values().dtype, dtype)
+
+    @skipMeta
+    @sparse_compressed_nonblock_layouts()
+    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
+    def test_empty_errors(self, layout, device, dtype):
+        with self.assertRaisesRegex(RuntimeError,
+                                    "torch.empty: Only batched sparse compressed \\(non-block\\) tensors are supported"
+                                    ", but got size"):
+            torch.empty((5,), dtype=dtype, device=device, layout=layout)
+
 
 class TestSparseCSR(TestCase):
 
@@ -334,35 +385,6 @@
             sparse[0, 0, 0, 0] = 99.0
 
     @skipMeta
-    @dtypes(*all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half))
-    def test_empty(self, device, dtype):
-        ns = [5, 2, 0]
-        batch_shapes = [(), (2,), (2, 3)]
-        for m, n, b in itertools.product(ns, ns, batch_shapes):
-            shape = (*b, m, n)
-            result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
-            self.assertEqual(result.shape, shape)
-            self.assertEqual(result.dtype, dtype)
-            self.assertEqual(result.device, torch.device(device))
-            self.assertEqual(result.layout, torch.sparse_csr)
-            self.assertEqual(result.crow_indices().shape, (*b, shape[-2] + 1,))
-            self.assertEqual(result.col_indices().shape, (*b, 0,))
-            self.assertEqual(result.values().shape, (*b, 0,))
-            self.assertEqual(result._nnz(), 0)
-            self.assertEqual(result.crow_indices().device, torch.device(device))
-            self.assertEqual(result.col_indices().device, torch.device(device))
-            self.assertEqual(result.values().device, torch.device(device))
-            self.assertEqual(result.crow_indices().dtype, torch.int64)
-            self.assertEqual(result.col_indices().dtype, torch.int64)
-            self.assertEqual(result.values().dtype, dtype)
-
-    @skipMeta
-    @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
-    def test_empty_errors(self, device, dtype):
-        with self.assertRaisesRegex(RuntimeError, "torch.empty: Only batched sparse CSR matrices are supported, but got size"):
-            torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)
-
-    @skipMeta
     @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
     def test_clone(self, device, dtype):
         from operator import mul