(bsr/csr) x dense mm (#85551)

As per title. This implementation is not the most optimal and could be improved albeit with native kernels (i.e. block matching need not be materialized).

Compared to existing kernels it offers:

- Half float support (In fact, any dtype that supports `matmul` will work).
- Arbitrary block sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85551
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
index 4ad0d55..cdeb3e1 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
@@ -2,11 +2,215 @@
 #include <ATen/Config.h>
 #include <ATen/native/mkl/SparseBlasImpl.h>
 #include <ATen/native/sparse/SparseBlasImpl.h>
+#include <ATen/SparseCsrTensorUtils.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Operators.h>
+#else
+#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
+#include <ATen/ops/empty_like.h>
+#include <ATen/ops/zeros.h>
+#endif
 
 namespace at {
 namespace native {
 namespace sparse {
 namespace impl {
+
+Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
+  const auto compressed_layout = compressed.layout();
+  const auto compressed_layout_str = at::sparse_csr::layoutToString(compressed_layout);
+
+  // Device restrictions
+  TORCH_CHECK(compressed.device() == strided.device()
+      && compressed.device() == result.device(),
+      "spmm_out(): all input arguments are expected to be on the same device.");
+
+  // Layout restrictions.
+  TORCH_CHECK(compressed_layout == kSparseCsr || compressed_layout == kSparseBsr,
+      "spmm(", compressed_layout_str, ", Strided): only Csr and Bsr formats are supported for the sparse argument.");
+  TORCH_CHECK(result.layout() == kStrided,
+      "spmm_out(): out argument is expected to be strided.");
+
+  // Dtype restrictions.
+  TORCH_CHECK(compressed.scalar_type() == strided.scalar_type(),
+      "spmm(", compressed_layout_str, ", Strided): arguments expected to have the same dtype.");
+
+  // Dim restrictions.
+  TORCH_CHECK(compressed.dim() == 2,
+      "spmm(", compressed_layout_str, ", Strided): sparse arguments which are not 2D are not supported.");
+  TORCH_CHECK(strided.dim() >= 2,
+      "spmm(", compressed_layout_str, ", Strided): expects strided inputs to be at least 2D.");
+
+  const auto m = compressed.sizes()[0];
+  const auto k = compressed.sizes()[1];
+  const auto n = strided.size(-1);
+  // Matrix product size compatibility.
+  TORCH_CHECK(strided.size(-2) == k,
+      "spmm(", compressed_layout_str, "Strided): argument sizes are not compatible for matrix multiplication. ",
+      "Got ", compressed_layout_str, ".sizes(-1) == ", k, " is not equal to ",
+      "Strided.sizes(-2) == ", strided.size(-2), ".");
+
+  // We assume that result is properly resized.
+  auto result_expected_size = at::DimVector(strided.sizes().slice(0, strided.dim() - 2));
+  result_expected_size.push_back(m);
+  result_expected_size.push_back(n);
+  TORCH_CHECK(result.sizes() == result_expected_size,
+      "spmm_out(): out argument has wrong size. ",
+      "Expected (", result_expected_size, ") but got (", result.sizes(), ").");
+
+  auto values = compressed.values();
+
+  using Blocksize = std::array<int64_t, 2>;
+  // We refer to these as (b0, b1) in the comments below.
+  Blocksize blocksize = {1, 1};
+  if (compressed_layout == kSparseBsr) {
+    blocksize = {values.size(-2), values.size(-1)};
+  }
+
+  // (..., r, c) -> (..., r / b0, c / b1, b0, b1)
+  // NOTE: this function ALWAYS creates a view upon successful execution.
+  const auto tile_tensor = [compressed_layout](
+      const Tensor& t, Blocksize blocksize) -> Tensor {
+    if (compressed_layout == kSparseCsr) {
+      return t.unsqueeze(-1).unsqueeze_(-1);
+    }
+    else {
+      const auto size_neg_2_blocked = t.size(-2) / blocksize[0];
+      const auto size_neg_1_blocked = t.size(-1) / blocksize[1];
+      auto tiled_sizes = at::DimVector(t.sizes().slice(0, t.dim() - 2));
+      tiled_sizes.push_back(size_neg_2_blocked);
+      tiled_sizes.push_back(blocksize[0]);
+      tiled_sizes.push_back(size_neg_1_blocked);
+      tiled_sizes.push_back(blocksize[1]);
+      return t.reshape(tiled_sizes).transpose(-3, -2);
+    }
+  };
+
+  // Note that sparse values are (..., b0, b1). This means that
+  // the strided input has to be "tilable" to (..., b1, x) with
+  // any x >= 1 such that all the shapes are (block) matrix product
+  // compatible. The matrix product will then have shape (..., b0, x).
+  // This in turn means the the result has to be "tilable" to
+  // (..., b0, x).
+  //
+  // These observations imply the following restrictions:
+  // 1. strided.size(-2) has to be divisible by b1.
+  // 2. result.size(-2) has to be divisible by b0.
+  // 3. both strided.size(-1) and result.size(-1)
+  //    have to be divisible by x.
+  //
+  // Restrictions 1 and 2 are trivially satisfied.
+  // Regarding restriction 3:
+  // it would make sense to take the largest possible x for better
+  // performance since it is very likely that the last dimension
+  // is contiguous. As such, this value is exactly
+  // x = strided.size(-1), since strided.size(-1) == result.size(-1)
+
+  // See the comments above. This is our x.
+  const auto outer_blocksize = n;
+
+  Blocksize strided_blocksize = {blocksize[1], outer_blocksize};
+  const auto strided_tiled = tile_tensor(strided, strided_blocksize);
+
+  // Left argument is (..., b0, b1) and right is (..., b1, x).
+  // This naturally implies the result should be "tilable" as
+  // (..., b0, x).
+  Blocksize result_blocksize = {blocksize[0], outer_blocksize};
+  auto result_tiled = tile_tensor(result, result_blocksize);
+
+  if (compressed_layout == kSparseCsr) {
+    values.unsqueeze_(-1).unsqueeze_(-1);
+  }
+
+  Tensor compressed_indices, plain_indices;
+  std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(compressed);
+
+  // Select block rows of the strided input that intersect with the block colums of the sparse input.
+  auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices);
+
+  // Promote to float if output is half or bfloat16 for better precision
+  const auto mm_dtype = (result.scalar_type() == kHalf || result.scalar_type() == kBFloat16)
+    ? kFloat : result.scalar_type();
+  // Now that we know which block rows intersect with which block columns,
+  // we can perform matrix products between pairs of blocks.
+  // NOTE: .to is a no-op when result.scalar_type() == mm_dtype.
+  const auto pairwise_block_mm = values.unsqueeze(-3).to(mm_dtype)
+    .matmul(strided_tiled_selected_rows.to(mm_dtype));
+
+  // Having pairwise block matrix products stored in pairwise_block_mm,
+  // it is sufficient to sum all the block products that share the same row
+  // encoded in the sparse index. Since the reduction step is done via
+  // advanced indexing methods, the compressed index ought to get converted
+  // to the COO format.
+  const auto compressed_indices_coo = at::_convert_indices_from_csr_to_coo(
+      compressed_indices,
+      plain_indices,
+      compressed_indices.scalar_type() == kInt).select(0, 0);
+
+  // Reduction step.
+  // If result is neither half nor bfloat16, do everyting in-place.
+  if (result.scalar_type() == mm_dtype) {
+    // Zero out and sum over the blocks that share the same row indices.
+    result_tiled.zero_();
+    result_tiled.index_add_(
+        /*dim=*/-4,
+        /*index=*/compressed_indices_coo,
+        /*source=*/pairwise_block_mm);
+  }
+  // Otherwise accumulate into a buffer and then copy.
+  else {
+    // No need to zero out, sum over the blocks goes into a buffer
+    // followed by a copy into result.
+    auto promoted_result_tiled = at::zeros(
+        result_tiled.sizes(),
+        result_tiled.options().dtype(mm_dtype));
+    promoted_result_tiled.index_add_(
+        /*dim=*/-4,
+        /*index=*/compressed_indices_coo,
+        /*source=*/pairwise_block_mm);
+    result_tiled.copy_(promoted_result_tiled);
+  }
+
+  return result;
+}
+
+Tensor& _compressed_row_strided_addmm_out(
+    const Tensor& self,
+    const Tensor& mat1,
+    const Tensor& mat2,
+    const Scalar& beta,
+    const Scalar& alpha,
+    Tensor& result) {
+  // If result is not the same as self, it could always be used as out argument to mm.
+  if (!result.is_same(self)) {
+    _compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
+
+    // Process beta
+    if (beta.toComplexDouble() != 0.) {
+      result.add_(self.mul(beta));
+    }
+  }
+  // Otherwise we need to allocate external memory for mm if beta != 0.
+  else {
+    // Process beta
+    if (beta.toComplexDouble() != 0.) {
+      result.mul_(beta);
+      auto mm = at::empty_like(result);
+      _compressed_row_strided_mm_out(mat1, mat2, mm);
+      mm.mul_(alpha);
+      result.add_(mm);
+    }
+    else {
+      _compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
+    }
+  }
+
+  return result;
+}
+
 namespace cpu {
 
 /*
diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.h b/aten/src/ATen/native/sparse/SparseBlasImpl.h
index b488396..acdd6b3 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.h
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.h
@@ -7,6 +7,20 @@
 namespace native {
 namespace sparse {
 namespace impl {
+
+TORCH_API Tensor& _compressed_row_strided_mm_out(
+    const Tensor& compressed_row_sparse,
+    const Tensor& strided,
+    Tensor& result);
+
+TORCH_API Tensor& _compressed_row_strided_addmm_out(
+    const Tensor& self,
+    const Tensor& mat1,
+    const Tensor& mat2,
+    const Scalar& beta,
+    const Scalar& alpha,
+    Tensor& result);
+
 namespace cpu {
 
 void addmv_out_sparse_csr(
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index 04f7739..e5393e5 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -22,6 +22,7 @@
 #include <ATen/ops/_conj_physical_native.h>
 #include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
 #include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
+#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index 91e20b5..379640b 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -8,6 +8,7 @@
 #include <ATen/cuda/CUDASparseDescriptors.h>
 #include <ATen/native/LinearAlgebraUtils.h>
 #include <ATen/native/cuda/MiscUtils.h>
+#include <ATen/native/sparse/SparseBlasImpl.h>
 #include <ATen/native/sparse/cuda/SparseBlasImpl.h>
 #include <ATen/native/sparse/cuda/SparseBlasLegacy.h>
 
@@ -480,6 +481,22 @@
       mat1.values().is_contiguous() ||
       mat1.values().transpose(-2, -1).is_contiguous());
 
+  // NOTE: the code below allows arbitrary block sizes
+  // and might be potentially faster than cuSPARSE implementation
+  // especially for not very sparse inputs.
+  if (mat1.scalar_type() == ScalarType::Half || mat1.scalar_type() == ScalarType::BFloat16) {
+    at::native::sparse::impl::_compressed_row_strided_addmm_out(
+        result,
+        mat1,
+        mat2,
+        /*beta=*/beta,
+        /*alpha=*/alpha,
+        // @nikitaved: not sure whether `const Tensor& result` makes sense,
+        // but let's keep the interface intact, hence the const cast.
+        const_cast<Tensor&>(result));
+    return;
+  }
+
   const cusparseDirection_t block_layout = mat1.values().is_contiguous()
       ? CUSPARSE_DIRECTION_ROW
       : CUSPARSE_DIRECTION_COLUMN;
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 7ac720c..eaa77e2 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1491,16 +1491,8 @@
 
         out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
         addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
+        expected = ref(c, a, b, alpha, beta)
 
-        a_bsr = sp.bsr_matrix(
-            (
-                a.values().cpu().numpy(),
-                a.col_indices().cpu().numpy(),
-                a.crow_indices().cpu().numpy(),
-            ),
-            shape=a.shape,
-        )
-        expected = ref(c.cpu().numpy(), a_bsr, b.cpu().resolve_conj().numpy(), alpha, beta)
         self.assertEqual(actual, out)
         self.assertEqual(actual, expected)
 
@@ -1510,8 +1502,13 @@
     @parametrize("noncontiguous", [True, False])
     @skipCPUIfNoMklSparse
     @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
-    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
-    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-5, torch.complex128: 1e-5})
+    @dtypes(*floating_and_complex_types())
+    @dtypesIfCUDA(*floating_and_complex_types_and(
+                  *[torch.half] if SM53OrLater else [],
+                  *[torch.bfloat16] if SM80OrLater else []))
+    @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
+                        torch.float64: 1e-5, torch.complex128: 1e-5,
+                        torch.float16: 1e-3, torch.bfloat16: 1e-3})
     def test_block_addmm(self, device, dtype, index_dtype, block_size, noncontiguous):
 
         def make_transposed_addmm_op(f):
@@ -1536,25 +1533,72 @@
 
             return wrapper
 
+        def ref_sp_numpy(c, a, b, alpha=None, beta=None, out=None):
+
+            def prep_input(t):
+
+                def to_sp_block_compressed(t):
+
+                    if t.layout is torch.sparse_bsc:
+                        tt = t.transpose(-1, -2)
+                    else:
+                        tt = t
+
+                    t_sp_bsr = sp.bsr_matrix(
+                        (
+                            tt.values().cpu().numpy(),
+                            tt.col_indices().cpu().numpy(),
+                            tt.crow_indices().cpu().numpy(),
+                        ),
+                        shape=tt.shape,
+                    )
+
+                    if t.layout is torch.sparse_bsc:
+                        return t_sp_bsr.transpose()
+                    else:
+                        return t_sp_bsr
+
+                if t.layout is not torch.strided:
+                    return to_sp_block_compressed(t)
+                else:
+                    return t.cpu().resolve_conj().numpy()
+
+            res = _npref_block_addmm_addmv(
+                *map(lambda t: prep_input(t), (c, a, b)),
+                alpha,
+                beta
+            )
+
+            if out is not None:
+                out.copy_(res)
+                return out
+            else:
+                return res
+
+        def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
+            res = alpha * (a.to_dense() @ b.to_dense()) + beta * c
+            if out is not None:
+                out.copy_(res)
+                return out
+            else:
+                return res
+
+        if dtype in (torch.half, torch.bfloat16):
+            ref = ref_half_bfloat16
+        else:
+            ref = ref_sp_numpy
+
         for (m, n, k) in itertools.product([2, 5], repeat=3):
             nnz = random.randint(0, m * k)
-            if not noncontiguous:
-                a = self.genSparseCSRTensor((m * block_size, k * block_size),
-                                            nnz,
-                                            dtype=dtype,
-                                            device=device,
-                                            index_dtype=index_dtype)
-                a = a.to_sparse_bsr((block_size, block_size))
-            else:
-                a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
-                a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
-                a_data = a_data.mT if noncontiguous else a_data  # Test column-major blocks
-                a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), a_data,
-                                                    (m * block_size, k * block_size))
+            a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+            a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
+            a_data = a_data.mT if noncontiguous else a_data
+            a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
+                                                a_data, (m * block_size, k * block_size))
             b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
             c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
             for op_b, op_out in itertools.product([True, False], repeat=2):
-                self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device)
+                self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device, ref=ref)
                 self.run_test_block_addmm_addmv(make_transposed_addmm_op(torch.addmm),
                                                 c,
                                                 a,
@@ -1563,7 +1607,7 @@
                                                 op_out,
                                                 dtype=dtype,
                                                 device=device,
-                                                ref=make_transposed_addmm_op(_npref_block_addmm_addmv))
+                                                ref=make_transposed_addmm_op(ref))
 
     @parametrize("block_size", [2, 3])
     @parametrize("index_dtype", [torch.int32, torch.int64])