ATen lu_unpack. Required for making `torch.lu_solve` differentiable. (#46913)
Summary:
Backward methods for `torch.lu` and `torch.lu_solve` require the `torch.lu_unpack` method.
However, while `torch.lu` is a Python wrapper over a native function, so its gradient is implemented via `autograd.Function`,
`torch.lu_solve` is a native function, so it cannot access `torch.lu_unpack` as it is implemented in Python.
Hence this PR presents a native (ATen) `lu_unpack` version. It is also possible to update the gradients for `torch.lu` so that backward+JIT is supported (no JIT for `autograd.Function`) with this function.
~~The interface for this method is different from the original `torch.lu_unpack`, so it is decided to keep it hidden.~~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46913
Reviewed By: albanD
Differential Revision: D28355725
Pulled By: mruberry
fbshipit-source-id: 281260f3b6e93c15b08b2ba66d5a221314b00e78
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 3ffe47a..2acdfd0 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -22,6 +22,8 @@
#include <functional>
#include <limits>
#include <numeric>
+#include <ATen/NamedTensorUtils.h>
+#include <ATen/native/TensorIterator.h>
namespace at {
namespace native {
@@ -2722,6 +2724,142 @@
};
}
+DEFINE_DISPATCH(unpack_pivots_stub);
+
+std::tuple<Tensor, Tensor, Tensor> lu_unpack(
+ const Tensor& LU_data,
+ const Tensor& LU_pivots,
+ bool unpack_data,
+ bool unpack_pivots
+ ) {
+ TORCH_CHECK(LU_pivots.is_contiguous() && (LU_pivots.scalar_type() == at::kInt),
+ "lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype."
+ "Note: this function is intended to be used with the output produced by torch{.linalg}.lu");
+
+ // trivial case
+ if (!unpack_data && !unpack_pivots) {
+ return std::make_tuple(Tensor(), Tensor(), Tensor());
+ }
+
+ Tensor L, U;
+ // In the generalized LU factorization, the following shape relations hold:
+ // A.shape[-2:] == (m, n),
+ // P.shape[-2:] == (m, m),
+ // U.shape[-2:] == (m, k),
+ // L.shape[-2:] == (k, n),
+ // where k = min(m, n)
+ int64_t m = LU_data.size(-2);
+ int64_t n = LU_data.size(-1);
+ int64_t k = std::min(m, n);
+
+ if (unpack_data) {
+ U = LU_data.triu();
+ if (m != k) {
+ U = U.narrow(-2, 0, k);
+ }
+
+ L = LU_data.tril();
+ if (k != n) {
+ L = L.narrow(-1, 0, k);
+ }
+ L.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(1);
+ }
+
+ if (!unpack_pivots) {
+ return std::make_tuple(Tensor(), L, U);
+ }
+
+ auto unpacked_pivots_sizes = LU_pivots.sizes().vec();
+ unpacked_pivots_sizes[LU_pivots.dim() - 1] = m;
+ auto unpacked_pivots = at::empty(
+ unpacked_pivots_sizes,
+ LU_pivots.options().memory_format(at::MemoryFormat::Contiguous)
+ );
+
+ // Fill `unpacked_pivots` with identity permutation
+ auto id_perm = at::arange(m, LU_pivots.options());
+ unpacked_pivots.copy_(id_perm);
+
+ // WARNING: we assume that unchanged LAPACK pivots are provided.
+ // Since LAPACK relies on the FORTRAN's 1-based indexing,
+ // we subtract 1 to convert the pivots to the C-style 0-based indexing.
+ // This behaviour could change in the future.
+ auto LU_pivots_zero_idx = LU_pivots - 1;
+
+ auto iter = TensorIteratorConfig()
+ .set_check_mem_overlap(false)
+ .check_all_same_dtype(false)
+ .resize_outputs(false)
+ .declare_static_shape(LU_pivots.sizes(), /*squash_dim=*/LU_pivots.dim() - 1)
+ .add_output(unpacked_pivots)
+ .add_input(LU_pivots_zero_idx)
+ .build();
+ // }
+
+ unpack_pivots_stub(
+ LU_pivots.device().type(),
+ iter,
+ LU_pivots.size(-1)
+ );
+
+ // The permutation matrix is converted to LU_data.dtype
+ // because `matmul` does not work with integer matrices.
+ unpacked_pivots_sizes.push_back(m);
+ auto permutation_matrix = at::zeros(
+ unpacked_pivots_sizes,
+ LU_data.options().memory_format(at::MemoryFormat::Contiguous)
+ );
+
+ // now that we know the final permutation,
+ // scatter 1s at proper locations.
+ permutation_matrix.scatter_(
+ -2,
+ unpacked_pivots.unsqueeze(-2).to(at::kLong),
+ at::ones({1}, permutation_matrix.options()).expand(permutation_matrix.sizes())
+ );
+
+ return std::make_tuple(permutation_matrix, L, U);
+}
+
+using TupleTensorRefs3 = std::tuple<Tensor&, Tensor&, Tensor&>;
+
+TupleTensorRefs3 lu_unpack_out(
+ const Tensor& LU_data,
+ const Tensor& LU_pivots,
+ bool unpack_data,
+ bool unpack_pivots,
+ Tensor& P,
+ Tensor& L,
+ Tensor& U
+ ) {
+ Tensor P_tmp, L_tmp, U_tmp;
+ std::tie(P_tmp, L_tmp, U_tmp) = at::lu_unpack(LU_data, LU_pivots, unpack_data, unpack_pivots);
+
+ if (unpack_pivots) {
+ checkSameDevice("lu_unpack", P, LU_data, "P");
+ // Note that lu_unpack returns P such that P.dtype == LU_data.dtype,
+ // because otherwise we cannot use P in matric products (no int -> float promotion)
+ checkLinalgCompatibleDtype("lu_unpack", P, LU_data, "L");
+
+ at::native::resize_output(P, P_tmp.sizes());
+ P.copy_(P_tmp);
+ }
+
+ if (unpack_data) {
+ checkSameDevice("lu_unpack", L, LU_data, "L");
+ checkSameDevice("lu_unpack", U, LU_data, "U");
+ checkLinalgCompatibleDtype("lu_unpack", L, LU_data, "L");
+ checkLinalgCompatibleDtype("lu_unpack", U, LU_data, "U");
+
+ at::native::resize_output(L, L_tmp.sizes());
+ at::native::resize_output(U, U_tmp.sizes());
+ L.copy_(L_tmp);
+ U.copy_(U_tmp);
+ }
+
+ return TupleTensorRefs3(P, L, U);
+}
+
/*
Calculates the Kronecker product between two Tensors.
*/
diff --git a/aten/src/ATen/native/LinearAlgebra.h b/aten/src/ATen/native/LinearAlgebra.h
index 821a9eb..c7095af 100644
--- a/aten/src/ATen/native/LinearAlgebra.h
+++ b/aten/src/ATen/native/LinearAlgebra.h
@@ -13,4 +13,11 @@
using linalg_vector_norm_fn = void(*)(TensorIterator &, Scalar);
DECLARE_DISPATCH(linalg_vector_norm_fn, linalg_vector_norm_stub);
+using unpack_pivots_fn = void(*)(
+ TensorIterator& iter,
+ int64_t dim_size
+);
+DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
+
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp b/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp
index 4502246..cef8d96 100644
--- a/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp
+++ b/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp
@@ -123,11 +123,46 @@
});
}
+void unpack_pivots_cpu_kernel(
+ TensorIterator& iter,
+ int64_t dim_size
+) {
+ if (iter.numel() == 0) {
+ return;
+ }
+
+ auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
+ auto* unpacked_pivots_ptr = data[0];
+ const auto* pivots_ptr = data[1];
+
+ for (int64_t elem = 0; elem < nelems; ++elem) {
+ // WARNING: torch.lu returns int32 pivots,
+ // this behavior could change in the future.
+ auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(unpacked_pivots_ptr);
+ auto* pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr);
+
+ for (int64_t i = 0; i < dim_size; ++i) {
+ std::swap(
+ unpacked_pivots_data[i],
+ unpacked_pivots_data[pivots_data[i]]
+ );
+ }
+
+ unpacked_pivots_ptr += strides[0];
+ pivots_ptr += strides[1];
+ }
+ };
+
+ iter.for_each(loop);
+}
+
} // anonymous namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(addr_stub, &addr_kernel);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cpu);
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index 79f8be2..0e79df6 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -575,9 +575,88 @@
});
}
+template <int n_threads, int n_elems_per_thread, typename func_t>
+C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
+__global__ void _elementwise_kernel(int total_n_elems, func_t f) {
+ constexpr int total_work_block = n_threads * n_elems_per_thread;
+ int idx = total_work_block * blockIdx.x + threadIdx.x;
+
+ #pragma unroll
+ for (int i = 0; i < n_elems_per_thread; ++i) {
+ if (idx < total_n_elems) {
+ f(idx);
+ idx += n_threads;
+ }
+ }
+}
+
+template <int n_threads, int n_elems_per_thread, typename func_t>
+static void _launch_kernel(int total_n_elems, func_t f) {
+ TORCH_INTERNAL_ASSERT(
+ total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
+ );
+
+ dim3 block(n_threads);
+ constexpr int total_work_block = n_threads * n_elems_per_thread;
+ dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
+
+ auto stream = at::cuda::getCurrentCUDAStream();
+ _elementwise_kernel<n_threads, n_elems_per_thread, func_t>
+ <<<grid, block, 0, stream>>>(total_n_elems, f);
+ AT_CUDA_CHECK(cudaGetLastError());
+}
+
+void _unpack_pivots_internal_kernel(
+ TensorIterator& iter,
+ int64_t dim_size
+) {
+ if (iter.numel() == 0) {
+ return;
+ }
+
+ if (!iter.can_use_32bit_indexing()) {
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
+ _unpack_pivots_internal_kernel(sub_iter, dim_size);
+ }
+ return;
+ }
+
+ auto offset_calculator = make_offset_calculator<2>(iter);
+
+ char* unpacked_pivots_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
+ const char* const __restrict__ pivots_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
+
+ auto loop = [=]C10_DEVICE(int i) {
+ auto offsets = offset_calculator.get(i);
+
+ auto* unpacked_pivots_data = reinterpret_cast<int32_t*>(
+ unpacked_pivots_ptr + offsets[0]);
+ const auto* const __restrict__ pivots_data = reinterpret_cast<const int32_t*>(
+ pivots_ptr + offsets[1]);
+
+ // QUESTION: can we mix 64bit offsets with 32bit Iterator indexing?
+ for (int64_t i = 0; i < dim_size; ++i) {
+ thrust::swap(
+ unpacked_pivots_data[i],
+ unpacked_pivots_data[pivots_data[i]]
+ );
+ }
+ };
+
+ _launch_kernel<num_threads, thread_work_size>(iter.numel(), loop);
+}
+
+void unpack_pivots_cuda_kernel(
+ TensorIterator& iter,
+ int64_t dim_size
+) {
+ _unpack_pivots_internal_kernel(iter, dim_size);
+}
+
} // anonymous namespace
REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda);
REGISTER_DISPATCH(linalg_vector_norm_stub, &linalg_vector_norm_kernel_cuda);
+REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cuda_kernel);
}}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 20a56a5..f89fcb3 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6417,6 +6417,16 @@
dispatch:
CompositeExplicitAutograd: lu_solve
+- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+ variants: function
+ dispatch:
+ CPU, CUDA: lu_unpack
+
+- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+ variants: function
+ dispatch:
+ CPU, CUDA: lu_unpack_out
+
# TODO: remove dispatch section when porting TH CUDA to ATen
- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 3bbb3e6..91d369b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -8234,32 +8234,6 @@
gradcheck(lambda x: x.logcumsumexp(2), a)
gradgradcheck(lambda x: x.logcumsumexp(2), a)
- @slowTest
- def test_lu_backward(self, device):
- def run_test(*sizes):
- x = torch.rand(*sizes, device=device, dtype=torch.double).requires_grad_(True)
-
- gradcheck(lambda x: x.lu(get_infos=True), x)
- gradgradcheck(lambda x: x.lu(get_infos=True), x)
-
- gradcheck(lambda x: x.lu(get_infos=False), x)
- gradgradcheck(lambda x: x.lu(get_infos=False), x)
-
- # there is no pivot-less LU factorization on CPU
- if x.device.type == 'cuda':
- gradcheck(lambda x: x.lu(pivot=False, get_infos=True), x)
- gradgradcheck(lambda x: x.lu(pivot=False, get_infos=True), x)
-
- gradcheck(lambda x: x.lu(pivot=False, get_infos=False), x)
- gradgradcheck(lambda x: x.lu(pivot=False, get_infos=False), x)
-
- run_test(3, 3)
- run_test(3, 3, 3)
- run_test(3, 3, 3, 3)
- run_test(5, 5)
- run_test(3, 5, 5)
- run_test(3, 3, 5, 5)
-
def test_strided_leaf_grad_layout(self, device):
# (1) If leaf is non-overlapping and dense, grad's layout should match its leaf.
for fmt_a in (torch.contiguous_format, torch.channels_last):
diff --git a/test/test_linalg.py b/test/test_linalg.py
index c7175e5..e48d51f 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -5413,8 +5413,9 @@
@skipCPUIfNoLapack
@skipCUDAIfNoMagma
- @dtypes(torch.double)
+ @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
@skipCUDAIfRocm
+ @precisionOverride({torch.float: 1e-3})
def test_lu_unpack(self, device, dtype):
def run_test(pivot):
for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)):
@@ -5422,12 +5423,43 @@
a_lu, p = torch.lu(a, pivot=pivot)
p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p)
self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
+ for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3),
+ (3, 5), (5, 3), (3, 3, 5), (3, 5, 3),
+ (7, 5, 3, 5, 3), (7, 5, 3, 3, 5),
+ # empty tensors
+ (0, 0), (0, 0, 0), (0, 3, 3)
+ ):
+ a = make_tensor(shape, dtype=dtype, device=device, low=-0.1, high=+0.1)
+ a_lu, p = torch.lu(a, pivot=pivot)
+ p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p)
+ self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
run_test(True)
if self.device_type == 'cuda':
run_test(False)
+ @skipCPUIfNoLapack
+ @skipCUDAIfNoMagma
+ @dtypes(torch.double)
+ @skipCUDAIfRocm
+ def test_lu_unpack_check_input(self, device, dtype):
+ x = torch.rand(5, 5, 5, device=device, dtype=dtype)
+ lu_data, lu_pivots = torch.lu(x, pivot=True)
+
+ with self.assertRaisesRegex(RuntimeError, "torch.int32 dtype"):
+ torch.lu_unpack(lu_data, lu_pivots.long())
+ with self.assertRaisesRegex(RuntimeError, "contiguous tensor"):
+ torch.lu_unpack(lu_data, lu_pivots.transpose(-1, -2))
+
+ # check that onces flags are unset, Nones are returned
+ p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False)
+ self.assertTrue((l == u) and l is None)
+ p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_pivots=False)
+ self.assertTrue(p is None)
+ p, l, u = torch.lu_unpack(lu_data, lu_pivots, unpack_data=False, unpack_pivots=False)
+ self.assertTrue((p == l == u) and p is None)
+
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.double)
diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py
index 14dca20..ec20b4a 100644
--- a/test/test_namedtuple_return_api.py
+++ b/test/test_namedtuple_return_api.py
@@ -16,7 +16,7 @@
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_unpack_dual", 'linalg_qr',
'_svd_helper', 'linalg_svd', 'linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
'fake_quantize_per_channel_affine_cachemask', 'linalg_lstsq', 'linalg_eig', 'linalg_cholesky_ex',
- 'frexp'
+ 'frexp', 'lu_unpack'
}
@@ -83,6 +83,9 @@
op(operators=['_unpack_dual'], input=(0,), names=('primal', 'tangent'), hasout=False),
op(operators=['linalg_lstsq'], input=(a,), names=('solution', 'residuals', 'rank', 'singular_values'), hasout=False),
op(operators=['frexp'], input=(), names=('mantissa', 'exponent'), hasout=True),
+ op(operators=['lu_unpack'],
+ input=(torch.tensor([3, 2, 1, 4, 5], dtype=torch.int32), True, True),
+ names=('P', 'L', 'U'), hasout=True),
]
def get_func(f):
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 834c537..58a18d2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -744,6 +744,10 @@
- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
self: not_implemented("lu_solve")
+- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+ LU_data: lu_unpack_backward(grads, LU_data, unpack_data)
+ LU_pivots: non_differentiable
+
- name: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)
self: grad.clone().masked_fill_(mask, 0)
mask: non_differentiable
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index f859b7a..bd1b8e0 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -100,7 +100,7 @@
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
- 'index', 'masked_fill', 'cross'
+ 'index', 'masked_fill', 'cross', 'lu_unpack'
}
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index c450650..9b266c9 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -5017,6 +5017,68 @@
tensor([[False, False], [True, False]])
""".format(**common_args))
+add_docstr(torch.lu_unpack, r"""
+lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True, *, out=None) -> (Tensor, Tensor, Tensor)
+
+Unpacks the data and pivots from a LU factorization of a tensor into tensors ``L`` and ``U`` and a permutation tensor ``P``
+such that ``LU_data, LU_pivots = (P @ L @ U).lu()``.
+
+Returns a tuple of tensors as ``(the P tensor (permutation matrix), the L tensor, the U tensor)``.
+
+.. note:: ``P.dtype == LU_data.dtype`` and ``P.dtype`` is not an integer type so that matrix products with ``P``
+ are possible without casting it to a floating type.
+
+Args:
+ LU_data (Tensor): the packed LU factorization data
+ LU_pivots (Tensor): the packed LU factorization pivots
+ unpack_data (bool): flag indicating if the data should be unpacked.
+ If ``False``, then the returned ``L`` and ``U`` are ``None``.
+ Default: ``True``
+ unpack_pivots (bool): flag indicating if the pivots should be unpacked into a permutation matrix ``P``.
+ If ``False``, then the returned ``P`` is ``None``.
+ Default: ``True``
+ out (tuple, optional): a tuple of three tensors to use for the outputs ``(P, L, U)``.
+
+Examples::
+
+ >>> A = torch.randn(2, 3, 3)
+ >>> A_LU, pivots = A.lu()
+ >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
+ >>>
+ >>> # can recover A from factorization
+ >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
+
+ >>> # LU factorization of a rectangular matrix:
+ >>> A = torch.randn(2, 3, 2)
+ >>> A_LU, pivots = A.lu()
+ >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
+ >>> P
+ tensor([[[1., 0., 0.],
+ [0., 1., 0.],
+ [0., 0., 1.]],
+
+ [[0., 0., 1.],
+ [0., 1., 0.],
+ [1., 0., 0.]]])
+ >>> A_L
+ tensor([[[ 1.0000, 0.0000],
+ [ 0.4763, 1.0000],
+ [ 0.3683, 0.1135]],
+
+ [[ 1.0000, 0.0000],
+ [ 0.2957, 1.0000],
+ [-0.9668, -0.3335]]])
+ >>> A_U
+ tensor([[[ 2.1962, 1.0881],
+ [ 0.0000, -0.8681]],
+
+ [[-1.0947, 0.3736],
+ [ 0.0000, 0.5718]]])
+ >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
+ >>> torch.norm(A_ - A)
+ tensor(2.9802e-08)
+""".format(**common_args))
+
add_docstr(torch.less, r"""
less(input, other, *, out=None) -> Tensor
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 7b87e43..0c5e1ed 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -3329,6 +3329,38 @@
return std::make_tuple(grad_abs, grad_angle);
}
+Tensor lu_unpack_backward(
+ const std::vector<torch::autograd::Variable>& grads,
+ const Tensor& LU_data,
+ bool unpack_data
+) {
+ auto L_grad = grads[1];
+ auto U_grad = grads[2];
+
+ auto m = LU_data.size(-2);
+ auto n = LU_data.size(-1);
+ auto k = std::min(m, n);
+
+ TORCH_CHECK(unpack_data, "lu_unpack_backward: cannot compute gradients unless unpack_data=True");
+
+ auto res = at::zeros(LU_data.sizes(), LU_data.options());
+
+ Tensor L_grad_contrib;
+ if (L_grad.defined()) {
+ L_grad_contrib = L_grad.tril();
+ L_grad_contrib.diagonal(0, -2, -1).fill_(0);
+ res.narrow(-2, 0, m).narrow(-1, 0, k).add_(L_grad_contrib);
+ }
+
+ Tensor U_grad_contrib;
+ if (U_grad.defined()) {
+ U_grad_contrib = U_grad.triu();
+ res.narrow(-2, 0, k).narrow(-1, 0, n).add_(U_grad_contrib);
+ }
+
+ return res;
+}
+
} // namespace details
} // namespace generated
} // namespace autograd
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index ffd643b..a9fa01e 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -230,6 +230,11 @@
std::tuple<Tensor, Tensor> polar_backward(
const Tensor& grad,
const Tensor& result);
+Tensor lu_unpack_backward(
+ const std::vector<torch::autograd::Variable>& grads,
+ const Tensor& LU_data,
+ bool unpack_data
+);
} // namespace details
} // namespace generated
diff --git a/torch/functional.py b/torch/functional.py
index 3a49c54..fc561b9 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -30,7 +30,6 @@
'einsum',
'istft',
'lu',
- 'lu_unpack',
'norm',
'meshgrid',
'pca_lowrank',
@@ -184,115 +183,6 @@
return out
-def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
- # type: (Tensor, Tensor, bool, bool) -> (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])
- r"""Unpacks the data and pivots from a LU factorization of a tensor.
-
- Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
-
- Args:
- LU_data (Tensor): the packed LU factorization data
- LU_pivots (Tensor): the packed LU factorization pivots
- unpack_data (bool): flag indicating if the data should be unpacked
- unpack_pivots (bool): flag indicating if the pivots should be unpacked
-
- Examples::
-
- >>> A = torch.randn(2, 3, 3)
- >>> A_LU, pivots = A.lu()
- >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
- >>>
- >>> # can recover A from factorization
- >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
-
- >>> # LU factorization of a rectangular matrix:
- >>> A = torch.randn(2, 3, 2)
- >>> A_LU, pivots = A.lu()
- >>> P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
- >>> P
- tensor([[[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]],
-
- [[0., 0., 1.],
- [0., 1., 0.],
- [1., 0., 0.]]])
- >>> A_L
- tensor([[[ 1.0000, 0.0000],
- [ 0.4763, 1.0000],
- [ 0.3683, 0.1135]],
-
- [[ 1.0000, 0.0000],
- [ 0.2957, 1.0000],
- [-0.9668, -0.3335]]])
- >>> A_U
- tensor([[[ 2.1962, 1.0881],
- [ 0.0000, -0.8681]],
-
- [[-1.0947, 0.3736],
- [ 0.0000, 0.5718]]])
- >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U))
- >>> torch.norm(A_ - A)
- tensor(2.9802e-08)
- """
- if has_torch_function_variadic(LU_data, LU_pivots):
- return handle_torch_function(
- lu_unpack, (LU_data, LU_pivots), LU_data, LU_pivots,
- unpack_data=unpack_data,
- unpack_pivots=unpack_pivots)
- shape = LU_data.shape
- # In generalized LU factorization, the following shape relations hold:
- # A.shape[-2:] == (m, n)
- # P.shape[-2:] == (m, m)
- # L.shape[-2:] == (m, k)
- # U.shape[-2:] == (k, n)
- # where k = min(m, n)
- m, n = shape[-2:]
- k = min(m, n)
- if unpack_data:
- U: Optional[Tensor] = LU_data.triu()
- assert U is not None
- if m != k:
- U = U.narrow(-2, 0, k)
- L: Optional[Tensor] = LU_data.tril()
- assert L is not None
- if k != n:
- L = L.narrow(-1, 0, k)
- L.diagonal(dim1=-2, dim2=-1).fill_(1)
- else:
- L = U = None
-
- if unpack_pivots:
- LU_pivots_zero_idx = LU_pivots - 1
- if LU_data.dim() > 2:
- P: Optional[Tensor] = torch.eye(m, device=LU_data.device,
- dtype=LU_data.dtype) \
- .expand(shape[:-1] + (m,)) \
- .clone(memory_format=torch.contiguous_format)
- assert P is not None
-
- # TODO: rewrite when TorchScript supports product and map as
- # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed
- indices = _indices_product(shape[:-2])
- for idx in indices:
- final_order = list(range(m))
- for k, j in enumerate(_index_tensor_with_indices_list(LU_pivots_zero_idx, idx)):
- final_order[k], final_order[j] = final_order[j], final_order[k]
- # TODO: remove _index_tensor_with_indices_list when TorchScript supports indexing Tensor with list
- p_idx = _index_tensor_with_indices_list(P, idx)
- p_idx.copy_(p_idx.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device)))
- else:
- P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype)
- final_order = list(range(m))
- for k, j, in enumerate(LU_pivots_zero_idx):
- final_order[k], final_order[j] = final_order[j], final_order[k]
- P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
- else:
- P = None
-
- return P, L, U
-
-
def einsum(equation, *operands):
r"""einsum(equation, *operands) -> Tensor
diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py
index 28839a7..10a0b15 100644
--- a/torch/jit/_builtins.py
+++ b/torch/jit/_builtins.py
@@ -113,7 +113,7 @@
# but we are currently only able to compile some of the functions. additionally,
# some functions directly map to their aten:: implementations.
# TODO: add support for more ops
- ops = ["stft", "istft", "lu", "lu_unpack", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
+ ops = ["stft", "istft", "lu", "cdist", "norm", "unique", "unique_consecutive", "tensordot"]
return set(getattr(torch.functional, name) for name in ops)
_functional_registered_ops = _gen_torch_functional_registered_ops()
diff --git a/torch/overrides.py b/torch/overrides.py
index 5cfc121..59e4b27 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -445,7 +445,7 @@
torch.frac: lambda input, out=None: -1,
torch.frexp: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
- torch.functional.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
+ torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
torch.gcd: lambda input, other, out=None: -1,
torch.ge: lambda input, other, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 035bc2b..2c6449c 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2278,6 +2278,31 @@
return list(generate_samples())
+def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs):
+ # not needed once OpInfo tests support Iterables
+ def generate_samples():
+ for lu_sample in sample_inputs_lu(op_info, device, dtype, requires_grad, **kwargs):
+ lu_data, pivots = lu_sample.input.lu()
+ yield SampleInput(lu_data, args=(pivots,))
+
+ # generate rectangular inputs
+ lu_data_shape = lu_data.shape
+ batch_shape = lu_data_shape[:-2]
+ n = lu_data_shape[-2]
+
+ for shape_inc in ((1, 0), (0, 1)):
+ lu_data, pivots = make_tensor(
+ batch_shape + (n + shape_inc[0], n + shape_inc[1]),
+ device, dtype,
+ requires_grad=False,
+ low=None, high=None
+ ).lu()
+ lu_data.requires_grad_(requires_grad)
+ yield SampleInput(lu_data, args=(pivots,))
+
+ return list(generate_samples())
+
+
def sample_inputs_roll(op_info, device, dtype, requires_grad=False, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -4774,6 +4799,21 @@
# Skip operator schema test because this is a functional and not an operator
SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
)),
+ OpInfo('lu_unpack',
+ op=torch.lu_unpack,
+ dtypes=floating_and_complex_types(),
+ supports_inplace_autograd=False,
+ # we use in-place operations which cannot be avoided.
+ # This cases vmap failures, hence we skip batched gradient checks
+ check_batched_grad=False,
+ supports_out=True,
+ sample_inputs_func=sample_inputs_lu_unpack,
+ decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
+ skips=(
+ # cuda gradchecks are slow
+ # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
+ SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),
+ )),
OpInfo('masked_fill',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_masked_fill,