blob: 97d89cf50798c4b417bd6df3be49b14557917ea7 [file] [log] [blame]
#pragma once
/*
Provides a subset of MKL Sparse BLAS functions as templates:
mv<scalar_t>(operation, alpha, A, descr, x, beta, y)
where scalar_t is double, float, c10::complex<double> or c10::complex<float>.
The functions are available in at::mkl::sparse namespace.
*/
#include <c10/util/Exception.h>
#include <c10/util/complex.h>
#include <mkl_spblas.h>
namespace at::mkl::sparse {
#define MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t) \
sparse_matrix_t *A, const sparse_index_base_t indexing, const MKL_INT rows, \
const MKL_INT cols, MKL_INT *rows_start, MKL_INT *rows_end, \
MKL_INT *col_indx, scalar_t *values
template <typename scalar_t>
inline void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::create_csr: not implemented for ",
typeid(scalar_t).name());
}
template <>
void create_csr<float>(MKL_SPARSE_CREATE_CSR_ARGTYPES(float));
template <>
void create_csr<double>(MKL_SPARSE_CREATE_CSR_ARGTYPES(double));
template <>
void create_csr<c10::complex<float>>(
MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<float>));
template <>
void create_csr<c10::complex<double>>(
MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t) \
sparse_matrix_t *A, const sparse_index_base_t indexing, \
const sparse_layout_t block_layout, const MKL_INT rows, \
const MKL_INT cols, MKL_INT block_size, MKL_INT *rows_start, \
MKL_INT *rows_end, MKL_INT *col_indx, scalar_t *values
template <typename scalar_t>
inline void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::create_bsr: not implemented for ",
typeid(scalar_t).name());
}
template <>
void create_bsr<float>(MKL_SPARSE_CREATE_BSR_ARGTYPES(float));
template <>
void create_bsr<double>(MKL_SPARSE_CREATE_BSR_ARGTYPES(double));
template <>
void create_bsr<c10::complex<float>>(
MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<float>));
template <>
void create_bsr<c10::complex<double>>(
MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_MV_ARGTYPES(scalar_t) \
const sparse_operation_t operation, const scalar_t alpha, \
const sparse_matrix_t A, const struct matrix_descr descr, \
const scalar_t *x, const scalar_t beta, scalar_t *y
template <typename scalar_t>
inline void mv(MKL_SPARSE_MV_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::mv: not implemented for ",
typeid(scalar_t).name());
}
template <>
void mv<float>(MKL_SPARSE_MV_ARGTYPES(float));
template <>
void mv<double>(MKL_SPARSE_MV_ARGTYPES(double));
template <>
void mv<c10::complex<float>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<float>));
template <>
void mv<c10::complex<double>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_ADD_ARGTYPES(scalar_t) \
const sparse_operation_t operation, const sparse_matrix_t A, \
const scalar_t alpha, const sparse_matrix_t B, sparse_matrix_t *C
template <typename scalar_t>
inline void add(MKL_SPARSE_ADD_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::add: not implemented for ",
typeid(scalar_t).name());
}
template <>
void add<float>(MKL_SPARSE_ADD_ARGTYPES(float));
template <>
void add<double>(MKL_SPARSE_ADD_ARGTYPES(double));
template <>
void add<c10::complex<float>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<float>));
template <>
void add<c10::complex<double>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t) \
const sparse_matrix_t source, sparse_index_base_t *indexing, MKL_INT *rows, \
MKL_INT *cols, MKL_INT **rows_start, MKL_INT **rows_end, \
MKL_INT **col_indx, scalar_t **values
template <typename scalar_t>
inline void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::export_csr: not implemented for ",
typeid(scalar_t).name());
}
template <>
void export_csr<float>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float));
template <>
void export_csr<double>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double));
template <>
void export_csr<c10::complex<float>>(
MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<float>));
template <>
void export_csr<c10::complex<double>>(
MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_MM_ARGTYPES(scalar_t) \
const sparse_operation_t operation, const scalar_t alpha, \
const sparse_matrix_t A, const struct matrix_descr descr, \
const sparse_layout_t layout, const scalar_t *B, const MKL_INT columns, \
const MKL_INT ldb, const scalar_t beta, scalar_t *C, const MKL_INT ldc
template <typename scalar_t>
inline void mm(MKL_SPARSE_MM_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::mm: not implemented for ",
typeid(scalar_t).name());
}
template <>
void mm<float>(MKL_SPARSE_MM_ARGTYPES(float));
template <>
void mm<double>(MKL_SPARSE_MM_ARGTYPES(double));
template <>
void mm<c10::complex<float>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<float>));
template <>
void mm<c10::complex<double>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_SPMMD_ARGTYPES(scalar_t) \
const sparse_operation_t operation, const sparse_matrix_t A, \
const sparse_matrix_t B, const sparse_layout_t layout, scalar_t *C, \
const MKL_INT ldc
template <typename scalar_t>
inline void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::spmmd: not implemented for ",
typeid(scalar_t).name());
}
template <>
void spmmd<float>(MKL_SPARSE_SPMMD_ARGTYPES(float));
template <>
void spmmd<double>(MKL_SPARSE_SPMMD_ARGTYPES(double));
template <>
void spmmd<c10::complex<float>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<float>));
template <>
void spmmd<c10::complex<double>>(
MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_TRSV_ARGTYPES(scalar_t) \
const sparse_operation_t operation, const scalar_t alpha, \
const sparse_matrix_t A, const struct matrix_descr descr, \
const scalar_t *x, scalar_t *y
template <typename scalar_t>
inline sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::trsv: not implemented for ",
typeid(scalar_t).name());
}
template <>
sparse_status_t trsv<float>(MKL_SPARSE_TRSV_ARGTYPES(float));
template <>
sparse_status_t trsv<double>(MKL_SPARSE_TRSV_ARGTYPES(double));
template <>
sparse_status_t trsv<c10::complex<float>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<float>));
template <>
sparse_status_t trsv<c10::complex<double>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<double>));
#define MKL_SPARSE_TRSM_ARGTYPES(scalar_t) \
const sparse_operation_t operation, const scalar_t alpha, \
const sparse_matrix_t A, const struct matrix_descr descr, \
const sparse_layout_t layout, const scalar_t *x, const MKL_INT columns, \
const MKL_INT ldx, scalar_t *y, const MKL_INT ldy
template <typename scalar_t>
inline sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(scalar_t)) {
TORCH_INTERNAL_ASSERT(
false,
"at::mkl::sparse::trsm: not implemented for ",
typeid(scalar_t).name());
}
template <>
sparse_status_t trsm<float>(MKL_SPARSE_TRSM_ARGTYPES(float));
template <>
sparse_status_t trsm<double>(MKL_SPARSE_TRSM_ARGTYPES(double));
template <>
sparse_status_t trsm<c10::complex<float>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<float>));
template <>
sparse_status_t trsm<c10::complex<double>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<double>));
} // namespace at::mkl::sparse