blob: 8830fe732fdcdb8a3bbd5d437e90378a89e4a586 [file] [log] [blame]
#include <ATen/Context.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDASolver.h>
#include <c10/cuda/CUDACachingAllocator.h>
#ifdef CUDART_VERSION
namespace at {
namespace cuda {
namespace solver {
template <>
void getrf<double>(
cusolverDnHandle_t handle, int m, int n, double* dA, int ldda, int* ipiv, int* info) {
int lwork;
TORCH_CUSOLVER_CHECK(
cusolverDnDgetrf_bufferSize(handle, m, n, dA, ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(double)*lwork);
TORCH_CUSOLVER_CHECK(cusolverDnDgetrf(
handle, m, n, dA, ldda, static_cast<double*>(dataPtr.get()), ipiv, info));
}
template <>
void getrf<float>(
cusolverDnHandle_t handle, int m, int n, float* dA, int ldda, int* ipiv, int* info) {
int lwork;
TORCH_CUSOLVER_CHECK(
cusolverDnSgetrf_bufferSize(handle, m, n, dA, ldda, &lwork));
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(sizeof(float)*lwork);
TORCH_CUSOLVER_CHECK(cusolverDnSgetrf(
handle, m, n, dA, ldda, static_cast<float*>(dataPtr.get()), ipiv, info));
}
template <>
void getrs<double>(
cusolverDnHandle_t handle, int n, int nrhs, double* dA, int lda, int* ipiv, double* ret, int ldb, int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnDgetrs(
handle, CUBLAS_OP_N, n, nrhs, dA, lda, ipiv, ret, ldb, info));
}
template <>
void getrs<float>(
cusolverDnHandle_t handle, int n, int nrhs, float* dA, int lda, int* ipiv, float* ret, int ldb, int* info) {
TORCH_CUSOLVER_CHECK(cusolverDnSgetrs(
handle, CUBLAS_OP_N, n, nrhs, dA, lda, ipiv, ret, ldb, info));
}
} // namespace solver
} // namespace cuda
} // namespace at
#endif // CUDART_VERSION