ATen lu_unpack. Required for making `torch.lu_solve` differentiable. (#46913)

Summary:
Backward methods for `torch.lu` and `torch.lu_solve` require the `torch.lu_unpack` method.
However, while `torch.lu` is a Python wrapper over a native function, so its gradient is implemented via `autograd.Function`,
`torch.lu_solve` is a native function, so it cannot access `torch.lu_unpack` as it is implemented in Python.

Hence this PR presents a native (ATen) `lu_unpack` version. It is also possible to update the gradients for `torch.lu` so that backward+JIT is supported (no JIT for `autograd.Function`) with this function.

~~The interface for this method is different from the original `torch.lu_unpack`, so it is decided to keep it hidden.~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46913

Reviewed By: albanD

Differential Revision: D28355725

Pulled By: mruberry

fbshipit-source-id: 281260f3b6e93c15b08b2ba66d5a221314b00e78
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 3ffe47a..2acdfd0 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -22,6 +22,8 @@
 #include <functional>
 #include <limits>
 #include <numeric>
+#include <ATen/NamedTensorUtils.h>
+#include <ATen/native/TensorIterator.h>
 
 namespace at {
 namespace native {
@@ -2722,6 +2724,142 @@
 };
 }
 
+DEFINE_DISPATCH(unpack_pivots_stub);
+
+std::tuple<Tensor, Tensor, Tensor> lu_unpack(
+    const Tensor& LU_data,
+    const Tensor& LU_pivots,
+    bool unpack_data,
+    bool unpack_pivots
+    ) {
+  TORCH_CHECK(LU_pivots.is_contiguous() && (LU_pivots.scalar_type() == at::kInt),
+      "lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype."
+      "Note: this function is intended to be used with the output produced by torch{.linalg}.lu");
+
+  // trivial case
+  if (!unpack_data && !unpack_pivots) {
+    return std::make_tuple(Tensor(), Tensor(), Tensor());
+  }
+
+  Tensor L, U;
+  // In the generalized LU factorization, the following shape relations hold:
+  // A.shape[-2:] == (m, n),
+  // P.shape[-2:] == (m, m),
+  // U.shape[-2:] == (m, k),
+  // L.shape[-2:] == (k, n),
+  // where k = min(m, n)
+  int64_t m = LU_data.size(-2);
+  int64_t n = LU_data.size(-1);
+  int64_t k = std::min(m, n);
+
+  if (unpack_data) {
+    U = LU_data.triu();
+    if (m != k) {
+      U = U.narrow(-2, 0, k);
+    }
+
+    L = LU_data.tril();
+    if (k != n) {
+      L = L.narrow(-1, 0, k);
+    }
+    L.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
+  }
+
+  if (!unpack_pivots) {
+    return std::make_tuple(Tensor(), L, U);
+  }
+
+  auto unpacked_pivots_sizes = LU_pivots.sizes().vec();
+  unpacked_pivots_sizes[LU_pivots.dim() - 1] = m;
+  auto unpacked_pivots = at::empty(
+    unpacked_pivots_sizes,
+    LU_pivots.options().memory_format(at::MemoryFormat::Contiguous)
+  );
+
+  // Fill `unpacked_pivots` with identity permutation
+  auto id_perm = at::arange(m, LU_pivots.options());
+  unpacked_pivots.copy_(id_perm);
+
+  // WARNING: we assume that unchanged LAPACK pivots are provided.
+  // Since LAPACK relies on the FORTRAN's 1-based indexing,
+  // we subtract 1 to convert the pivots to the C-style 0-based indexing.
+  // This behaviour could change in the future.
+  auto LU_pivots_zero_idx = LU_pivots - 1;
+
+  auto iter = TensorIteratorConfig()
+    .set_check_mem_overlap(false)
+    .check_all_same_dtype(false)
+    .resize_outputs(false)
+    .declare_static_shape(LU_pivots.sizes(), /*squash_dim=*/LU_pivots.dim() - 1)
+    .add_output(unpacked_pivots)
+    .add_input(LU_pivots_zero_idx)
+    .build();
+  // }
+
+  unpack_pivots_stub(
+    LU_pivots.device().type(),
+    iter,
+    LU_pivots.size(-1)
+  );
+
+  // The permutation matrix is converted to LU_data.dtype
+  // because `matmul` does not work with integer matrices.
+  unpacked_pivots_sizes.push_back(m);
+  auto permutation_matrix = at::zeros(
+    unpacked_pivots_sizes,
+    LU_data.options().memory_format(at::MemoryFormat::Contiguous)
+  );
+
+  // now that we know the final permutation,
+  // scatter 1s at proper locations.
+  permutation_matrix.scatter_(
+    -2,
+    unpacked_pivots.unsqueeze(-2).to(at::kLong),
+    at::ones({1}, permutation_matrix.options()).expand(permutation_matrix.sizes())
+  );
+
+  return std::make_tuple(permutation_matrix, L, U);
+}
+
+using TupleTensorRefs3 = std::tuple<Tensor&, Tensor&, Tensor&>;
+
+TupleTensorRefs3 lu_unpack_out(
+    const Tensor& LU_data,
+    const Tensor& LU_pivots,
+    bool unpack_data,
+    bool unpack_pivots,
+    Tensor& P,
+    Tensor& L,
+    Tensor& U
+    ) {
+  Tensor P_tmp, L_tmp, U_tmp;
+  std::tie(P_tmp, L_tmp, U_tmp) = at::lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots);
+
+  if (unpack_pivots) {
+    checkSameDevice("lu_unpack", P, LU_data, "P");
+    // Note that lu_unpack returns P such that P.dtype == LU_data.dtype,
+    // because otherwise we cannot use P in matric products (no int -> float promotion)
+    checkLinalgCompatibleDtype("lu_unpack", P, LU_data, "L");
+
+    at::native::resize_output(P, P_tmp.sizes());
+    P.copy_(P_tmp);
+  }
+
+  if (unpack_data) {
+    checkSameDevice("lu_unpack", L, LU_data, "L");
+    checkSameDevice("lu_unpack", U, LU_data, "U");
+    checkLinalgCompatibleDtype("lu_unpack", L, LU_data, "L");
+    checkLinalgCompatibleDtype("lu_unpack", U, LU_data, "U");
+
+    at::native::resize_output(L, L_tmp.sizes());
+    at::native::resize_output(U, U_tmp.sizes());
+    L.copy_(L_tmp);
+    U.copy_(U_tmp);
+  }
+
+  return TupleTensorRefs3(P, L, U);
+}
+
 /*
 Calculates the Kronecker product between two Tensors.
 */
diff --git a/aten/src/ATen/native/LinearAlgebra.h b/aten/src/ATen/native/LinearAlgebra.h
index 821a9eb..c7095af 100644
--- a/aten/src/ATen/native/LinearAlgebra.h
+++ b/aten/src/ATen/native/LinearAlgebra.h
@@ -13,4 +13,11 @@
 using linalg_vector_norm_fn = void(*)(TensorIterator &, Scalar);
 DECLARE_DISPATCH(linalg_vector_norm_fn, linalg_vector_norm_stub);
 
+using unpack_pivots_fn = void(*)(
+  TensorIterator& iter,
+  int64_t dim_size
+);
+DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
+
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp b/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp
index 4502246..cef8d96 100644
--- a/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp
+++ b/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp
@@ -123,11 +123,46 @@
   });
 }
 
+void unpack_pivots_cpu_kernel(
+  TensorIterator& iter,
+  int64_t dim_size
+) {
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
+    auto* unpacked_pivots_ptr = data[0];
+    const auto* pivots_ptr = data[1];
+
+    for (int64_t elem = 0; elem < nelems; ++elem) {
+      // WARNING: torch.lu returns int32 pivots,
+      // this behavior could change in the future.
+      auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(unpacked_pivots_ptr);
+      auto* pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr);
+
+      for (int64_t i = 0; i < dim_size; ++i) {
+        std::swap(
+          unpacked_pivots_data[i],
+          unpacked_pivots_data[pivots_data[i]]
+        );
+      }
+
+      unpacked_pivots_ptr += strides[0];
+      pivots_ptr += strides[1];
+    }
+  };
+
+  iter.for_each(loop);
+}
+
 } // anonymous namespace
 
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 REGISTER_DISPATCH(addr_stub, &addr_kernel);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cpu);
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index 79f8be2..0e79df6 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -575,9 +575,88 @@
   });
 }
 
+template <int n_threads, int n_elems_per_thread, typename func_t>
+C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
+__global__ void _elementwise_kernel(int total_n_elems, func_t f) {
+  constexpr int total_work_block = n_threads * n_elems_per_thread;
+  int idx = total_work_block * blockIdx.x + threadIdx.x;
+
+  #pragma unroll
+  for (int i = 0; i < n_elems_per_thread; ++i) {
+    if (idx < total_n_elems) {
+      f(idx);
+      idx += n_threads;
+    }
+  }
+}
+
+template <int n_threads, int n_elems_per_thread, typename func_t>
+static void _launch_kernel(int total_n_elems, func_t f) {
+  TORCH_INTERNAL_ASSERT(
+    total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
+  );
+
+  dim3 block(n_threads);
+  constexpr int total_work_block = n_threads * n_elems_per_thread;
+  dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
+
+  auto stream = at::cuda::getCurrentCUDAStream();
+  _elementwise_kernel<n_threads, n_elems_per_thread, func_t>
+    <<<grid, block, 0, stream>>>(total_n_elems, f);
+  AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void _unpack_pivots_internal_kernel(
+  TensorIterator& iter,
+  int64_t dim_size
+) {
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  if (!iter.can_use_32bit_indexing()) {
+    for (auto& sub_iter : iter.with_32bit_indexing()) {
+      _unpack_pivots_internal_kernel(sub_iter, dim_size);
+    }
+    return;
+  }
+
+  auto offset_calculator = make_offset_calculator<2>(iter);
+
+  char* unpacked_pivots_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
+  const char* const __restrict__ pivots_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
+
+  auto loop = [=]C10_DEVICE(int i) {
+    auto offsets = offset_calculator.get(i);
+
+    auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(
+      unpacked_pivots_ptr + offsets[0]);
+    const auto* const __restrict__ pivots_data = reinterpret_cast<const int32_t*>(
+      pivots_ptr + offsets[1]);
+
+    // QUESTION: can we mix 64bit offsets with 32bit Iterator indexing?
+    for (int64_t i = 0; i < dim_size; ++i) {
+      thrust::swap(
+        unpacked_pivots_data[i],
+        unpacked_pivots_data[pivots_data[i]]
+      );
+    }
+  };
+
+  _launch_kernel<num_threads, thread_work_size>(iter.numel(), loop);
+}
+
+void unpack_pivots_cuda_kernel(
+  TensorIterator& iter,
+  int64_t dim_size
+) {
+  _unpack_pivots_internal_kernel(iter, dim_size);
+}
+
 } // anonymous namespace
 
 REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda);
 REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cuda);
+REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cuda_kernel);
 
 }}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 20a56a5..f89fcb3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6417,6 +6417,16 @@
   dispatch:
     CompositeExplicitAutograd: lu_solve
 
+- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+  variants: function
+  dispatch:
+    CPU, CUDA: lu_unpack
+
+- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+  variants: function
+  dispatch:
+    CPU, CUDA: lu_unpack_out
+
 # TODO: remove dispatch section when porting TH CUDA to ATen
 - func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
   dispatch:
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 3bbb3e6..91d369b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -8234,32 +8234,6 @@
         gradcheck(lambda x: x.logcumsumexp(2), a)
         gradgradcheck(lambda x: x.logcumsumexp(2), a)
 
-    @slowTest
-    def test_lu_backward(self, device):
-        def run_test(*sizes):
-            x = torch.rand(*sizes, device=device, dtype=torch.double).requires_grad_(True)
-
-            gradcheck(lambda x: x.lu(get_infos=True), x)
-            gradgradcheck(lambda x: x.lu(get_infos=True), x)
-
-            gradcheck(lambda x: x.lu(get_infos=False), x)
-            gradgradcheck(lambda x: x.lu(get_infos=False), x)
-
-            # there is no pivot-less LU factorization on CPU
-            if x.device.type == 'cuda':
-                gradcheck(lambda x: x.lu(pivot=False, get_infos=True), x)
-                gradgradcheck(lambda x: x.lu(pivot=False, get_infos=True), x)
-
-                gradcheck(lambda x: x.lu(pivot=False, get_infos=False), x)
-                gradgradcheck(lambda x: x.lu(pivot=False, get_infos=False), x)
-
-        run_test(3, 3)
-        run_test(3, 3, 3)
-        run_test(3, 3, 3, 3)
-        run_test(5, 5)
-        run_test(3, 5, 5)
-        run_test(3, 3, 5, 5)
-
     def test_strided_leaf_grad_layout(self, device):
         # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf.
         for fmt_a in (torch.contiguous_format, torch.channels_last):
diff --git a/test/test_linalg.py b/test/test_linalg.py
index c7175e5..e48d51f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -5413,8 +5413,9 @@
 
     @skipCPUIfNoLapack
     @skipCUDAIfNoMagma
-    @dtypes(torch.double)
+    @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
     @skipCUDAIfRocm
+    @precisionOverride({torch.float: 1e-3})
     def test_lu_unpack(self, device, dtype):
         def run_test(pivot):
             for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
@@ -5422,12 +5423,43 @@
                 a_lu, p = torch.lu(a, pivot=pivot)
                 p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p)
                 self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
+            for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3),
+                          (3, 5), (5, 3), (3, 3, 5), (3, 5, 3),
+                          (7, 5, 3, 5, 3), (7, 5, 3, 3, 5),
+                          # empty tensors
+                          (0, 0), (0, 0, 0), (0, 3, 3)
+                          ):
+                a = make_tensor(shape, dtype=dtype, device=device, low=-0.1, high=+0.1)
+                a_lu, p = torch.lu(a, pivot=pivot)
+                p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p)
+                self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
 
         run_test(True)
 
         if self.device_type == 'cuda':
             run_test(False)
 
+    @skipCPUIfNoLapack
+    @skipCUDAIfNoMagma
+    @dtypes(torch.double)
+    @skipCUDAIfRocm
+    def test_lu_unpack_check_input(self, device, dtype):
+        x = torch.rand(5, 5, 5, device=device, dtype=dtype)
+        lu_data, lu_pivots = torch.lu(x, pivot=True)
+
+        with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
+            torch.lu_unpack(lu_data, lu_pivots.long())
+        with self.assertRaisesRegex(RuntimeError, "contiguous tensor"):
+            torch.lu_unpack(lu_data, lu_pivots.transpose(-1, -2))
+
+        # check that onces flags are unset, Nones are returned
+        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False)
+        self.assertTrue((l == u) and l is None)
+        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False)
+        self.assertTrue(p is None)
+        p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False)
+        self.assertTrue((p == l == u) and p is None)
+
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
     @dtypes(torch.double)
diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py
index 14dca20..ec20b4a 100644
--- a/test/test_namedtuple_return_api.py
+++ b/test/test_namedtuple_return_api.py
@@ -16,7 +16,7 @@
     'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_unpack_dual", 'linalg_qr',
     '_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
     'fake_quantize_per_channel_affine_cachemask', 'linalg_lstsq', 'linalg_eig', 'linalg_cholesky_ex',
-    'frexp'
+    'frexp', 'lu_unpack'
 }
 
 
@@ -83,6 +83,9 @@
             op(operators=['_unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False),
             op(operators=['linalg_lstsq'], input=(a,), names=('solution', 'residuals', 'rank', 'singular_values'), hasout=False),
             op(operators=['frexp'], input=(), names=('mantissa', 'exponent'), hasout=True),
+            op(operators=['lu_unpack'],
+               input=(torch.tensor([3, 2, 1, 4, 5], dtype=torch.int32), True, True),
+               names=('P', 'L', 'U'), hasout=True),
         ]
 
         def get_func(f):
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 834c537..58a18d2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -744,6 +744,10 @@
 - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
   self: not_implemented("lu_solve")
 
+- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+  LU_data: lu_unpack_backward(grads, LU_data, unpack_data)
+  LU_pivots: non_differentiable
+
 - name: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)
   self: grad.clone().masked_fill_(mask, 0)
   mask: non_differentiable
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index f859b7a..bd1b8e0 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -100,7 +100,7 @@
     'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
     'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
     'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
-    'index', 'masked_fill', 'cross'
+    'index', 'masked_fill', 'cross', 'lu_unpack'
 }
 
 GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index c450650..9b266c9 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5017,6 +5017,68 @@
     tensor([[False, False], [True, False]])
 """.format(**common_args))
 
+add_docstr(torch.lu_unpack, r"""
+lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor)
+
+Unpacks the data and pivots from a LU factorization of a tensor into tensors ``L`` and ``U`` and a permutation tensor ``P``
+such that ``LU_data, LU_pivots = (P @ L @ U).lu()``.
+
+Returns a tuple of tensors as ``(the P tensor (permutation matrix), the L tensor, the U tensor)``.
+
+.. note:: ``P.dtype == LU_data.dtype`` and ``P.dtype`` is not an integer type so that matrix products with ``P``
+          are possible without casting it to a floating type.
+
+Args:
+    LU_data (Tensor): the packed LU factorization data
+    LU_pivots (Tensor): the packed LU factorization pivots
+    unpack_data (bool): flag indicating if the data should be unpacked.
+                        If ``False``, then the returned ``L`` and ``U`` are ``None``.
+                        Default: ``True``
+    unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``.
+                          If ``False``, then the returned ``P`` is  ``None``.
+                          Default: ``True``
+    out (tuple, optional): a tuple of three tensors to use for the outputs ``(P, L, U)``.
+
+Examples::
+
+    >>> A = torch.randn(2, 3, 3)
+    >>> A_LU, pivots = A.lu()
+    >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
+    >>>
+    >>> # can recover A from factorization
+    >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
+
+    >>> # LU factorization of a rectangular matrix:
+    >>> A = torch.randn(2, 3, 2)
+    >>> A_LU, pivots = A.lu()
+    >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
+    >>> P
+    tensor([[[1., 0., 0.],
+             [0., 1., 0.],
+             [0., 0., 1.]],
+
+            [[0., 0., 1.],
+             [0., 1., 0.],
+             [1., 0., 0.]]])
+    >>> A_L
+    tensor([[[ 1.0000,  0.0000],
+             [ 0.4763,  1.0000],
+             [ 0.3683,  0.1135]],
+
+            [[ 1.0000,  0.0000],
+             [ 0.2957,  1.0000],
+             [-0.9668, -0.3335]]])
+    >>> A_U
+    tensor([[[ 2.1962,  1.0881],
+             [ 0.0000, -0.8681]],
+
+            [[-1.0947,  0.3736],
+             [ 0.0000,  0.5718]]])
+    >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
+    >>> torch.norm(A_ - A)
+    tensor(2.9802e-08)
+""".format(**common_args))
+
 add_docstr(torch.less, r"""
 less(input, other, *, out=None) -> Tensor
 
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 7b87e43..0c5e1ed 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -3329,6 +3329,38 @@
   return std::make_tuple(grad_abs, grad_angle);
 }
 
+Tensor lu_unpack_backward(
+  const std::vector<torch::autograd::Variable>& grads,
+  const Tensor& LU_data,
+  bool unpack_data
+) {
+  auto L_grad = grads[1];
+  auto U_grad = grads[2];
+
+  auto m = LU_data.size(-2);
+  auto n = LU_data.size(-1);
+  auto k = std::min(m, n);
+
+  TORCH_CHECK(unpack_data, "lu_unpack_backward: cannot compute gradients unless unpack_data=True");
+
+  auto res = at::zeros(LU_data.sizes(), LU_data.options());
+
+  Tensor L_grad_contrib;
+  if (L_grad.defined()) {
+    L_grad_contrib = L_grad.tril();
+    L_grad_contrib.diagonal(0, -2, -1).fill_(0);
+    res.narrow(-2, 0, m).narrow(-1, 0, k).add_(L_grad_contrib);
+  }
+
+  Tensor U_grad_contrib;
+  if (U_grad.defined()) {
+    U_grad_contrib = U_grad.triu();
+    res.narrow(-2, 0, k).narrow(-1, 0, n).add_(U_grad_contrib);
+  }
+
+  return res;
+}
+
 } // namespace details
 } // namespace generated
 } // namespace autograd
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index ffd643b..a9fa01e 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -230,6 +230,11 @@
 std::tuple<Tensor, Tensor> polar_backward(
     const Tensor& grad,
     const Tensor& result);
+Tensor lu_unpack_backward(
+  const std::vector<torch::autograd::Variable>& grads,
+  const Tensor& LU_data,
+  bool unpack_data
+);
 
 } // namespace details
 } // namespace generated
diff --git a/torch/functional.py b/torch/functional.py
index 3a49c54..fc561b9 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -30,7 +30,6 @@
     'einsum',
     'istft',
     'lu',
-    'lu_unpack',
     'norm',
     'meshgrid',
     'pca_lowrank',
@@ -184,115 +183,6 @@
     return out
 
 
-def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
-    # type: (Tensor, Tensor, bool, bool) ->  (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])
-    r"""Unpacks the data and pivots from a LU factorization of a tensor.
-
-    Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
-
-    Args:
-        LU_data (Tensor): the packed LU factorization data
-        LU_pivots (Tensor): the packed LU factorization pivots
-        unpack_data (bool): flag indicating if the data should be unpacked
-        unpack_pivots (bool): flag indicating if the pivots should be unpacked
-
-    Examples::
-
-        >>> A = torch.randn(2, 3, 3)
-        >>> A_LU, pivots = A.lu()
-        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
-        >>>
-        >>> # can recover A from factorization
-        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
-
-        >>> # LU factorization of a rectangular matrix:
-        >>> A = torch.randn(2, 3, 2)
-        >>> A_LU, pivots = A.lu()
-        >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
-        >>> P
-        tensor([[[1., 0., 0.],
-                 [0., 1., 0.],
-                 [0., 0., 1.]],
-
-                [[0., 0., 1.],
-                 [0., 1., 0.],
-                 [1., 0., 0.]]])
-        >>> A_L
-        tensor([[[ 1.0000,  0.0000],
-                 [ 0.4763,  1.0000],
-                 [ 0.3683,  0.1135]],
-
-                [[ 1.0000,  0.0000],
-                 [ 0.2957,  1.0000],
-                 [-0.9668, -0.3335]]])
-        >>> A_U
-        tensor([[[ 2.1962,  1.0881],
-                 [ 0.0000, -0.8681]],
-
-                [[-1.0947,  0.3736],
-                 [ 0.0000,  0.5718]]])
-        >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
-        >>> torch.norm(A_ - A)
-        tensor(2.9802e-08)
-    """
-    if has_torch_function_variadic(LU_data, LU_pivots):
-        return handle_torch_function(
-            lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots,
-            unpack_data=unpack_data,
-            unpack_pivots=unpack_pivots)
-    shape = LU_data.shape
-    # In generalized LU factorization, the following shape relations hold:
-    #   A.shape[-2:] == (m, n)
-    #   P.shape[-2:] == (m, m)
-    #   L.shape[-2:] == (m, k)
-    #   U.shape[-2:] == (k, n)
-    # where k = min(m, n)
-    m, n = shape[-2:]
-    k = min(m, n)
-    if unpack_data:
-        U: Optional[Tensor] = LU_data.triu()
-        assert U is not None
-        if m != k:
-            U = U.narrow(-2, 0, k)
-        L: Optional[Tensor] = LU_data.tril()
-        assert L is not None
-        if k != n:
-            L = L.narrow(-1, 0, k)
-        L.diagonal(dim1=-2, dim2=-1).fill_(1)
-    else:
-        L = U = None
-
-    if unpack_pivots:
-        LU_pivots_zero_idx = LU_pivots - 1
-        if LU_data.dim() > 2:
-            P: Optional[Tensor] = torch.eye(m, device=LU_data.device,
-                                            dtype=LU_data.dtype) \
-                .expand(shape[:-1] + (m,)) \
-                .clone(memory_format=torch.contiguous_format)
-            assert P is not None
-
-            # TODO: rewrite when TorchScript supports product and map as
-            # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed
-            indices = _indices_product(shape[:-2])
-            for idx in indices:
-                final_order = list(range(m))
-                for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):
-                    final_order[k], final_order[j] = final_order[j], final_order[k]
-                # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list
-                p_idx = _index_tensor_with_indices_list(P, idx)
-                p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))
-        else:
-            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)
-            final_order = list(range(m))
-            for k, j, in enumerate(LU_pivots_zero_idx):
-                final_order[k], final_order[j] = final_order[j], final_order[k]
-            P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
-    else:
-        P = None
-
-    return P, L, U
-
-
 def einsum(equation, *operands):
     r"""einsum(equation, *operands) -> Tensor
 
diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py
index 28839a7..10a0b15 100644
--- a/torch/jit/_builtins.py
+++ b/torch/jit/_builtins.py
@@ -113,7 +113,7 @@
     # but we are currently only able to compile some of the functions. additionally,
     # some functions directly map to their aten:: implementations.
     # TODO: add support for more ops
-    ops = ["stft", "istft", "lu", "lu_unpack", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
+    ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
     return set(getattr(torch.functional, name) for name in ops)
 
 _functional_registered_ops = _gen_torch_functional_registered_ops()
diff --git a/torch/overrides.py b/torch/overrides.py
index 5cfc121..59e4b27 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -445,7 +445,7 @@
         torch.frac: lambda input, out=None: -1,
         torch.frexp: lambda input, out=None: -1,
         torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
-        torch.functional.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
+        torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
         torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
         torch.gcd: lambda input, other, out=None: -1,
         torch.ge: lambda input, other, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 035bc2b..2c6449c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2278,6 +2278,31 @@
     return list(generate_samples())
 
 
+def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs):
+    # not needed once OpInfo tests support Iterables
+    def generate_samples():
+        for lu_sample in sample_inputs_lu(op_info, device, dtype, requires_grad, **kwargs):
+            lu_data, pivots = lu_sample.input.lu()
+            yield SampleInput(lu_data, args=(pivots,))
+
+            # generate rectangular inputs
+            lu_data_shape = lu_data.shape
+            batch_shape = lu_data_shape[:-2]
+            n = lu_data_shape[-2]
+
+            for shape_inc in ((1, 0), (0, 1)):
+                lu_data, pivots = make_tensor(
+                    batch_shape + (n + shape_inc[0], n + shape_inc[1]),
+                    device, dtype,
+                    requires_grad=False,
+                    low=None, high=None
+                ).lu()
+                lu_data.requires_grad_(requires_grad)
+                yield SampleInput(lu_data, args=(pivots,))
+
+    return list(generate_samples())
+
+
 def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
 
@@ -4774,6 +4799,21 @@
                # Skip operator schema test because this is a functional and not an operator
                SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
            )),
+    OpInfo('lu_unpack',
+           op=torch.lu_unpack,
+           dtypes=floating_and_complex_types(),
+           supports_inplace_autograd=False,
+           # we use in-place operations which cannot be avoided.
+           # This cases vmap failures, hence we skip batched gradient checks
+           check_batched_grad=False,
+           supports_out=True,
+           sample_inputs_func=sample_inputs_lu_unpack,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
+           skips=(
+               # cuda gradchecks are slow
+               # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
+               SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+           )),
     OpInfo('masked_fill',
            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
            sample_inputs_func=sample_inputs_masked_fill,