| #pragma once |
| |
| /* |
| Provides a subset of MKL Sparse BLAS functions as templates: |
| |
| mv<scalar_t>(operation, alpha, A, descr, x, beta, y) |
| |
| where scalar_t is double, float, c10::complex<double> or c10::complex<float>. |
| The functions are available in at::mkl::sparse namespace. |
| */ |
| |
| #include <c10/util/Exception.h> |
| #include <c10/util/complex.h> |
| |
| #include <mkl_spblas.h> |
| |
| namespace at::mkl::sparse { |
| |
| #define MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t) \ |
| sparse_matrix_t *A, const sparse_index_base_t indexing, const MKL_INT rows, \ |
| const MKL_INT cols, MKL_INT *rows_start, MKL_INT *rows_end, \ |
| MKL_INT *col_indx, scalar_t *values |
| |
| template <typename scalar_t> |
| inline void create_csr(MKL_SPARSE_CREATE_CSR_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::create_csr: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void create_csr<float>(MKL_SPARSE_CREATE_CSR_ARGTYPES(float)); |
| template <> |
| void create_csr<double>(MKL_SPARSE_CREATE_CSR_ARGTYPES(double)); |
| template <> |
| void create_csr<c10::complex<float>>( |
| MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<float>)); |
| template <> |
| void create_csr<c10::complex<double>>( |
| MKL_SPARSE_CREATE_CSR_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t) \ |
| sparse_matrix_t *A, const sparse_index_base_t indexing, \ |
| const sparse_layout_t block_layout, const MKL_INT rows, \ |
| const MKL_INT cols, MKL_INT block_size, MKL_INT *rows_start, \ |
| MKL_INT *rows_end, MKL_INT *col_indx, scalar_t *values |
| |
| template <typename scalar_t> |
| inline void create_bsr(MKL_SPARSE_CREATE_BSR_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::create_bsr: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void create_bsr<float>(MKL_SPARSE_CREATE_BSR_ARGTYPES(float)); |
| template <> |
| void create_bsr<double>(MKL_SPARSE_CREATE_BSR_ARGTYPES(double)); |
| template <> |
| void create_bsr<c10::complex<float>>( |
| MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<float>)); |
| template <> |
| void create_bsr<c10::complex<double>>( |
| MKL_SPARSE_CREATE_BSR_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_MV_ARGTYPES(scalar_t) \ |
| const sparse_operation_t operation, const scalar_t alpha, \ |
| const sparse_matrix_t A, const struct matrix_descr descr, \ |
| const scalar_t *x, const scalar_t beta, scalar_t *y |
| |
| template <typename scalar_t> |
| inline void mv(MKL_SPARSE_MV_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::mv: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void mv<float>(MKL_SPARSE_MV_ARGTYPES(float)); |
| template <> |
| void mv<double>(MKL_SPARSE_MV_ARGTYPES(double)); |
| template <> |
| void mv<c10::complex<float>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<float>)); |
| template <> |
| void mv<c10::complex<double>>(MKL_SPARSE_MV_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_ADD_ARGTYPES(scalar_t) \ |
| const sparse_operation_t operation, const sparse_matrix_t A, \ |
| const scalar_t alpha, const sparse_matrix_t B, sparse_matrix_t *C |
| |
| template <typename scalar_t> |
| inline void add(MKL_SPARSE_ADD_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::add: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void add<float>(MKL_SPARSE_ADD_ARGTYPES(float)); |
| template <> |
| void add<double>(MKL_SPARSE_ADD_ARGTYPES(double)); |
| template <> |
| void add<c10::complex<float>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<float>)); |
| template <> |
| void add<c10::complex<double>>(MKL_SPARSE_ADD_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t) \ |
| const sparse_matrix_t source, sparse_index_base_t *indexing, MKL_INT *rows, \ |
| MKL_INT *cols, MKL_INT **rows_start, MKL_INT **rows_end, \ |
| MKL_INT **col_indx, scalar_t **values |
| |
| template <typename scalar_t> |
| inline void export_csr(MKL_SPARSE_EXPORT_CSR_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::export_csr: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void export_csr<float>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(float)); |
| template <> |
| void export_csr<double>(MKL_SPARSE_EXPORT_CSR_ARGTYPES(double)); |
| template <> |
| void export_csr<c10::complex<float>>( |
| MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<float>)); |
| template <> |
| void export_csr<c10::complex<double>>( |
| MKL_SPARSE_EXPORT_CSR_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_MM_ARGTYPES(scalar_t) \ |
| const sparse_operation_t operation, const scalar_t alpha, \ |
| const sparse_matrix_t A, const struct matrix_descr descr, \ |
| const sparse_layout_t layout, const scalar_t *B, const MKL_INT columns, \ |
| const MKL_INT ldb, const scalar_t beta, scalar_t *C, const MKL_INT ldc |
| |
| template <typename scalar_t> |
| inline void mm(MKL_SPARSE_MM_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::mm: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void mm<float>(MKL_SPARSE_MM_ARGTYPES(float)); |
| template <> |
| void mm<double>(MKL_SPARSE_MM_ARGTYPES(double)); |
| template <> |
| void mm<c10::complex<float>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<float>)); |
| template <> |
| void mm<c10::complex<double>>(MKL_SPARSE_MM_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_SPMMD_ARGTYPES(scalar_t) \ |
| const sparse_operation_t operation, const sparse_matrix_t A, \ |
| const sparse_matrix_t B, const sparse_layout_t layout, scalar_t *C, \ |
| const MKL_INT ldc |
| |
| template <typename scalar_t> |
| inline void spmmd(MKL_SPARSE_SPMMD_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::spmmd: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| void spmmd<float>(MKL_SPARSE_SPMMD_ARGTYPES(float)); |
| template <> |
| void spmmd<double>(MKL_SPARSE_SPMMD_ARGTYPES(double)); |
| template <> |
| void spmmd<c10::complex<float>>(MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<float>)); |
| template <> |
| void spmmd<c10::complex<double>>( |
| MKL_SPARSE_SPMMD_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_TRSV_ARGTYPES(scalar_t) \ |
| const sparse_operation_t operation, const scalar_t alpha, \ |
| const sparse_matrix_t A, const struct matrix_descr descr, \ |
| const scalar_t *x, scalar_t *y |
| |
| template <typename scalar_t> |
| inline sparse_status_t trsv(MKL_SPARSE_TRSV_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::trsv: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| sparse_status_t trsv<float>(MKL_SPARSE_TRSV_ARGTYPES(float)); |
| template <> |
| sparse_status_t trsv<double>(MKL_SPARSE_TRSV_ARGTYPES(double)); |
| template <> |
| sparse_status_t trsv<c10::complex<float>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<float>)); |
| template <> |
| sparse_status_t trsv<c10::complex<double>>(MKL_SPARSE_TRSV_ARGTYPES(c10::complex<double>)); |
| |
| #define MKL_SPARSE_TRSM_ARGTYPES(scalar_t) \ |
| const sparse_operation_t operation, const scalar_t alpha, \ |
| const sparse_matrix_t A, const struct matrix_descr descr, \ |
| const sparse_layout_t layout, const scalar_t *x, const MKL_INT columns, \ |
| const MKL_INT ldx, scalar_t *y, const MKL_INT ldy |
| |
| template <typename scalar_t> |
| inline sparse_status_t trsm(MKL_SPARSE_TRSM_ARGTYPES(scalar_t)) { |
| TORCH_INTERNAL_ASSERT( |
| false, |
| "at::mkl::sparse::trsm: not implemented for ", |
| typeid(scalar_t).name()); |
| } |
| |
| template <> |
| sparse_status_t trsm<float>(MKL_SPARSE_TRSM_ARGTYPES(float)); |
| template <> |
| sparse_status_t trsm<double>(MKL_SPARSE_TRSM_ARGTYPES(double)); |
| template <> |
| sparse_status_t trsm<c10::complex<float>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<float>)); |
| template <> |
| sparse_status_t trsm<c10::complex<double>>(MKL_SPARSE_TRSM_ARGTYPES(c10::complex<double>)); |
| |
| } // namespace at::mkl::sparse |