| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/native/CPUBlas.h> |
| #include <ATen/native/mkl/LinearAlgebra.h> |
| #include <ATen/native/mkldnn/Matmul.h> |
| #include <ATen/Config.h> |
| |
| #include <c10/util/SmallBuffer.h> |
| #include <c10/util/C++17.h> |
| #include <c10/util/irange.h> |
| |
| #include <climits> |
| |
| #if AT_BUILD_WITH_BLAS() |
| #if C10_IOS |
| #include <Accelerate/Accelerate.h> |
| #else |
| extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc); |
| extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc); |
| extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); |
| extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); |
| #ifdef BLAS_HAS_SBGEMM |
| extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k, |
| float *alpha, |
| const at::BFloat16 *a, int *lda, |
| const at::BFloat16 *b, int *ldb, |
| float *beta, |
| float *c, int *ldc); |
| #endif // BLAS_HAS_SBGEMM |
| extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy); |
| extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy); |
| extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy); |
| extern "C" void zcopy_(int *n, const void *x, int *incx, void *y, int *incy); |
| extern "C" void ccopy_(int *n, const void *x, int *incx, void *y, int *incy); |
| extern "C" void daxpy_(int *n, double *a, const double *x, int *incx, double *y, int *incy); |
| extern "C" void saxpy_(int *n, float *a, const float *x, int *incx, float *y, int *incy); |
| extern "C" void caxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy); |
| extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy); |
| #endif // C10_IOS |
| #endif // AT_BUILD_WITH_BLAS |
| |
| #ifdef USE_FBGEMM |
| #include <fbgemm/FbgemmI64.h> |
| #endif // USE_FBGEMM |
| |
| namespace at { |
| namespace native { |
| namespace cpublas { |
| namespace internal { |
| |
| void normalize_last_dims( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| int64_t *lda, int64_t *ldb, int64_t *ldc) { |
| if (n == 1) { |
| *ldc = m; |
| } |
| |
| if(transa != TransposeType::NoTranspose) { |
| if (m == 1) { |
| *lda = k; |
| } |
| } else if(k == 1) { |
| *lda = m; |
| } |
| |
| if(transb != TransposeType::NoTranspose) { |
| if (k == 1) { |
| *ldb = n; |
| } |
| } else if (n == 1) { |
| *ldb = k; |
| } |
| } |
| } // namespace internal |
| |
| namespace { |
| |
| bool use_blas_gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| int64_t lda, int64_t ldb, int64_t ldc) { |
| const bool transa_ = transa != TransposeType::NoTranspose; |
| const bool transb_ = transb != TransposeType::NoTranspose; |
| return ( |
| (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && |
| (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) && |
| (lda >= std::max(int64_t{1}, (transa_ ? k : m))) && |
| (ldb >= std::max(int64_t{1}, (transb_ ? n : k))) && |
| (ldc >= std::max(int64_t{1}, m))); |
| } |
| |
| #ifdef USE_FBGEMM |
| fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { |
| switch (trans) { |
| case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose; |
| case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose; |
| case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm"); |
| } |
| TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); |
| } |
| #endif // USE_FBGEMM |
| |
| #if (AT_BUILD_WITH_BLAS() && C10_IOS) |
| CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) { |
| switch (trans) { |
| case TransposeType::Transpose: return CblasTrans; |
| case TransposeType::NoTranspose: return CblasNoTrans; |
| case TransposeType::ConjTranspose: return CblasConjTrans; |
| } |
| TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); |
| } |
| #endif |
| |
| } // namespace (anonymous) |
| |
| DEFINE_DISPATCH(gemm_stub); |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const double alpha, |
| const double *a, int64_t lda, |
| const double *b, int64_t ldb, |
| const double beta, |
| double *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_BUILD_WITH_BLAS() |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| double alpha_ = alpha, beta_ = beta; |
| #if C10_IOS |
| CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); |
| CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); |
| cblas_dgemm(CblasColMajor, |
| transa_, transb_, |
| m_, n_, k_, |
| alpha_, |
| a, lda_, |
| b, ldb_, |
| beta_, |
| c, ldc_); |
| #else |
| char transa_ = to_blas(transa), transb_ = to_blas(transb); |
| dgemm_( |
| &transa_, &transb_, |
| &m_, &n_, &k_, |
| &alpha_, |
| a, &lda_, |
| b, &ldb_, |
| &beta_, |
| c, &ldc_); |
| #endif |
| return; |
| } |
| #endif |
| gemm_stub( |
| at::kCPU, at::kDouble, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const float alpha, |
| const float *a, int64_t lda, |
| const float *b, int64_t ldb, |
| const float beta, |
| float *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_BUILD_WITH_BLAS() |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| float alpha_ = alpha, beta_ = beta; |
| #if C10_IOS |
| CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); |
| CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); |
| cblas_sgemm(CblasColMajor, |
| transa_, transb_, |
| m_, n_, k_, |
| alpha_, |
| a, lda_, |
| b, ldb_, |
| beta_, |
| c, ldc_); |
| #else |
| char transa_ = to_blas(transa), transb_ = to_blas(transb); |
| sgemm_( |
| &transa_, &transb_, |
| &m_, &n_, &k_, |
| &alpha_, |
| a, &lda_, |
| b, &ldb_, |
| &beta_, |
| c, &ldc_); |
| #endif |
| return; |
| } |
| #endif |
| gemm_stub( |
| at::kCPU, at::kFloat, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const c10::complex<double> alpha, |
| const c10::complex<double> *a, int64_t lda, |
| const c10::complex<double> *b, int64_t ldb, |
| const c10::complex<double> beta, |
| c10::complex<double> *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_BUILD_WITH_BLAS() |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| c10::complex<double> alpha_ = alpha, beta_ = beta; |
| #if C10_IOS |
| CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); |
| CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); |
| cblas_zgemm(CblasColMajor, |
| transa_, transb_, |
| m_, n_, k_, |
| &alpha_, |
| a, lda_, |
| b, ldb_, |
| &beta_, |
| c, ldc_); |
| #else |
| char transa_ = to_blas(transa), transb_ = to_blas(transb); |
| zgemm_( |
| &transa_, &transb_, |
| &m_, &n_, &k_, |
| &alpha_, |
| a, &lda_, |
| b, &ldb_, |
| &beta_, |
| c, &ldc_); |
| #endif |
| return; |
| } |
| #endif |
| gemm_stub( |
| at::kCPU, at::kComplexDouble, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const c10::complex<float> alpha, |
| const c10::complex<float> *a, int64_t lda, |
| const c10::complex<float> *b, int64_t ldb, |
| const c10::complex<float> beta, |
| c10::complex<float> *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_BUILD_WITH_BLAS() |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| c10::complex<float> alpha_ = alpha, beta_ = beta; |
| #if C10_IOS |
| CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa); |
| CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb); |
| cblas_cgemm(CblasColMajor, |
| transa_, transb_, |
| m_, n_, k_, |
| &alpha_, |
| a, lda_, |
| b, ldb_, |
| &beta_, |
| c, ldc_); |
| #else |
| char transa_ = to_blas(transa), transb_ = to_blas(transb); |
| cgemm_( |
| &transa_, &transb_, |
| &m_, &n_, &k_, |
| &alpha_, |
| a, &lda_, |
| b, &ldb_, |
| &beta_, |
| c, &ldc_); |
| #endif |
| return; |
| } |
| #endif |
| gemm_stub( |
| at::kCPU, at::kComplexFloat, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const float alpha, |
| const at::BFloat16 *a, int64_t lda, |
| const at::BFloat16 *b, int64_t ldb, |
| const float beta, |
| at::BFloat16 *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| char transa_ = to_blas(transa), transb_ = to_blas(transb); |
| float alpha_ = alpha, beta_ = beta; |
| int c_size = n_ * ldc_; |
| // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back. |
| std::vector<float> float_v(c, c + c_size); |
| sbgemm_(&transa_, &transb_, |
| &m_, &n_, &k_, |
| &alpha_, |
| a, &lda_, |
| b, &ldb_, |
| &beta_, |
| float_v.data(), &ldc_); |
| for (auto cv: float_v) { |
| *(c++) = c10::convert<at::BFloat16>(cv); |
| } |
| return; |
| } |
| #endif |
| #if AT_MKLDNN_ENABLED() |
| if (mkldnn_bf16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { |
| return; |
| } |
| #endif |
| gemm_stub( |
| at::kCPU, at::kBFloat16, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const float alpha, |
| const at::Half *a, int64_t lda, |
| const at::Half *b, int64_t ldb, |
| const float beta, |
| at::Half *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_MKLDNN_ENABLED() |
| if (mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { |
| return; |
| } |
| #endif |
| gemm_stub( |
| at::kCPU, at::kHalf, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const float alpha, |
| const at::BFloat16 *a, int64_t lda, |
| const at::BFloat16 *b, int64_t ldb, |
| const float beta, |
| float *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM) |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| char transa_ = to_blas(transa), transb_ = to_blas(transb); |
| float alpha_ = alpha, beta_ = beta; |
| sbgemm_(&transa_, &transb_, |
| &m_, &n_, &k_, |
| &alpha_, |
| a, &lda_, |
| b, &ldb_, |
| &beta_, |
| c, &ldc_); |
| return; |
| } |
| #endif |
| #ifdef MKL_HAS_SBGEMM |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; |
| mkl_gemm_bf16bf16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_); |
| return; |
| } |
| #endif |
| // for the fallback path, first compute gemm with beta = 0, |
| // and then add c in full precision. |
| int64_t c_size = n * m; |
| std::vector<at::BFloat16> bfloat_c(c_size, 0.f); |
| gemm_stub( |
| at::kCPU, at::kBFloat16, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m); |
| for (const auto j : c10::irange(n)) { |
| for (const auto i : c10::irange(m)) { |
| auto offset = j * ldc + i; |
| // beta == 0 won't propagate NaN from C |
| if (beta == 0.f) { |
| c[offset] = c10::convert<float>(bfloat_c[j * m + i]); |
| } else { |
| c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]); |
| } |
| } |
| } |
| } |
| |
| void gemm( |
| TransposeType transa, TransposeType transb, |
| int64_t m, int64_t n, int64_t k, |
| const int64_t alpha, |
| const int64_t *a, int64_t lda, |
| const int64_t *b, int64_t ldb, |
| const int64_t beta, |
| int64_t *c, int64_t ldc) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| #ifdef USE_FBGEMM |
| if (alpha == 1 && (beta == 0 || beta == 1)) { |
| // In FBGEMM, we assume row-major ordering; However, here we assume the |
| // column-major ordering following the FORTRAN tradition in BLAS interface |
| // in this function: we can configure the layout (row/column-major ordering) |
| // of A and B by changing transa_ and transb_, but we cannot change the |
| // layout of C with this FORTRAN-style BLAS interface. |
| // |
| // The workaround is that we compute |
| // C^T (n x m) = B^T (n x k) * A^T (k x m) instead. |
| // |
| // In this way we view C^T as the row-major ordering when passing to FBGEMM. |
| fbgemm::cblas_gemm_i64_i64acc( |
| to_fbgemm(transb), |
| to_fbgemm(transa), |
| n, |
| m, |
| k, |
| b, |
| ldb, |
| a, |
| lda, |
| beta == 1, |
| c, |
| ldc); |
| return; |
| } |
| #endif |
| |
| gemm_stub( |
| kCPU, kLong, |
| transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| template <typename scalar_t> |
| static void gemm_batched_mkl_impl( |
| TransposeType transa, TransposeType transb, |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, |
| scalar_t alpha, |
| const scalar_t **a, int64_t lda, |
| const scalar_t **b, int64_t ldb, |
| scalar_t beta, |
| scalar_t **c, int64_t ldc) { |
| for (int64_t i = 0; i < batch_size;) { |
| int sub_batch = std::min(batch_size - i, int64_t{INT_MAX}); |
| mkl_gemm_batched(transa, transb, sub_batch, m, n, k, alpha, |
| &a[i], lda, &b[i], ldb, beta, &c[i], ldc); |
| i += sub_batch; |
| } |
| } |
| |
| template <typename scalar_t> |
| using is_blas_library_type = std::integral_constant<bool, |
| std::is_same<scalar_t, double>::value || |
| std::is_same<scalar_t, float>::value || |
| std::is_same<scalar_t, c10::complex<double>>::value || |
| std::is_same<scalar_t, c10::complex<float>>::value>; |
| |
| template <typename scalar_t> |
| void gemm_batched_generic( |
| TransposeType transa, TransposeType transb, |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, |
| scalar_t alpha, |
| const scalar_t **a, int64_t lda, |
| const scalar_t **b, int64_t ldb, |
| scalar_t beta, |
| scalar_t **c, int64_t ldc) { |
| for (const auto batch : c10::irange(batch_size)) { |
| gemm(transa, transb, m, n, k, alpha, a[batch], lda, b[batch], ldb, beta, c[batch], ldc); |
| } |
| } |
| |
| template <typename scalar_t> |
| void gemm_batched( |
| TransposeType transa, TransposeType transb, |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, |
| scalar_t alpha, |
| const scalar_t **a, int64_t lda, |
| const scalar_t **b, int64_t ldb, |
| scalar_t beta, |
| scalar_t **c, int64_t ldc) { |
| if (batch_size == 1) { |
| return gemm(transa, transb, m, n, k, alpha, a[0], lda, b[0], ldb, beta, c[0], ldc); |
| } |
| |
| if constexpr (AT_MKL_ENABLED() && is_blas_library_type<scalar_t>::value) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| gemm_batched_mkl_impl( |
| transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } else { |
| gemm_batched_generic( |
| transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| } else { |
| gemm_batched_generic( |
| transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| } |
| |
| template <typename scalar_t> |
| void gemm_batched_with_stride_generic( |
| TransposeType transa, TransposeType transb, |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, |
| scalar_t alpha, |
| const scalar_t *a, int64_t lda, int64_t batch_stride_a, |
| const scalar_t *b, int64_t ldb, int64_t batch_stride_b, |
| scalar_t beta, |
| scalar_t *c, int64_t ldc, int64_t batch_stride_c) { |
| for (const auto batch : c10::irange(batch_size)) { |
| const auto a_batch = a + batch_stride_a * batch; |
| const auto b_batch = b + batch_stride_b * batch; |
| const auto c_batch = c + batch_stride_c * batch; |
| gemm(transa, transb, m, n, k, alpha, a_batch, lda, b_batch, ldb, beta, c_batch, ldc); |
| } |
| } |
| |
| template <typename scalar_t> |
| void gemm_batched_with_stride( |
| TransposeType transa, TransposeType transb, |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, |
| scalar_t alpha, |
| const scalar_t *a, int64_t lda, int64_t batch_stride_a, |
| const scalar_t *b, int64_t ldb, int64_t batch_stride_b, |
| scalar_t beta, |
| scalar_t *c, int64_t ldc, int64_t batch_stride_c) { |
| if (batch_size == 1) { |
| return gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); |
| } |
| |
| if constexpr (AT_MKL_ENABLED() && is_blas_library_type<scalar_t>::value) { |
| internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); |
| if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { |
| c10::SmallBuffer<const scalar_t*, 16> a_ptrs(batch_size); |
| c10::SmallBuffer<const scalar_t*, 16> b_ptrs(batch_size); |
| c10::SmallBuffer<scalar_t*, 16> c_ptrs(batch_size); |
| |
| for (const auto batch : c10::irange(batch_size)) { |
| a_ptrs[batch] = a + batch_stride_a * batch; |
| b_ptrs[batch] = b + batch_stride_b * batch; |
| c_ptrs[batch] = c + batch_stride_c * batch; |
| } |
| gemm_batched_mkl_impl( |
| transa, transb, batch_size, m, n, k, alpha, a_ptrs.data(), lda, |
| b_ptrs.data(), ldb, beta, c_ptrs.data(), ldc); |
| } else { |
| gemm_batched_with_stride_generic( |
| transa, transb, batch_size, m, n, k, alpha, a, lda, batch_stride_a, |
| b, ldb, batch_stride_b, beta, c, ldc, batch_stride_c); |
| } |
| } else { |
| gemm_batched_with_stride_generic(transa, transb, batch_size, m, n, k, alpha, |
| a, lda, batch_stride_a, b, ldb, batch_stride_b, |
| beta, c, ldc, batch_stride_c); |
| } |
| } |
| |
| #define INSTANTIATE_BATCHED_GEMM(scalar_t, DType) \ |
| template void gemm_batched( \ |
| TransposeType transa, TransposeType transb, \ |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, \ |
| scalar_t alpha, \ |
| const scalar_t **a, int64_t lda, \ |
| const scalar_t **b, int64_t ldb, \ |
| scalar_t beta, \ |
| scalar_t **c, int64_t ldc); \ |
| template void gemm_batched_with_stride( \ |
| TransposeType transa, TransposeType transb, \ |
| int64_t batch_size, int64_t m, int64_t n, int64_t k, \ |
| scalar_t alpha, \ |
| const scalar_t *a, int64_t lda, int64_t batch_stride_a, \ |
| const scalar_t *b, int64_t ldb, int64_t batch_stride_b, \ |
| scalar_t beta, \ |
| scalar_t *c, int64_t ldc, int64_t batch_stride_c); |
| |
| AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(INSTANTIATE_BATCHED_GEMM) |
| |
| DEFINE_DISPATCH(axpy_stub); |
| |
| void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) |
| { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_daxpy(i_n, a, x, i_incx, y, i_incy); |
| #else |
| daxpy_(&i_n, &a, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| axpy_stub( |
| kCPU, at::kDouble, |
| n, a, x, incx, y, incy); |
| } |
| |
| void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) |
| { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_saxpy(i_n, a, x, i_incx, y, i_incy); |
| #else |
| saxpy_(&i_n, &a, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| axpy_stub( |
| kCPU, at::kFloat, |
| n, a, x, incx, y, incy); |
| } |
| |
| void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) |
| { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy); |
| #else |
| zaxpy_(&i_n, &a, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| axpy_stub( |
| kCPU, at::kComplexDouble, |
| n, a, x, incx, y, incy); |
| } |
| |
| void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) |
| { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_caxpy(i_n, &a, x, i_incx, y, i_incy); |
| #else |
| caxpy_(&i_n, &a, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| axpy_stub( |
| kCPU, at::kComplexFloat, |
| n, a, x, incx, y, incy); |
| } |
| |
| DEFINE_DISPATCH(copy_stub); |
| |
| void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_dcopy(i_n, x, i_incx, y, i_incy); |
| #else |
| dcopy_(&i_n, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| copy_stub( |
| kCPU, at::kDouble, |
| n, x, incx, y, incy); |
| } |
| |
| void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_scopy(i_n, x, i_incx, y, i_incy); |
| #else |
| scopy_(&i_n, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| copy_stub( |
| kCPU, at::kFloat, |
| n, x, incx, y, incy); |
| } |
| |
| void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy) { |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_zcopy(i_n, x, i_incx, y, i_incy); |
| #else |
| zcopy_(&i_n, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| copy_stub( |
| kCPU, at::kComplexDouble, |
| n, x, incx, y, incy); |
| } |
| |
| void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy){ |
| if(n == 1) |
| { |
| incx = 1; |
| incy = 1; |
| } |
| #if AT_BUILD_WITH_BLAS() |
| if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { |
| int i_n = (int)n; |
| int i_incx = (int)incx; |
| int i_incy = (int)incy; |
| #if C10_IOS |
| cblas_ccopy(i_n, &x, i_incx, y, i_incy); |
| #else |
| ccopy_(&i_n, x, &i_incx, y, &i_incy); |
| #endif |
| return; |
| } |
| #endif |
| copy_stub( |
| kCPU, at::kComplexFloat, |
| n, x, incx, y, incy); |
| } |
| |
| }}} // namespace at::native::cpublas |