(bsr/csr) x dense mm (#85551)
As per title. This implementation is not the most optimal and could be improved albeit with native kernels (i.e. block matching need not be materialized).
Compared to existing kernels it offers:
- Half float support (In fact, any dtype that supports `matmul` will work).
- Arbitrary block sizes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85551
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
index 4ad0d55..cdeb3e1 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
@@ -2,11 +2,215 @@
#include <ATen/Config.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
+#include <ATen/SparseCsrTensorUtils.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/Operators.h>
+#else
+#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
+#include <ATen/ops/empty_like.h>
+#include <ATen/ops/zeros.h>
+#endif
namespace at {
namespace native {
namespace sparse {
namespace impl {
+
+Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
+ const auto compressed_layout = compressed.layout();
+ const auto compressed_layout_str = at::sparse_csr::layoutToString(compressed_layout);
+
+ // Device restrictions
+ TORCH_CHECK(compressed.device() == strided.device()
+ && compressed.device() == result.device(),
+ "spmm_out(): all input arguments are expected to be on the same device.");
+
+ // Layout restrictions.
+ TORCH_CHECK(compressed_layout == kSparseCsr || compressed_layout == kSparseBsr,
+ "spmm(", compressed_layout_str, ", Strided): only Csr and Bsr formats are supported for the sparse argument.");
+ TORCH_CHECK(result.layout() == kStrided,
+ "spmm_out(): out argument is expected to be strided.");
+
+ // Dtype restrictions.
+ TORCH_CHECK(compressed.scalar_type() == strided.scalar_type(),
+ "spmm(", compressed_layout_str, ", Strided): arguments expected to have the same dtype.");
+
+ // Dim restrictions.
+ TORCH_CHECK(compressed.dim() == 2,
+ "spmm(", compressed_layout_str, ", Strided): sparse arguments which are not 2D are not supported.");
+ TORCH_CHECK(strided.dim() >= 2,
+ "spmm(", compressed_layout_str, ", Strided): expects strided inputs to be at least 2D.");
+
+ const auto m = compressed.sizes()[0];
+ const auto k = compressed.sizes()[1];
+ const auto n = strided.size(-1);
+ // Matrix product size compatibility.
+ TORCH_CHECK(strided.size(-2) == k,
+ "spmm(", compressed_layout_str, "Strided): argument sizes are not compatible for matrix multiplication. ",
+ "Got ", compressed_layout_str, ".sizes(-1) == ", k, " is not equal to ",
+ "Strided.sizes(-2) == ", strided.size(-2), ".");
+
+ // We assume that result is properly resized.
+ auto result_expected_size = at::DimVector(strided.sizes().slice(0, strided.dim() - 2));
+ result_expected_size.push_back(m);
+ result_expected_size.push_back(n);
+ TORCH_CHECK(result.sizes() == result_expected_size,
+ "spmm_out(): out argument has wrong size. ",
+ "Expected (", result_expected_size, ") but got (", result.sizes(), ").");
+
+ auto values = compressed.values();
+
+ using Blocksize = std::array<int64_t, 2>;
+ // We refer to these as (b0, b1) in the comments below.
+ Blocksize blocksize = {1, 1};
+ if (compressed_layout == kSparseBsr) {
+ blocksize = {values.size(-2), values.size(-1)};
+ }
+
+ // (..., r, c) -> (..., r / b0, c / b1, b0, b1)
+ // NOTE: this function ALWAYS creates a view upon successful execution.
+ const auto tile_tensor = [compressed_layout](
+ const Tensor& t, Blocksize blocksize) -> Tensor {
+ if (compressed_layout == kSparseCsr) {
+ return t.unsqueeze(-1).unsqueeze_(-1);
+ }
+ else {
+ const auto size_neg_2_blocked = t.size(-2) / blocksize[0];
+ const auto size_neg_1_blocked = t.size(-1) / blocksize[1];
+ auto tiled_sizes = at::DimVector(t.sizes().slice(0, t.dim() - 2));
+ tiled_sizes.push_back(size_neg_2_blocked);
+ tiled_sizes.push_back(blocksize[0]);
+ tiled_sizes.push_back(size_neg_1_blocked);
+ tiled_sizes.push_back(blocksize[1]);
+ return t.reshape(tiled_sizes).transpose(-3, -2);
+ }
+ };
+
+ // Note that sparse values are (..., b0, b1). This means that
+ // the strided input has to be "tilable" to (..., b1, x) with
+ // any x >= 1 such that all the shapes are (block) matrix product
+ // compatible. The matrix product will then have shape (..., b0, x).
+ // This in turn means the the result has to be "tilable" to
+ // (..., b0, x).
+ //
+ // These observations imply the following restrictions:
+ // 1. strided.size(-2) has to be divisible by b1.
+ // 2. result.size(-2) has to be divisible by b0.
+ // 3. both strided.size(-1) and result.size(-1)
+ // have to be divisible by x.
+ //
+ // Restrictions 1 and 2 are trivially satisfied.
+ // Regarding restriction 3:
+ // it would make sense to take the largest possible x for better
+ // performance since it is very likely that the last dimension
+ // is contiguous. As such, this value is exactly
+ // x = strided.size(-1), since strided.size(-1) == result.size(-1)
+
+ // See the comments above. This is our x.
+ const auto outer_blocksize = n;
+
+ Blocksize strided_blocksize = {blocksize[1], outer_blocksize};
+ const auto strided_tiled = tile_tensor(strided, strided_blocksize);
+
+ // Left argument is (..., b0, b1) and right is (..., b1, x).
+ // This naturally implies the result should be "tilable" as
+ // (..., b0, x).
+ Blocksize result_blocksize = {blocksize[0], outer_blocksize};
+ auto result_tiled = tile_tensor(result, result_blocksize);
+
+ if (compressed_layout == kSparseCsr) {
+ values.unsqueeze_(-1).unsqueeze_(-1);
+ }
+
+ Tensor compressed_indices, plain_indices;
+ std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(compressed);
+
+ // Select block rows of the strided input that intersect with the block colums of the sparse input.
+ auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices);
+
+ // Promote to float if output is half or bfloat16 for better precision
+ const auto mm_dtype = (result.scalar_type() == kHalf || result.scalar_type() == kBFloat16)
+ ? kFloat : result.scalar_type();
+ // Now that we know which block rows intersect with which block columns,
+ // we can perform matrix products between pairs of blocks.
+ // NOTE: .to is a no-op when result.scalar_type() == mm_dtype.
+ const auto pairwise_block_mm = values.unsqueeze(-3).to(mm_dtype)
+ .matmul(strided_tiled_selected_rows.to(mm_dtype));
+
+ // Having pairwise block matrix products stored in pairwise_block_mm,
+ // it is sufficient to sum all the block products that share the same row
+ // encoded in the sparse index. Since the reduction step is done via
+ // advanced indexing methods, the compressed index ought to get converted
+ // to the COO format.
+ const auto compressed_indices_coo = at::_convert_indices_from_csr_to_coo(
+ compressed_indices,
+ plain_indices,
+ compressed_indices.scalar_type() == kInt).select(0, 0);
+
+ // Reduction step.
+ // If result is neither half nor bfloat16, do everyting in-place.
+ if (result.scalar_type() == mm_dtype) {
+ // Zero out and sum over the blocks that share the same row indices.
+ result_tiled.zero_();
+ result_tiled.index_add_(
+ /*dim=*/-4,
+ /*index=*/compressed_indices_coo,
+ /*source=*/pairwise_block_mm);
+ }
+ // Otherwise accumulate into a buffer and then copy.
+ else {
+ // No need to zero out, sum over the blocks goes into a buffer
+ // followed by a copy into result.
+ auto promoted_result_tiled = at::zeros(
+ result_tiled.sizes(),
+ result_tiled.options().dtype(mm_dtype));
+ promoted_result_tiled.index_add_(
+ /*dim=*/-4,
+ /*index=*/compressed_indices_coo,
+ /*source=*/pairwise_block_mm);
+ result_tiled.copy_(promoted_result_tiled);
+ }
+
+ return result;
+}
+
+Tensor& _compressed_row_strided_addmm_out(
+ const Tensor& self,
+ const Tensor& mat1,
+ const Tensor& mat2,
+ const Scalar& beta,
+ const Scalar& alpha,
+ Tensor& result) {
+ // If result is not the same as self, it could always be used as out argument to mm.
+ if (!result.is_same(self)) {
+ _compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
+
+ // Process beta
+ if (beta.toComplexDouble() != 0.) {
+ result.add_(self.mul(beta));
+ }
+ }
+ // Otherwise we need to allocate external memory for mm if beta != 0.
+ else {
+ // Process beta
+ if (beta.toComplexDouble() != 0.) {
+ result.mul_(beta);
+ auto mm = at::empty_like(result);
+ _compressed_row_strided_mm_out(mat1, mat2, mm);
+ mm.mul_(alpha);
+ result.add_(mm);
+ }
+ else {
+ _compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
+ }
+ }
+
+ return result;
+}
+
namespace cpu {
/*
diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.h b/aten/src/ATen/native/sparse/SparseBlasImpl.h
index b488396..acdd6b3 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.h
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.h
@@ -7,6 +7,20 @@
namespace native {
namespace sparse {
namespace impl {
+
+TORCH_API Tensor& _compressed_row_strided_mm_out(
+ const Tensor& compressed_row_sparse,
+ const Tensor& strided,
+ Tensor& result);
+
+TORCH_API Tensor& _compressed_row_strided_addmm_out(
+ const Tensor& self,
+ const Tensor& mat1,
+ const Tensor& mat2,
+ const Scalar& beta,
+ const Scalar& alpha,
+ Tensor& result);
+
namespace cpu {
void addmv_out_sparse_csr(
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index 04f7739..e5393e5 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -22,6 +22,7 @@
#include <ATen/ops/_conj_physical_native.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
+#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index 91e20b5..379640b 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -8,6 +8,7 @@
#include <ATen/cuda/CUDASparseDescriptors.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cuda/MiscUtils.h>
+#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/native/sparse/cuda/SparseBlasImpl.h>
#include <ATen/native/sparse/cuda/SparseBlasLegacy.h>
@@ -480,6 +481,22 @@
mat1.values().is_contiguous() ||
mat1.values().transpose(-2, -1).is_contiguous());
+ // NOTE: the code below allows arbitrary block sizes
+ // and might be potentially faster than cuSPARSE implementation
+ // especially for not very sparse inputs.
+ if (mat1.scalar_type() == ScalarType::Half || mat1.scalar_type() == ScalarType::BFloat16) {
+ at::native::sparse::impl::_compressed_row_strided_addmm_out(
+ result,
+ mat1,
+ mat2,
+ /*beta=*/beta,
+ /*alpha=*/alpha,
+ // @nikitaved: not sure whether `const Tensor& result` makes sense,
+ // but let's keep the interface intact, hence the const cast.
+ const_cast<Tensor&>(result));
+ return;
+ }
+
const cusparseDirection_t block_layout = mat1.values().is_contiguous()
? CUSPARSE_DIRECTION_ROW
: CUSPARSE_DIRECTION_COLUMN;
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 7ac720c..eaa77e2 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -1491,16 +1491,8 @@
out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
+ expected = ref(c, a, b, alpha, beta)
- a_bsr = sp.bsr_matrix(
- (
- a.values().cpu().numpy(),
- a.col_indices().cpu().numpy(),
- a.crow_indices().cpu().numpy(),
- ),
- shape=a.shape,
- )
- expected = ref(c.cpu().numpy(), a_bsr, b.cpu().resolve_conj().numpy(), alpha, beta)
self.assertEqual(actual, out)
self.assertEqual(actual, expected)
@@ -1510,8 +1502,13 @@
@parametrize("noncontiguous", [True, False])
@skipCPUIfNoMklSparse
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
- @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
- @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-5, torch.complex128: 1e-5})
+ @dtypes(*floating_and_complex_types())
+ @dtypesIfCUDA(*floating_and_complex_types_and(
+ *[torch.half] if SM53OrLater else [],
+ *[torch.bfloat16] if SM80OrLater else []))
+ @precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3,
+ torch.float64: 1e-5, torch.complex128: 1e-5,
+ torch.float16: 1e-3, torch.bfloat16: 1e-3})
def test_block_addmm(self, device, dtype, index_dtype, block_size, noncontiguous):
def make_transposed_addmm_op(f):
@@ -1536,25 +1533,72 @@
return wrapper
+ def ref_sp_numpy(c, a, b, alpha=None, beta=None, out=None):
+
+ def prep_input(t):
+
+ def to_sp_block_compressed(t):
+
+ if t.layout is torch.sparse_bsc:
+ tt = t.transpose(-1, -2)
+ else:
+ tt = t
+
+ t_sp_bsr = sp.bsr_matrix(
+ (
+ tt.values().cpu().numpy(),
+ tt.col_indices().cpu().numpy(),
+ tt.crow_indices().cpu().numpy(),
+ ),
+ shape=tt.shape,
+ )
+
+ if t.layout is torch.sparse_bsc:
+ return t_sp_bsr.transpose()
+ else:
+ return t_sp_bsr
+
+ if t.layout is not torch.strided:
+ return to_sp_block_compressed(t)
+ else:
+ return t.cpu().resolve_conj().numpy()
+
+ res = _npref_block_addmm_addmv(
+ *map(lambda t: prep_input(t), (c, a, b)),
+ alpha,
+ beta
+ )
+
+ if out is not None:
+ out.copy_(res)
+ return out
+ else:
+ return res
+
+ def ref_half_bfloat16(c, a, b, alpha=None, beta=None, out=None):
+ res = alpha * (a.to_dense() @ b.to_dense()) + beta * c
+ if out is not None:
+ out.copy_(res)
+ return out
+ else:
+ return res
+
+ if dtype in (torch.half, torch.bfloat16):
+ ref = ref_half_bfloat16
+ else:
+ ref = ref_sp_numpy
+
for (m, n, k) in itertools.product([2, 5], repeat=3):
nnz = random.randint(0, m * k)
- if not noncontiguous:
- a = self.genSparseCSRTensor((m * block_size, k * block_size),
- nnz,
- dtype=dtype,
- device=device,
- index_dtype=index_dtype)
- a = a.to_sparse_bsr((block_size, block_size))
- else:
- a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
- a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
- a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
- a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(), a_data,
- (m * block_size, k * block_size))
+ a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
+ a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
+ a_data = a_data.mT if noncontiguous else a_data
+ a = torch._sparse_bsr_tensor_unsafe(a.crow_indices(), a.col_indices(),
+ a_data, (m * block_size, k * block_size))
b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
for op_b, op_out in itertools.product([True, False], repeat=2):
- self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device)
+ self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device, ref=ref)
self.run_test_block_addmm_addmv(make_transposed_addmm_op(torch.addmm),
c,
a,
@@ -1563,7 +1607,7 @@
op_out,
dtype=dtype,
device=device,
- ref=make_transposed_addmm_op(_npref_block_addmm_addmv))
+ ref=make_transposed_addmm_op(ref))
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])