blob: 912997572071961e2bbdc48eee41d8d12aff263c [file] [log] [blame]
#include "ATen/Context.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"
#include "ATen/PinnedMemoryAllocator.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
#include "ATen/native/LinearAlgebraUtils.h"
#include "ATen/native/Gesv.h"
#include "THC.h" // for USE_MAGMA
#ifdef USE_MAGMA
#include <magma.h>
#include <magma_types.h>
#endif
namespace at {
namespace native {
#ifdef USE_MAGMA
template<class scalar_t>
void magmaGesvBatched(
magma_int_t n, magma_int_t nrhs, scalar_t** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, scalar_t** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
AT_ERROR("gesv only takes float or double Tensors");
}
template<>
void magmaGesvBatched<float>(
magma_int_t n, magma_int_t nrhs, float** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, float** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
magma_sgesv_batched(
n, nrhs, dA_array, ldda, dipiv_array,
dB_array, lddb, dinfo_array, batch_count, queue);
}
template<>
void magmaGesvBatched<double>(
magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
magma_int_t** dipiv_array, double** dB_array, magma_int_t lddb,
magma_int_t* dinfo_array, magma_int_t batch_count, magma_queue_t queue) {
magma_dgesv_batched(
n, nrhs, dA_array, ldda, dipiv_array,
dB_array, lddb, dinfo_array, batch_count, queue);
}
static magma_queue_t createMagmaQueue(const Tensor& tensor) {
auto& context = tensor.type().get_context();
magma_queue_t magma_queue;
magma_queue_create_from_cuda(
tensor.get_device(),
context.getCurrentCUDAStream(),
THCState_getCurrentBlasHandle(context.thc_state),
THCState_getCurrentSparseHandle(context.thc_state),
&magma_queue);
return magma_queue;
}
#endif
static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
auto result = static_cast<magma_int_t>(value);
if (static_cast<int64_t>(result) != value) {
AT_ERROR("magma: The value of %s (%lld) is too large to fit into a magma_int_t (%llu bytes)",
varname, (long long)value, sizeof(magma_int_t));
}
return result;
}
// Creates an array of size elements of type T, backed by pinned memory
// wrapped in a Storage
template<class T>
static inline std::unique_ptr<Storage> pin_memory(int64_t size, Tensor dummy) {
int64_t adjusted_size = size * sizeof(T);
auto allocator = std::unique_ptr<Allocator>(new PinnedMemoryAllocator());
auto& backend = dummy.type().toBackend(kCPU).toScalarType(kByte);
return backend.storageWithAllocator(adjusted_size, std::move(allocator));
}
#define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
auto storage_##name = pin_memory<type>(size, dummy_tensor); \
name = reinterpret_cast<type*>(storage_##name->data());
template <typename scalar_t>
static void applyGesv(Tensor& b, Tensor& A, std::vector<int64_t> infos) {
#ifndef USE_MAGMA
AT_ERROR("gesv: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
auto A_data = A.data<scalar_t>();
auto b_data = b.data<scalar_t>();
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
magma_int_t batch_size = magma_int_cast(batchCount(A), "batchCount");
magma_int_t n = magma_int_cast(A.size(-2), "A.size(-2)");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
magma_int_t* info_array;
magma_int_t* ipiv_data;
magma_int_t** ipiv_array;
scalar_t** A_array;
scalar_t** b_array;
ALLOCATE_ARRAY(info_array, magma_int_t, batch_size, b);
ALLOCATE_ARRAY(ipiv_data, magma_int_t, batch_size * n, b);
ALLOCATE_ARRAY(ipiv_array, magma_int_t*, batch_size, b);
ALLOCATE_ARRAY(A_array, scalar_t*, batch_size, b);
ALLOCATE_ARRAY(b_array, scalar_t*, batch_size, b);
// Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
A_array[i] = &A_data[i * A_mat_stride];
b_array[i] = &b_data[i * b_mat_stride];
ipiv_array[i] = &ipiv_data[i * n];
}
magmaGesvBatched<scalar_t>(
n, nrhs, A_array, n, ipiv_array, b_array, n,
info_array, batch_size, createMagmaQueue(b));
for (int64_t i = 0; i < batch_size; i++) {
infos[i] = info_array[i];
}
#endif
}
std::tuple<Tensor,Tensor> _gesv_helper_cuda(const Tensor& self, const Tensor& A) {
std::vector<int64_t> infos(batchCount(A), 0);
auto A_working_copy = cloneBatchedColumnMajor(A);
auto b_working_copy = cloneBatchedColumnMajor(self);
AT_DISPATCH_FLOATING_TYPES(self.type(), "gesv", [&]{
applyGesv<scalar_t>(b_working_copy, A_working_copy, infos);
});
checkErrors(infos);
return std::tuple<Tensor,Tensor>(b_working_copy, A_working_copy);
}
}} // namespace at::native
#undef ALLOCATE_ARRAY