[core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)

This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
diff --git a/benchmarks/sparse/benchmark_semi_structured_sparsity.py b/benchmarks/sparse/benchmark_semi_structured_sparsity.py
new file mode 100644
index 0000000..c6753a9
--- /dev/null
+++ b/benchmarks/sparse/benchmark_semi_structured_sparsity.py
@@ -0,0 +1,245 @@
+import random
+import torch
+import torch.utils.benchmark as benchmark
+from torch import nn
+from tqdm import tqdm
+import pandas as pd
+import argparse
+from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
+
+
+torch.set_printoptions(
+    precision=2,
+    threshold=None,
+    edgeitems=16,
+    linewidth=480,
+    profile=None,
+    sci_mode=False,
+)
+
+
+# helper model definition for pruner
+class Model(nn.Module):
+    def __init__(self, m, k, dtype=None):
+        super().__init__()
+        # transposed so reversed
+        self.linear = nn.Linear(k, m)
+
+    def forward(self, x):
+        return self.linear(x)
+
+
+def rand_sparse_semi_structured_mask(
+    r, c, dtype=torch.float16, device="cuda", choice=None
+):
+    """
+    This function returns a 1:2 sparse matrix of size (r, c).
+    Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
+    """
+
+    choices = [[0, 1], [1, 0]]
+    mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
+
+    return (
+        torch.tensor(mask_entries, dtype=dtype, device=device)
+        .reshape(r, c)
+        .contiguous()
+    )
+
+
+def test_linear(m, k, n, dtype, contiguous, backend):
+    SparseSemiStructuredTensor.fuse_transpose = contiguous
+    mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
+    sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask
+    input_tensor = torch.zeros(n, k).to(dtype).cuda()
+    model = Model(m, k).to(dtype).cuda().eval()
+
+    dense_measurement = benchmark.Timer(
+        stmt="model(input_tensor)",
+        globals=locals(),
+    ).blocked_autorange()
+
+    dense_output = model(input_tensor)
+
+    # sparsify weights
+    model.linear.weight = nn.Parameter(to_sparse_semi_structured(sparse_weight, mask=mask.bool()))
+
+    sparse_output = model(input_tensor)
+
+    sparse_measurement = benchmark.Timer(
+        stmt="model(input_tensor)",
+        globals=locals(),
+    ).blocked_autorange()
+
+    correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
+
+    return {
+        "test_function": "linear",
+        "m": m,
+        "k": k,
+        "n": n,
+        "dtype": str(dtype),
+        "backend": backend,
+        "sparse_latency (ms)": sparse_measurement.median * 1000,
+        "dense_latency (ms)": dense_measurement.median * 1000,
+        "speedup (d/s)": dense_measurement.median / sparse_measurement.median,
+        "correct": correct,
+        "contiguous": sparse_output.is_contiguous(),
+    }
+
+
+def test_tensor(m, k, n, dtype, contiguous, backend):
+    A = rand_sparse_semi_structured_mask(m, k, dtype=dtype)
+    B = torch.zeros(k, n).to(dtype).cuda()
+    bias = torch.rand(n).to(dtype).cuda()
+
+    sA = to_sparse_semi_structured(A, mask=A.bool())
+
+    # torch.mm calculation
+    if dtype is not torch.int8:
+        dense_output = torch.mm(A, B)
+
+        dense_measurement = benchmark.Timer(
+            stmt="torch.mm(A, B)",
+            globals=locals(),
+        ).blocked_autorange()
+
+    else:
+        print("int8 baseline not supported")
+        dense_output = torch.mm(sA, B)
+
+        dense_measurement = benchmark.Timer(
+            stmt="torch.mm(sA, B)",
+            globals=locals(),
+        ).blocked_autorange()
+
+    sparse_output = torch.mm(sA, B)
+    sparse_measurement = benchmark.Timer(
+        stmt="torch.mm(sA, B)",
+        globals=locals(),
+    ).blocked_autorange()
+
+    correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3)
+
+    return {
+        "test_function": "tensor",
+        "m": m,
+        "k": k,
+        "n": n,
+        "dtype": str(dtype),
+        "backend": backend,
+        "sparse_latency (ms)": sparse_measurement.median * 1000,
+        "dense_latency (ms)": dense_measurement.median * 1000,
+        "speedup (d/s)": dense_measurement.median / sparse_measurement.median,
+        "correct": correct,
+        "contiguous": sparse_output.is_contiguous(),
+    }
+
+
+if __name__ == "__main__":
+    dtype_lookup = {
+        "int8": torch.int8,
+        "fp16": torch.float16,
+        "bf16": torch.bfloat16,
+        "fp32": torch.float32,
+    }
+
+    parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks")
+    parser.add_argument(
+        "--mode",
+        type=str,
+        choices=[
+            "nvidia-bert",
+            "nvidia-fixed-k",
+            "nvidia-fixed-mn",
+        ],
+    )
+    parser.add_argument(
+        "--dtype",
+        type=str,
+        choices=dtype_lookup.keys(),
+        default="fp16",
+    )
+    parser.add_argument(
+        "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
+    )
+    parser.add_argument("-contiguous", action="store_true")
+    parser.add_argument("-e2e", action="store_true")
+    parser.add_argument("-save", action="store_true")
+    args = parser.parse_args()
+
+    if args.e2e:
+        eval_fn = test_linear
+    else:
+        eval_fn = test_tensor
+
+    print(f"Started benchmark: {args.mode} | dtype: {args.dtype}")
+    dtype = dtype_lookup[args.dtype]
+
+    if args.mode == "nvidia-bert":
+        bert_shapes = [
+            (3072, 1024, 16384),
+            (4096, 1024, 16384),
+            (1024, 1024, 16384),
+            (1024, 4096, 16384),
+        ]
+        results = (
+            eval_fn(m, k, n, dtype, args.contiguous, args.backend)
+            for (m, k, n) in tqdm(bert_shapes)
+        )
+
+    elif args.mode == "nvidia-fixed-k":
+        mn_vals = [
+            3072,
+            4096,
+            5120,
+            6144,
+            7168,
+            8192,
+            9216,
+            10240,
+            11264,
+            12288,
+            13312,
+            14336,
+            15360,
+            16384,
+            17408,
+            18432,
+            19456,
+            20480,
+        ]
+        results = (
+            eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend)
+            for mn in tqdm(mn_vals)
+        )
+
+    elif args.mode == "nvidia-fixed-mn":
+        k_vals = [
+            2560,
+            3840,
+            5120,
+            6400,
+            7680,
+            8960,
+            10240,
+            11520,
+            12800,
+            14080,
+            15360,
+            16640,
+            17920,
+            19200,
+            20480,
+        ]
+        results = (
+            eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend)
+            for k in tqdm(k_vals)
+        )
+
+    df = pd.DataFrame.from_records(results)
+    if args.save:
+        save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
+        df.to_csv(save_file)
+        print(f"Finished benchmark: {args.mode} saved results to {save_file}")
+    print(df)
diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst
index c273f74..364d457 100644
--- a/docs/source/sparse.rst
+++ b/docs/source/sparse.rst
@@ -24,7 +24,7 @@
 mostly zero valued*. We recognize these are important applications and aim
 to provide performance optimizations for these use cases via sparse storage formats.
 
-Various sparse storage formats such as COO, CSR/CSC, LIL, etc. have been
+Various sparse storage formats such as COO, CSR/CSC, semi-structured, LIL, etc. have been
 developed over the years. While they differ in exact layouts, they all
 compress data through efficient representation of zero valued elements.
 We call the uncompressed values *specified* in contrast to *unspecified*,
@@ -67,6 +67,8 @@
 
 PyTorch currently supports :ref:`COO<sparse-coo-docs>`, :ref:`CSR<sparse-csr-docs>`,
 :ref:`CSC<sparse-csc-docs>`, :ref:`BSR<sparse-bsr-docs>`, and :ref:`BSC<sparse-bsc-docs>`.
+
+We also have a prototype implementation to support :ref: `semi-structured sparsity<sparse-semi-structured-docs>`.
 Please see the references for more details.
 
 Note that we provide slight generalizations of these formats.
@@ -167,6 +169,147 @@
 and recognize it is an important feature to plan a more optimal path of execution for
 any given model.
 
+.. _sparse-semi-structured-docs:
+
+Sparse Semi-Structured Tensors
+++++++++++++++++++++++++++++++
+
+.. warning::
+
+   Sparse semi-sturctured tensors are currently a prototype feature and subject to change. Please feel free to open an issue to report a bug or if you have feedback to share.
+
+Semi-Structured sparsity is a sparse data layout that was first introduced in NVIDIA's Ampere architecture. It is also referred to as **fine-grained structured sparsity** or **2:4 structured sparsity**.
+
+This sparse layout stores `n` elements out of every `2n` elements, with `n` being determined by the width of the Tensor's data type (dtype). The most frequently used dtype is float16, where `n=2`, thus the term "2:4 structured sparsity."
+
+Semi-structured sparsity is explained in greater detail in `this NVIDIA blog post <https://developer.nvidia.com/blog/exploiting-ampere-structured-sparsity-with-cusparselt>`_.
+
+In PyTorch, semi-structured sparsity is implemented via a Tensor subclass.
+By subclassing, we can override ``__torch_dispatch__`` , allowing us to use faster sparse kernels when performing matrix multiplication.
+We can also store the tensor in it's compressed form inside the subclass to reduce memory overhead.
+
+In this compressed form, the sparse tensor is stored by retaining only the *specified* elements and some metadata, which encodes the mask.
+
+.. note::
+    The specified elements and metadata mask of a semi-structured sparse tensor are stored together in a single
+    flat compressed tensor. They are appended to each other to form a contiguous chunk of memory.
+
+    compressed tensor = [ specified elements of original tensor |   metadata_mask ]
+
+    For an original tensor of size `(r, c)` we expect the first `m * k // 2` elements to be the kept elements
+    and the rest of the tensor is metadata.
+
+    In order to make it easier for the user to view the specified elements
+    and mask, one can use ``.indices()`` and ``.values()`` to access the mask and specified elements respectively.
+
+
+    - ``.values()`` returns the specified elements in a tensor of size `(r, c//2)` and with the same dtype as the dense matrix.
+
+    - ``.indices()`` returns the metadata_mask in a tensor of size `(r, c//2 )` and with element type ``torch.int16`` if dtype is torch.float16 and element type ``torch.int32`` if dtype is torch.int8.
+
+
+For 2:4 sparse tensors, the metadata overhead is minor - just 2 bits per specified element.
+
+.. note::
+  It's important to note that ``torch.float32`` is only supported for 1:2 sparsity. Therefore, it does not follow the same formula as above.
+
+Here, we break down how to calculate the compression ratio ( size dense / size sparse) of a 2:4 sparse tensor.
+
+Let `(r, c) = tensor.shape` and `e = bitwidth(tensor.dtype)`, so `e = 16` for ``torch.float16`` and ``torch.bfloat16`` and `e = 8` for ``torch.int8``.
+
+.. math::
+  M_{dense} = r \times c \times e \\
+  M_{sparse} = M_{specified} + M_{metadata} = r \times \frac{c}{2} \times e + r \times \frac{c}{2} \times 2 = \frac{rce}{2} + rc =rce(\frac{1}{2} +\frac{1}{e})
+
+Using these calculations, we can determine the total memory footprint for both the original dense and the new sparse representation.
+
+This gives us a simple formula for the compression ratio, which is dependent only on the bitwidth of the tensor datatype.
+
+.. math::
+  C = \frac{M_{sparse}}{M_{dense}} =  \frac{1}{2} + \frac{1}{e}
+
+By using this formula, we find that the compression ratio is 56.25% for ``torch.float16`` and 62.5% for ``torch.int8``.
+
+Constructing Sparse Semi-Structured Tensors
+-------------------------------------------
+
+You can transform a dense tensor into a sparse semi-structured tensor by using the ``torch.sparse.to_sparse_semi_structured`` function.
+
+Please also note that we only support CUDA tensors since hardware compatibility for semi-structured sparsity is limited to NVIDIA GPUs.
+
+
+The following datatypes are supported for semi-structured sparsity. Note that each datatype has its own shape constraints and compression factor.
+
+.. csv-table::
+   :header: "PyTorch dtype", "Shape Constraints", "Compression Factor", "Sparsity Pattern"
+   :widths: 15, 45, 10, 10
+   :delim: ;
+
+   ``torch.float16``; Tensor must be 2D and (r, c) must both be a positive multiple of 64;9/16;2:4
+   ``torch.int8``; Tensor must be 2D and (r, c) must both be a positive multiple of 128;10/16;2:4
+
+
+To construct a semi-structured sparse tensor, start by creating a regular dense tensor that adheres to a 2:4 (or semi-structured) sparse format.
+To do this we  tile a small 1x4 strip to create a 16x16 dense float16 tensor.
+Afterwards, we can call ``to_sparse_semi_structured`` on this matrix to compress it for accelerated inference.
+
+    >>> from torch.sparse import to_sparse_semi_structured
+    >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
+    tensor([[0., 0., 1.,  ..., 0., 1., 1.],
+            [0., 0., 1.,  ..., 0., 1., 1.],
+            [0., 0., 1.,  ..., 0., 1., 1.],
+            ...,
+            [0., 0., 1.,  ..., 0., 1., 1.],
+            [0., 0., 1.,  ..., 0., 1., 1.],
+            [0., 0., 1.,  ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
+    >>> A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+    SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1.,  ..., 1., 1., 1.],
+            [1., 1., 1.,  ..., 1., 1., 1.],
+            [1., 1., 1.,  ..., 1., 1., 1.],
+            ...,
+            [1., 1., 1.,  ..., 1., 1., 1.],
+            [1., 1., 1.,  ..., 1., 1., 1.],
+            [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), metadata=tensor([[-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+            [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+            [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+            ...,
+            [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+            [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+            [-4370, -4370, -4370,  ..., -4370, -4370, -4370]], device='cuda:0',
+    dtype=torch.int16))
+
+Sparse Semi-Structured Tensor Operations
+----------------------------------------
+
+Currently, the following operations are supported for semi-structured sparse tensors:
+
+- torch.addmm(bias, dense, sparse.t())
+- torch.mm(dense, sparse)
+- torch.mm(sparse, dense)
+- aten.linear.default(dense, sparse, bias)
+- aten.t.default(sparse)
+- aten.t.detach(sparse)
+
+To use these ops, simply pass the output of ``to_sparse_semi_structured(tensor)``  instead of using ``tensor`` once your tensor has 0s in a semi-structured sparse format, like this:
+
+    >>> a = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).half().cuda()
+    >>> b = torch.rand(64, 64).half().cuda()
+    >>> c = torch.mm(a, b)
+    >>> a_sparse = to_sparse_semi_structured(a, mask=a.bool())
+    >>> torch.allclose(c, torch.mm(a_sparse, b))
+    True
+
+Under the hood, SparseSemiStructuredTensor will call ``torch._structured_sparse_linear`` for accelerated inference using CUTLASS sparse kernels.
+
+Accelerating nn.Linear with semi-structured sparsity
+----------------------------------------------------
+You can accelerate the linear layers in your model if the weights are already semi-structured sparse with just a few lines of code:
+
+    >>> input = torch.rand(64, 64).half().cuda()
+    >>> mask = torch.Tensor([0, 0, 1, 1]).tile((64, 16)).cuda().bool()
+    >>> linear = nn.Linear(64, 64).half().cuda()
+    >>> linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight, mask=mask))
+
 
 .. _sparse-coo-docs:
 
@@ -992,12 +1135,18 @@
    :func:`torch.mv`;no; ``M[sparse_csr] @ V[strided] -> V[strided]``
    :func:`torch.matmul`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
    :func:`torch.matmul`; no; ``M[sparse_csr] @ M[strided] -> M[strided]``
+   :func:`torch.matmul`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]``
+   :func:`torch.matmul`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]``
    :func:`torch.mm`; no; ``M[sparse_coo] @ M[strided] -> M[strided]``
+   :func:`torch.mm`; no; ``M[SparseSemiStructured] @ M[strided] -> M[strided]``
+   :func:`torch.mm`; no; ``M[strided] @ M[SparseSemiStructured] -> M[strided]``
    :func:`torch.sparse.mm`; yes; ``M[sparse_coo] @ M[strided] -> M[strided]``
    :func:`torch.smm`; no; ``M[sparse_coo] @ M[strided] -> M[sparse_coo]``
    :func:`torch.hspmm`; no; ``M[sparse_coo] @ M[strided] -> M[hybrid sparse_coo]``
    :func:`torch.bmm`; no; ``T[sparse_coo] @ T[strided] -> T[strided]``
    :func:`torch.addmm`; no; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
+   :func:`torch.addmm`; no; ``f * M[strided] + f * (M[SparseSemiStructured] @ M[strided]) -> M[strided]``
+   :func:`torch.addmm`; no; ``f * M[strided] + f * (M[strided] @ M[SparseSemiStructured]) -> M[strided]``
    :func:`torch.sparse.addmm`; yes; ``f * M[strided] + f * (M[sparse_coo] @ M[strided]) -> M[strided]``
    :func:`torch.sspaddmm`; no; ``f * M[sparse_coo] + f * (M[sparse_coo] @ M[strided]) -> M[sparse_coo]``
    :func:`torch.lobpcg`; no; ``GENEIG(M[sparse_coo]) -> M[strided], M[strided]``
diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py
new file mode 100644
index 0000000..7f2c813
--- /dev/null
+++ b/test/test_sparse_semi_structured.py
@@ -0,0 +1,227 @@
+# Owner(s): ["module: sparse"]
+import random
+import unittest
+
+import torch
+from torch import nn
+
+from torch.sparse.semi_structured import (
+    _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
+    SparseSemiStructuredTensor,
+    to_sparse_semi_structured,
+)
+
+from torch.testing._internal.common_device_type import (
+    dtypes,
+    instantiate_device_type_tests,
+)
+
+from torch.testing._internal.common_dtype import all_types_and_complex
+
+from torch.testing._internal.common_utils import (
+    parametrize,
+    run_tests,
+    subtest,
+    TestCase,
+)
+
+SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
+
+_IS_SM8X = False
+if torch.cuda.is_available():
+    _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
+
+def rand_sparse_semi_structured_mask(
+    r, c, dtype=torch.float16, device="cuda", choice=None
+):
+    """
+    This function returns a 1:2 sparse matrix of size (r, c).
+    Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
+    """
+
+    choices = [[0, 1], [1, 0]]
+    mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
+
+    return (
+        torch.tensor(mask_entries, dtype=dtype, device=device)
+        .reshape(r, c)
+        .contiguous()
+    )
+
+
+class TestSparseSemiStructured(TestCase):
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
+    def test_to_sparse_semi_structured(self, dtype):
+        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
+        A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+
+        assert A.shape == A_sparse.shape
+        assert A.device == A_sparse.device
+        assert A.dtype == A_sparse.dtype
+
+        assert isinstance(A, torch.Tensor)
+        assert isinstance(A_sparse, SparseSemiStructuredTensor)
+
+        with self.assertRaisesRegex(
+            NotImplementedError,
+            "You must pass in a mask to to_sparse_semi_structured, currently mask=None.",
+        ):
+            A_sparse = to_sparse_semi_structured(A)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
+    def test_mm_sparse_first_NT(self, dtype, device):
+        """
+        Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
+        Ensure torch.mm(A_sparse, B.t()) is correct
+        """
+        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
+        A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+
+        B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
+
+        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
+        if dtype is torch.int8:
+            # This should fail
+            with self.assertRaisesRegex(RuntimeError, "_structured_sparse_linear"):
+                sparse_result = torch.mm(A_sparse, B)
+
+            # test transpose
+            # NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
+            # CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
+            dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32)
+            sparse_result = torch.mm(A_sparse, B.t())
+            assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
+        else:
+            dense_result = torch.mm(A, B)
+            sparse_result = torch.mm(A_sparse, B)
+            assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
+            # test transpose
+            dense_result = torch.mm(A, B.t())
+            sparse_result = torch.mm(A_sparse, B.t())
+            assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
+    def test_mm_sparse_first_T(self, dtype, device):
+        """
+        Ensure torch.mm(A_sparse.t(), B) throws error
+        """
+        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
+        A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+
+        B = torch.rand((128, 128), device=A_sparse.device).to(dtype)
+
+        with self.assertRaisesRegex(
+            NotImplementedError,
+            r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
+        ):
+            torch.mm(A_sparse.t(), B)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
+    def test_mm_sparse_second_T(self, dtype, device):
+        """
+        Ensure torch.mm(A, B_sparse.t()) is correct
+        """
+        B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
+        B_sparse = to_sparse_semi_structured(B, mask=B.bool())
+
+        A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
+
+        # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
+        if dtype is torch.int8:
+            dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32)
+            sparse_result = torch.mm(A, B_sparse.t())
+        else:
+            dense_result = torch.mm(A, B.t())
+            sparse_result = torch.mm(A, B_sparse.t())
+
+        assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
+    def test_mm_sparse_second_NT(self, dtype, device):
+        """
+        Ensure torch.mm(A, B_sparse) throws error
+        """
+        B = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
+        B_sparse = to_sparse_semi_structured(B, mask=B.bool())
+
+        A = torch.rand((128, 128), device=B_sparse.device).to(dtype)
+
+        with self.assertRaisesRegex(
+            NotImplementedError,
+            r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
+        ):
+            sparse_result = torch.mm(A, B_sparse)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @parametrize("inference_mode", [subtest(False), subtest(True)])
+    def test_linear(self, inference_mode, device):
+        """
+        Test nn.Linear has the same numerics
+        """
+        input = torch.rand(128, 128, device=device).half()
+        model = nn.Linear(128, 128).to(device).half()
+        m, n = model.weight.shape
+        mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
+        # set masked weight
+        model.weight = nn.Parameter(model.weight * mask)
+
+        dense_result = model(input)
+        model.weight = nn.Parameter(to_sparse_semi_structured(model.weight, mask=mask))
+
+        if inference_mode:
+            with torch.inference_mode():
+                sparse_result = model(input)
+        else:
+            sparse_result = model(input)
+
+        assert torch.allclose(dense_result, sparse_result, rtol=1e-5, atol=1e-5)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    def test_values(self):
+        A = rand_sparse_semi_structured_mask(128, 128)
+        A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+        assert A_sparse.values().shape == (128, 64)
+        assert (A_sparse.values() == 1).all()
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    def test_indices(self):
+        A = rand_sparse_semi_structured_mask(128, 128)
+        A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+        assert A_sparse.indices().shape == (128, 8)
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
+    def test_unsupported_shape(self, dtype, device):
+        A = rand_sparse_semi_structured_mask(4, 4, dtype=dtype, device=device)
+        with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
+            A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    @dtypes(*all_types_and_complex())
+    def test_unsupported_dtype(self, dtype, device):
+        A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
+
+        if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
+            with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
+                A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+        else:
+            A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+
+    @unittest.skipIf(not _IS_SM8X, "semi-structured sparsity not supported on this library version")
+    def test_unsupported_dim(self, device):
+        A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
+
+        with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
+            A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+
+
+instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py
index 6f05dfb..b1a91fb 100644
--- a/torch/sparse/__init__.py
+++ b/torch/sparse/__init__.py
@@ -5,6 +5,9 @@
 from torch._C import _add_docstr, _sparse  # type: ignore[attr-defined]
 from torch import Tensor
 
+# Semi structured sparsity support
+from .semi_structured import SparseSemiStructuredTensor, to_sparse_semi_structured
+
 # A workaround to support both TorchScript and MyPy:
 from typing import TYPE_CHECKING
 if TYPE_CHECKING:
@@ -23,9 +26,10 @@
     'sum',
     'softmax',
     'log_softmax',
+    'SparseSemiStructuredTensor',
+    'to_sparse_semi_structured',
 ]
 
-
 addmm = _add_docstr(_sparse._sparse_addmm, r"""
 sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
 
diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py
new file mode 100644
index 0000000..1bf13c0
--- /dev/null
+++ b/torch/sparse/semi_structured.py
@@ -0,0 +1,389 @@
+import warnings
+from collections import namedtuple
+from typing import Any, Optional
+
+import torch
+
+
+__all__ = [
+    "to_sparse_semi_structured",
+    "SparseSemiStructuredTensor",
+]
+
+_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
+    "_SEMI_STRUCTURED_SPARSE_CONFIG", "compression_factor min_size"
+)
+_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = {
+    torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(9, 64),
+    torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(10, 128),
+}
+
+_WARNING_SHOWN = False
+
+class SparseSemiStructuredTensor(torch.Tensor):
+    """This class implementes semi-structured sparsity as a Tensor subclass.
+
+    Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
+    depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
+    structured sparsity.
+
+    Currently, this class supports 2:4 sparsity for int8 and float16 dtypes.
+
+    This subclass stores the dense tensor in a compressed form by only storing the specified elemenets and a metadata mask.
+    These two are stored next to each other in one contiguous tensor.
+
+    We choose to store the specified elements and the metadata in a single tensor for future compatibilty with cuSPARSELt.
+
+    compressed tensor = [ specified elements of original tensor |   mask_metadata     ]
+
+    For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
+    The rest of the tensor is metadata.
+
+    This subclass also overrides __torch_dispatch__ to use _structured_sparse_linear for faster matrix multiplications
+    via sparse CUTLASS kernels. In the future we will also call into cuSPARSELt kernels for more performance gains.
+    """
+
+    @staticmethod
+    def __new__(
+        cls,
+        original_tensor: Optional[torch.Tensor],
+        original_shape: Optional[torch.Size] = None,
+        mask: Optional[torch.Tensor] = None,
+        compressed_tensor: Optional[torch.Tensor] = None,
+        transposed: bool = False,
+    ):
+        """
+        Create a new instance of the class.
+
+        When original_tensor is passed in, we compress it and store the compresed representation.
+        We can also create new instance of the class from the compressed representation without the original tensor.
+
+        Args:
+            original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
+            original_shape: The shape of the original dense tensor
+            mask: Mask to be applied to the original tensor.
+            compressed_tensor: A flattened tensor to store the specified elements and mask metadata.
+            transposed: Whether the tensor is transposed or not.
+
+        Returns:
+            torch.Tensor: A torch.Tensor wrapper subclass.
+
+        Raises:
+            ValueError: If both original_tensor and compressed_tensor are None.
+
+        """
+        if original_tensor is not None:
+            previous_tensor = original_tensor
+            original_shape = original_tensor.shape
+        elif compressed_tensor is not None:
+            previous_tensor = compressed_tensor
+        else:
+            raise ValueError("Both compressed_tensor and original_tensor are None!")
+
+        kwargs = {}
+        kwargs["device"] = previous_tensor.device  # type: ignore[assignment]
+        kwargs["dtype"] = previous_tensor.dtype  # type: ignore[assignment]
+        kwargs["layout"] = previous_tensor.layout  # type: ignore[assignment]
+        kwargs["requires_grad"] = False  # type: ignore[assignment]
+
+        return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs)  # type: ignore[attr-defined]
+
+    def __init__(
+        self,
+        original_tensor: Optional[torch.Tensor],
+        original_shape: Optional[torch.Size] = None,
+        mask: Optional[torch.Tensor] = None,
+        compressed_tensor: Optional[torch.Tensor] = None,
+        transposed: bool = False,
+    ) -> None:
+        """SparseSemiStructuredTensor constructor.
+
+        Args:
+            original_tensor: The original dense tensor, or None, if we have already compressed the tensor.
+            original_shape: The shape of the original dense tensor
+            mask: Mask to be applied to the original tensor.
+            compressed_tensor: A flattened tensor to store the specified elements and mask metadata.
+            transposed: Whether the tensor is transposed or not.
+
+        Returns:
+            None
+
+        Raises:
+            NotImplementedError: If ``mask=None``, as we currently do not support inferring a mask from the dense tensor.
+            RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device.
+        """
+        global _WARNING_SHOWN
+        if not _WARNING_SHOWN:
+            warnings.warn(
+                (
+                    "The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
+                    "and will change in the near future. Please open a Github issue "
+                    "for features requests and see our documentation on the torch.sparse "
+                    "module for further information about the project."
+                ),
+                UserWarning,
+            )
+            _WARNING_SHOWN = True
+
+        # if original tensor is passed in, we need to compress it and store the compressed representation.
+        if original_tensor is not None:
+            # check if mask passed in
+            if mask is None:
+                raise NotImplementedError("You must pass in a mask to to_sparse_semi_structured, currently mask=None.")
+
+            # check device
+            if not original_tensor.is_cuda:
+                raise RuntimeError(
+                    (
+                        f"Error original_tensor.device= {original_tensor.device} is not supported! "
+                        "Only CUDA tensors are currently supported."
+                    )
+                )
+
+            # check dim
+            if original_tensor.dim() != 2:
+                raise RuntimeError(
+                    (
+                        f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
+                        "Only 2d tensors are currently supported."
+                    )
+                )
+
+            # check dtype
+            if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG:
+                raise RuntimeError(
+                    (
+                        f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
+                        "dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}"
+                    )
+                )
+
+            # check shape
+            m, n = original_tensor.shape
+            min_size = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_size
+            if m < min_size or m % min_size or n < min_size or n % min_size:
+                # TODO in the future we can add in padding to support dimensions that aren't perfect multiples
+                raise RuntimeError(
+                    (
+                        f"Error original_tensor.shape {original_tensor.shape} is not supported! "
+                        "Both dimensions must be larger than and a multiple of {min_size}"
+                    )
+                )
+
+            # This code calculates the size of the compressed tensor.
+            # compression factor is different based on dtype it's given by the formula below for 2:4 sparsity:
+            # compression_factor = 1/2 + 1/bitwidth(dtype)
+            original_size = original_tensor.nelement()
+            compression_factor = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[
+                original_tensor.dtype
+            ].compression_factor
+            compressed_size = original_size * compression_factor // 16
+
+            compressed_tensor = torch.empty(
+                (compressed_size,),
+                dtype=original_tensor.dtype,
+                device=original_tensor.device,
+            )
+
+            # TODO This is a temporoary hack to get the mask in compressed form so we can store the compressed tensor.
+            # In the future, we will add in a conversion function from the mask to the meta that we can use instead.
+            placeholder = torch.ones(
+                (128, n), dtype=original_tensor.dtype, device=original_tensor.device
+            )
+            specified = original_tensor.masked_select(mask).view(m, n // 2)
+            _, meta = torch._structured_sparse_linear(placeholder, specified, mask)
+            # set the specified elements
+            compressed_tensor[: m * n // 2] = specified.view(-1)
+            # set the metadata
+            compressed_tensor[m * n // 2 :] = meta.view(original_tensor.dtype).view(-1)
+
+        # set values
+        self.original_tensor = None
+        self.compressed_tensor = compressed_tensor
+        self.transposed = transposed
+
+    def __repr__(self) -> str:
+        """Return string representation of SparseSemiStructuredTensor
+
+        Returns:
+            str: String representation
+
+        Raises:
+            None
+        """
+        return (
+            f"SparseSemiStructuredTensor(shape={self.shape}, "
+            f"transposed={self.transposed}"
+            f"values={self.values()}"
+            f"metadata={self.indices()})"
+        )
+
+    __torch_function__ = torch._C._disabled_torch_function_impl
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
+        """Overload __torch_dispatch__ to use torch._structured_sparse_linear.
+
+        `torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
+        In the future we plan to also add in support for cuSPARSELt kernels.
+
+        Args:
+            func: The function being dispatched.
+            types: The types of the arguments.
+            args: The arguments passed to the function.
+            kwargs: The keyword arguments passed to the function.
+
+        Returns:
+            Any: The result of the dispatched operation.
+
+        Raises:
+            NotImplementedError: If the dispatched operation is not implemented.
+        """
+        # Since this code runs below autograd, a detach corresponds to only returning a new object
+        if func is torch.ops.aten.detach.default:
+            return SparseSemiStructuredTensor(
+                args[0].original_tensor,
+                original_shape=args[0].shape,
+                mask=None,
+                compressed_tensor=args[0].compressed_tensor,
+                transposed=args[0].transposed,
+            )
+
+        # Because we cannot go from the compressed representation back to the dense representation currently,
+        # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
+        # is the first or second argument, we expect an even / odd number of calls to transpose respectively.
+        if func is torch.ops.aten.t.default:
+            return SparseSemiStructuredTensor(
+                args[0].original_tensor,
+                original_shape=args[0].shape,
+                mask=None,
+                compressed_tensor=args[0].compressed_tensor,
+                transposed=not args[0].transposed,
+            )
+
+        # handle addmm
+        if func is torch.ops.aten.addmm.default:
+            bias, input_A, input_B = args
+
+            # Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS.
+            # CUTLASS only supports the first input to be sparse for a given matmul.
+            # cuSPARSELt does not have this limitation, although our implementation is only for sparse first.
+
+            # We support second matrix sparse matmul by taking advantage of some transpose properties:
+            # This is also why we want an odd number of transposed for second matrix sparse vs an even number
+            # of transpose calss for first matrix sparse.
+            # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
+            #        = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
+            if isinstance(input_B, cls) and input_B.transposed:
+                result, _ = torch._structured_sparse_linear(
+                    input_A, input_B.values(), input_B.indices(), bias=bias
+                )
+                return result
+
+        # handle mm
+        if func is torch.ops.aten.mm.default:
+            input_A, input_B = args
+
+            if isinstance(input_A, cls) and not input_A.transposed:
+                transposed_result, _ = torch._structured_sparse_linear(
+                    input_B.t(), input_A.values(), input_A.indices()
+                )
+                return transposed_result.t()
+
+            elif isinstance(input_B, cls) and input_B.transposed:
+                result, _ = torch._structured_sparse_linear(
+                    input_A, input_B.values(), input_B.indices()
+                )
+                return result
+
+        # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(),
+        # so we must match the aten.linear op.
+        # TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here.
+        if func is torch.ops.aten.linear.default:
+            input_tensor, weight, bias = args
+            if isinstance(weight, cls):
+                result, _ = torch._structured_sparse_linear(
+                    input_tensor, weight.values(), weight.indices(), bias=bias
+                )
+                return result
+
+        # handle values
+        if func is torch.ops.aten.values.default:
+            m, k = args[0].shape
+            num_kept_elements = m * k // 2
+            return args[0].compressed_tensor[:num_kept_elements].view(m, k // 2)
+
+        # handle indices
+        if func is torch.ops.aten.indices.default:
+            m, k = args[0].shape
+            num_kept_elements = m * k // 2
+            metadata = args[0].compressed_tensor[num_kept_elements:].view(m, -1)
+
+            # the metadata is expected to be in different datatypes for fp16/int8 respectively for CUTLASS.
+            if args[0].dtype is torch.int8:
+                return metadata.view(torch.int32)
+            elif args[0].dtype is torch.float16:
+                return metadata.view(torch.int16)
+
+        error_string = "\n".join(
+            [f"func {func} with args: "]
+            + [f"arg{i}: {arg}" for i, arg in enumerate(args)]
+        )
+        raise NotImplementedError(error_string)
+
+
+def to_sparse_semi_structured(
+    original_tensor: torch.Tensor,
+    mask: Optional[torch.Tensor] = None,
+    transposed: bool = False,
+) -> SparseSemiStructuredTensor:
+    """
+    This function converts a dense tensor into a sparse semi-structured tensor.
+    It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
+
+    This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
+    We currently only support semi-structured sparse tensors for 2d CUDA tensors.
+    Additionally, your tensor must be a positive multiple of a block size given the dtype
+
+    - torch.float16  (r, c) must be >= and a multiple of 64
+    - torch.int8     (r, c) must be >= and a multiple of 128
+
+    Args:
+        original_tensor (Tensor): the dense tensor to convert
+        mask (Optional BoolTensor): boolean mask to apply to the original tensor
+        transposed (bool, optional): whether the dense tensor is transposed
+
+    Returns:
+        SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor and mask
+
+    Raises:
+        None
+
+    Example:
+        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
+        >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
+        tensor([[0., 0., 1.,  ..., 0., 1., 1.],
+                [0., 0., 1.,  ..., 0., 1., 1.],
+                [0., 0., 1.,  ..., 0., 1., 1.],
+                ...,
+                [0., 0., 1.,  ..., 0., 1., 1.],
+                [0., 0., 1.,  ..., 0., 1., 1.],
+                [0., 0., 1.,  ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
+        >>> A_sparse = to_sparse_semi_structured(A, mask=A.bool())
+        SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1.,  ..., 1., 1., 1.],
+                [1., 1., 1.,  ..., 1., 1., 1.],
+                [1., 1., 1.,  ..., 1., 1., 1.],
+                ...,
+                [1., 1., 1.,  ..., 1., 1., 1.],
+                [1., 1., 1.,  ..., 1., 1., 1.],
+                [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
+            metadata=tensor([[-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+                ...,
+                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+                [-4370, -4370, -4370,  ..., -4370, -4370, -4370],
+                [-4370, -4370, -4370,  ..., -4370, -4370, -4370]], device='cuda:0',
+       dtype=torch.int16))
+    """
+    return SparseSemiStructuredTensor(original_tensor, mask=mask, transposed=transposed)