Improve `bsr @ strided` performance in `baddmm` for `bfloat16/half` with Triton kernels. (#88078)

As per title.

Additionally we also introduce support for:
- Rectangular block sizes which are powers of 2 and at least 16 (triton's `dot` limitation).
- Batch support with broadcasting for either of the arguments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88078
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e50779a..3d34667 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6456,6 +6456,12 @@
     SparseCPU: s_addmm_sparse_dense_cpu_
     SparseCUDA: s_addmm_sparse_dense_cuda_
 
+- func: _triton_bsr_dense_mm(Tensor bsr, Tensor dense) -> Tensor
+  variants: function
+  dispatch:
+    CPU: triton_bsr_dense_mm
+  autogen: _triton_bsr_dense_mm.out
+
 - func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
   structured: True
   dispatch:
diff --git a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
index cdeb3e1..c147e8c 100644
--- a/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/SparseBlasImpl.cpp
@@ -4,6 +4,10 @@
 #include <ATen/native/sparse/SparseBlasImpl.h>
 #include <ATen/SparseCsrTensorUtils.h>
 
+// Required for checking whether Triton kernels are available
+#include <torch/csrc/jit/frontend/function_schema_parser.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
@@ -12,6 +16,7 @@
 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
 #include <ATen/ops/empty_like.h>
 #include <ATen/ops/zeros.h>
+#include <ATen/ops/_triton_bsr_dense_mm.h>
 #endif
 
 namespace at {
@@ -70,6 +75,31 @@
     blocksize = {values.size(-2), values.size(-1)};
   }
 
+// No stable support for ROCM in Triton yet.
+#ifndef USE_ROCM
+  // Triton works only with blocksizes which are powers of 2.
+  const auto is_power_of_2 = [](int64_t v) -> bool {
+    return !(v & (v - 1));
+  };
+
+  // Dtype and blocksize checks for potential Triton usage.
+  if ((strided.scalar_type() == ScalarType::Half
+    || strided.scalar_type() == ScalarType::BFloat16)
+   && is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
+   && (blocksize[0] >= 16) && (blocksize[1] >= 16)
+   // lhs is retiled to (b0, b1) while rhs is to (b1, b0),
+   // so the result is tiled to (b0, b0) and we need to make
+   // sure that dense.size(-1) is divisible by b0.
+   && n % blocksize[0] == 0) {
+    const auto triton_kernel = c10::Dispatcher::singleton()
+      .findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm"));
+    // Call Triton only if dispatch key was overwritten.
+    if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
+      return at::_triton_bsr_dense_mm_out(result, compressed, strided);
+    }
+  }
+#endif
+
   // (..., r, c) -> (..., r / b0, c / b1, b0, b1)
   // NOTE: this function ALWAYS creates a view upon successful execution.
   const auto tile_tensor = [compressed_layout](
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index efa6926..f407b7b 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -1292,5 +1292,12 @@
   return result;
 }
 
+Tensor triton_bsr_dense_mm(
+    const Tensor& bsr,
+    const Tensor& dense) {
+  TORCH_CHECK(false, "_triton_bsr_dense_mm: Triton kernel should be overwritten in Python.");
+  return Tensor {};
+}
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/sparse/SparseMatMul.cpp b/aten/src/ATen/native/sparse/SparseMatMul.cpp
index 548b66a..e5f283b 100644
--- a/aten/src/ATen/native/sparse/SparseMatMul.cpp
+++ b/aten/src/ATen/native/sparse/SparseMatMul.cpp
@@ -274,6 +274,5 @@
   return output;
 }
 
-
 } // namespace native
 } // namespace at
diff --git a/mypy.ini b/mypy.ini
index 4afe7dc..7108fee 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -188,6 +188,9 @@
 # Third party dependencies that don't have types.
 #
 
+[mypy-triton.*]
+ignore_missing_imports = True
+
 [mypy-tensorflow.*]
 ignore_missing_imports = True
 
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 30606d1..b1bfe59 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -20,6 +20,7 @@
     floating_types, all_types_and_complex_and, floating_and_complex_types, floating_types_and,
     all_types_and_complex, floating_and_complex_types_and
 )
+from torch._inductor.utils import has_triton
 from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
 
 if TEST_SCIPY:
@@ -1464,6 +1465,63 @@
         self.assertEqual(actual, out)
         self.assertEqual(actual, expected)
 
+    @parametrize("block_size", [16, 32, 64])
+    @parametrize("index_dtype", [torch.int32, torch.int64])
+    @unittest.skipIf(not has_triton(), "Triton is not available")
+    @skipCUDAIfRocm
+    @onlyCUDA
+    @dtypes(torch.half, torch.bfloat16)
+    @dtypesIfCUDA(*[torch.half] if SM53OrLater else [],
+                  *[torch.bfloat16] if SM80OrLater else [])
+    def test_triton_bsr_dense_bmm(self, device, dtype, index_dtype, block_size):
+        from functools import partial
+
+        # Note that each value in a non-zero block is in range block_size * [low^2, high^2).
+        tensor = partial(make_tensor, device=device, dtype=dtype, low=0.5, high=1.5)
+
+        # NOTE: batch dims with zero sizes are not supported in `to_sparse_bsr`.
+        batches = [(), (2,)]
+        size = [128, 256, 0]
+
+        # Whether to make inputs orthogonal so that the product is zero
+        make_orthogonal = [True, False]
+
+        for bd, bs, m, n, k, is_ortho in itertools.product(batches, batches, size, size, size, make_orthogonal):
+            bsr = tensor(bs + (m, k))
+            # NOTE: do not get confused, it will be transposed
+            dense = tensor(bd + (n, k))
+
+            if is_ortho:
+                bsr = torch.cat((bsr, torch.zeros_like(bsr)), dim=-1)
+                dense = torch.cat((torch.zeros_like(dense), dense), dim=-1)
+
+            bsr = bsr.to_sparse_bsr(block_size)
+
+            if bsr.dim() == 2:
+                # Test against linear to check dispatch.
+                res_tri = torch.nn.functional.linear(dense, bsr)
+                res_dense = torch.nn.functional.linear(dense, bsr.to_dense())
+            else:
+                # Otherwise check correctness against bmm
+                # since nn.linear does not support bsr.dim() > 2.
+                res_tri = torch._triton_bsr_dense_mm(bsr, dense.transpose(-2, -1))
+                res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
+            self.assertEqual(res_tri, res_dense)
+
+            res_dense = bsr.to_dense() @ dense.transpose(-2, -1)
+            # check whether bsr_dense_mm handles different grid sizes
+            # None means max possible grid size which is CUDA-dependent.
+            grid_size = (None, 2, 4)
+            grid_gen = itertools.product(grid_size, repeat=3)
+            for is_sparse_rowspace, grid in itertools.product((True, False), grid_gen):
+                res_tri = torch.sparse._triton_ops.bsr_dense_mm(
+                    bsr,
+                    dense.transpose(-2, -1),
+                    max_grid=grid,
+                    is_sparse_rowspace_mode=is_sparse_rowspace
+                )
+                self.assertEqual(res_tri, res_dense)
+
     # TODO: block_size 1 is broken
     @parametrize("block_size", [2, 3])
     @parametrize("index_dtype", [torch.int32, torch.int64])
diff --git a/torch/__init__.py b/torch/__init__.py
index 310e012..5b12580 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1368,3 +1368,8 @@
                   'use torch.sparse_coo_tensor(..., check_invariants=False) instead.')
     kwargs['check_invariants'] = False
     return torch.sparse_coo_tensor(*args, **kwargs)
+
+
+# dynamic registration of sparse triton kernels
+from torch.sparse import _register_impls
+_register_impls(torch.library.Library("aten", "IMPL"))
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index eec9bc2..ec1b883 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -898,18 +898,18 @@
 del _LegacyStorage
 del _CudaLegacyStorage
 
-torch._storage_classes.add(DoubleStorage)
-torch._storage_classes.add(FloatStorage)
-torch._storage_classes.add(LongStorage)
-torch._storage_classes.add(IntStorage)
-torch._storage_classes.add(ShortStorage)
-torch._storage_classes.add(CharStorage)
-torch._storage_classes.add(ByteStorage)
-torch._storage_classes.add(HalfStorage)
-torch._storage_classes.add(BoolStorage)
-torch._storage_classes.add(BFloat16Storage)
-torch._storage_classes.add(ComplexDoubleStorage)
-torch._storage_classes.add(ComplexFloatStorage)
+torch._storage_classes.add(DoubleStorage)  # type: ignore[has-type]
+torch._storage_classes.add(FloatStorage)  # type: ignore[has-type]
+torch._storage_classes.add(LongStorage)  # type: ignore[has-type]
+torch._storage_classes.add(IntStorage)  # type: ignore[has-type]
+torch._storage_classes.add(ShortStorage)  # type: ignore[has-type]
+torch._storage_classes.add(CharStorage)  # type: ignore[has-type]
+torch._storage_classes.add(ByteStorage)  # type: ignore[has-type]
+torch._storage_classes.add(HalfStorage)  # type: ignore[has-type]
+torch._storage_classes.add(BoolStorage)  # type: ignore[has-type]
+torch._storage_classes.add(BFloat16Storage)  # type: ignore[has-type]
+torch._storage_classes.add(ComplexDoubleStorage)  # type: ignore[has-type]
+torch._storage_classes.add(ComplexFloatStorage)  # type: ignore[has-type]
 
 from . import sparse
 from . import profiler
diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py
index 3ceaf56..a7a909e 100644
--- a/torch/sparse/__init__.py
+++ b/torch/sparse/__init__.py
@@ -4,6 +4,8 @@
 import torch
 from torch._C import _add_docstr, _sparse  # type: ignore[attr-defined]
 from torch import Tensor
+from torch.cuda import _lazy_call
+from torch._inductor.cuda_properties import get_device_capability
 
 # A workaround to support both TorchScript and MyPy:
 from typing import TYPE_CHECKING
@@ -462,3 +464,30 @@
                 return mth(*args, **kwargs)
 
         return test_mth
+
+# Triton registrations
+def _has_triton():
+    if not torch.cuda.is_available():
+        return False
+    try:
+        import triton
+
+        return triton is not None and get_device_capability() >= (7, 0)
+    except ImportError:
+        return False
+
+
+def _register_impls(lib):
+    """This function is called from torch/__init__.py to do any dynamic registrations. """
+
+
+    def register_sparse_cuda_impls(lib=lib):
+        from ._triton_ops import bsr_dense_mm
+
+        if bsr_dense_mm is not None:
+            lib.impl("aten::_triton_bsr_dense_mm",
+                     lambda *args, **kwargs: bsr_dense_mm(*args, skip_checks=True, **kwargs), "SparseCsrCUDA")
+
+    # This code is evaluated on import torch and therefore cannot force initialization of the cuda rt
+    # We must schedule the registration to occur lazily.
+    _lazy_call(register_sparse_cuda_impls)
diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py
new file mode 100644
index 0000000..d7b34f3
--- /dev/null
+++ b/torch/sparse/_triton_ops.py
@@ -0,0 +1,608 @@
+import torch
+from torch._inductor.cuda_properties import get_device_capability
+
+def _has_triton():
+    if not torch.cuda.is_available():
+        return False
+    try:
+        import triton
+
+        return triton is not None and get_device_capability() >= (7, 0)
+    except ImportError:
+        return False
+
+def compressed_indices_to_plain_indices(cidx, pidx):
+    nnz = pidx.shape[-1]
+    cdim = cidx.shape[-1] - 1
+    batch_numel = cidx.shape[0]
+    batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[
+        :, None
+    ]
+
+    cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset
+    cidx_linear = torch.empty(
+        (batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device
+    )
+    cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1)
+    cidx_linear[-1] = nnz * batch_numel
+
+    idx_linear = torch._convert_indices_from_csr_to_coo(
+        cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32)
+    ).select(0, 0)
+
+    return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset)
+
+
+def slicer(dim, slice_range, *tensors):
+    for t in tensors:
+        slices = [slice(None)] * t.dim()
+        slices[dim] = slice_range
+        yield t[slices]
+
+if _has_triton():
+    import triton
+    import triton.language as tl
+    from typing import Optional, Tuple
+
+    @triton.jit
+    def _bsr_strided_dense_rowspace_kernel(
+        BLOCKSIZE_ROW: tl.constexpr,
+        BLOCKSIZE_COL: tl.constexpr,
+        # values prologue
+        values_ptr,
+        values_batch_stride,
+        values_nnz_stride,
+        values_row_block_stride,
+        values_col_block_stride,
+        # values epilogue
+        # crow_indices prologue
+        crow_indices_ptr,
+        crow_indices_batch_stride,
+        crow_indices_stride,
+        # crow_indices epilogue
+        # col_indices prologue
+        col_indices_ptr,
+        col_indices_batch_stride,
+        col_indices_stride,
+        # col_indices epilogue
+        # dense prologue
+        dense_ptr,
+        dense_batch_stride,
+        dense_tiled_row_stride,
+        dense_tiled_col_stride,
+        dense_row_block_stride,
+        dense_col_block_stride,
+        # dense epilogue
+        # output prologue
+        output_ptr,
+        output_batch_stride,
+        output_tiled_row_stride,
+        output_tiled_col_stride,
+        output_row_block_stride,
+        output_col_block_stride,
+        # output epilogue
+        GROUP_SIZE_ROW: tl.constexpr,
+    ):
+        batch_pid = tl.program_id(axis=2)
+        row_block_pid = tl.program_id(axis=0)
+        col_block_pid = tl.program_id(axis=1)
+        n_block_rows = tl.num_programs(axis=0)
+        n_block_cols = tl.num_programs(axis=1)
+
+        row_block_pid, col_block_pid = tl.swizzle2d(
+            row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
+        )
+
+        crow_indices_offset_ptr = (
+            crow_indices_ptr
+            + crow_indices_batch_stride * batch_pid
+            + crow_indices_stride * row_block_pid
+        )
+        nnz_offset = tl.load(crow_indices_offset_ptr)
+        nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
+
+        # Compute nnz for the row with number row_block_pid.
+        # If it is zero, skip the row.
+        row_nnz = nnz_offset_next - nnz_offset
+        if row_nnz == 0:
+            return
+
+        row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
+        col_block_arange = tl.arange(0, BLOCKSIZE_COL)
+
+        # Pointers are set to the first block of the current row.
+        values_block_ptrs = (
+            values_ptr
+            + values_batch_stride * batch_pid
+            + values_nnz_stride * nnz_offset
+            + values_row_block_stride * row_block_arange[:, None]
+            + values_col_block_stride * col_block_arange[None, :]
+        )
+
+        # NOTE: dense is advanced into all dimensions but the tiled row one.
+        # That will be advanced in the loop according to values in col_indices.
+        dense_block_ptrs = (
+            dense_ptr
+            + dense_batch_stride * batch_pid
+            + dense_tiled_col_stride * col_block_pid
+            + dense_row_block_stride * col_block_arange[:, None]
+            + dense_col_block_stride * row_block_arange[None, :]
+        )
+
+        # Pointers are set to exact write-to locations
+        output_ptrs = (
+            output_ptr
+            + output_batch_stride * batch_pid
+            + output_tiled_row_stride * row_block_pid
+            + output_tiled_col_stride * col_block_pid
+            + output_row_block_stride * row_block_arange[:, None]
+            + output_col_block_stride * row_block_arange[None, :]
+        )
+
+        # Set pointer to the first nonzero element in the current row
+        col_index_nnz_ptr = (
+            col_indices_ptr
+            + col_indices_batch_stride * batch_pid
+            + col_indices_stride * nnz_offset
+        )
+
+        output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
+        for _ in range(row_nnz):
+            values_block = tl.load(values_block_ptrs)
+
+            # find which row of dense needs to get loaded
+            # for multiplication with values_block.
+            dense_row_idx = tl.load(col_index_nnz_ptr)
+            dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
+
+            # do block mm
+            output_acc_block += tl.dot(values_block, dense_block)
+
+            # move val/col_index ptrs to the next block in the row
+            values_block_ptrs += values_nnz_stride
+            col_index_nnz_ptr += col_indices_stride
+
+        # write back the result
+        tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
+
+
+    @triton.jit
+    def _bsr_strided_sparse_rowspace_kernel(
+        BLOCKSIZE_ROW: tl.constexpr,
+        BLOCKSIZE_COL: tl.constexpr,
+        batch_idx_ptr,
+        row_idx_ptr,
+        nnz_per_row_ptr,
+        nnz_per_row_cumsum_ptr,
+        col_indices_ptr,
+        col_indices_stride,
+        # values prologue
+        values_ptr,
+        values_nnz_stride,
+        values_row_block_stride,
+        values_col_block_stride,
+        # values epilogue
+        # dense prologue
+        dense_ptr,
+        dense_batch_stride,
+        dense_tiled_row_stride,
+        dense_tiled_col_stride,
+        dense_row_block_stride,
+        dense_col_block_stride,
+        # dense epilogue
+        # output prologue
+        output_ptr,
+        output_batch_stride,
+        output_tiled_row_stride,
+        output_tiled_col_stride,
+        output_row_block_stride,
+        output_col_block_stride,
+        # output epilogue
+        GROUP_SIZE_ROW: tl.constexpr,
+    ):
+        row_block_pid = tl.program_id(axis=0)
+        col_block_pid = tl.program_id(axis=1)
+        n_block_rows = tl.num_programs(axis=0)
+        n_block_cols = tl.num_programs(axis=1)
+
+        row_block_pid, col_block_pid = tl.swizzle2d(
+            row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
+        )
+
+        batch_idx = tl.load(batch_idx_ptr + row_block_pid)
+        row_idx = tl.load(row_idx_ptr + row_block_pid)
+        row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid)
+        row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid)
+        row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz
+
+        row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
+        col_block_arange = tl.arange(0, BLOCKSIZE_COL)
+
+        # Pointers are set to the first block of the current row.
+        values_block_ptrs = (
+            values_ptr
+            + values_nnz_stride * row_idx_nnz_offset
+            + values_row_block_stride * row_block_arange[:, None]
+            + values_col_block_stride * col_block_arange[None, :]
+        )
+
+        # NOTE: dense is advanced into all dimensions but the tiled row one.
+        # That will be advanced in the loop according to values in col_indices.
+        dense_block_ptrs = (
+            dense_ptr
+            + dense_batch_stride * batch_idx
+            + dense_tiled_col_stride * col_block_pid
+            + dense_row_block_stride * col_block_arange[:, None]
+            + dense_col_block_stride * row_block_arange[None, :]
+        )
+
+        # Pointers are set to exact write-to locations
+        output_ptrs = (
+            output_ptr
+            + output_batch_stride * batch_idx
+            + output_tiled_row_stride * row_idx
+            + output_tiled_col_stride * col_block_pid
+            + output_row_block_stride * row_block_arange[:, None]
+            + output_col_block_stride * row_block_arange[None, :]
+        )
+
+        output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
+        col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride
+        for _ in range(row_idx_nnz):
+            values_block = tl.load(values_block_ptrs)
+
+            # find which row of dense needs to get loaded
+            # for multiplication with values_block.
+            dense_row_idx = tl.load(col_index_nnz_ptr)
+            dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
+
+            # do block mm
+            output_acc_block += tl.dot(values_block, dense_block)
+
+            # move val/col_index ptrs to the next block in the row
+            values_block_ptrs += values_nnz_stride
+            col_index_nnz_ptr += col_indices_stride
+
+        # write back the result
+        tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
+
+
+    def _run_sparse_rowspace_kernel(
+        blocksize, values, crow_indices, col_indices, dense, output, max_grid
+    ):
+        # Compute a vector of non-zero elements numbers per each row.
+        # We want to ultimately iterate over non-zero rows.
+        nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1]
+
+        # Compute indices of non-zero counts.
+        # batch_idx maps to a broadcasted batch index, while
+        # row_idx tracks non-zero rows of the sparse argument
+        # and rows of the output that get modified.
+        batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)
+
+        # Compress the vector of counts to hold only non-zero values.
+        nnz_per_row = nnz_per_row[batch_idx, row_idx]
+        # Compute cumulative counts which along with nnz_per_row
+        # are used to compute offsets into nnz values.
+        nnz_per_row_cumsum = nnz_per_row.cumsum(-1)
+
+        n_nnz_block_rows = row_idx.size(-1)
+        n_block_cols = dense.size(-3)
+        max_n_nnz_block_rows, max_n_block_cols = max_grid[:2]
+
+        for c_start in range(0, n_block_cols, max_n_block_cols):
+            c_dense, c_output = slicer(
+                -3, slice(c_start, c_start + max_n_block_cols), dense, output
+            )
+            c_grid = min(n_block_cols - c_start, max_n_block_cols)
+
+            for r_start in range(0, n_nnz_block_rows, max_n_nnz_block_rows):
+                r_batch_idx, r_row_idx, r_nnz_per_row, r_nnz_per_row_cumsum = slicer(
+                    0,
+                    slice(r_start, r_start + max_n_nnz_block_rows),
+                    batch_idx,
+                    row_idx,
+                    nnz_per_row,
+                    nnz_per_row_cumsum,
+                )
+                r_grid = min(n_nnz_block_rows - r_start, max_n_nnz_block_rows)
+
+                _bsr_strided_sparse_rowspace_kernel[(r_grid, c_grid)](
+                    *blocksize,
+                    r_batch_idx,
+                    r_row_idx,
+                    r_nnz_per_row,
+                    r_nnz_per_row_cumsum,
+                    col_indices,
+                    *col_indices.stride(),
+                    values,
+                    *values.stride(),
+                    c_dense,
+                    *c_dense.stride(),
+                    c_output,
+                    *c_output.stride(),
+                    GROUP_SIZE_ROW=4,
+                    num_stages=4,
+                    num_warps=4,
+                )
+
+
+    def _run_dense_rowspace_kernel(
+        blocksize, values, crow_indices, col_indices, dense, output, max_grid
+    ):
+        # Launch kernel
+        n_batches = dense.size(0)
+        n_block_rows = crow_indices.size(-1) - 1
+        n_block_cols = dense.size(-3)
+        max_n_block_rows, max_n_block_cols, max_n_batches = max_grid
+
+        for b_start in range(0, n_batches, max_n_batches):
+            b_v, b_crow, b_col, b_d, b_o = slicer(
+                0,
+                slice(b_start, b_start + max_n_batches),
+                values,
+                crow_indices,
+                col_indices,
+                dense,
+                output,
+            )
+            b_grid = min(n_batches - b_start, max_n_batches)
+
+            for c_start in range(0, n_block_cols, max_n_block_cols):
+                bc_d, bc_o = slicer(
+                    -3, slice(c_start, c_start + max_n_block_cols), b_d, b_o
+                )
+                c_grid = min(n_block_cols - c_start, max_n_block_cols)
+
+                for r_start in range(0, n_block_rows, max_n_block_rows):
+                    r_slice = slice(r_start, r_start + max_n_block_rows)
+                    br_crow = next(slicer(-1, r_slice, b_crow))
+                    brc_o = next(slicer(-4, r_slice, bc_o))
+                    r_grid = min(n_block_rows - r_start, max_n_block_rows)
+
+                    _bsr_strided_dense_rowspace_kernel[(r_grid, c_grid, b_grid)](
+                        *blocksize,
+                        b_v,
+                        *b_v.stride(),
+                        br_crow,
+                        *br_crow.stride(),
+                        b_col,
+                        *b_col.stride(),
+                        bc_d,
+                        *bc_d.stride(),
+                        brc_o,
+                        *brc_o.stride(),
+                        GROUP_SIZE_ROW=4,
+                        num_stages=4,
+                        num_warps=4,
+                    )
+
+
+    def bsr_dense_mm(
+        bsr: torch.Tensor,
+        dense: torch.Tensor,
+        *,
+        skip_checks: bool = False,
+        is_sparse_rowspace_mode: Optional[bool] = None,
+        max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
+        out: Optional[torch.Tensor] = None,
+    ):
+        m, kl = bsr.shape[-2:]
+        kr, n = dense.shape[-2:]
+
+        def check(cond, msg):
+            if not cond:
+                raise ValueError(msg)
+
+        if not skip_checks:
+            check(
+                bsr.layout == torch.sparse_bsr,
+                "bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.",
+            )
+
+            check(
+                bsr.device == dense.device and bsr.device.type == "cuda",
+                "bsr_dense_mm(): all inputs are expected to be on the same GPU device.",
+            )
+
+            check(
+                bsr.dtype == dense.dtype
+                and bsr.dtype in (torch.half, torch.bfloat16, torch.float),
+                "bsr_dense_mm(): all inputs are expected to be of the same dtype "
+                "and one of (half, bfloat16, float32), "
+                f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.",
+            )
+
+            check(
+                bsr.dim() >= 2 and dense.dim() >= 2,
+                "bsr_dense_mm(): all inputs are expected to be at least 2D, "
+                f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
+            )
+
+            check(
+                kl == kr,
+                "bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
+                f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
+            )
+
+            row_block = bsr.values().shape[-2]
+            check(
+                not n % row_block,
+                f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
+                f"blocksize[0] == {row_block}.",
+            )
+
+        # Required to undo the fake batch dimension insertion.
+        original_batch_dims_broadcasted = torch.broadcast_shapes(
+            bsr.shape[:-2], dense.shape[:-2]
+        )
+
+        if out is not None and not skip_checks:
+            expected_out_shape = original_batch_dims_broadcasted + (m, n)
+            check(
+                out.shape == expected_out_shape,
+                "bsr_dense_mm(): `out` argument has wrong shape, "
+                f"expected {expected_out_shape}, but got {out.shape}.",
+            )
+            check(
+                out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
+                "bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
+                "i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
+                "should be True.",
+            )
+
+        # Short circuit if lhs is zero
+        if bsr._nnz() == 0:
+            return dense.new_zeros(original_batch_dims_broadcasted + (m, n))
+
+        # TODO: insert switch
+        if is_sparse_rowspace_mode is None:
+            is_sparse_rowspace_mode = False
+
+        # Introduce fake batch dimension if not present for convenience.
+        def unsqueeze_batch_dim(t, n_non_batch_dims):
+            if t.dim() > n_non_batch_dims:
+                return t
+            else:
+                return t.unsqueeze(0)
+
+        def make_triton_contiguous(t):
+            # Triton does not distinguish between row- and col-majorness
+            # and will be fast as long as there is a contiguous dimension.
+            if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()):
+                return t.contiguous()
+            else:
+                return t
+
+        crow_indices = unsqueeze_batch_dim(bsr.crow_indices(), 1)
+        col_indices = unsqueeze_batch_dim(bsr.col_indices(), 1)
+        values = make_triton_contiguous(unsqueeze_batch_dim(bsr.values(), 3))
+        dense = make_triton_contiguous(unsqueeze_batch_dim(dense, 2))
+        nnz = values.shape[-3]
+        blocksize = values.shape[-2:]
+
+        # Compute broadcasted batch dimension
+        bsr_batch_dims = values.shape[:-3]
+        dense_batch_dims = dense.shape[:-2]
+        batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)
+
+        # Allocate out
+        if out is None:
+            out = dense.new_zeros(batch_dims_broadcasted + (m, n))
+
+        # Broadcast batch dimensions and squash
+        def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
+            return t.broadcast_to(batch_dims + invariant_dims).flatten(
+                0, len(batch_dims) - 1
+            )
+
+        crow_indices = batch_broadcast_and_squash(
+            crow_indices, batch_dims_broadcasted, (-1,)
+        )
+
+        if is_sparse_rowspace_mode:
+            # Flatten batch dimension with nnz dimension
+            # as required by the sparse rowspace kernel.
+            col_indices = batch_broadcast_and_squash(
+                col_indices, batch_dims_broadcasted + (-1,), ()
+            )
+            values = batch_broadcast_and_squash(
+                values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:]
+            )
+        else:
+            col_indices = batch_broadcast_and_squash(
+                col_indices, batch_dims_broadcasted, (-1,)
+            )
+            values = batch_broadcast_and_squash(
+                values, batch_dims_broadcasted, values.shape[-3:]
+            )
+
+        dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])
+
+        # NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
+        out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])
+
+        # NOTE: this function will ALWAYS create a view
+        def tile_to_blocksize(t, blocksize):
+            *rest, m, n = t.shape
+            new_shape = rest + [
+                m // blocksize[0],
+                blocksize[0],
+                n // blocksize[1],
+                blocksize[1],
+            ]
+            return t.reshape(new_shape).transpose(-3, -2)
+
+        # "Blockify" the row dimension of dense with blocksize[1]
+        # since dense is on the rhs of matmul
+        dense = tile_to_blocksize(dense, blocksize[::-1])
+        # "Blockify" the row dimension of out with blocksize[0]
+        # which is inherited from the bsr input.
+        # NOTE: tile_to_blocksize will create a view.
+        # NOTE: out.blocksize[-1] == dense.blocksize[-1],
+        # so it could be any value in [1, dense.shape[-1]).
+        # We need to probably use the largest possible blocksize
+        # so that it fits into SRAM.
+        out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
+
+        # Launch kernel
+        if is_sparse_rowspace_mode:
+            kernel = _run_sparse_rowspace_kernel
+        else:
+            kernel = _run_dense_rowspace_kernel
+
+        # cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
+        cuda_max_grid = (2147483647, 65535, 65535)
+        if max_grid is None:
+            max_grid = cuda_max_grid
+        else:
+
+            def valid_grid_dim(g, mg):
+                if g is None:
+                    return mg
+                else:
+                    # grid must be at least 1 and no greater than mg
+                    return max(1, min(g, mg))
+
+            max_grid = tuple(
+                valid_grid_dim(g, mg) for g, mg in zip(max_grid, cuda_max_grid)
+            )  # type: ignore[assignment]
+
+        kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)
+
+        # Block dims need to rejoin with the corresponding block dimensions
+        # prior to reshape so that blocks do not end up being transposed.
+        # NB: type checker is not able to narrow Optional[Tensor] to tensor by this point
+        return out.transpose(-3, -2).reshape(original_batch_dims_broadcasted + (m, n))  # type: ignore[union-attr]
+else:
+    bsr_dense_mm = None  # type: ignore[assignment]
+
+
+if __name__ == "__main__":
+    from torch._inductor.utils import has_triton
+
+    if has_triton():
+        torch.manual_seed(13)
+        dtype = torch.float32
+        p = 0.5
+        mask_size = (8, 8)
+        block_size = (64, 64)
+        size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1])
+
+        n_exp = 512
+        diff = torch.ones(n_exp, device="cuda", dtype=torch.float32)
+        for i in range(n_exp):
+            mask = torch.rand(*mask_size, device="cuda") < p
+            x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10
+            x = (
+                (mask[:, :, None, None] * x)
+                .transpose(-3, -2)
+                .reshape(*size)
+                .to_sparse_bsr(*block_size)
+            )
+            y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10
+            res_dense = x.to_dense() @ y
+            res = bsr_dense_mm(x, y)
+            diff[i] = (res - res_dense).abs().max()
+        print(f"mean: {diff.mean()}, std: {diff.std()}")
+        print(f"max diff: {diff.max()}")