Support (non-batch) BSR/BSC to COO sparse tensor conversions (#90718)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90718
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp
index e6c7bd3..0799a37 100644
--- a/aten/src/ATen/native/TensorConversions.cpp
+++ b/aten/src/ATen/native/TensorConversions.cpp
@@ -508,6 +508,14 @@
     switch(input_.layout()) {
     case kSparseCsr: return grad.sparse_mask(input_.to_sparse()).to_sparse_csr();
     case kSparseCsc: return grad.sparse_mask(input_.to_sparse()).to_sparse_csc();
+    case kSparseBsr: {
+      auto blocksize = DimVector(input_.values().sizes().slice(1, 2));
+      return grad.sparse_mask(input_.to_sparse().coalesce()).to_sparse_bsr(blocksize);
+    }
+    case kSparseBsc: {
+      auto blocksize = DimVector(input_.values().sizes().slice(1, 2));
+      return grad.sparse_mask(input_.to_sparse().coalesce()).to_sparse_bsc(blocksize);
+    }
       // BSR and BSC should be handled via implement sparse_compressed_mask
     default: ; // fall back to unsupported input layout error
     }
@@ -1734,24 +1742,27 @@
   // TODO: implement coo.to_sparse(sparse_dim) and then use
   // return self.to_sparse().to_sparse(sparse_dim);
   TORCH_CHECK(
-      sparse_dim == 2, "sparse dim 1 is not supported by sparse_csr_to_dense");
-  if (self.layout() == kSparseCsc) {
-    Tensor indices = at::_convert_indices_from_csr_to_coo(
-        self.ccol_indices(), self.row_indices(), false, true);
-    return at::native::_sparse_coo_tensor_unsafe(
-               indices, self.values(), self.sizes())
-        ._coalesced_(true);
-  }
-  if (self.layout() == kSparseCsr) {
-    Tensor indices = at::_convert_indices_from_csr_to_coo(
-        self.crow_indices(), self.col_indices(), false, false);
-    return at::native::_sparse_coo_tensor_unsafe(
-               indices, self.values(), self.sizes())
-        ._coalesced_(true);
-  }
-  AT_ERROR(
-      "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got ",
-      self.layout());
+      sparse_dim == 2, "sparse dim 1 is not supported by sparse_compressed_to_dense");
+  Layout layout = self.layout();
+  Tensor compressed_indices, plain_indices;
+  std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(self);
+  Tensor values;
+  Tensor indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices,
+                                                        false, (layout == kSparseCsc || layout == kSparseBsc));
+  bool coalesced = true;
+  AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "sparse_compressed_to_sparse",
+    [&] { values = self.values(); },
+    [&] {
+      auto size = DimVector(self.sizes().slice(0, 2));
+      auto blocksize = DimVector(self.values().sizes().slice(1, 2));
+      auto nnz = indices.size(1);
+      indices = indices.repeat_interleave(blocksize[0] * blocksize[1], 1)
+        .mul_(at::tensor({blocksize[0], blocksize[1]}, indices.options()).reshape({2, 1}))
+        .add_(at::stack(at::where(at::ones(blocksize, indices.options()))).repeat({1, nnz}));
+      values = self.values().flatten(0, 2);
+      coalesced = nnz == 1;
+    });
+  return at::native::_sparse_coo_tensor_unsafe(indices, values, self.sizes())._coalesced_(coalesced);
 }
 
 Tensor sparse_compressed_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize) {
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 93a2241..528ae0a 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4181,7 +4181,8 @@
                         RuntimeError, "Only tensors with two sparse dimensions can be converted to the Sparse(Csr|Csc) layout"):
                     explicit_to_sparse(t)
                 continue
-            elif from_layout in {torch.sparse_csr, torch.sparse_csc} and to_layout is torch.sparse_coo and is_batch:
+            elif from_layout in {torch.sparse_csr, torch.sparse_csc,
+                                 torch.sparse_bsr, torch.sparse_bsc} and to_layout is torch.sparse_coo and is_batch:
                 with self.assertRaisesRegex(RuntimeError,
                                             "crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"):
                     t.to_sparse(layout=to_layout, blocksize=blocksize)
@@ -4189,16 +4190,6 @@
                                             "crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"):
                     explicit_to_sparse(t)
                 continue
-            elif from_layout in {torch.sparse_bsr, torch.sparse_bsc} and to_layout is torch.sparse_coo:
-                with self.assertRaisesRegex(
-                        RuntimeError,
-                        "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got Sparse(Bsr|Bsc)"):
-                    t.to_sparse(layout=to_layout, blocksize=blocksize)
-                with self.assertRaisesRegex(
-                        RuntimeError,
-                        "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got Sparse(Bsr|Bsc)"):
-                    explicit_to_sparse(t)
-                self.skipTest('NOT IMPL')
             elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc),
                                               (torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc),
                                               (torch.sparse_csc, torch.sparse_bsr), (torch.sparse_csc, torch.sparse_bsc),