blob: 2696bff5344f7588655141d9da242aedddc75364 [file] [log] [blame]
import torch
from torch._inductor.cuda_properties import get_device_capability
def _has_triton():
if not torch.cuda.is_available():
return False
try:
import triton
return triton is not None and get_device_capability() >= (7, 0)
except ImportError:
return False
def make_triton_contiguous(t):
# Triton does not distinguish between row- and col-majorness
# and will be fast as long as there is a contiguous dimension.
if not (t.is_contiguous() or t.transpose(-2, -1).is_contiguous()):
return t.contiguous()
else:
return t
def broadcast_batch_dims(*tensors):
return torch.broadcast_shapes(*(t.shape[:-2] for t in tensors))
def slicer(dim, slice_range, *tensors):
for t in tensors:
slices = [slice(None)] * t.dim()
slices[dim] = slice_range
yield t[slices]
def multidim_slicer(dims, slices, *tensors):
for t in tensors:
s = [slice(None)] * t.dim()
for d, d_slice in zip(dims, slices):
if d is not None:
s[d] = d_slice
yield t[s]
def ptr_stride_extractor(*tensors):
for t in tensors:
yield t
yield from t.stride()
def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
assert 0 <= len(full_grid) <= 3
assert 0 <= len(grid_blocks) <= 3
import itertools
def generate_grid_points():
for fg, mg in zip(full_grid, grid_blocks):
yield range(0, fg, mg)
def generate_sliced_tensors(slices):
for t, t_dims in tensor_dims_map.items():
yield next(multidim_slicer(t_dims, slices, t))
for grid_point in itertools.product(*generate_grid_points()):
grid = [min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)]
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)]
# grid_points are iterated in a "contiguous" order, i.e.
# left dimensions traversed slower than right dimensions.
# This order is reversed for CUDA grids.
yield grid[::-1], *generate_sliced_tensors(slices)
def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
# cuda_max_grid = (2 ** 31 - 1, 2 ** 16 - 1, 2 ** 16 - 1)
cuda_max_grid = (2147483647, 65535, 65535)[::-1]
if grid_blocks is None:
grid_blocks = cuda_max_grid
else:
def valid_grid_dim(g, mg):
if g is None:
return mg
else:
# grid must be at least 1 and no greater than mg
return max(1, min(g, mg))
grid_blocks = tuple(
valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
) # type: ignore[assignment]
for grid, *sliced_tensors in grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
kernel(grid, *sliced_tensors)
if _has_triton():
import triton
import triton.language as tl
from typing import Optional, Tuple
@triton.jit
def _bsr_strided_dense_rowspace_kernel(
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
# values prologue
values_ptr,
values_batch_stride,
values_nnz_stride,
values_row_block_stride,
values_col_block_stride,
# values epilogue
# crow_indices prologue
crow_indices_ptr,
crow_indices_batch_stride,
crow_indices_stride,
# crow_indices epilogue
# col_indices prologue
col_indices_ptr,
col_indices_batch_stride,
col_indices_stride,
# col_indices epilogue
# dense prologue
dense_ptr,
dense_batch_stride,
dense_tiled_row_stride,
dense_tiled_col_stride,
dense_row_block_stride,
dense_col_block_stride,
# dense epilogue
# output prologue
output_ptr,
output_batch_stride,
output_tiled_row_stride,
output_tiled_col_stride,
output_row_block_stride,
output_col_block_stride,
# output epilogue
acc_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
GROUP_SIZE_ROW: tl.constexpr,
):
batch_pid = tl.program_id(axis=2)
row_block_pid = tl.program_id(axis=0)
col_block_pid = tl.program_id(axis=1)
n_block_rows = tl.num_programs(axis=0)
n_block_cols = tl.num_programs(axis=1)
row_block_pid, col_block_pid = tl.swizzle2d(
row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW
)
crow_indices_offset_ptr = (
crow_indices_ptr
+ crow_indices_batch_stride * batch_pid
+ crow_indices_stride * row_block_pid
)
nnz_offset = tl.load(crow_indices_offset_ptr)
nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride)
# Compute nnz for the row with number row_block_pid.
# If it is zero, skip the row.
row_nnz = nnz_offset_next - nnz_offset
if row_nnz == 0:
return
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
col_block_arange = tl.arange(0, BLOCKSIZE_COL)
# Pointers are set to the first block of the current row.
values_block_ptrs = (
values_ptr
+ values_batch_stride * batch_pid
+ values_nnz_stride * nnz_offset
+ values_row_block_stride * row_block_arange[:, None]
+ values_col_block_stride * col_block_arange[None, :]
)
# NOTE: dense is advanced into all dimensions but the tiled row one.
# That will be advanced in the loop according to values in col_indices.
dense_block_ptrs = (
dense_ptr
+ dense_batch_stride * batch_pid
+ dense_tiled_col_stride * col_block_pid
+ dense_row_block_stride * col_block_arange[:, None]
+ dense_col_block_stride * row_block_arange[None, :]
)
# Pointers are set to exact write-to locations
output_ptrs = (
output_ptr
+ output_batch_stride * batch_pid
+ output_tiled_row_stride * row_block_pid
+ output_tiled_col_stride * col_block_pid
+ output_row_block_stride * row_block_arange[:, None]
+ output_col_block_stride * row_block_arange[None, :]
)
# Set pointer to the first nonzero element in the current row
col_index_nnz_ptr = (
col_indices_ptr
+ col_indices_batch_stride * batch_pid
+ col_indices_stride * nnz_offset
)
output_acc_block = tl.zeros((BLOCKSIZE_ROW, BLOCKSIZE_ROW), dtype=acc_dtype)
for _ in range(row_nnz):
values_block = tl.load(values_block_ptrs)
# find which row of dense needs to get loaded
# for multiplication with values_block.
dense_row_idx = tl.load(col_index_nnz_ptr)
dense_block = tl.load(dense_block_ptrs + dense_tiled_row_stride * dense_row_idx)
# do block mm
output_acc_block += tl.dot(values_block, dense_block, allow_tf32=allow_tf32)
# move val/col_index ptrs to the next block in the row
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride
# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty))
def _run_dense_rowspace_kernel(
blocksize, values, crow_indices, col_indices, dense, output, max_grid
):
n_batches = dense.size(0)
n_block_rows = crow_indices.size(-1) - 1
n_block_cols = dense.size(-3)
full_grid = (n_batches, n_block_cols, n_block_rows)
if max_grid is not None:
grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3]))
else:
grid_blocks = None
tensor_dims_map = {
values: (0, None, None),
crow_indices: (0, None, -1),
col_indices: (0, None, None),
dense: (0, -3, None),
output: (0, -3, -4)
}
if values.dtype in (torch.half, torch.bfloat16):
acc_dtype = tl.float32
allow_tf32 = True
else:
acc_dtype = tl.float64
allow_tf32 = False
def kernel(grid, *sliced_tensors):
_bsr_strided_dense_rowspace_kernel[grid](
*blocksize,
*ptr_stride_extractor(*sliced_tensors),
acc_dtype=acc_dtype,
allow_tf32=allow_tf32,
GROUP_SIZE_ROW=4,
num_stages=1,
num_warps=4
)
launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks)
def bsr_dense_mm(
bsr: torch.Tensor,
dense: torch.Tensor,
*,
skip_checks: bool = False,
max_grid: Optional[Tuple[Optional[int], Optional[int], Optional[int]]] = None,
out: Optional[torch.Tensor] = None,
):
def check(cond, msg):
if not cond:
raise ValueError(msg)
if not skip_checks:
check(
bsr.layout == torch.sparse_bsr,
"bsr_dense_mm(): only BSR sparse format is supported for the sparse argument.",
)
check(
bsr.device == dense.device and bsr.device.type == "cuda",
"bsr_dense_mm(): all inputs are expected to be on the same GPU device.",
)
check(
bsr.dtype == dense.dtype
and bsr.dtype in (torch.half, torch.bfloat16, torch.float),
"bsr_dense_mm(): all inputs are expected to be of the same dtype "
"and one of (half, bfloat16, float32), "
f"but got bsr.dtype == {bsr.dtype} and dense.dtype == {dense.dtype}.",
)
check(
bsr.dim() >= 2 and dense.dim() >= 2,
"bsr_dense_mm(): all inputs are expected to be at least 2D, "
f"but got bsr.dim() == {bsr.dim()} and dense.dim() == {dense.dim()}.",
)
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]
check(
kl == kr,
"bsr_dense_mm(): argument sizes are not compatible for matrix multiplication, "
f"got bsr.shape[-1] == {kl} which is not equal to dense.shape[-2] == {kr}.",
)
row_block, col_block = bsr.values().shape[-2:]
check(
not n % row_block,
f"bsr_dense_mm(): dense.size(-1) == {n} should be divisible by "
f"blocksize[0] == {row_block}.",
)
def is_power_of_two(v):
return not (v & (v - 1))
def is_compatible_blocksize(b):
assert len(b) == 2
res = True
for blocksize in b:
# Triton loads only blocks which are at least 16 and powers of 2.
res = (blocksize >= 16 and is_power_of_two(blocksize)) and res
return res
check(
is_compatible_blocksize((row_block, col_block)),
f"bsr_dense_mm(): sparse inputs' blocksize ({row_block}, {col_block}) "
"should be at least 16 and a power of 2 in each dimension.",
)
else:
m, kl = bsr.shape[-2:]
kr, n = dense.shape[-2:]
original_batch_dims_broadcasted = broadcast_batch_dims(bsr, dense)
if out is not None and not skip_checks:
expected_out_shape = original_batch_dims_broadcasted + (m, n)
check(
out.shape == expected_out_shape,
"bsr_dense_mm(): `out` argument has wrong shape, "
f"expected {expected_out_shape}, but got {out.shape}.",
)
check(
out.is_contiguous() or out.transpose(-2, -1).is_contiguous(),
"bsr_dense_mm(): only row-major/col-major `out` arguments are supported, "
"i.e. (out.is_contiguous() or out.transpose(-2, -1).is_contiguous()) "
"should be True.",
)
# Allocate out
if out is None:
out = dense.new_zeros(original_batch_dims_broadcasted + (m, n))
else:
out.zero_()
# Short circuit if lhs is zero
if bsr._nnz() == 0:
return out
# Introduce fake batch dimension if not present for convenience.
crow_indices = bsr.crow_indices().unsqueeze(0)
col_indices = bsr.col_indices().unsqueeze(0)
values = make_triton_contiguous(bsr.values().unsqueeze(0))
dense = make_triton_contiguous(dense.unsqueeze(0))
nnz = values.shape[-3]
blocksize = values.shape[-2:]
# Compute broadcasted batch dimension
bsr_batch_dims = values.shape[:-3]
dense_batch_dims = dense.shape[:-2]
batch_dims_broadcasted = torch.broadcast_shapes(bsr_batch_dims, dense_batch_dims)
# Broadcast batch dimensions and squash
def batch_broadcast_and_squash(t, batch_dims, invariant_dims):
return t.broadcast_to(batch_dims + invariant_dims).flatten(
0, len(batch_dims) - 1
)
crow_indices = batch_broadcast_and_squash(
crow_indices, batch_dims_broadcasted, (-1,)
)
col_indices = batch_broadcast_and_squash(
col_indices, batch_dims_broadcasted, (-1,)
)
values = batch_broadcast_and_squash(
values, batch_dims_broadcasted, values.shape[-3:]
)
dense = batch_broadcast_and_squash(dense, batch_dims_broadcasted, dense.shape[-2:])
# NOTE: out is contiguous, so batch_broadcast_and_squash will create a view
# out gets modified in-place, so we store a backup copy.
out_backup = out
out = batch_broadcast_and_squash(out, batch_dims_broadcasted, out.shape[-2:])
# NOTE: this function will ALWAYS create a view
def tile_to_blocksize(t, blocksize):
*rest, m, n = t.shape
new_shape = rest + [
m // blocksize[0],
blocksize[0],
n // blocksize[1],
blocksize[1],
]
return t.reshape(new_shape).transpose(-3, -2)
# "Blockify" the row dimension of dense with blocksize[1]
# since dense is on the rhs of matmul
dense = tile_to_blocksize(dense, blocksize[::-1])
# "Blockify" the row dimension of out with blocksize[0]
# which is inherited from the bsr input.
# NOTE: tile_to_blocksize will create a view.
# NOTE: out.blocksize[-1] == dense.blocksize[-1],
# so it could be any value in [1, dense.shape[-1]).
# We need to probably use the largest possible blocksize
# so that it fits into SRAM.
out = tile_to_blocksize(out, (blocksize[0], blocksize[0]))
# Launch kernel
kernel = _run_dense_rowspace_kernel
kernel(blocksize, values, crow_indices, col_indices, dense, out, max_grid)
return out_backup
else:
bsr_dense_mm = None # type: ignore[assignment]