blob: ed7a3e1a1011b46153347ddd5777442fe9e0736c [file] [log] [blame]
import torch
from torch._inductor.cuda_properties import get_device_capability
def _has_triton():
if not torch.cuda.is_available():
return False
try:
import triton
return triton is not None and get_device_capability() >= (7, 0)
except ImportError:
return False
def compressed_indices_to_plain_indices(cidx, pidx):
nnz = pidx.shape[-1]
cdim = cidx.shape[-1] - 1
batch_numel = cidx.shape[0]
batch_offset = torch.arange(batch_numel, dtype=cidx.dtype, device=cidx.device)[
:, None
]
cidx_batch_offsetted = cidx[:, :-1] + nnz * batch_offset
cidx_linear = torch.empty(
(batch_numel * cdim + 1,), dtype=cidx.dtype, device=cidx.device
)
cidx_linear[:-1] = cidx_batch_offsetted.reshape(-1)
cidx_linear[-1] = nnz * batch_numel
idx_linear = torch._convert_indices_from_csr_to_coo(
cidx_linear, pidx.reshape(-1), out_int32=(cidx.dtype == torch.int32)
).select(0, 0)
return idx_linear.reshape(batch_numel, -1).sub_(cdim * batch_offset)
def slicer(dim, slice_range, *tensors):
for t in tensors:
slices = [slice(None)] * t.dim()
slices[dim] = slice_range
yield t[slices]
if _has_triton():
import triton
import triton.language as tl
from typing import Optional, Tuple
@triton.jit
def _bsr_strided_dense_rowspace_kernel(
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
# values prologue
values_ptr,
values_batch_stride,
values_nnz_stride,
values_row_block_stride,
values_col_block_stride,
# values epilogue
# crow_indices prologue
crow_indices_ptr,
crow_indices_batch_stride,
crow_indices_stride,
# crow_indices epilogue
# col_indices prologue
col_indices_ptr,
col_indices_batch_stride,
col_indices_stride,
# col_indices epilogue
# dense prologue
dense_ptr,
dense_batch_stride,
dense_tiled_row_stride,
dense_tiled_col_stride,
dense_row_block_stride,
dense_col_block_stride,
# dense epilogue
# output prologue
output_ptr,
output_batch_stride,
output_tiled_row_stride,
output_tiled_col_stride,
output_row_block_stride,
output_col_block_stride,
# output epilogue
GROUP_SIZE_ROW: tl.constexpr,
):
batch_pid = tl.program_id(axis=2)
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
n_block_rows = tl.num_programs(axis=0)
n_block_cols = tl.num_programs(axis=1)
row_block_pid, col_block_pid = tl.swizzle2d(
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
)
crow_indices_offset_ptr = (
crow_indices_ptr
+ crow_indices_batch_stride * batch_pid
+ crow_indices_stride * row_block_pid
)
nnz_offset = tl.load(crow_indices_offset_ptr)
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
# Compute nnz for the row with number row_block_pid.
# If it is zero, skip the row.
row_nnz = nnz_offset_next - nnz_offset
if row_nnz == 0:
return
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
# Pointers are set to the first block of the current row.
values_block_ptrs = (
values_ptr
+ values_batch_stride * batch_pid
+ values_nnz_stride * nnz_offset
+ values_row_block_stride * row_block_arange[:, None]
+ values_col_block_stride * col_block_arange[None, :]
)
# NOTE: dense is advanced into all dimensions but the tiled row one.
# That will be advanced in the loop according to values in col_indices.
dense_block_ptrs = (
dense_ptr
+ dense_batch_stride * batch_pid
+ dense_tiled_col_stride * col_block_pid
+ dense_row_block_stride * col_block_arange[:, None]
+ dense_col_block_stride * row_block_arange[None, :]
)
# Pointers are set to exact write-to locations
output_ptrs = (
output_ptr
+ output_batch_stride * batch_pid
+ output_tiled_row_stride * row_block_pid
+ output_tiled_col_stride * col_block_pid
+ output_row_block_stride * row_block_arange[:, None]
+ output_col_block_stride * row_block_arange[None, :]
)
# Set pointer to the first nonzero element in the current row
col_index_nnz_ptr = (
col_indices_ptr
+ col_indices_batch_stride * batch_pid
+ col_indices_stride * nnz_offset
)
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
for _ in range(row_nnz):
values_block = tl.load(values_block_ptrs)
# find which row of dense needs to get loaded
# for multiplication with values_block.
dense_row_idx = tl.load(col_index_nnz_ptr)
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
# do block mm
output_acc_block += tl.dot(values_block, dense_block)
# move val/col_index ptrs to the next block in the row
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride
# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
@triton.jit
def _bsr_strided_sparse_rowspace_kernel(
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
batch_idx_ptr,
row_idx_ptr,
nnz_per_row_ptr,
nnz_per_row_cumsum_ptr,
col_indices_ptr,
col_indices_stride,
# values prologue
values_ptr,
values_nnz_stride,
values_row_block_stride,
values_col_block_stride,
# values epilogue
# dense prologue
dense_ptr,
dense_batch_stride,
dense_tiled_row_stride,
dense_tiled_col_stride,
dense_row_block_stride,
dense_col_block_stride,
# dense epilogue
# output prologue
output_ptr,
output_batch_stride,
output_tiled_row_stride,
output_tiled_col_stride,
output_row_block_stride,
output_col_block_stride,
# output epilogue
GROUP_SIZE_ROW: tl.constexpr,
):
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
n_block_rows = tl.num_programs(axis=0)
n_block_cols = tl.num_programs(axis=1)
row_block_pid, col_block_pid = tl.swizzle2d(
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
)
batch_idx = tl.load(batch_idx_ptr + row_block_pid)
row_idx = tl.load(row_idx_ptr + row_block_pid)
row_idx_nnz = tl.load(nnz_per_row_ptr + row_block_pid)
row_idx_nnz_cumsum = tl.load(nnz_per_row_cumsum_ptr + row_block_pid)
row_idx_nnz_offset = row_idx_nnz_cumsum - row_idx_nnz
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
# Pointers are set to the first block of the current row.
values_block_ptrs = (
values_ptr
+ values_nnz_stride * row_idx_nnz_offset
+ values_row_block_stride * row_block_arange[:, None]
+ values_col_block_stride * col_block_arange[None, :]
)
# NOTE: dense is advanced into all dimensions but the tiled row one.
# That will be advanced in the loop according to values in col_indices.
dense_block_ptrs = (
dense_ptr
+ dense_batch_stride * batch_idx
+ dense_tiled_col_stride * col_block_pid
+ dense_row_block_stride * col_block_arange[:, None]
+ dense_col_block_stride * row_block_arange[None, :]
)
# Pointers are set to exact write-to locations
output_ptrs = (
output_ptr
+ output_batch_stride * batch_idx
+ output_tiled_row_stride * row_idx
+ output_tiled_col_stride * col_block_pid
+ output_row_block_stride * row_block_arange[:, None]
+ output_col_block_stride * row_block_arange[None, :]
)
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), tl.float32)
col_index_nnz_ptr = col_indices_ptr + row_idx_nnz_offset * col_indices_stride
for _ in range(row_idx_nnz):
values_block = tl.load(values_block_ptrs)
# find which row of dense needs to get loaded
# for multiplication with values_block.
dense_row_idx = tl.load(col_index_nnz_ptr)
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
# do block mm
output_acc_block += tl.dot(values_block, dense_block)
# move val/col_index ptrs to the next block in the row
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride
# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
def _run_sparse_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
# Compute a vector of non-zero elements numbers per each row.
# We want to ultimately iterate over non-zero rows.
nnz_per_row = crow_indices[:, 1:] - crow_indices[:, :-1]
# Compute indices of non-zero counts.
# batch_idx maps to a broadcasted batch index, while
# row_idx tracks non-zero rows of the sparse argument
# and rows of the output that get modified.
batch_idx, row_idx = nnz_per_row.nonzero(as_tuple=True)
# Compress the vector of counts to hold only non-zero values.
nnz_per_row = nnz_per_row[batch_idx, row_idx]
# Compute cumulative counts which along with nnz_per_row
# are used to compute offsets into nnz values.
nnz_per_row_cumsum = nnz_per_row.cumsum(-1)
n_nnz_block_rows = row_idx.size(-1)
n_block_cols = dense.size(-3)
max_n_nnz_block_rows, max_n_block_cols = max_grid[:2]
for c_start in range(0, n_block_cols, max_n_block_cols):
c_dense, c_output = slicer(
-3, slice(c_start, c_start + max_n_block_cols), dense, output
)
c_grid = min(n_block_cols - c_start, max_n_block_cols)
for r_start in range(0, n_nnz_block_rows, max_n_nnz_block_rows):
r_batch_idx, r_row_idx, r_nnz_per_row, r_nnz_per_row_cumsum = slicer(
0,
slice(r_start, r_start + max_n_nnz_block_rows),
batch_idx,
row_idx,
nnz_per_row,
nnz_per_row_cumsum,
)
r_grid = min(n_nnz_block_rows - r_start, max_n_nnz_block_rows)
_bsr_strided_sparse_rowspace_kernel[(r_grid, c_grid)](
*blocksize,
r_batch_idx,
r_row_idx,
r_nnz_per_row,
r_nnz_per_row_cumsum,
col_indices,
*col_indices.stride(),
values,
*values.stride(),
c_dense,
*c_dense.stride(),
c_output,
*c_output.stride(),
GROUP_SIZE_ROW=4,
num_stages=1,
num_warps=4,
)
def _run_dense_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
# Launch kernel
n_batches = dense.size(0)
n_block_rows = crow_indices.size(-1) - 1
n_block_cols = dense.size(-3)
max_n_block_rows, max_n_block_cols, max_n_batches = max_grid
for b_start in range(0, n_batches, max_n_batches):
b_v, b_crow, b_col, b_d, b_o = slicer(
0,
slice(b_start, b_start + max_n_batches),
values,
crow_indices,
col_indices,
dense,
output,
)
b_grid = min(n_batches - b_start, max_n_batches)
for c_start in range(0, n_block_cols, max_n_block_cols):
bc_d, bc_o = slicer(
-3, slice(c_start, c_start + max_n_block_cols), b_d, b_o
)
c_grid = min(n_block_cols - c_start, max_n_block_cols)
for r_start in range(0, n_block_rows, max_n_block_rows):
r_slice = slice(r_start, r_start + max_n_block_rows)
br_crow = next(slicer(-1, r_slice, b_crow))
brc_o = next(slicer(-4, r_slice, bc_o))
r_grid = min(n_block_rows - r_start, max_n_block_rows)
_bsr_strided_dense_rowspace_kernel[(r_grid, c_grid, b_grid)](
*blocksize,
b_v,
*b_v.stride(),
br_crow,
*br_crow.stride(),
b_col,
*b_col.stride(),
bc_d,
*bc_d.stride(),
brc_o,
*brc_o.stride(),
GROUP_SIZE_ROW=4,
num_stages=1,
num_warps=4,
)
def bsr_dense_mm(
bsr: torch.Tensor,
dense: torch.Tensor,
*,
skip_checks: bool = False,
is_sparse_rowspace_mode: Optional[bool] = None,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
out: Optional[torch.Tensor] = None,
):
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]
def check(cond, msg):
if not cond:
raise ValueError(msg)
if not skip_checks:
check(
bsr.layout == torch.sparse_bsr,
"bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.",
)
check(
bsr.device == dense.device and bsr.device.type == "cuda",
"bsr_dense_mm(): all inputs are expected to be on the same GPU device.",
)
check(
bsr.dtype == dense.dtype
and bsr.dtype in (torch.half, torch.bfloat16, torch.float),
"bsr_dense_mm(): all inputs are expected to be of the same dtype "
"and one of (half, bfloat16, float32), "
f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.",
)
check(
bsr.dim() >= 2 and dense.dim() >= 2,
"bsr_dense_mm(): all inputs are expected to be at least 2D, "
f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
)
check(
kl == kr,
"bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
)
row_block = bsr.values().shape[-2]
check(
not n % row_block,
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
f"blocksize[0] == {row_block}.",
)
# Required to undo the fake batch dimension insertion.
original_batch_dims_broadcasted = torch.broadcast_shapes(
bsr.shape[:-2], dense.shape[:-2]
)
if out is not None and not skip_checks:
expected_out_shape = original_batch_dims_broadcasted + (m, n)
check(
out.shape == expected_out_shape,
"bsr_dense_mm(): `out` argument has wrong shape, "
f"expected {expected_out_shape}, but got {out.shape}.",
)
check(
out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
"bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
"i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
"should be True.",
)
# Allocate out
if out is None:
out = dense.new_zeros(original_batch_dims_broadcasted + (m, n))
else:
out.zero_()
# Short circuit if lhs is zero
if bsr._nnz() == 0:
return out
# TODO: insert switch
if is_sparse_rowspace_mode is None:
is_sparse_rowspace_mode = False
# Introduce fake batch dimension if not present for convenience.
def unsqueeze_batch_dim(t, n_non_batch_dims):
if t.dim() > n_non_batch_dims:
return t
else:
return t.unsqueeze(0)
def make_triton_contiguous(t):
# Triton does not distinguish between row- and col-majorness
# and will be fast as long as there is a contiguous dimension.
if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()):
return t.contiguous()
else:
return t
crow_indices = unsqueeze_batch_dim(bsr.crow_indices(), 1)
col_indices = unsqueeze_batch_dim(bsr.col_indices(), 1)
values = make_triton_contiguous(unsqueeze_batch_dim(bsr.values(), 3))
dense = make_triton_contiguous(unsqueeze_batch_dim(dense, 2))
nnz = values.shape[-3]
blocksize = values.shape[-2:]
# Compute broadcasted batch dimension
bsr_batch_dims = values.shape[:-3]
dense_batch_dims = dense.shape[:-2]
batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)
# Broadcast batch dimensions and squash
def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
return t.broadcast_to(batch_dims + invariant_dims).flatten(
0, len(batch_dims) - 1
)
crow_indices = batch_broadcast_and_squash(
crow_indices, batch_dims_broadcasted, (-1,)
)
if is_sparse_rowspace_mode:
# Flatten batch dimension with nnz dimension
# as required by the sparse rowspace kernel.
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted + (-1,), ()
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted + (values.shape[-3],), values.shape[-2:]
)
else:
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted, (-1,)
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted, values.shape[-3:]
)
dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])
# NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
# out gets modified in-place, so we store a backup copy.
out_backup = out
out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])
# NOTE: this function will ALWAYS create a view
def tile_to_blocksize(t, blocksize):
*rest, m, n = t.shape
new_shape = rest + [
m // blocksize[0],
blocksize[0],
n // blocksize[1],
blocksize[1],
]
return t.reshape(new_shape).transpose(-3, -2)
# "Blockify" the row dimension of dense with blocksize[1]
# since dense is on the rhs of matmul
dense = tile_to_blocksize(dense, blocksize[::-1])
# "Blockify" the row dimension of out with blocksize[0]
# which is inherited from the bsr input.
# NOTE: tile_to_blocksize will create a view.
# NOTE: out.blocksize[-1] == dense.blocksize[-1],
# so it could be any value in [1, dense.shape[-1]).
# We need to probably use the largest possible blocksize
# so that it fits into SRAM.
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
# Launch kernel
if is_sparse_rowspace_mode:
kernel = _run_sparse_rowspace_kernel
else:
kernel = _run_dense_rowspace_kernel
# cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
cuda_max_grid = (2147483647, 65535, 65535)
if max_grid is None:
max_grid = cuda_max_grid
else:
def valid_grid_dim(g, mg):
if g is None:
return mg
else:
# grid must be at least 1 and no greater than mg
return max(1, min(g, mg))
max_grid = tuple(
valid_grid_dim(g, mg) for g, mg in zip(max_grid, cuda_max_grid)
) # type: ignore[assignment]
kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)
return out_backup
else:
bsr_dense_mm = None # type: ignore[assignment]
if __name__ == "__main__":
from torch._inductor.utils import has_triton
if has_triton():
torch.manual_seed(13)
dtype = torch.float32
p = 0.5
mask_size = (8, 8)
block_size = (64, 64)
size = (mask_size[0] * block_size[0], mask_size[1] * block_size[1])
n_exp = 512
diff = torch.ones(n_exp, device="cuda", dtype=torch.float32)
for i in range(n_exp):
mask = torch.rand(*mask_size, device="cuda") < p
x = torch.rand(*mask_size, *block_size, dtype=dtype, device="cuda") / 10
x = (
(mask[:, :, None, None] * x)
.transpose(-3, -2)
.reshape(*size)
.to_sparse_bsr(*block_size)
)
y = torch.rand(5, *size, dtype=dtype, device="cuda") / 10
res_dense = x.to_dense() @ y
res = bsr_dense_mm(x, y)
diff[i] = (res - res_dense).abs().max()
print(f"mean: {diff.mean()}, std: {diff.std()}")
print(f"max diff: {diff.max()}")