Update internal code for at::_lu_with_info (#56612)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56612

The goal of this refactoring is to make the `torch.linalg.solve`
to be a composition of calls to `lu_stub` and `lu_solve_stub`.
Once `lu_stub` and `lu_solve_stub` have cuSOLVER-based codepath,
`torch.linalg.solve` will have it as well.

Replaced `lu_with_info_{cpu, cuda}` with one function that calls
to `lu_stub`.
Split MAGMA-based `apply_lu` into `apply_lu_looped_magma`
and `apply_lu_batched_magma`. This simplifies the future switch to
cuSOLVER and cuBLAS libraries.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D28248756

Pulled By: mruberry

fbshipit-source-id: 40e02b5be4ff5f78885bcc95685aba581043e096
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 25c9f25..5901c97 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -212,9 +212,6 @@
 void lapackSolve(int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
 
 template<class scalar_t>
-void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
-
-template<class scalar_t>
 void lapackGetri(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int *info);
 
 template<class scalar_t>
@@ -1429,31 +1426,9 @@
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-template<typename scalar_t>
-static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) {
-#ifndef USE_LAPACK
-  AT_ERROR("lu: LAPACK library not found in compilation");
-#else
-  auto self_data = self.data_ptr<scalar_t>();
-  auto pivots_data = pivots.data_ptr<int>();
-  auto infos_data = infos.data_ptr<int>();
-  auto self_matrix_stride = matrixStride(self);
-  auto pivots_matrix_stride = pivots.size(-1);
-  auto batch_size = batchCount(self);
-  auto m = self.size(-2);
-  auto n = self.size(-1);
+DEFINE_DISPATCH(lu_stub);
 
-  for (const auto i : c10::irange(batch_size)) {
-    scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
-    int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
-    int* infos_working_ptr = &infos_data[i];
-    lapackLu<scalar_t>(m, n, self_working_ptr, m, pivots_working_ptr, infos_working_ptr);
-  }
-#endif
-}
-
-std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) {
-  TORCH_CHECK(pivot, "lu without pivoting is not implemented on the CPU");
+std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool check_errors) {
   TORCH_CHECK(self.dim() >= 2,
            "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
            " instead");
@@ -1466,15 +1441,11 @@
   req_size.pop_back();
   auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
 
-  Tensor self_working_copy;
-  if (self.numel() == 0) {
-    self_working_copy = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  } else {
-    self_working_copy = cloneBatchedColumnMajor(self);
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_cpu", [&]{
-      apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
-    });
-  }
+  // lu_stub (apply_lu) requires batched column major (Fortran-contiguous) tensors
+  // 'lu' tensor is modified in-place and must be a copy of 'self'
+  Tensor lu = cloneBatchedColumnMajor(self);
+  lu_stub(self.device().type(), lu, pivots_tensor, infos_tensor, compute_pivots);
+
   if (check_errors) {
     if (self.dim() > 2) {
       batchCheckErrors(infos_tensor, "lu", /*allow_singular=*/true);
@@ -1482,7 +1453,7 @@
       singleCheckErrors(infos_tensor.item<int64_t>(), "lu", /*allow_singular=*/true);
     }
   }
-  return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
+  return std::make_tuple(lu, pivots_tensor, infos_tensor);
 }
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h
index ec2ae8b..2ceb3ac 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.h
+++ b/aten/src/ATen/native/BatchLinearAlgebra.h
@@ -162,6 +162,9 @@
 template <class scalar_t>
 void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
 
+template <class scalar_t>
+void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
+
 #endif
 
 using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
@@ -216,6 +219,13 @@
     bool /*unitriangular*/);
 DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
 
+using lu_fn = void (*)(
+    const Tensor& /*input*/,
+    const Tensor& /*pivots*/,
+    const Tensor& /*infos*/,
+    bool /*compute_pivots*/);
+DECLARE_DISPATCH(lu_fn, lu_stub);
+
 using lu_solve_fn = void (*)(
     const Tensor& /*b*/,
     const Tensor& /*lu*/,
diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
index 1c6a64d..a858099 100644
--- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
@@ -847,6 +847,55 @@
 }
 
 /*
+  Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
+  This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
+
+  Args:
+  * `input` - [in] the input matrix for LU decomposition
+              [out] the LU decomposition
+  * `pivots` - [out] the pivot indices
+  * `infos` - [out] error codes, positive values indicate singular matrices
+  * `compute_pivots` - should always be true (can be false only for CUDA)
+
+  For further details, please see the LAPACK documentation for GETRF.
+*/
+template <typename scalar_t>
+void apply_lu(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+#ifndef USE_LAPACK
+  TORCH_CHECK(
+      false,
+      "Calling torch.lu on a CPU tensor requires compiling ",
+      "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
+#else
+  TORCH_CHECK(compute_pivots, "lu without pivoting is not implemented on the CPU");
+
+  auto input_data = input.data_ptr<scalar_t>();
+  auto pivots_data = pivots.data_ptr<int>();
+  auto infos_data = infos.data_ptr<int>();
+  auto input_matrix_stride = matrixStride(input);
+  auto pivots_stride = pivots.size(-1);
+  auto batch_size = batchCount(input);
+  auto m = input.size(-2);
+  auto n = input.size(-1);
+  auto leading_dimension = std::max<int64_t>(1, m);
+
+  for (const auto i : c10::irange(batch_size)) {
+    scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
+    int* pivots_working_ptr = &pivots_data[i * pivots_stride];
+    int* infos_working_ptr = &infos_data[i];
+    lapackLu<scalar_t>(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr);
+  }
+#endif
+}
+
+// This is a type dispatching helper function for 'apply_lu'
+void lu_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{
+    apply_lu<scalar_t>(input, pivots, infos, compute_pivots);
+  });
+}
+
+/*
   Solves the matrix equation A X = B
   X and B are n-by-nrhs matrices, A is represented using the LU factorization.
   This is an in-place routine, content of `b` is overwritten.
@@ -985,6 +1034,11 @@
 REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
 REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
 
+REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel);
+REGISTER_AVX_DISPATCH(lu_stub, &lu_kernel);
+REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel);
+REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel);
+
 REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
 REGISTER_AVX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
 REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
index d02ea53..4aba94f 100644
--- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
@@ -1790,101 +1790,150 @@
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
+/*
+  Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
+  This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
+  This is a "looped" variant for calling single input MAGMA function on batched input.
+
+  Args:
+  * `input` - [in] the input matrix for LU decomposition
+              [out] the LU decomposition
+  * `pivots` - [out] the pivot indices
+  * `infos` - [out] error codes, positive values indicate singular matrices
+  * `compute_pivots` - controls whether LU is computed with or without pivoting
+
+  For further details, please see the MAGMA documentation for magma_dgetrf_gpu.
+*/
 template <typename scalar_t>
-static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
+static void apply_lu_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
 #ifndef USE_MAGMA
-AT_ERROR("lu: MAGMA library not found in "
-    "compilation. Please rebuild with MAGMA.");
+  TORCH_CHECK(
+      false,
+      "Calling torch.lu on a CUDA tensor requires compiling ",
+      "PyTorch with MAGMA. lease rebuild with MAGMA.");
 #else
-  auto self_data = self.data_ptr<scalar_t>();
-  magma_int_t m = magma_int_cast(self.size(-2), "m");
-  magma_int_t n = magma_int_cast(self.size(-1), "n");
-  magma_int_t k = std::min(m, n);
+  // magmaLu and magmaLuNoPiv require infos and pivots tensor to be on CPU
+  // the data is later copied back to the appropriate output tensor
+  Tensor infos_cpu = at::empty_like(infos, infos.options().device(kCPU).pinned_memory(true));
 
-  if (self.dim() == 2) {
-    // If `pivots` is defined, then we have to compute them.
-    // magmaLu and magmaLuNoPiv use a hybrid CPU-GPU algorithm to compute
-    // the partially-pivoted LU decomposition with / without pivots.
-    // The driver routines magma_(d/s)getrf_(nopiv_)gpu accepts a tensor on the CPU for pivots.
-    // The data is later copied back to the appropriate output tensor.
-    Tensor info_tmp = at::zeros({}, at::kInt);
-    if (get_pivots) {
-      Tensor piv_tmp = at::empty({k}, at::kInt);
-      magmaLu<scalar_t>(
-        m, n, self_data, m, piv_tmp.data_ptr<magma_int_t>(), info_tmp.data_ptr<magma_int_t>());
-      pivots.copy_(piv_tmp);
-    } else {
-      magmaLuNoPiv<scalar_t>(m, n, self_data, m, info_tmp.data_ptr<magma_int_t>());
+  auto input_data = input.data_ptr<scalar_t>();
+  auto infos_data = infos_cpu.data_ptr<magma_int_t>();
+  auto input_matrix_stride = matrixStride(input);
+  auto pivots_stride = pivots.size(-1);
+  auto batch_size = batchCount(input);
+  magma_int_t m = magma_int_cast(input.size(-2), "m");
+  magma_int_t n = magma_int_cast(input.size(-1), "n");
+  auto leading_dimension = std::max<magma_int_t>(1, m);
+
+  if (compute_pivots) {
+    Tensor pivots_cpu = at::empty_like(pivots, pivots.options().device(kCPU).pinned_memory(true));
+    auto pivots_data = pivots_cpu.data_ptr<magma_int_t>();
+    for (decltype(batch_size) i = 0; i < batch_size; i++) {
+      scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
+      int* pivots_working_ptr = &pivots_data[i * pivots_stride];
+      int* infos_working_ptr = &infos_data[i];
+      magmaLu<scalar_t>(m, n, input_working_ptr, leading_dimension, pivots_working_ptr, infos_working_ptr);
     }
-    infos.copy_(info_tmp);
+    pivots.copy_(pivots_cpu, /*non_blocking=*/true);
   } else {
-    auto self_matrix_stride = matrixStride(self);
-    magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
+    for (decltype(batch_size) i = 0; i < batch_size; i++) {
+      scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
+      int* infos_working_ptr = &infos_data[i];
+      magmaLuNoPiv<scalar_t>(m, n, input_working_ptr, leading_dimension, infos_working_ptr);
+    }
 
-    scalar_t** self_array;
-    ALLOCATE_ARRAY(self_array, scalar_t*, batch_size);
+    // fill the pivots tensor with indices using 1-based (Fortran) indexing
+    auto k = std::min(m, n);
+    Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots);
+    pivots.copy_(pivots_tmp);
+  }
+  infos.copy_(infos_cpu, /*non_blocking=*/true);
+#endif
+}
 
-    // Set up the created arrays
+/*
+  Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
+  This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
+  This is a specialized batched variant, it is expected to be faster than the "looped" version only for small inputs.
+
+  Args:
+  * `input` - [in] the input matrix for LU decomposition
+              [out] the LU decomposition
+  * `pivots` - [out] the pivot indices
+  * `infos` - [out] error codes, positive values indicate singular matrices
+  * `compute_pivots` - controls whether LU is computed with or without pivoting
+
+  For further details, please see the MAGMA documentation for magma_dgetrf_batched.
+*/
+template <typename scalar_t>
+static void apply_lu_batched_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+#ifndef USE_MAGMA
+  TORCH_CHECK(
+      false,
+      "Calling torch.lu on a CUDA tensor requires compiling ",
+      "PyTorch with MAGMA. lease rebuild with MAGMA.");
+#else
+  auto input_data = input.data_ptr<scalar_t>();
+  auto infos_data = infos.data_ptr<magma_int_t>();
+  auto input_matrix_stride = matrixStride(input);
+  magma_int_t batch_size = magma_int_cast(batchCount(input), "batchCount");
+
+  // magmaLuBatched doesn't work with zero batch dimensions
+  // it gives CUDA error: invalid configuration argument
+  if (batch_size == 0) {
+    infos.fill_(0);
+    return;
+  }
+
+  magma_int_t m = magma_int_cast(input.size(-2), "m");
+  magma_int_t n = magma_int_cast(input.size(-1), "n");
+  auto leading_dimension = std::max<magma_int_t>(1, m);
+
+  scalar_t** input_array;
+  ALLOCATE_ARRAY(input_array, scalar_t*, batch_size);
+
+  // Set up array of pointers to matrices
+  for (int64_t i = 0; i < batch_size; i++) {
+    input_array[i] = &input_data[i * input_matrix_stride];
+  }
+
+  MAGMAQueue magma_queue(input.get_device());
+
+  if (compute_pivots) {
+    auto pivots_data = pivots.data_ptr<magma_int_t>();
+    auto pivots_stride = pivots.size(-1);
+    magma_int_t** pivots_array;
+    ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
     for (int64_t i = 0; i < batch_size; i++) {
-      self_array[i] = &self_data[i * self_matrix_stride];
+      pivots_array[i] = &pivots_data[i * pivots_stride];
     }
+    magmaLuBatched<scalar_t>(m, n, input_array, leading_dimension, pivots_array, infos_data, batch_size, magma_queue);
+  } else {
+    magmaLuNoPivBatched<scalar_t>(m, n, input_array, leading_dimension, infos_data, batch_size, magma_queue);
 
-    MAGMAQueue magma_queue(self.get_device());
-
-    // Same comment as in the case of single matrix above.
-    if (get_pivots) {
-      auto pivots_data = pivots.data_ptr<magma_int_t>();
-      auto pivots_matrix_stride = pivots.size(-1);
-      magma_int_t** pivots_array;
-      ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size);
-      for (int64_t i = 0; i < batch_size; i++) {
-        pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
-      }
-      magmaLuBatched<scalar_t>(
-        m, n, self_array, m, pivots_array,
-        infos.data_ptr<magma_int_t>(), batch_size, magma_queue);
-    } else {
-      magmaLuNoPivBatched<scalar_t>(
-        m, n, self_array, m, infos.data_ptr<magma_int_t>(),
-        batch_size, magma_queue);
-    }
+    // fill the pivots tensor with indices using 1-based (Fortran) indexing
+    auto k = std::min(m, n);
+    Tensor pivots_tmp = at::arange(1, k + 1, input.options().dtype(at::kInt)).expand_as(pivots);
+    pivots.copy_(pivots_tmp);
   }
 #endif
 }
 
-std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) {
-  TORCH_CHECK(self.dim() >= 2,
-           "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
-           " instead");
-  auto m = self.size(-2);
-  auto n = self.size(-1);
-  auto k = std::min(m, n);
-  auto req_size = self.sizes().vec();
-  req_size.pop_back();
-  req_size.back() = k;
-  Tensor pivots_tensor = at::arange(1, k + 1, self.options().dtype(at::kInt)).expand(req_size).contiguous();
-  req_size.pop_back();
-  auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
-
-  Tensor self_working_copy;
-  if (self.numel() == 0) {
-    self_working_copy = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+static void lu_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
+  // TODO: compare performance and use the best performing option based on input's sizes
+  if (input.dim() == 2) {
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma", [&]{
+      apply_lu_looped_magma<scalar_t>(input, pivots, infos, compute_pivots);
+    });
   } else {
-    self_working_copy = cloneBatchedColumnMajor(self);
-    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "lu_cuda", [&]{
-        apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor, pivot);
+    AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_magma", [&]{
+      apply_lu_batched_magma<scalar_t>(input, pivots, infos, compute_pivots);
     });
   }
-  if (check_errors) {
-    if (self.dim() == 2) {
-      singleCheckErrors(infos_tensor.item<int64_t>(), "lu", /*allow_singular=*/true);
-    } else {
-      batchCheckErrors(infos_tensor, "lu", /*allow_singular=*/true);
-    }
-  }
-  return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
 }
 
+REGISTER_DISPATCH(lu_stub, &lu_magma);
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename scalar_t>
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index aa1fdd0..8aedd82 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6332,8 +6332,7 @@
 - func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
   variants: function
   dispatch:
-    CPU: _lu_with_info_cpu
-    CUDA: _lu_with_info_cuda
+    CPU, CUDA: _lu_with_info
 
 - func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
   dispatch: