Update internal code for at::_lu_with_info (#56612)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56612
The goal of this refactoring is to make the `torch.linalg.solve`
to be a composition of calls to `lu_stub` and `lu_solve_stub`.
Once `lu_stub` and `lu_solve_stub` have cuSOLVER-based codepath,
`torch.linalg.solve` will have it as well.
Replaced `lu_with_info_{cpu, cuda}` with one function that calls
to `lu_stub`.
Split MAGMA-based `apply_lu` into `apply_lu_looped_magma`
and `apply_lu_batched_magma`. This simplifies the future switch to
cuSOLVER and cuBLAS libraries.
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D28248756
Pulled By: mruberry
fbshipit-source-id: 40e02b5be4ff5f78885bcc95685aba581043e096
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 25c9f25..5901c97 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -212,9 +212,6 @@
void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
template<class scalar_t>
-void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
-
-template<class scalar_t>
void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int *info);
template<class scalar_t>
@@ -1429,31 +1426,9 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-template<typename scalar_t>
-static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) {
-#ifndef USE_LAPACK
- AT_ERROR("lu: LAPACK library not found in compilation");
-#else
- auto self_data = self.data_ptr<scalar_t>();
- auto pivots_data = pivots.data_ptr<int>();
- auto infos_data = infos.data_ptr<int>();
- auto self_matrix_stride = matrixStride(self);
- auto pivots_matrix_stride = pivots.size(-1);
- auto batch_size = batchCount(self);
- auto m = self.size(-2);
- auto n = self.size(-1);
+DEFINE_DISPATCH(lu_stub);
- for (const auto i : c10::irange(batch_size)) {
- scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
- int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
- int* infos_working_ptr = &infos_data[i];
- lapackLu<scalar_t>(m, n, self_working_ptr, m, pivots_working_ptr, infos_working_ptr);
- }
-#endif
-}
-
-std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) {
- TORCH_CHECK(pivot, "lu without pivoting is not implemented on the CPU");
+std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) {
TORCH_CHECK(self.dim() >= 2,
"expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
@@ -1466,15 +1441,11 @@
req_size.pop_back();
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
- Tensor self_working_copy;
- if (self.numel() == 0) {
- self_working_copy = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- } else {
- self_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_cpu", [&]{
- apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
- });
- }
+ // lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors
+ // 'lu' tensor is modified in-place and must be a copy of 'self'
+ Tensor lu = cloneBatchedColumnMajor(self);
+ lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots);
+
if (check_errors) {
if (self.dim() > 2) {
batchCheckErrors(infos_tensor, "lu", /*allow_singular=*/true);
@@ -1482,7 +1453,7 @@
singleCheckErrors(infos_tensor.item<int64_t>(), "lu", /*allow_singular=*/true);
}
}
- return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
+ return std::make_tuple(lu, pivots_tensor, infos_tensor);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h
index ec2ae8b..2ceb3ac 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.h
+++ b/aten/src/ATen/native/BatchLinearAlgebra.h
@@ -162,6 +162,9 @@
template <class scalar_t>
void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
+template <class scalar_t>
+void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
+
#endif
using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
@@ -216,6 +219,13 @@
bool /*unitriangular*/);
DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
+using lu_fn = void (*)(
+ const Tensor& /*input*/,
+ const Tensor& /*pivots*/,
+ const Tensor& /*infos*/,
+ bool /*compute_pivots*/);
+DECLARE_DISPATCH(lu_fn, lu_stub);
+
using lu_solve_fn = void (*)(
const Tensor& /*b*/,
const Tensor& /*lu*/,
diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
index 1c6a64d..a858099 100644
--- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
@@ -847,6 +847,55 @@
}
/*
+ Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
+ This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
+
+ Args:
+ * `input` - [in] the input matrix for LU decomposition
+ [out] the LU decomposition
+ * `pivots` - [out] the pivot indices
+ * `infos` - [out] error codes, positive values indicate singular matrices
+ * `compute_pivots` - should always be true (can be false only for CUDA)
+
+ For further details, please see the LAPACK documentation for GETRF.
+*/
+template <typename scalar_t>
+void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+#ifndef USE_LAPACK
+ TORCH_CHECK(
+ false,
+ "Calling torch.lu on a CPU tensor requires compiling ",
+ "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
+#else
+ TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU");
+
+ auto input_data = input.data_ptr<scalar_t>();
+ auto pivots_data = pivots.data_ptr<int>();
+ auto infos_data = infos.data_ptr<int>();
+ auto input_matrix_stride = matrixStride(input);
+ auto pivots_stride = pivots.size(-1);
+ auto batch_size = batchCount(input);
+ auto m = input.size(-2);
+ auto n = input.size(-1);
+ auto leading_dimension = std::max<int64_t>(1, m);
+
+ for (const auto i : c10::irange(batch_size)) {
+ scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
+ int* pivots_working_ptr = &pivots_data[i * pivots_stride];
+ int* infos_working_ptr = &infos_data[i];
+ lapackLu<scalar_t>(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr);
+ }
+#endif
+}
+
+// This is a type dispatching helper function for 'apply_lu'
+void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{
+ apply_lu<scalar_t>(input, pivots, infos, compute_pivots);
+ });
+}
+
+/*
Solves the matrix equation A X = B
X and B are n-by-nrhs matrices, A is represented using the LU factorization.
This is an in-place routine, content of `b` is overwritten.
@@ -985,6 +1034,11 @@
REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
+REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
+REGISTER_AVX_DISPATCH(lu_stub, &lu_kernel);
+REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
+REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
+
REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
REGISTER_AVX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
index d02ea53..4aba94f 100644
--- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
@@ -1790,101 +1790,150 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/*
+ Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
+ This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
+ This is a "looped" variant for calling single input MAGMA function on batched input.
+
+ Args:
+ * `input` - [in] the input matrix for LU decomposition
+ [out] the LU decomposition
+ * `pivots` - [out] the pivot indices
+ * `infos` - [out] error codes, positive values indicate singular matrices
+ * `compute_pivots` - controls whether LU is computed with or without pivoting
+
+ For further details, please see the MAGMA documentation for magma_dgetrf_gpu.
+*/
template <typename scalar_t>
-static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
+static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
#ifndef USE_MAGMA
-AT_ERROR("lu: MAGMA library not found in "
- "compilation. Please rebuild with MAGMA.");
+ TORCH_CHECK(
+ false,
+ "Calling torch.lu on a CUDA tensor requires compiling ",
+ "PyTorch with MAGMA. lease rebuild with MAGMA.");
#else
- auto self_data = self.data_ptr<scalar_t>();
- magma_int_t m = magma_int_cast(self.size(-2), "m");
- magma_int_t n = magma_int_cast(self.size(-1), "n");
- magma_int_t k = std::min(m, n);
+ // magmaLu and magmaLuNoPiv require infos and pivots tensor to be on CPU
+ // the data is later copied back to the appropriate output tensor
+ Tensor infos_cpu = at::empty_like(infos, infos.options().device(kCPU).pinned_memory(true));
- if (self.dim() == 2) {
- // If `pivots` is defined, then we have to compute them.
- // magmaLu and magmaLuNoPiv use a hybrid CPU-GPU algorithm to compute
- // the partially-pivoted LU decomposition with / without pivots.
- // The driver routines magma_(d/s)getrf_(nopiv_)gpu accepts a tensor on the CPU for pivots.
- // The data is later copied back to the appropriate output tensor.
- Tensor info_tmp = at::zeros({}, at::kInt);
- if (get_pivots) {
- Tensor piv_tmp = at::empty({k}, at::kInt);
- magmaLu<scalar_t>(
- m, n, self_data, m, piv_tmp.data_ptr<magma_int_t>(), info_tmp.data_ptr<magma_int_t>());
- pivots.copy_(piv_tmp);
- } else {
- magmaLuNoPiv<scalar_t>(m, n, self_data, m, info_tmp.data_ptr<magma_int_t>());
+ auto input_data = input.data_ptr<scalar_t>();
+ auto infos_data = infos_cpu.data_ptr<magma_int_t>();
+ auto input_matrix_stride = matrixStride(input);
+ auto pivots_stride = pivots.size(-1);
+ auto batch_size = batchCount(input);
+ magma_int_t m = magma_int_cast(input.size(-2), "m");
+ magma_int_t n = magma_int_cast(input.size(-1), "n");
+ auto leading_dimension = std::max<magma_int_t>(1, m);
+
+ if (compute_pivots) {
+ Tensor pivots_cpu = at::empty_like(pivots, pivots.options().device(kCPU).pinned_memory(true));
+ auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();
+ for (decltype(batch_size) i = 0; i < batch_size; i++) {
+ scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
+ int* pivots_working_ptr = &pivots_data[i * pivots_stride];
+ int* infos_working_ptr = &infos_data[i];
+ magmaLu<scalar_t>(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr);
}
- infos.copy_(info_tmp);
+ pivots.copy_(pivots_cpu, /*non_blocking=*/true);
} else {
- auto self_matrix_stride = matrixStride(self);
- magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
+ for (decltype(batch_size) i = 0; i < batch_size; i++) {
+ scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
+ int* infos_working_ptr = &infos_data[i];
+ magmaLuNoPiv<scalar_t>(m, n, input_working_ptr, leading_dimension, infos_working_ptr);
+ }
- scalar_t** self_array;
- ALLOCATE_ARRAY(self_array, scalar_t*, batch_size);
+ // fill the pivots tensor with indices using 1-based (Fortran) indexing
+ auto k = std::min(m, n);
+ Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots);
+ pivots.copy_(pivots_tmp);
+ }
+ infos.copy_(infos_cpu, /*non_blocking=*/true);
+#endif
+}
- // Set up the created arrays
+/*
+ Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
+ This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
+ This is a specialized batched variant, it is expected to be faster than the "looped" version only for small inputs.
+
+ Args:
+ * `input` - [in] the input matrix for LU decomposition
+ [out] the LU decomposition
+ * `pivots` - [out] the pivot indices
+ * `infos` - [out] error codes, positive values indicate singular matrices
+ * `compute_pivots` - controls whether LU is computed with or without pivoting
+
+ For further details, please see the MAGMA documentation for magma_dgetrf_batched.
+*/
+template <typename scalar_t>
+static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+#ifndef USE_MAGMA
+ TORCH_CHECK(
+ false,
+ "Calling torch.lu on a CUDA tensor requires compiling ",
+ "PyTorch with MAGMA. lease rebuild with MAGMA.");
+#else
+ auto input_data = input.data_ptr<scalar_t>();
+ auto infos_data = infos.data_ptr<magma_int_t>();
+ auto input_matrix_stride = matrixStride(input);
+ magma_int_t batch_size = magma_int_cast(batchCount(input), "batchCount");
+
+ // magmaLuBatched doesn't work with zero batch dimensions
+ // it gives CUDA error: invalid configuration argument
+ if (batch_size == 0) {
+ infos.fill_(0);
+ return;
+ }
+
+ magma_int_t m = magma_int_cast(input.size(-2), "m");
+ magma_int_t n = magma_int_cast(input.size(-1), "n");
+ auto leading_dimension = std::max<magma_int_t>(1, m);
+
+ scalar_t** input_array;
+ ALLOCATE_ARRAY(input_array, scalar_t*, batch_size);
+
+ // Set up array of pointers to matrices
+ for (int64_t i = 0; i < batch_size; i++) {
+ input_array[i] = &input_data[i * input_matrix_stride];
+ }
+
+ MAGMAQueue magma_queue(input.get_device());
+
+ if (compute_pivots) {
+ auto pivots_data = pivots.data_ptr<magma_int_t>();
+ auto pivots_stride = pivots.size(-1);
+ magma_int_t** pivots_array;
+ ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
for (int64_t i = 0; i < batch_size; i++) {
- self_array[i] = &self_data[i * self_matrix_stride];
+ pivots_array[i] = &pivots_data[i * pivots_stride];
}
+ magmaLuBatched<scalar_t>(m, n, input_array, leading_dimension, pivots_array, infos_data, batch_size, magma_queue);
+ } else {
+ magmaLuNoPivBatched<scalar_t>(m, n, input_array, leading_dimension, infos_data, batch_size, magma_queue);
- MAGMAQueue magma_queue(self.get_device());
-
- // Same comment as in the case of single matrix above.
- if (get_pivots) {
- auto pivots_data = pivots.data_ptr<magma_int_t>();
- auto pivots_matrix_stride = pivots.size(-1);
- magma_int_t** pivots_array;
- ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
- for (int64_t i = 0; i < batch_size; i++) {
- pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
- }
- magmaLuBatched<scalar_t>(
- m, n, self_array, m, pivots_array,
- infos.data_ptr<magma_int_t>(), batch_size, magma_queue);
- } else {
- magmaLuNoPivBatched<scalar_t>(
- m, n, self_array, m, infos.data_ptr<magma_int_t>(),
- batch_size, magma_queue);
- }
+ // fill the pivots tensor with indices using 1-based (Fortran) indexing
+ auto k = std::min(m, n);
+ Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots);
+ pivots.copy_(pivots_tmp);
}
#endif
}
-std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) {
- TORCH_CHECK(self.dim() >= 2,
- "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
- " instead");
- auto m = self.size(-2);
- auto n = self.size(-1);
- auto k = std::min(m, n);
- auto req_size = self.sizes().vec();
- req_size.pop_back();
- req_size.back() = k;
- Tensor pivots_tensor = at::arange(1, k + 1, self.options().dtype(at::kInt)).expand(req_size).contiguous();
- req_size.pop_back();
- auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
-
- Tensor self_working_copy;
- if (self.numel() == 0) {
- self_working_copy = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+static void lu_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+ // TODO: compare performance and use the best performing option based on input's sizes
+ if (input.dim() == 2) {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma", [&]{
+ apply_lu_looped_magma<scalar_t>(input, pivots, infos, compute_pivots);
+ });
} else {
- self_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_cuda", [&]{
- apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor, pivot);
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma", [&]{
+ apply_lu_batched_magma<scalar_t>(input, pivots, infos, compute_pivots);
});
}
- if (check_errors) {
- if (self.dim() == 2) {
- singleCheckErrors(infos_tensor.item<int64_t>(), "lu", /*allow_singular=*/true);
- } else {
- batchCheckErrors(infos_tensor, "lu", /*allow_singular=*/true);
- }
- }
- return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}
+REGISTER_DISPATCH(lu_stub, &lu_magma);
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index aa1fdd0..8aedd82 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6332,8 +6332,7 @@
- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
- CPU: _lu_with_info_cpu
- CUDA: _lu_with_info_cuda
+ CPU, CUDA: _lu_with_info
- func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
dispatch: