| import warnings |
| from collections import namedtuple |
| from typing import Any, Optional, Tuple, List, Callable, Dict |
| |
| import torch |
| from torch.sparse._semi_structured_conversions import ( |
| sparse_semi_structured_from_dense_cutlass, |
| sparse_semi_structured_to_dense_cutlass, |
| ) |
| from torch.sparse._semi_structured_ops import ( |
| fallback_dispatcher, |
| semi_sparse_values, |
| semi_sparse_indices, |
| semi_sparse_detach, |
| semi_sparse_t, |
| semi_sparse_view, |
| semi_sparse_mm, |
| semi_sparse_addmm, |
| semi_sparse_linear, |
| ) |
| |
| __all__ = [ |
| "SparseSemiStructuredTensor", |
| "SparseSemiStructuredTensorCUTLASS", |
| "SparseSemiStructuredTensorCUSPARSELT", |
| "to_sparse_semi_structured", |
| ] |
| |
| _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( |
| "_SEMI_STRUCTURED_SPARSE_CONFIG", |
| "sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols", |
| ) |
| |
| |
| class SparseSemiStructuredTensor(torch.Tensor): |
| """ |
| This class implementes semi-structured sparsity as a Tensor subclass. |
| |
| Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, |
| depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained |
| structured sparsity. |
| |
| There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS. |
| This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS |
| and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items. |
| Note that as such, this class cannot be insantiated directly. |
| |
| -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints |
| - `def from_dense()` - backend specific compression routines |
| - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear) |
| """ |
| |
| _DEFAULT_ALG_ID: int = 0 |
| _DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG] |
| _FORCE_CUTLASS: bool = True |
| _FUSE_TRANSPOSE: bool = False |
| _PROTOTYPE_WARNING_SHOWN: bool = False |
| |
| SPARSE_DISPATCH: Dict[Callable, Callable] |
| |
| packed: Optional[torch.Tensor] |
| meta: Optional[torch.Tensor] |
| packed_t: Optional[torch.Tensor] |
| meta_t: Optional[torch.Tensor] |
| threads_masks: Optional[torch.Tensor] |
| fuse_transpose_cusparselt: bool |
| alg_id_cusparselt: int |
| |
| __slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"] |
| |
| @staticmethod |
| def __new__( # noqa: PYI034 |
| cls, |
| shape: torch.Size, |
| packed: Optional[torch.Tensor], |
| meta: Optional[torch.Tensor], |
| packed_t: Optional[torch.Tensor], |
| meta_t: Optional[torch.Tensor], |
| threads_masks: Optional[torch.Tensor], |
| fuse_transpose_cusparselt: bool = False, |
| alg_id_cusparselt: int = 0, |
| requires_grad: bool = False, |
| ): |
| """ |
| Create a new instance of the tensor subclass from the compressed sparse representation. |
| |
| We have the option to create the subclass with the compressed representations of both X and X', for training. |
| For inference, we only need a single representation (either X or X'), while the corresponding other set will be None. |
| |
| Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS) |
| |
| Args: |
| shape: The shape of the original dense tensor |
| packed: The compressed representation of the original dense tensor |
| meta: The metadata of the original dense tensor, if it is stored separately |
| packed_t: The compressed representation of the transposed original dense tensor |
| meta_t: The metadata of the transposed original dense tensor, if it is stored separately |
| threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation. |
| Used for pointwise ops. |
| fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition |
| with a matmul, which is useful in the case of 2:4 sparse training. |
| alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance |
| |
| Returns: |
| torch.Tensor: A torch.Tensor wrapper subclass. |
| |
| Raises: |
| ValueError: If all of the tensor arguments are None. |
| """ |
| if not cls._PROTOTYPE_WARNING_SHOWN: |
| warnings.warn( |
| ( |
| "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " |
| "and will change in the near future. Please open a Github issue " |
| "for features requests and see our documentation on the torch.sparse " |
| "module for further information about the project." |
| ), |
| UserWarning, |
| ) |
| cls._PROTOTYPE_WARNING_SHOWN = True |
| |
| # Because this only runs onces, we also load the dispatch table here as well. |
| # We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead |
| # But this is useful since it allows users to overload the dispatch table for debugging / testing. |
| cls._load_dispatch_table() |
| |
| if packed is not None: |
| previous_tensor = packed |
| elif packed_t is not None: |
| previous_tensor = packed_t |
| else: |
| raise ValueError("At least one of packed or packed_t must be provided") |
| |
| kwargs = { |
| "device": previous_tensor.device, |
| "dtype": previous_tensor.dtype, |
| "layout": previous_tensor.layout, |
| "requires_grad": requires_grad, |
| } |
| tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| |
| tensor.packed = packed |
| tensor.meta = meta |
| tensor.packed_t = packed_t |
| tensor.meta_t = meta_t |
| tensor.threads_masks = threads_masks |
| tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt |
| tensor.alg_id_cusparselt = alg_id_cusparselt |
| return tensor |
| |
| def __repr__(self) -> str: # type: ignore[override] |
| assert hasattr(self, "shape") |
| return f"{self.__class__.__name__}(shape={self.shape})" |
| |
| def __tensor_flatten__( |
| self, |
| ) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]: |
| inner_tensors = list( |
| filter(lambda x: getattr(self, x) is not None, self.__slots__) |
| ) |
| tensor_meta = ( |
| self.shape, |
| self.fuse_transpose_cusparselt, |
| self.alg_id_cusparselt, |
| self.requires_grad, |
| ) |
| return inner_tensors, tensor_meta |
| |
| @classmethod |
| def __tensor_unflatten__( |
| cls, |
| inner_tensors, |
| tensor_meta : Tuple[torch.Size, bool, int, bool], |
| outer_size, |
| outer_stride, |
| ) -> torch.Tensor: |
| shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta |
| return cls( |
| shape=shape, |
| packed=inner_tensors.get("packed", None), |
| meta=inner_tensors.get("meta", None), |
| packed_t=inner_tensors.get("packed_t", None), |
| meta_t=inner_tensors.get("meta_t", None), |
| threads_masks=inner_tensors.get("threads_masks", None), |
| fuse_transpose_cusparselt=fuse_transpose_cusparselt, |
| alg_id_cusparselt=alg_id_cusparselt, |
| requires_grad=requires_grad, |
| ) |
| |
| __torch_function__ = torch._C._disabled_torch_function_impl |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: |
| if func._overloadpacket not in cls.SPARSE_DISPATCH: |
| raise NotImplementedError( |
| f"{cls.__name__} only supports a specific set of operations, " |
| f"can't perform requested op ({func.__name__})" |
| ) |
| return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs) |
| |
| @classmethod |
| def _load_dispatch_table(cls, custom_dispatch_table=None) -> None: |
| """ |
| Loads the op overload sparse dispatch table for the current class. |
| """ |
| if getattr(cls, "SPARSE_DISPATCH", None) is None: |
| cls.SPARSE_DISPATCH = { |
| torch.ops.aten.values: semi_sparse_values, |
| torch.ops.aten.indices: semi_sparse_indices, |
| torch.ops.aten.is_same_size: fallback_dispatcher, |
| torch.ops.aten.detach_: fallback_dispatcher, |
| torch.ops.aten.detach: semi_sparse_detach, |
| torch.ops.aten.t: semi_sparse_t, |
| torch.ops.aten.view: semi_sparse_view, |
| torch.ops.aten.mm: semi_sparse_mm, |
| torch.ops.aten.matmul: semi_sparse_mm, |
| torch.ops.aten.addmm: semi_sparse_addmm, |
| torch.ops.aten.linear: semi_sparse_linear, |
| } |
| if custom_dispatch_table is not None: |
| cls.SPARSE_DISPATCH.update(custom_dispatch_table) |
| |
| @classmethod |
| def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None: |
| """ |
| Assert that the given tensor is valid for semi-structured sparse compression. |
| """ |
| # check device |
| if not original_tensor.is_cuda: |
| raise RuntimeError( |
| f"Error original_tensor.device= {original_tensor.device} is not supported! " |
| "Only CUDA tensors are currently supported." |
| ) |
| |
| # check dim |
| if original_tensor.dim() != 2: |
| raise RuntimeError( |
| f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " |
| "Only 2d tensors are currently supported." |
| ) |
| |
| # check contiguous |
| if not original_tensor.is_contiguous(): |
| raise RuntimeError( |
| "Error original_tensor is not contiguous!" |
| "Only contiguous tensors are currently supported." |
| ) |
| |
| # check dtype |
| if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS: |
| raise RuntimeError( |
| f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " |
| "dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}" |
| ) |
| |
| # check shape |
| m, n = original_tensor.shape |
| min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows |
| min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols |
| if m < min_rows or m % min_rows or n < min_cols or n % min_cols: |
| # TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples |
| raise RuntimeError( |
| f"Error original_tensor.shape {original_tensor.shape} is not supported! " |
| f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" |
| ) |
| |
| @classmethod |
| def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: |
| """ |
| Calculates padding for dense tensor and pads tensor if necessary. |
| If padding is not required, this function returns the original tensor. |
| """ |
| # only 2d matmul |
| assert dense_input.dim() == 2 |
| |
| # check shape |
| m, n = dense_input.shape |
| min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows |
| min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols |
| |
| # calculate padding |
| to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 |
| to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 |
| if to_pad_m or to_pad_n: |
| return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) |
| else: |
| return dense_input |
| |
| def to_dense(self): |
| col = self.shape[-1] |
| return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) |
| |
| @classmethod |
| def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor": |
| raise NotImplementedError |
| |
| def _mm( |
| self, |
| B: torch.Tensor, |
| *, |
| bias: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| raise NotImplementedError |
| |
| |
| def to_sparse_semi_structured( |
| original_tensor: torch.Tensor, |
| transposed: bool = False, |
| ) -> SparseSemiStructuredTensor: |
| """ |
| This function converts a dense tensor into a sparse semi-structured tensor. |
| It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. |
| |
| This function will check to ensure the dense tensor has the right dtype, size, dims, and device. |
| We currently only support semi-structured sparse tensors for 2d CUDA tensors. |
| Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in |
| `_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8). |
| |
| Args: |
| original_tensor (Tensor): the dense tensor to convert |
| transposed (bool, optional): deprecated arg to be removed in another release. Do not use. |
| Returns: |
| SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor |
| Raises: |
| None |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() |
| tensor([[0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.], |
| ..., |
| [0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) |
| >>> A_sparse = to_sparse_semi_structured(A) |
| SparseSemiStructuredTensor(shape=torch.Size([128, 128])) |
| >>> A_sparse.values() |
| tensor([[1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.], |
| ..., |
| [1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), |
| >>> A_sparse.indices() |
| tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| ..., |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16)) |
| """ |
| if transposed: |
| raise DeprecationWarning( |
| "Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release." |
| "SparseSemiStructuredTensor only support contiguous input tensors. " |
| ) |
| |
| sparse_subclass = ( |
| torch.sparse.SparseSemiStructuredTensorCUTLASS |
| if SparseSemiStructuredTensor._FORCE_CUTLASS |
| else torch.sparse.SparseSemiStructuredTensorCUSPARSELT |
| ) |
| return sparse_subclass.from_dense(original_tensor) |
| |
| |
| class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): |
| """ |
| This class implements semi-structured sparsity for the CUTLASS backend. |
| |
| In this implementation, the specified elements and metadata are stored seprately, |
| in packed and meta respectively. |
| |
| When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear |
| and sparse_semi_structured_from_dense for conversion to the compressed format. |
| """ |
| |
| _DTYPE_SHAPE_CONSTRAINTS = { |
| torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16), |
| torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), |
| torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8), |
| torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4), |
| } |
| |
| @classmethod |
| def from_dense( |
| cls, original_tensor: torch.Tensor |
| ) -> "SparseSemiStructuredTensorCUTLASS": |
| cls._validate_device_dim_dtype_shape(original_tensor) |
| ( |
| sparse_tensor_cutlass, |
| meta_tensor_cutlass, |
| ) = sparse_semi_structured_from_dense_cutlass(original_tensor) |
| return cls( |
| original_tensor.shape, |
| packed=sparse_tensor_cutlass, |
| meta=meta_tensor_cutlass, |
| packed_t=None, |
| meta_t=None, |
| threads_masks=None, |
| requires_grad=original_tensor.requires_grad, |
| ) |
| |
| def to_dense(self): |
| assert self.meta is not None and self.packed is not None |
| return ( |
| sparse_semi_structured_to_dense_cutlass( |
| self.packed, |
| self.meta, |
| ) |
| if self.meta.ndim == 2 |
| else super().to_dense() |
| ) |
| |
| def _mm( |
| self, |
| B: torch.Tensor, |
| *, |
| bias: Optional[torch.Tensor] = None, |
| **kwargs |
| ) -> torch.Tensor: |
| if isinstance(B, SparseSemiStructuredTensor): |
| raise ValueError( |
| "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" |
| ) |
| cls_name = self.__class__.__name__ |
| if self.ndim != 2 or B.ndim != 2: |
| raise NotImplementedError( |
| f"`{cls_name}` matmul: Broadcasting is not implemented" |
| ) |
| if self.packed is None or self.meta is None: |
| raise NotImplementedError( |
| f"`{cls_name}` matmul: operation is not supported" |
| ) |
| else: |
| res = torch._sparse_semi_structured_linear( |
| B.t(), self.packed, self.meta, bias=bias |
| ).t() |
| return res[: self.shape[0]] |
| |
| |
| class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): |
| """ |
| The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: |
| packed = [ specified elements of original tensor | metadata ] |
| For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements |
| The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t |
| attributes respectively. |
| |
| cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well |
| as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes. |
| """ |
| |
| _DTYPE_SHAPE_CONSTRAINTS = { |
| torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16), |
| torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), |
| torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8), |
| torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4), |
| } |
| |
| @classmethod |
| def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT": |
| cls._validate_device_dim_dtype_shape(original_tensor) |
| return cls( |
| shape=original_tensor.shape, |
| packed=torch._cslt_compress(original_tensor), |
| meta=None, |
| packed_t=None, |
| meta_t=None, |
| threads_masks=None, |
| fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE, |
| alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID, |
| requires_grad=original_tensor.requires_grad, |
| ) |
| |
| def _mm( |
| self, |
| B: torch.Tensor, |
| *, |
| bias: Optional[torch.Tensor] = None, |
| **kwargs |
| ) -> torch.Tensor: |
| if isinstance(B, SparseSemiStructuredTensor): |
| raise ValueError( |
| "`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware" |
| ) |
| if self.ndim != 2 or B.ndim != 2: |
| raise NotImplementedError( |
| f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented" |
| ) |
| if B.dtype != self.dtype: |
| raise NotImplementedError( |
| f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, " |
| f"with A.dtype={self.dtype} and B.dtype={B.dtype}. " |
| "This operation is only supported when A and B have the same data type." |
| ) |
| if bias is not None and bias.dtype != self.dtype: |
| raise NotImplementedError( |
| f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, " |
| "with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. " |
| "This operation is only supported when A, B and C have the same data type." |
| ) |
| if self.packed is None: |
| raise NotImplementedError( |
| f"`{self.__class__.__name__}` matmul: operation is not supported" |
| ) |
| else: |
| res = torch._cslt_sparse_mm( |
| self.packed, |
| B, |
| bias=bias, |
| transpose_result=self.fuse_transpose_cusparselt, |
| alg_id=self.alg_id_cusparselt, |
| ) |
| return res.t() if self.fuse_transpose_cusparselt else res |