Add CUTLASS-based MM for structured sparse linear operator (#100485)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100485
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7442670..306af51 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3169,6 +3169,10 @@
dispatch:
CompositeExplicitAutograd: linear_out
+- func: _structured_sparse_linear(Tensor a, Tensor b, Tensor mask_or_meta) -> (Tensor, Tensor)
+ dispatch:
+ CUDA: _structured_sparse_linear
+
- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
python_module: nn
dispatch:
diff --git a/aten/src/ATen/native/sparse/cuda/StructuredSparseLinearCUTLASS.cu b/aten/src/ATen/native/sparse/cuda/StructuredSparseLinearCUTLASS.cu
new file mode 100644
index 0000000..d83befe
--- /dev/null
+++ b/aten/src/ATen/native/sparse/cuda/StructuredSparseLinearCUTLASS.cu
@@ -0,0 +1,590 @@
+#include <ATen/core/Tensor.h>
+#include <ATen/cuda/CUDAUtils.h>
+#include <ATen/Dispatch.h>
+
+#ifndef USE_ROCM
+#include <cutlass/cutlass.h>
+#include <cutlass/gemm/device/gemm_sparse.h>
+#endif
+
+#include <type_traits>
+#include <tuple>
+
+#ifndef USE_ROCM
+#define CUTLASS_STATUS_CHECK(status) \
+ { \
+ TORCH_CHECK(status == cutlass::Status::kSuccess, \
+ "Got CUTLASS error: ", cutlassGetStatusString(status)); \
+ }
+#endif
+
+namespace at {
+namespace native {
+
+#ifndef USE_ROCM
+// This kernel is for creating 2:4 sparse matrix metadata for given
+// "mask" matrix corresponding to the original dense matrix. The
+// "mask" matrix contains true values where dense matrix elements are
+// zeros, and false values otherwise. The "mask" matrix has
+// "length_m" rows and "length_n" columns, and it is assumed that this
+// matrix is in row-major format, with row stride "mask_stride" (and
+// that the column stride is 1). The kernel will store metadata in
+// "meta" matrix, and it is also assumed that this matrix is in
+// row-major format, with row stride "meta_stride" (and with the
+// column stride equals 1). If the "mask" matrix is not in 2:4 sparse
+// format, the kernel will set value pointed by "error" to 1.
+//
+// This kernel could be improved for efficiency, but it should be
+// called once for given sparse operand, so it should not affect
+// performance much.
+template<typename T>
+__global__ void two_four_create_meta_kernel(
+ const int length_m, const int length_k, const int mask_stride,
+ const bool* mask, const int meta_stride, T* meta, int* error) {
+ const auto k = blockDim.x * blockIdx.x + threadIdx.x;
+ const auto m = blockDim.y * blockIdx.y + threadIdx.y;
+
+ const auto in_range = m < length_m && k < length_k / 4;
+ unsigned active_mask = __ballot_sync(0xffffffff, in_range);
+ if (!in_range) {
+ return;
+ }
+
+ T val = 0;
+ const auto pos0 = mask[m * mask_stride + k * 4];
+ const auto pos1 = mask[m * mask_stride + k * 4 + 1];
+ const auto pos2 = mask[m * mask_stride + k * 4 + 2];
+ const auto pos3 = mask[m * mask_stride + k * 4 + 3];
+ const auto pos_tuple = std::make_tuple(pos0, pos1, pos2, pos3);
+
+ // See
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-sparse-matrix-storage
+ // There are only 6 valid configurations (4 choose 2) and for each
+ // there is a special number.
+ if (pos_tuple == std::make_tuple(1, 1, 0, 0)) {
+ val = 4; // 0100
+ } else if (pos_tuple == std::make_tuple(1, 0, 1, 0)) {
+ val = 8; // 1000
+ } else if (pos_tuple == std::make_tuple(0, 1, 1, 0)) {
+ val = 9; // 1001
+ } else if (pos_tuple == std::make_tuple(1, 0, 0, 1)) {
+ val = 12; // 1100
+ } else if (pos_tuple == std::make_tuple(0, 1, 0, 1)) {
+ val = 13; // 1101
+ } else if (pos_tuple == std::make_tuple(0, 0, 1, 1)) {
+ val = 14; // 1110
+ } else {
+ atomicExch(error, 1);
+ }
+
+ auto tile_size = 2 * sizeof(T);
+ for (auto i = 1; i < tile_size; i *= 2) {
+ val |= __shfl_down_sync(active_mask, val, i) << (4 * i);
+ }
+ if (k % tile_size == 0) {
+ meta[m * meta_stride + k / tile_size] = val;
+ }
+}
+
+// This kernel reimplements reorder_meta() function from
+// tools/util/include/cutlass/util/host_reorder.h file from CUTLASS
+// source distribution. The purpose of having CUDA version of this
+// function is to avoid to copy meta matrix to CPU and back, as
+// CUTLASS for now supplies only host versio of this function.
+//
+// Alike to the above kernel, this kernel should be called once for
+// given sparse operand, so not much effort is put into the
+// optimization (hopefully, CUTLASS may provide own CUDA version at
+// some point).
+template <typename Element, typename LayoutSrc, typename LayoutDest>
+__global__ void two_four_reorder_meta_kernel(
+ const int length_m, const int length_k,
+ const cutlass::TensorRef<Element, LayoutSrc> src,
+ cutlass::TensorRef<Element, LayoutDest> dst) {
+ const int k = blockDim.x * blockIdx.x + threadIdx.x;
+ const int m = blockDim.y * blockIdx.y + threadIdx.y;
+
+ if (m >= length_m || k >= length_k) {
+ return;
+ }
+
+ // First reorder the rows.
+ int group = (sizeof(Element) == 2) ? 32 : 16;
+ int interweave = (sizeof(Element) == 2) ? 4 : 2;
+
+ int dst_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
+ int dst_col = k;
+
+ // Next swizzle the 2x2 blocks from Z to N.
+ if (((dst_row % 2) == 0) && ((dst_col % 2) == 1)) {
+ ++dst_row;
+ --dst_col;
+ } else if (((dst_row % 2) == 1) && ((dst_col % 2) == 0)) {
+ --dst_row;
+ ++dst_col;
+ }
+ dst.at({dst_row, dst_col}) = src.at({m, k});
+}
+
+// Wrapper function for CUTLASS sparse GEMM implementation, used
+// solely to simplify dispatching from _structured_sparse_linear()
+// function below.
+template <
+ typename ElementInputA,
+ typename ElementInputB,
+ typename ElementOutput,
+ typename ElementAccumulator,
+ typename ElementComputeEpilogue,
+ typename ThreadblockShape,
+ typename WarpShape,
+ typename InstructionShape,
+ typename EpilogueOp,
+ typename LayoutInputA,
+ typename LayoutInputB>
+std::tuple<Tensor, Tensor> two_four_sgemm_cutlass(
+ const Tensor& tensor_a,
+ const at::IntArrayRef::value_type& tensor_a_stride,
+ const Tensor& tensor_b,
+ const at::IntArrayRef::value_type& tensor_b_stride,
+ const Tensor& mask_or_meta) {
+ // Fix CUTLASS sparse GEMM template arguments that are not
+ // provided as template argument of this function, and create an
+ // alias for particular instantiation of this template.
+ using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format.
+ using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance.
+ using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are suported at the moment.
+ using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
+ constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
+ using Gemm = cutlass::gemm::device::SparseGemm<
+ ElementInputA,
+ LayoutInputA,
+ ElementInputB,
+ LayoutInputB,
+ ElementOutput,
+ LayoutOutput,
+ ElementAccumulator,
+ MMAOp,
+ SmArch,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ SwizzleThreadBlock,
+ NumStages>;
+
+ // Datatype and layout of metadata matrix are inferred from sparse
+ // GEMM template.
+ using ElementInputE = typename Gemm::ElementE;
+ using LayoutInputE = cutlass::layout::RowMajor;
+ using ReorderedLayoutInputE = typename Gemm::LayoutE;
+
+ constexpr auto kSparse = Gemm::kSparse;
+ constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
+
+ // Operand sizes.
+ const int length_m = tensor_a.size(0);
+ const int length_k = tensor_b.size(0);
+ const int length_n = tensor_b.size(1);
+ const auto meta_ncols = length_k / kSparse / kElementsPerElementE;
+
+ // Check for current CUTLASS limitations w.r.t. input sizes.
+ constexpr auto input_a_is_half =
+ std::is_same<ElementInputA, cutlass::half_t>::value;
+ TORCH_CHECK(length_m % 32 == 0,
+ "torch._structured_sparse_linear: Number of rows of sparse matrix must "
+ "be divisible by 32");
+ TORCH_CHECK(length_k % (input_a_is_half ? 64 : 128) == 0,
+ "torch._structured_sparse_linear: Number of rows of dense matrix must "
+ "be divisible by ", (input_a_is_half ? 64 : 128));
+ TORCH_CHECK(length_n % (input_a_is_half ? 8 : 16) == 0,
+ "torch._structured_sparse_linear: Number of columns of dense matrix "
+ "must be divisible by ", (input_a_is_half ? 8 : 16));
+
+ // Determine PyTorch datatype for the output matrix.
+ auto tensor_d_dtype = at::kChar;
+ if (std::is_same<ElementOutput, int32_t>::value) {
+ tensor_d_dtype = at::kInt;
+ }
+ else if (std::is_same<ElementOutput, cutlass::half_t>::value) {
+ tensor_d_dtype = at::kHalf;
+ }
+ else {
+ AT_ERROR("torch._structured_sparse_linear: invalid sparse GEMM output "
+ "datatype encountered");
+ }
+
+ // Create output matrix.
+ auto tensor_d =
+ tensor_a.new_empty({length_m, length_n},
+ at::TensorOptions().dtype(tensor_d_dtype));
+
+ // If mask matrix passed as an argument, create metadata matrix.
+ // CUTLASS required metadata matrix in a shuffled order, so
+ // perform the reordering in that case too.
+ Tensor meta_reordered;
+ if (mask_or_meta.dtype() == at::kBool) {
+ auto mask = mask_or_meta;
+
+ // Check dimensions and format of the mask matrix.
+ TORCH_CHECK(mask.layout() == Layout::Strided,
+ "torch._structured_sparse_linear: Expected mask argument to be "
+ "strided, but got layout ", mask.layout());
+ TORCH_CHECK(mask.dim() == 2,
+ "torch._structured_sparse_linear: Expected mask argument to be 2D "
+ "tensor, got ", mask.dim(), " dims");
+ const auto strides_mask = mask.strides();
+ TORCH_CHECK(strides_mask[1] == 1,
+ "torch._structured_sparse_linear: Invalid strides for mask_or_meta "
+ "argument: row stride = ", strides_mask[0], ", column stride = ",
+ strides_mask[1]);
+
+ // Determine PyTorch datatype for the metadata matrix.
+ auto meta_dtype = at::kChar;
+ switch (sizeof(ElementInputE)) {
+ case 1:
+ break;
+ case 2:
+ meta_dtype = at::kShort;
+ break;
+ case 4:
+ meta_dtype = at::kInt;
+ break;
+ default:
+ AT_ERROR("torch._structured_sparse_linear: invalid size of meta "
+ "tensor datatype encountered");
+ }
+
+ // Create tensor for metadata matrix, and run CUDA kernel to
+ // build this matrix from mask matrix.
+ auto meta = mask.new_empty({length_m, meta_ncols},
+ at::TensorOptions().dtype(meta_dtype));
+ auto error = mask.new_zeros({1}, at::TensorOptions().dtype(at::kInt));
+ two_four_create_meta_kernel<<<
+ dim3((length_k + 63) / 64, (length_m + 15) / 16),
+ dim3(16, 16),
+ 0,
+ at::cuda::getCurrentCUDAStream()>>> (
+ length_m, length_k, strides_mask[0], (bool*)mask.data_ptr(),
+ meta.stride(0), (ElementInputE*)meta.data_ptr(),
+ (int*)error.data_ptr());
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ TORCH_CHECK(error.item().equal(0),
+ "torch._structured_sparse_linear: Mask matrix is not 2:4 "
+ "sparse");
+
+ // Create tensor for reordered metadata matrix, and run CUDA
+ // kernel to build this matrix from above calculated metadata matrix.
+ meta_reordered = meta.new_empty(meta.sizes());
+ auto meta_device_ref =
+ cutlass::TensorRef<ElementInputE, cutlass::layout::RowMajor>(
+ (ElementInputE*)meta.data_ptr(),
+ LayoutInputE::packed({length_m, meta_ncols}));
+ auto meta_reordered_device_ref =
+ cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
+ (ElementInputE*)meta_reordered.data_ptr(),
+ ReorderedLayoutInputE::packed({length_m, meta_ncols}));
+ two_four_reorder_meta_kernel<<<
+ dim3((meta_ncols + 15) / 16, (length_m + 15) / 16),
+ dim3(16, 16),
+ 0,
+ at::cuda::getCurrentCUDAStream()>>> (
+ length_m, meta_ncols, meta_device_ref,
+ meta_reordered_device_ref);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }
+ else {
+ meta_reordered = mask_or_meta;
+ }
+
+ // Prepare arguments for CUTLASS sparse GEMM kernel.
+ cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
+ LayoutInputA layout_a(tensor_a_stride);
+ LayoutInputB layout_b(tensor_b_stride);
+ LayoutOutput layout_d(tensor_d.stride(0));
+ auto tensor_a_device_ref =
+ cutlass::TensorRef<ElementInputA, LayoutInputA>(
+ (ElementInputA*)tensor_a.data_ptr(), layout_a);
+ auto tensor_b_device_ref =
+ cutlass::TensorRef<ElementInputB, LayoutInputB>(
+ (ElementInputB*)tensor_b.data_ptr(), layout_b);
+ auto tensor_d_device_ref =
+ cutlass::TensorRef<ElementOutput, LayoutOutput>(
+ (ElementOutput*)tensor_d.data_ptr(), layout_d);
+ auto tensor_e_reordered_device_ref =
+ cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
+ (ElementInputE*)meta_reordered.data_ptr(),
+ ReorderedLayoutInputE::packed({length_m, meta_ncols}));
+ ElementComputeEpilogue alpha(1);
+ ElementComputeEpilogue beta(0);
+ constexpr int split_k_slices = 1;
+
+ // Create a tuple of CUTLASS sparse GEMM kernel arguments.
+ typename Gemm::Arguments arguments{
+ problem_size,
+ tensor_a_device_ref,
+ tensor_b_device_ref,
+ tensor_d_device_ref,
+ tensor_d_device_ref,
+ tensor_e_reordered_device_ref,
+ {alpha, beta},
+ split_k_slices};
+
+ cutlass::Status status;
+
+ // Create CUTLASS sparse GEMM kernel object.
+ Gemm gemm_op;
+
+ // Verify that sparse GEMM operation with given arguments can be
+ // performed by CUTLASS.
+ status = gemm_op.can_implement(arguments);
+ CUTLASS_STATUS_CHECK(status);
+
+ // Allocate workspace for CUTLASS sparse GEMM kernel.
+ const auto workspace_size = Gemm::get_workspace_size(arguments);
+ auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
+ at::TensorOptions().dtype(at::kByte));
+
+ // Initialize CUTLASS sparse GEMM object.
+ status = gemm_op.initialize(arguments, workspace.data_ptr(),
+ at::cuda::getCurrentCUDAStream());
+ CUTLASS_STATUS_CHECK(status);
+
+ // Perform sparse GEMM operation.
+ status = gemm_op.run(at::cuda::getCurrentCUDAStream());
+ CUTLASS_STATUS_CHECK(status);
+
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+
+ return std::make_tuple(tensor_d, meta_reordered);
+}
+#endif
+
+// Perform GEMM operation between matrix with 2:4 sparsity pattern,
+// and dense matrix, using corresponding CUTLASS sparse GEMM kernel.
+// The sparse matrix is given as argument "tensor_a" and the dense
+// matrix is given as argument "tensor_b". It is assummed that
+// matrices are supplied either in row-major or column-major format
+// (matrices could be in different formats, but not all combinations
+// of formats are supported for some datatypes of these matrices).
+// The "mask_or_meta" argument contains either a mask matrix
+// corresponding to the original dense matrix with 2:4 sparsity
+// pattern, from which sparse matrix "tensor_a" is compressed, or to
+// the corresponding metadata matrix. The function differentiates
+// between these two cases by the datatype of "mask_or_meta" tensor:
+// if it is of boolean datatype, then it is assumed that the mask
+// matrix is passed, otherwise it is assumed that the metadata matrix
+// is passed. In the first case, metadata matrix is calculated from
+// the matrix matrix. The function returns a tupple with the product
+// of "tensor_a" and "tensor_b" matrices, and metadata matrix (either
+// calculated by this function, in case mask is passed as
+// "mask_or_meta" argument, or the same one that is passed to this
+// function otherwise).
+//
+// There exists numerous limitations of CUTLASS sparse GEMM kernel,
+// with regards to sizes and alignments of input tensors, their
+// layouts and datatypes, and so on; this is the reason for large
+// number of checks throughout the code.
+std::tuple<Tensor, Tensor> _structured_sparse_linear(
+ const Tensor& tensor_a, const Tensor& tensor_b,
+ const Tensor& mask_or_meta) {
+#ifndef USE_ROCM
+ // No need to check that all tensors are on CUDA device, as this
+ // is provided by dispatch.
+
+ // For now, only CC 8.x devices are supported.
+ const auto dprops = at::cuda::getCurrentDeviceProperties();
+ const auto is_sm8x = dprops->major == 8;
+ TORCH_CHECK(is_sm8x,
+ "torch._structured_sparse_linear: Supported only on GPUs with "
+ "compute capability 8.x");
+
+ // Validate layouts of input tensors.
+ TORCH_CHECK(tensor_a.layout() == Layout::Strided,
+ "torch._structured_sparse_linear: Expected tensor_a argument "
+ "to be strided, but got layout ", tensor_a.layout());
+ TORCH_CHECK(tensor_a.dim() == 2,
+ "torch._structured_sparse_linear: Expected tensor_a argument "
+ "to be 2D tensor, got ", tensor_a.dim(), " dims");
+ const auto strides_a = tensor_a.strides();
+ TORCH_CHECK((strides_a[0] == 1 || strides_a[1] == 1) && strides_a[0] != strides_a[1],
+ "torch._structured_sparse_linear: Invalid strides for tensor_a "
+ "argument: row stride = ", strides_a[0], ", column stride = ",
+ strides_a[1]);
+ TORCH_CHECK(tensor_b.layout() == Layout::Strided,
+ "torch._structured_sparse_linear: Expected tensor_b argument "
+ "to be strided, but got layout ", tensor_b.layout());
+ TORCH_CHECK(tensor_b.dim() == 2,
+ "torch._structured_sparse_linear: Expected tensor_b argument "
+ "to be 2D tensor, got ", tensor_b.dim(), " dims");
+ const auto strides_b = tensor_b.strides();
+ TORCH_CHECK((strides_b[0] == 1 || strides_b[1] == 1) && strides_b[0] != strides_b[1],
+ "torch._structured_sparse_linear: Invalid strides for tensor_b "
+ "argument: row stride = ", strides_b[0], ", column stride = ",
+ strides_b[1]);
+
+ // Determine layout (row-major or column-major) of input tensors.
+ auto tensor_a_row_major = strides_a[1] == 1;
+ auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1];
+ auto tensor_b_row_major = strides_b[1] == 1;
+ auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1];
+
+ // Call wrapper function for CUTLASS sparse GEMM, dispatching on
+ // the input datatype, and then on input tensors layouts.
+ // According to the input tensors datatypes and layouts,
+ // correspnding template arguments are supplied for instantiating
+ // the wrapper function. The tile sizes template arguments are
+ // selected according to the CUTLASS profiler results, for number
+ // of runs.
+ std::tuple<Tensor, Tensor> result;
+ AT_DISPATCH_SWITCH(
+ tensor_a.scalar_type(),
+ "_structured_sparse_linear",
+ AT_DISPATCH_CASE(
+ at::ScalarType::Char,
+ [&]() {
+ using ElementInputA = int8_t;
+ using ElementInputB = int8_t;
+ using ElementOutput = int32_t;
+ using ElementAccumulator = int32_t;
+ using ElementComputeEpilogue = int32_t;
+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>;
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
+ ElementOutput,
+ 128 / cutlass::sizeof_bits<ElementOutput>::value,
+ ElementAccumulator,
+ ElementComputeEpilogue>;
+ if (tensor_a_row_major && !tensor_b_row_major) {
+ result = two_four_sgemm_cutlass<
+ ElementInputA,
+ ElementInputB,
+ ElementOutput,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ cutlass::layout::RowMajor,
+ cutlass::layout::ColumnMajor>(
+ tensor_a,
+ tensor_a_stride,
+ tensor_b,
+ tensor_b_stride,
+ mask_or_meta);
+ return;
+ }
+ AT_ERROR("torch._structured_sparse_linear: Combination of "
+ "tensor_a in ",
+ tensor_a_row_major ? "row-major" : "column_major",
+ " layout and tensor_b in ",
+ tensor_b_row_major ? "row-major" : "column_major",
+ " layout is not supported");
+ })
+ AT_DISPATCH_CASE(
+ at::ScalarType::Half,
+ [&]() {
+ using ElementInputA = cutlass::half_t;
+ using ElementInputB = cutlass::half_t;
+ using ElementOutput = cutlass::half_t;
+ using ElementAccumulator = float;
+ using ElementComputeEpilogue = cutlass::half_t;
+ using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
+ using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
+ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
+ ElementOutput,
+ 128 / cutlass::sizeof_bits<ElementOutput>::value,
+ ElementAccumulator,
+ ElementComputeEpilogue>;
+ if (tensor_a_row_major && tensor_b_row_major) {
+ result = two_four_sgemm_cutlass<
+ ElementInputA,
+ ElementInputB,
+ ElementOutput,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ cutlass::layout::RowMajor,
+ cutlass::layout::RowMajor>(
+ tensor_a,
+ tensor_a_stride,
+ tensor_b,
+ tensor_b_stride,
+ mask_or_meta);
+ return;
+ }
+ if (tensor_a_row_major && !tensor_b_row_major) {
+ result = two_four_sgemm_cutlass<
+ ElementInputA,
+ ElementInputB,
+ ElementOutput,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ cutlass::layout::RowMajor,
+ cutlass::layout::ColumnMajor>(
+ tensor_a,
+ tensor_a_stride,
+ tensor_b,
+ tensor_b_stride,
+ mask_or_meta);
+ return;
+ }
+ if (!tensor_a_row_major && tensor_b_row_major) {
+ result = two_four_sgemm_cutlass<
+ ElementInputA,
+ ElementInputB,
+ ElementOutput,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ cutlass::layout::ColumnMajor,
+ cutlass::layout::RowMajor>(
+ tensor_a,
+ tensor_a_stride,
+ tensor_b,
+ tensor_b_stride,
+ mask_or_meta);
+ return;
+ }
+ if (!tensor_a_row_major && !tensor_b_row_major) {
+ result = two_four_sgemm_cutlass<
+ ElementInputA,
+ ElementInputB,
+ ElementOutput,
+ ElementAccumulator,
+ ElementComputeEpilogue,
+ ThreadblockShape,
+ WarpShape,
+ InstructionShape,
+ EpilogueOp,
+ cutlass::layout::ColumnMajor,
+ cutlass::layout::ColumnMajor>(
+ tensor_a,
+ tensor_a_stride,
+ tensor_b,
+ tensor_b_stride,
+ mask_or_meta);
+ return;
+ }
+ }));
+ return result;
+#else
+ AT_ERROR("torch._structured_sparse_linear: ROCm doesn't support CUTLASS");
+ return std::make_tuple(Tensor{}, Tensor{});
+#endif
+}
+
+} // namespace native
+} // namespace at
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 5aed10c..778e435 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -488,6 +488,7 @@
aten::_standard_gamma.out
aten::_standard_gamma_grad
aten::_standard_gamma_grad.out
+aten::_structured_sparse_linear
aten::_test_autograd_multiple_dispatch.fullcoverage
aten::_test_autograd_multiple_dispatch.fullcoverage_out
aten::_test_autograd_multiple_dispatch_view
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 7eacc4b..3652d58 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4804,6 +4804,61 @@
self.assertEqual(result, dense)
+ @onlyCUDA
+ @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
+ @dtypes(torch.int8, torch.half)
+ def test_structured_sparse_linear(self, device, dtype):
+ def make_tensor(shape, dtype):
+ if dtype.is_complex:
+ return torch.zeros(shape, dtype=dtype)
+ elif dtype.is_floating_point:
+ return torch.randn(shape, dtype=dtype) / 10
+ else:
+ return torch.randint(-5, 5, shape, dtype=dtype)
+
+ def random_mask_choice(i=None):
+ choices = [
+ [1, 1, 0, 0],
+ [1, 0, 1, 0],
+ [1, 0, 0, 1],
+ [0, 1, 1, 0],
+ [0, 1, 0, 1],
+ [0, 0, 1, 1]
+ ]
+ if i is None:
+ i = random.randint(0, len(choices) - 1)
+ return choices[i]
+
+ def run_test(m, n, k, device, dtype):
+ a = make_tensor((m, k), dtype).to(device)
+ b = make_tensor((n, k), dtype).to(device).T
+
+ for meta_choice in (list(range(6)) + [None]):
+ mask_entries = [random_mask_choice(meta_choice) for i in range(m * (k // 4))]
+ mask = torch.tensor(mask_entries, dtype=torch.bool).view(m, k).to(device)
+
+ a_sparse = a.masked_select(mask).view(m, k // 2)
+ a_dense = a.masked_fill(~mask, 0)
+
+ dtype_dense = torch.float
+ c1 = torch.mm(a_dense.to(dtype_dense), b.to(dtype_dense))
+
+ c0, meta = torch._structured_sparse_linear(a_sparse, b, mask)
+ torch.testing.assert_close(c0.to(dtype_dense), c1, rtol=1e-3, atol=1e-3)
+
+ c0, _ = torch._structured_sparse_linear(a_sparse, b, meta)
+ torch.testing.assert_close(c0.to(dtype_dense), c1, rtol=1e-3, atol=1e-3)
+
+ is_sm8x = torch.cuda.get_device_capability(0)[0] == 8
+ if not is_sm8x:
+ return
+ for (m, n, k) in itertools.product(range(4), range(4), range(4)):
+ m = (m + 1) * 32
+ n = (n + 1) * 32
+ k = (k + 1) * 128
+ run_test(m, n, k, device, dtype)
+
+
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')