Support (non-batch) BSR/BSC to COO sparse tensor conversions (#90718)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90718
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp
index e6c7bd3..0799a37 100644
--- a/aten/src/ATen/native/TensorConversions.cpp
+++ b/aten/src/ATen/native/TensorConversions.cpp
@@ -508,6 +508,14 @@
switch(input_.layout()) {
case kSparseCsr: return grad.sparse_mask(input_.to_sparse()).to_sparse_csr();
case kSparseCsc: return grad.sparse_mask(input_.to_sparse()).to_sparse_csc();
+ case kSparseBsr: {
+ auto blocksize = DimVector(input_.values().sizes().slice(1, 2));
+ return grad.sparse_mask(input_.to_sparse().coalesce()).to_sparse_bsr(blocksize);
+ }
+ case kSparseBsc: {
+ auto blocksize = DimVector(input_.values().sizes().slice(1, 2));
+ return grad.sparse_mask(input_.to_sparse().coalesce()).to_sparse_bsc(blocksize);
+ }
// BSR and BSC should be handled via implement sparse_compressed_mask
default: ; // fall back to unsupported input layout error
}
@@ -1734,24 +1742,27 @@
// TODO: implement coo.to_sparse(sparse_dim) and then use
// return self.to_sparse().to_sparse(sparse_dim);
TORCH_CHECK(
- sparse_dim == 2, "sparse dim 1 is not supported by sparse_csr_to_dense");
- if (self.layout() == kSparseCsc) {
- Tensor indices = at::_convert_indices_from_csr_to_coo(
- self.ccol_indices(), self.row_indices(), false, true);
- return at::native::_sparse_coo_tensor_unsafe(
- indices, self.values(), self.sizes())
- ._coalesced_(true);
- }
- if (self.layout() == kSparseCsr) {
- Tensor indices = at::_convert_indices_from_csr_to_coo(
- self.crow_indices(), self.col_indices(), false, false);
- return at::native::_sparse_coo_tensor_unsafe(
- indices, self.values(), self.sizes())
- ._coalesced_(true);
- }
- AT_ERROR(
- "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got ",
- self.layout());
+ sparse_dim == 2, "sparse dim 1 is not supported by sparse_compressed_to_dense");
+ Layout layout = self.layout();
+ Tensor compressed_indices, plain_indices;
+ std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(self);
+ Tensor values;
+ Tensor indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices,
+ false, (layout == kSparseCsc || layout == kSparseBsc));
+ bool coalesced = true;
+ AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "sparse_compressed_to_sparse",
+ [&] { values = self.values(); },
+ [&] {
+ auto size = DimVector(self.sizes().slice(0, 2));
+ auto blocksize = DimVector(self.values().sizes().slice(1, 2));
+ auto nnz = indices.size(1);
+ indices = indices.repeat_interleave(blocksize[0] * blocksize[1], 1)
+ .mul_(at::tensor({blocksize[0], blocksize[1]}, indices.options()).reshape({2, 1}))
+ .add_(at::stack(at::where(at::ones(blocksize, indices.options()))).repeat({1, nnz}));
+ values = self.values().flatten(0, 2);
+ coalesced = nnz == 1;
+ });
+ return at::native::_sparse_coo_tensor_unsafe(indices, values, self.sizes())._coalesced_(coalesced);
}
Tensor sparse_compressed_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize) {
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 93a2241..528ae0a 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4181,7 +4181,8 @@
RuntimeError, "Only tensors with two sparse dimensions can be converted to the Sparse(Csr|Csc) layout"):
explicit_to_sparse(t)
continue
- elif from_layout in {torch.sparse_csr, torch.sparse_csc} and to_layout is torch.sparse_coo and is_batch:
+ elif from_layout in {torch.sparse_csr, torch.sparse_csc,
+ torch.sparse_bsr, torch.sparse_bsc} and to_layout is torch.sparse_coo and is_batch:
with self.assertRaisesRegex(RuntimeError,
"crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"):
t.to_sparse(layout=to_layout, blocksize=blocksize)
@@ -4189,16 +4190,6 @@
"crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"):
explicit_to_sparse(t)
continue
- elif from_layout in {torch.sparse_bsr, torch.sparse_bsc} and to_layout is torch.sparse_coo:
- with self.assertRaisesRegex(
- RuntimeError,
- "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got Sparse(Bsr|Bsc)"):
- t.to_sparse(layout=to_layout, blocksize=blocksize)
- with self.assertRaisesRegex(
- RuntimeError,
- "sparse_compressed_to_sparse expected SparseCsr or SparseCsc layout but got Sparse(Bsr|Bsc)"):
- explicit_to_sparse(t)
- self.skipTest('NOT IMPL')
elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc),
(torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc),
(torch.sparse_csc, torch.sparse_bsr), (torch.sparse_csc, torch.sparse_bsc),