Improve `bsr @ strided` performance in `baddmm` for `bfloat16/half` with Triton kernels. (#88078)
As per title.
Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e50779a..3d34667 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6456,6 +6456,12 @@
SparseCPU: s_addmm_sparse_dense_cpu_
SparseCUDA: s_addmm_sparse_dense_cuda_
+- func: _triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor
+ variants: function
+ dispatch:
+ CPU: triton_bsr_dense_mm
+ autogen: _triton_bsr_dense_mm.out
+
- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
structured: True
dispatch:
diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
index cdeb3e1..c147e8c 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
@@ -4,6 +4,10 @@
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
+// Required for checking whether Triton kernels are available
+#include <torch/csrc/jit/frontend/function_schema_parser.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
@@ -12,6 +16,7 @@
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
+#include <ATen/ops/_triton_bsr_dense_mm.h>
#endif
namespace at {
@@ -70,6 +75,31 @@
blocksize = {values.size(-2), values.size(-1)};
}
+// No stable support for ROCM in Triton yet.
+#ifndef USE_ROCM
+ // Triton works only with blocksizes which are powers of 2.
+ const auto is_power_of_2 = [](int64_t v) -> bool {
+ return !(v & (v - 1));
+ };
+
+ // Dtype and blocksize checks for potential Triton usage.
+ if ((strided.scalar_type() == ScalarType::Half
+ || strided.scalar_type() == ScalarType::BFloat16)
+ && is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
+ && (blocksize[0] >= 16) && (blocksize[1] >= 16)
+ // lhs is retiled to (b0, b1) while rhs is to (b1, b0),
+ // so the result is tiled to (b0, b0) and we need to make
+ // sure that dense.size(-1) is divisible by b0.
+ && n % blocksize[0] == 0) {
+ const auto triton_kernel = c10::Dispatcher::singleton()
+ .findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm"));
+ // Call Triton only if dispatch key was overwritten.
+ if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
+ return at::_triton_bsr_dense_mm_out(result, compressed, strided);
+ }
+ }
+#endif
+
// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
// NOTE: this function ALWAYS creates a view upon successful execution.
const auto tile_tensor = [compressed_layout](
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index efa6926..f407b7b 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -1292,5 +1292,12 @@
return result;
}
+Tensor triton_bsr_dense_mm(
+ const Tensor& bsr,
+ const Tensor& dense) {
+ TORCH_CHECK(false, "_triton_bsr_dense_mm: Triton kernel should be overwritten in Python.");
+ return Tensor {};
+}
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/sparse/SparseMatMul.cpp b/aten/src/ATen/native/sparse/SparseMatMul.cpp
index 548b66a..e5f283b 100644
--- a/aten/src/ATen/native/sparse/SparseMatMul.cpp
+++ b/aten/src/ATen/native/sparse/SparseMatMul.cpp
@@ -274,6 +274,5 @@
return output;
}
-
} // namespace native
} // namespace at
diff --git a/mypy.ini b/mypy.ini
index 4afe7dc..7108fee 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -188,6 +188,9 @@
# Third party dependencies that don't have types.
#
+[mypy-triton.*]
+ignore_missing_imports = True
+
[mypy-tensorflow.*]
ignore_missing_imports = True
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 30606d1..b1bfe59 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -20,6 +20,7 @@
floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
all_types_and_complex, floating_and_complex_types_and
)
+from torch._inductor.utils import has_triton
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
if TEST_SCIPY:
@@ -1464,6 +1465,63 @@
self.assertEqual(actual, out)
self.assertEqual(actual, expected)
+ @parametrize("block_size", [16, 32, 64])
+ @parametrize("index_dtype", [torch.int32, torch.int64])
+ @unittest.skipIf(not has_triton(), "Triton is not available")
+ @skipCUDAIfRocm
+ @onlyCUDA
+ @dtypes(torch.half, torch.bfloat16)
+ @dtypesIfCUDA(*[torch.half] if SM53OrLater else [],
+ *[torch.bfloat16] if SM80OrLater else [])
+ def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
+ from functools import partial
+
+ # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
+ tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
+
+ # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
+ batches = [(), (2,)]
+ size = [128, 256, 0]
+
+ # Whether to make inputs orthogonal so that the product is zero
+ make_orthogonal = [True, False]
+
+ for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
+ bsr = tensor(bs + (m, k))
+ # NOTE: do not get confused, it will be transposed
+ dense = tensor(bd + (n, k))
+
+ if is_ortho:
+ bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
+ dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)
+
+ bsr = bsr.to_sparse_bsr(block_size)
+
+ if bsr.dim() == 2:
+ # Test against linear to check dispatch.
+ res_tri = torch.nn.functional.linear(dense, bsr)
+ res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
+ else:
+ # Otherwise check correctness against bmm
+ # since nn.linear does not support bsr.dim() > 2.
+ res_tri = torch._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))
+ res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
+ self.assertEqual(res_tri, res_dense)
+
+ res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
+ # check whether bsr_dense_mm handles different grid sizes
+ # None means max possible grid size which is CUDA-dependent.
+ grid_size = (None, 2, 4)
+ grid_gen = itertools.product(grid_size, repeat=3)
+ for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
+ res_tri = torch.sparse._triton_ops.bsr_dense_mm(
+ bsr,
+ dense.transpose(-2, -1),
+ max_grid=grid,
+ is_sparse_rowspace_mode=is_sparse_rowspace
+ )
+ self.assertEqual(res_tri, res_dense)
+
# TODO: block_size 1 is broken
@parametrize("block_size", [2, 3])
@parametrize("index_dtype", [torch.int32, torch.int64])
diff --git a/torch/__init__.py b/torch/__init__.py
index 310e012..5b12580 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1368,3 +1368,8 @@
'use torch.sparse_coo_tensor(..., check_invariants=False) instead.')
kwargs['check_invariants'] = False
return torch.sparse_coo_tensor(*args, **kwargs)
+
+
+# dynamic registration of sparse triton kernels
+from torch.sparse import _register_impls
+_register_impls(torch.library.Library("aten", "IMPL"))
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index eec9bc2..ec1b883 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -898,18 +898,18 @@
del _LegacyStorage
del _CudaLegacyStorage
-torch._storage_classes.add(DoubleStorage)
-torch._storage_classes.add(FloatStorage)
-torch._storage_classes.add(LongStorage)
-torch._storage_classes.add(IntStorage)
-torch._storage_classes.add(ShortStorage)
-torch._storage_classes.add(CharStorage)
-torch._storage_classes.add(ByteStorage)
-torch._storage_classes.add(HalfStorage)
-torch._storage_classes.add(BoolStorage)
-torch._storage_classes.add(BFloat16Storage)
-torch._storage_classes.add(ComplexDoubleStorage)
-torch._storage_classes.add(ComplexFloatStorage)
+torch._storage_classes.add(DoubleStorage) # type: ignore[has-type]
+torch._storage_classes.add(FloatStorage) # type: ignore[has-type]
+torch._storage_classes.add(LongStorage) # type: ignore[has-type]
+torch._storage_classes.add(IntStorage) # type: ignore[has-type]
+torch._storage_classes.add(ShortStorage) # type: ignore[has-type]
+torch._storage_classes.add(CharStorage) # type: ignore[has-type]
+torch._storage_classes.add(ByteStorage) # type: ignore[has-type]
+torch._storage_classes.add(HalfStorage) # type: ignore[has-type]
+torch._storage_classes.add(BoolStorage) # type: ignore[has-type]
+torch._storage_classes.add(BFloat16Storage) # type: ignore[has-type]
+torch._storage_classes.add(ComplexDoubleStorage) # type: ignore[has-type]
+torch._storage_classes.add(ComplexFloatStorage) # type: ignore[has-type]
from . import sparse
from . import profiler
diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py
index 3ceaf56..a7a909e 100644
--- a/torch/sparse/__init__.py
+++ b/torch/sparse/__init__.py
@@ -4,6 +4,8 @@
import torch
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
from torch import Tensor
+from torch.cuda import _lazy_call
+from torch._inductor.cuda_properties import get_device_capability
# A workaround to support both TorchScript and MyPy:
from typing import TYPE_CHECKING
@@ -462,3 +464,30 @@
return mth(*args, **kwargs)
return test_mth
+
+# Triton registrations
+def _has_triton():
+ if not torch.cuda.is_available():
+ return False
+ try:
+ import triton
+
+ return triton is not None and get_device_capability() >= (7, 0)
+ except ImportError:
+ return False
+
+
+def _register_impls(lib):
+ """This function is called from torch/__init__.py to do any dynamic registrations. """
+
+
+ def register_sparse_cuda_impls(lib=lib):
+ from ._triton_ops import bsr_dense_mm
+
+ if bsr_dense_mm is not None:
+ lib.impl("aten::_triton_bsr_dense_mm",
+ lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs), "SparseCsrCUDA")
+
+ # This code is evaluated on import torch and therefore cannot force initialization of the cuda rt
+ # We must schedule the registration to occur lazily.
+ _lazy_call(register_sparse_cuda_impls)
diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py
new file mode 100644
index 0000000..d7b34f3
--- /dev/null
+++ b/torch/sparse/_triton_ops.py
@@ -0,0 +1,608 @@
+import torch
+from torch._inductor.cuda_properties import get_device_capability
+
+def _has_triton():
+ if not torch.cuda.is_available():
+ return False
+ try:
+ import triton
+
+ return triton is not None and get_device_capability() >= (7, 0)
+ except ImportError:
+ return False
+
+def compressed_indices_to_plain_indices(cidx, pidx):
+ nnz = pidx.shape[-1]
+ cdim = cidx.shape[-1] - 1
+ batch_numel = cidx.shape[0]
+ batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[
+ :, None
+ ]
+
+ cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset
+ cidx_linear = torch.empty(
+ (batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device
+ )
+ cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1)
+ cidx_linear[-1] = nnz * batch_numel
+
+ idx_linear = torch._convert_indices_from_csr_to_coo(
+ cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32)
+ ).select(0, 0)
+
+ return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset)
+
+
+def slicer(dim, slice_range, *tensors):
+ for t in tensors:
+ slices = [slice(None)] * t.dim()
+ slices[dim] = slice_range
+ yield t[slices]
+
+if _has_triton():
+ import triton
+ import triton.language as tl
+ from typing import Optional, Tuple
+
+ @triton.jit
+ def _bsr_strided_dense_rowspace_kernel(
+ BLOCKSIZE_ROW: tl.constexpr,
+ BLOCKSIZE_COL: tl.constexpr,
+ # values prologue
+ values_ptr,
+ values_batch_stride,
+ values_nnz_stride,
+ values_row_block_stride,
+ values_col_block_stride,
+ # values epilogue
+ # crow_indices prologue
+ crow_indices_ptr,
+ crow_indices_batch_stride,
+ crow_indices_stride,
+ # crow_indices epilogue
+ # col_indices prologue
+ col_indices_ptr,
+ col_indices_batch_stride,
+ col_indices_stride,
+ # col_indices epilogue
+ # dense prologue
+ dense_ptr,
+ dense_batch_stride,
+ dense_tiled_row_stride,
+ dense_tiled_col_stride,
+ dense_row_block_stride,
+ dense_col_block_stride,
+ # dense epilogue
+ # output prologue
+ output_ptr,
+ output_batch_stride,
+ output_tiled_row_stride,
+ output_tiled_col_stride,
+ output_row_block_stride,
+ output_col_block_stride,
+ # output epilogue
+ GROUP_SIZE_ROW: tl.constexpr,
+ ):
+ batch_pid = tl.program_id(axis=2)
+ row_block_pid = tl.program_id(axis=0)
+ col_block_pid = tl.program_id(axis=1)
+ n_block_rows = tl.num_programs(axis=0)
+ n_block_cols = tl.num_programs(axis=1)
+
+ row_block_pid, col_block_pid = tl.swizzle2d(
+ row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
+ )
+
+ crow_indices_offset_ptr = (
+ crow_indices_ptr
+ + crow_indices_batch_stride * batch_pid
+ + crow_indices_stride * row_block_pid
+ )
+ nnz_offset = tl.load(crow_indices_offset_ptr)
+ nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
+
+ # Compute nnz for the row with number row_block_pid.
+ # If it is zero, skip the row.
+ row_nnz = nnz_offset_next - nnz_offset
+ if row_nnz == 0:
+ return
+
+ row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
+ col_block_arange = tl.arange(0, BLOCKSIZE_COL)
+
+ # Pointers are set to the first block of the current row.
+ values_block_ptrs = (
+ values_ptr
+ + values_batch_stride * batch_pid
+ + values_nnz_stride * nnz_offset
+ + values_row_block_stride * row_block_arange[:, None]
+ + values_col_block_stride * col_block_arange[None, :]
+ )
+
+ # NOTE: dense is advanced into all dimensions but the tiled row one.
+ # That will be advanced in the loop according to values in col_indices.
+ dense_block_ptrs = (
+ dense_ptr
+ + dense_batch_stride * batch_pid
+ + dense_tiled_col_stride * col_block_pid
+ + dense_row_block_stride * col_block_arange[:, None]
+ + dense_col_block_stride * row_block_arange[None, :]
+ )
+
+ # Pointers are set to exact write-to locations
+ output_ptrs = (
+ output_ptr
+ + output_batch_stride * batch_pid
+ + output_tiled_row_stride * row_block_pid
+ + output_tiled_col_stride * col_block_pid
+ + output_row_block_stride * row_block_arange[:, None]
+ + output_col_block_stride * row_block_arange[None, :]
+ )
+
+ # Set pointer to the first nonzero element in the current row
+ col_index_nnz_ptr = (
+ col_indices_ptr
+ + col_indices_batch_stride * batch_pid
+ + col_indices_stride * nnz_offset
+ )
+
+ output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
+ for _ in range(row_nnz):
+ values_block = tl.load(values_block_ptrs)
+
+ # find which row of dense needs to get loaded
+ # for multiplication with values_block.
+ dense_row_idx = tl.load(col_index_nnz_ptr)
+ dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
+
+ # do block mm
+ output_acc_block += tl.dot(values_block, dense_block)
+
+ # move val/col_index ptrs to the next block in the row
+ values_block_ptrs += values_nnz_stride
+ col_index_nnz_ptr += col_indices_stride
+
+ # write back the result
+ tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
+
+
+ @triton.jit
+ def _bsr_strided_sparse_rowspace_kernel(
+ BLOCKSIZE_ROW: tl.constexpr,
+ BLOCKSIZE_COL: tl.constexpr,
+ batch_idx_ptr,
+ row_idx_ptr,
+ nnz_per_row_ptr,
+ nnz_per_row_cumsum_ptr,
+ col_indices_ptr,
+ col_indices_stride,
+ # values prologue
+ values_ptr,
+ values_nnz_stride,
+ values_row_block_stride,
+ values_col_block_stride,
+ # values epilogue
+ # dense prologue
+ dense_ptr,
+ dense_batch_stride,
+ dense_tiled_row_stride,
+ dense_tiled_col_stride,
+ dense_row_block_stride,
+ dense_col_block_stride,
+ # dense epilogue
+ # output prologue
+ output_ptr,
+ output_batch_stride,
+ output_tiled_row_stride,
+ output_tiled_col_stride,
+ output_row_block_stride,
+ output_col_block_stride,
+ # output epilogue
+ GROUP_SIZE_ROW: tl.constexpr,
+ ):
+ row_block_pid = tl.program_id(axis=0)
+ col_block_pid = tl.program_id(axis=1)
+ n_block_rows = tl.num_programs(axis=0)
+ n_block_cols = tl.num_programs(axis=1)
+
+ row_block_pid, col_block_pid = tl.swizzle2d(
+ row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
+ )
+
+ batch_idx = tl.load(batch_idx_ptr + row_block_pid)
+ row_idx = tl.load(row_idx_ptr + row_block_pid)
+ row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid)
+ row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid)
+ row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz
+
+ row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
+ col_block_arange = tl.arange(0, BLOCKSIZE_COL)
+
+ # Pointers are set to the first block of the current row.
+ values_block_ptrs = (
+ values_ptr
+ + values_nnz_stride * row_idx_nnz_offset
+ + values_row_block_stride * row_block_arange[:, None]
+ + values_col_block_stride * col_block_arange[None, :]
+ )
+
+ # NOTE: dense is advanced into all dimensions but the tiled row one.
+ # That will be advanced in the loop according to values in col_indices.
+ dense_block_ptrs = (
+ dense_ptr
+ + dense_batch_stride * batch_idx
+ + dense_tiled_col_stride * col_block_pid
+ + dense_row_block_stride * col_block_arange[:, None]
+ + dense_col_block_stride * row_block_arange[None, :]
+ )
+
+ # Pointers are set to exact write-to locations
+ output_ptrs = (
+ output_ptr
+ + output_batch_stride * batch_idx
+ + output_tiled_row_stride * row_idx
+ + output_tiled_col_stride * col_block_pid
+ + output_row_block_stride * row_block_arange[:, None]
+ + output_col_block_stride * row_block_arange[None, :]
+ )
+
+ output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
+ col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride
+ for _ in range(row_idx_nnz):
+ values_block = tl.load(values_block_ptrs)
+
+ # find which row of dense needs to get loaded
+ # for multiplication with values_block.
+ dense_row_idx = tl.load(col_index_nnz_ptr)
+ dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
+
+ # do block mm
+ output_acc_block += tl.dot(values_block, dense_block)
+
+ # move val/col_index ptrs to the next block in the row
+ values_block_ptrs += values_nnz_stride
+ col_index_nnz_ptr += col_indices_stride
+
+ # write back the result
+ tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
+
+
+ def _run_sparse_rowspace_kernel(
+ blocksize, values, crow_indices, col_indices, dense, output, max_grid
+ ):
+ # Compute a vector of non-zero elements numbers per each row.
+ # We want to ultimately iterate over non-zero rows.
+ nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1]
+
+ # Compute indices of non-zero counts.
+ # batch_idx maps to a broadcasted batch index, while
+ # row_idx tracks non-zero rows of the sparse argument
+ # and rows of the output that get modified.
+ batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)
+
+ # Compress the vector of counts to hold only non-zero values.
+ nnz_per_row = nnz_per_row[batch_idx, row_idx]
+ # Compute cumulative counts which along with nnz_per_row
+ # are used to compute offsets into nnz values.
+ nnz_per_row_cumsum = nnz_per_row.cumsum(-1)
+
+ n_nnz_block_rows = row_idx.size(-1)
+ n_block_cols = dense.size(-3)
+ max_n_nnz_block_rows, max_n_block_cols = max_grid[:2]
+
+ for c_start in range(0, n_block_cols, max_n_block_cols):
+ c_dense, c_output = slicer(
+ -3, slice(c_start, c_start + max_n_block_cols), dense, output
+ )
+ c_grid = min(n_block_cols - c_start, max_n_block_cols)
+
+ for r_start in range(0, n_nnz_block_rows, max_n_nnz_block_rows):
+ r_batch_idx, r_row_idx, r_nnz_per_row, r_nnz_per_row_cumsum = slicer(
+ 0,
+ slice(r_start, r_start + max_n_nnz_block_rows),
+ batch_idx,
+ row_idx,
+ nnz_per_row,
+ nnz_per_row_cumsum,
+ )
+ r_grid = min(n_nnz_block_rows - r_start, max_n_nnz_block_rows)
+
+ _bsr_strided_sparse_rowspace_kernel[(r_grid, c_grid)](
+ *blocksize,
+ r_batch_idx,
+ r_row_idx,
+ r_nnz_per_row,
+ r_nnz_per_row_cumsum,
+ col_indices,
+ *col_indices.stride(),
+ values,
+ *values.stride(),
+ c_dense,
+ *c_dense.stride(),
+ c_output,
+ *c_output.stride(),
+ GROUP_SIZE_ROW=4,
+ num_stages=4,
+ num_warps=4,
+ )
+
+
+ def _run_dense_rowspace_kernel(
+ blocksize, values, crow_indices, col_indices, dense, output, max_grid
+ ):
+ # Launch kernel
+ n_batches = dense.size(0)
+ n_block_rows = crow_indices.size(-1) - 1
+ n_block_cols = dense.size(-3)
+ max_n_block_rows, max_n_block_cols, max_n_batches = max_grid
+
+ for b_start in range(0, n_batches, max_n_batches):
+ b_v, b_crow, b_col, b_d, b_o = slicer(
+ 0,
+ slice(b_start, b_start + max_n_batches),
+ values,
+ crow_indices,
+ col_indices,
+ dense,
+ output,
+ )
+ b_grid = min(n_batches - b_start, max_n_batches)
+
+ for c_start in range(0, n_block_cols, max_n_block_cols):
+ bc_d, bc_o = slicer(
+ -3, slice(c_start, c_start + max_n_block_cols), b_d, b_o
+ )
+ c_grid = min(n_block_cols - c_start, max_n_block_cols)
+
+ for r_start in range(0, n_block_rows, max_n_block_rows):
+ r_slice = slice(r_start, r_start + max_n_block_rows)
+ br_crow = next(slicer(-1, r_slice, b_crow))
+ brc_o = next(slicer(-4, r_slice, bc_o))
+ r_grid = min(n_block_rows - r_start, max_n_block_rows)
+
+ _bsr_strided_dense_rowspace_kernel[(r_grid, c_grid, b_grid)](
+ *blocksize,
+ b_v,
+ *b_v.stride(),
+ br_crow,
+ *br_crow.stride(),
+ b_col,
+ *b_col.stride(),
+ bc_d,
+ *bc_d.stride(),
+ brc_o,
+ *brc_o.stride(),
+ GROUP_SIZE_ROW=4,
+ num_stages=4,
+ num_warps=4,
+ )
+
+
+ def bsr_dense_mm(
+ bsr: torch.Tensor,
+ dense: torch.Tensor,
+ *,
+ skip_checks: bool = False,
+ is_sparse_rowspace_mode: Optional[bool] = None,
+ max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
+ out: Optional[torch.Tensor] = None,
+ ):
+ m, kl = bsr.shape[-2:]
+ kr, n = dense.shape[-2:]
+
+ def check(cond, msg):
+ if not cond:
+ raise ValueError(msg)
+
+ if not skip_checks:
+ check(
+ bsr.layout == torch.sparse_bsr,
+ "bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.",
+ )
+
+ check(
+ bsr.device == dense.device and bsr.device.type == "cuda",
+ "bsr_dense_mm(): all inputs are expected to be on the same GPU device.",
+ )
+
+ check(
+ bsr.dtype == dense.dtype
+ and bsr.dtype in (torch.half, torch.bfloat16, torch.float),
+ "bsr_dense_mm(): all inputs are expected to be of the same dtype "
+ "and one of (half, bfloat16, float32), "
+ f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.",
+ )
+
+ check(
+ bsr.dim() >= 2 and dense.dim() >= 2,
+ "bsr_dense_mm(): all inputs are expected to be at least 2D, "
+ f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
+ )
+
+ check(
+ kl == kr,
+ "bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
+ f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
+ )
+
+ row_block = bsr.values().shape[-2]
+ check(
+ not n % row_block,
+ f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
+ f"blocksize[0] == {row_block}.",
+ )
+
+ # Required to undo the fake batch dimension insertion.
+ original_batch_dims_broadcasted = torch.broadcast_shapes(
+ bsr.shape[:-2], dense.shape[:-2]
+ )
+
+ if out is not None and not skip_checks:
+ expected_out_shape = original_batch_dims_broadcasted + (m, n)
+ check(
+ out.shape == expected_out_shape,
+ "bsr_dense_mm(): `out` argument has wrong shape, "
+ f"expected {expected_out_shape}, but got {out.shape}.",
+ )
+ check(
+ out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
+ "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
+ "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
+ "should be True.",
+ )
+
+ # Short circuit if lhs is zero
+ if bsr._nnz() == 0:
+ return dense.new_zeros(original_batch_dims_broadcasted + (m, n))
+
+ # TODO: insert switch
+ if is_sparse_rowspace_mode is None:
+ is_sparse_rowspace_mode = False
+
+ # Introduce fake batch dimension if not present for convenience.
+ def unsqueeze_batch_dim(t, n_non_batch_dims):
+ if t.dim() > n_non_batch_dims:
+ return t
+ else:
+ return t.unsqueeze(0)
+
+ def make_triton_contiguous(t):
+ # Triton does not distinguish between row- and col-majorness
+ # and will be fast as long as there is a contiguous dimension.
+ if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()):
+ return t.contiguous()
+ else:
+ return t
+
+ crow_indices = unsqueeze_batch_dim(bsr.crow_indices(), 1)
+ col_indices = unsqueeze_batch_dim(bsr.col_indices(), 1)
+ values = make_triton_contiguous(unsqueeze_batch_dim(bsr.values(), 3))
+ dense = make_triton_contiguous(unsqueeze_batch_dim(dense, 2))
+ nnz = values.shape[-3]
+ blocksize = values.shape[-2:]
+
+ # Compute broadcasted batch dimension
+ bsr_batch_dims = values.shape[:-3]
+ dense_batch_dims = dense.shape[:-2]
+ batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)
+
+ # Allocate out
+ if out is None:
+ out = dense.new_zeros(batch_dims_broadcasted + (m, n))
+
+ # Broadcast batch dimensions and squash
+ def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
+ return t.broadcast_to(batch_dims + invariant_dims).flatten(
+ 0, len(batch_dims) - 1
+ )
+
+ crow_indices = batch_broadcast_and_squash(
+ crow_indices, batch_dims_broadcasted, (-1,)
+ )
+
+ if is_sparse_rowspace_mode:
+ # Flatten batch dimension with nnz dimension
+ # as required by the sparse rowspace kernel.
+ col_indices = batch_broadcast_and_squash(
+ col_indices, batch_dims_broadcasted + (-1,), ()
+ )
+ values = batch_broadcast_and_squash(
+ values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:]
+ )
+ else:
+ col_indices = batch_broadcast_and_squash(
+ col_indices, batch_dims_broadcasted, (-1,)
+ )
+ values = batch_broadcast_and_squash(
+ values, batch_dims_broadcasted, values.shape[-3:]
+ )
+
+ dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])
+
+ # NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
+ out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])
+
+ # NOTE: this function will ALWAYS create a view
+ def tile_to_blocksize(t, blocksize):
+ *rest, m, n = t.shape
+ new_shape = rest + [
+ m // blocksize[0],
+ blocksize[0],
+ n // blocksize[1],
+ blocksize[1],
+ ]
+ return t.reshape(new_shape).transpose(-3, -2)
+
+ # "Blockify" the row dimension of dense with blocksize[1]
+ # since dense is on the rhs of matmul
+ dense = tile_to_blocksize(dense, blocksize[::-1])
+ # "Blockify" the row dimension of out with blocksize[0]
+ # which is inherited from the bsr input.
+ # NOTE: tile_to_blocksize will create a view.
+ # NOTE: out.blocksize[-1] == dense.blocksize[-1],
+ # so it could be any value in [1, dense.shape[-1]).
+ # We need to probably use the largest possible blocksize
+ # so that it fits into SRAM.
+ out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
+
+ # Launch kernel
+ if is_sparse_rowspace_mode:
+ kernel = _run_sparse_rowspace_kernel
+ else:
+ kernel = _run_dense_rowspace_kernel
+
+ # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
+ cuda_max_grid = (2147483647, 65535, 65535)
+ if max_grid is None:
+ max_grid = cuda_max_grid
+ else:
+
+ def valid_grid_dim(g, mg):
+ if g is None:
+ return mg
+ else:
+ # grid must be at least 1 and no greater than mg
+ return max(1, min(g, mg))
+
+ max_grid = tuple(
+ valid_grid_dim(g, mg) for g, mg in zip(max_grid, cuda_max_grid)
+ ) # type: ignore[assignment]
+
+ kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)
+
+ # Block dims need to rejoin with the corresponding block dimensions
+ # prior to reshape so that blocks do not end up being transposed.
+ # NB: type checker is not able to narrow Optional[Tensor] to tensor by this point
+ return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n)) # type: ignore[union-attr]
+else:
+ bsr_dense_mm = None # type: ignore[assignment]
+
+
+if __name__ == "__main__":
+ from torch._inductor.utils import has_triton
+
+ if has_triton():
+ torch.manual_seed(13)
+ dtype = torch.float32
+ p = 0.5
+ mask_size = (8, 8)
+ block_size = (64, 64)
+ size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1])
+
+ n_exp = 512
+ diff = torch.ones(n_exp, device="cuda", dtype=torch.float32)
+ for i in range(n_exp):
+ mask = torch.rand(*mask_size, device="cuda") < p
+ x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10
+ x = (
+ (mask[:, :, None, None] * x)
+ .transpose(-3, -2)
+ .reshape(*size)
+ .to_sparse_bsr(*block_size)
+ )
+ y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10
+ res_dense = x.to_dense() @ y
+ res = bsr_dense_mm(x, y)
+ diff[i] = (res - res_dense).abs().max()
+ print(f"mean: {diff.mean()}, std: {diff.std()}")
+ print(f"max diff: {diff.max()}")