blob: c147e8c7090eeb3b320c539407e6f2a4668968ba [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
// Required for checking whether Triton kernels are available
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <ATen/core/dispatch/Dispatcher.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#else
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/_triton_bsr_dense_mm.h>
#endif
namespace at {
namespace native {
namespace sparse {
namespace impl {
Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
const auto compressed_layout = compressed.layout();
const auto compressed_layout_str = at::sparse_csr::layoutToString(compressed_layout);
// Device restrictions
TORCH_CHECK(compressed.device() == strided.device()
&& compressed.device() == result.device(),
"spmm_out(): all input arguments are expected to be on the same device.");
// Layout restrictions.
TORCH_CHECK(compressed_layout == kSparseCsr || compressed_layout == kSparseBsr,
"spmm(", compressed_layout_str, ", Strided): only Csr and Bsr formats are supported for the sparse argument.");
TORCH_CHECK(result.layout() == kStrided,
"spmm_out(): out argument is expected to be strided.");
// Dtype restrictions.
TORCH_CHECK(compressed.scalar_type() == strided.scalar_type(),
"spmm(", compressed_layout_str, ", Strided): arguments expected to have the same dtype.");
// Dim restrictions.
TORCH_CHECK(compressed.dim() == 2,
"spmm(", compressed_layout_str, ", Strided): sparse arguments which are not 2D are not supported.");
TORCH_CHECK(strided.dim() >= 2,
"spmm(", compressed_layout_str, ", Strided): expects strided inputs to be at least 2D.");
const auto m = compressed.sizes()[0];
const auto k = compressed.sizes()[1];
const auto n = strided.size(-1);
// Matrix product size compatibility.
TORCH_CHECK(strided.size(-2) == k,
"spmm(", compressed_layout_str, "Strided): argument sizes are not compatible for matrix multiplication. ",
"Got ", compressed_layout_str, ".sizes(-1) == ", k, " is not equal to ",
"Strided.sizes(-2) == ", strided.size(-2), ".");
// We assume that result is properly resized.
auto result_expected_size = at::DimVector(strided.sizes().slice(0, strided.dim() - 2));
result_expected_size.push_back(m);
result_expected_size.push_back(n);
TORCH_CHECK(result.sizes() == result_expected_size,
"spmm_out(): out argument has wrong size. ",
"Expected (", result_expected_size, ") but got (", result.sizes(), ").");
auto values = compressed.values();
using Blocksize = std::array<int64_t, 2>;
// We refer to these as (b0, b1) in the comments below.
Blocksize blocksize = {1, 1};
if (compressed_layout == kSparseBsr) {
blocksize = {values.size(-2), values.size(-1)};
}
// No stable support for ROCM in Triton yet.
#ifndef USE_ROCM
// Triton works only with blocksizes which are powers of 2.
const auto is_power_of_2 = [](int64_t v) -> bool {
return !(v & (v - 1));
};
// Dtype and blocksize checks for potential Triton usage.
if ((strided.scalar_type() == ScalarType::Half
|| strided.scalar_type() == ScalarType::BFloat16)
&& is_power_of_2(blocksize[0]) && is_power_of_2(blocksize[1])
&& (blocksize[0] >= 16) && (blocksize[1] >= 16)
// lhs is retiled to (b0, b1) while rhs is to (b1, b0),
// so the result is tiled to (b0, b0) and we need to make
// sure that dense.size(-1) is divisible by b0.
&& n % blocksize[0] == 0) {
const auto triton_kernel = c10::Dispatcher::singleton()
.findOp(torch::jit::parseName("aten::_triton_bsr_dense_mm"));
// Call Triton only if dispatch key was overwritten.
if (triton_kernel->hasKernelForDispatchKey(c10::DispatchKey::SparseCsrCUDA)) {
return at::_triton_bsr_dense_mm_out(result, compressed, strided);
}
}
#endif
// (..., r, c) -> (..., r / b0, c / b1, b0, b1)
// NOTE: this function ALWAYS creates a view upon successful execution.
const auto tile_tensor = [compressed_layout](
const Tensor& t, Blocksize blocksize) -> Tensor {
if (compressed_layout == kSparseCsr) {
return t.unsqueeze(-1).unsqueeze_(-1);
}
else {
const auto size_neg_2_blocked = t.size(-2) / blocksize[0];
const auto size_neg_1_blocked = t.size(-1) / blocksize[1];
auto tiled_sizes = at::DimVector(t.sizes().slice(0, t.dim() - 2));
tiled_sizes.push_back(size_neg_2_blocked);
tiled_sizes.push_back(blocksize[0]);
tiled_sizes.push_back(size_neg_1_blocked);
tiled_sizes.push_back(blocksize[1]);
return t.reshape(tiled_sizes).transpose(-3, -2);
}
};
// Note that sparse values are (..., b0, b1). This means that
// the strided input has to be "tilable" to (..., b1, x) with
// any x >= 1 such that all the shapes are (block) matrix product
// compatible. The matrix product will then have shape (..., b0, x).
// This in turn means the the result has to be "tilable" to
// (..., b0, x).
//
// These observations imply the following restrictions:
// 1. strided.size(-2) has to be divisible by b1.
// 2. result.size(-2) has to be divisible by b0.
// 3. both strided.size(-1) and result.size(-1)
// have to be divisible by x.
//
// Restrictions 1 and 2 are trivially satisfied.
// Regarding restriction 3:
// it would make sense to take the largest possible x for better
// performance since it is very likely that the last dimension
// is contiguous. As such, this value is exactly
// x = strided.size(-1), since strided.size(-1) == result.size(-1)
// See the comments above. This is our x.
const auto outer_blocksize = n;
Blocksize strided_blocksize = {blocksize[1], outer_blocksize};
const auto strided_tiled = tile_tensor(strided, strided_blocksize);
// Left argument is (..., b0, b1) and right is (..., b1, x).
// This naturally implies the result should be "tilable" as
// (..., b0, x).
Blocksize result_blocksize = {blocksize[0], outer_blocksize};
auto result_tiled = tile_tensor(result, result_blocksize);
if (compressed_layout == kSparseCsr) {
values.unsqueeze_(-1).unsqueeze_(-1);
}
Tensor compressed_indices, plain_indices;
std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(compressed);
// Select block rows of the strided input that intersect with the block colums of the sparse input.
auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices);
// Promote to float if output is half or bfloat16 for better precision
const auto mm_dtype = (result.scalar_type() == kHalf || result.scalar_type() == kBFloat16)
? kFloat : result.scalar_type();
// Now that we know which block rows intersect with which block columns,
// we can perform matrix products between pairs of blocks.
// NOTE: .to is a no-op when result.scalar_type() == mm_dtype.
const auto pairwise_block_mm = values.unsqueeze(-3).to(mm_dtype)
.matmul(strided_tiled_selected_rows.to(mm_dtype));
// Having pairwise block matrix products stored in pairwise_block_mm,
// it is sufficient to sum all the block products that share the same row
// encoded in the sparse index. Since the reduction step is done via
// advanced indexing methods, the compressed index ought to get converted
// to the COO format.
const auto compressed_indices_coo = at::_convert_indices_from_csr_to_coo(
compressed_indices,
plain_indices,
compressed_indices.scalar_type() == kInt).select(0, 0);
// Reduction step.
// If result is neither half nor bfloat16, do everyting in-place.
if (result.scalar_type() == mm_dtype) {
// Zero out and sum over the blocks that share the same row indices.
result_tiled.zero_();
result_tiled.index_add_(
/*dim=*/-4,
/*index=*/compressed_indices_coo,
/*source=*/pairwise_block_mm);
}
// Otherwise accumulate into a buffer and then copy.
else {
// No need to zero out, sum over the blocks goes into a buffer
// followed by a copy into result.
auto promoted_result_tiled = at::zeros(
result_tiled.sizes(),
result_tiled.options().dtype(mm_dtype));
promoted_result_tiled.index_add_(
/*dim=*/-4,
/*index=*/compressed_indices_coo,
/*source=*/pairwise_block_mm);
result_tiled.copy_(promoted_result_tiled);
}
return result;
}
Tensor& _compressed_row_strided_addmm_out(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
// If result is not the same as self, it could always be used as out argument to mm.
if (!result.is_same(self)) {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
// Process beta
if (beta.toComplexDouble() != 0.) {
result.add_(self.mul(beta));
}
}
// Otherwise we need to allocate external memory for mm if beta != 0.
else {
// Process beta
if (beta.toComplexDouble() != 0.) {
result.mul_(beta);
auto mm = at::empty_like(result);
_compressed_row_strided_mm_out(mat1, mat2, mm);
mm.mul_(alpha);
result.add_(mm);
}
else {
_compressed_row_strided_mm_out(mat1, mat2, result).mul_(alpha);
}
}
return result;
}
namespace cpu {
/*
Computes a sparse matrix-dense vector product defined as
y <- alpha*op(A)*x + beta*y
Args:
* `mat` - Tensor storing sparse m x n matrix A.
* `vec` - Tensor storing dense vector x of size n.
* `result` - [in] Tensor storing dense vector y of size m.
[out] result of the operation.
*/
void addmv_out_sparse_csr(
const Tensor& mat,
const Tensor& vec,
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling addmv on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
sparse::impl::mkl::addmv_out_sparse_csr(mat, vec, beta, alpha, result);
#endif
}
/*
Computes a sum of two sparse matrices defined as
result <- mat1 + alpha*mat2
Args:
* `mat1` - CSR Tensor storing sparse m x n matrix.
* `mat2` - CSR Tensor storing sparse m x n matrix.
* `result` - [in] CSR Tensor storing sparse m x n matrix.
[out] result of the operation.
*/
void add_out_sparse_csr(
const Tensor& mat1,
const Tensor& mat2,
const Scalar& alpha,
const Tensor& result) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling add on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
sparse::impl::mkl::add_out_sparse_csr(mat1, mat2, alpha, result);
#endif
}
void triangular_solve_out_sparse_csr(
const Tensor& A,
const Tensor& B,
const Tensor& X,
bool upper,
bool transpose,
bool unitriangular) {
#if !AT_MKL_ENABLED()
TORCH_CHECK(
false,
"Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. ",
"Please use PyTorch built MKL support.");
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(A.layout() == kSparseCsr || A.layout() == kSparseBsr);
sparse::impl::mkl::triangular_solve_out_sparse_csr(A, B, X, upper, transpose, unitriangular);
#endif
}
} // namespace cpu
} // namespace impl
} // namespace sparse
} // namespace native
} // namespace at