Make torch.linalg.eigvalsh differentiable (#57189)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57189
`torch.linalg.eigvalsh` now supports autograd. This is achieved by
computing the eigenvectors internally if input requires grad,
otherwise the eigenvectors are not computed and the operation is faster.
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D28199708
Pulled By: albanD
fbshipit-source-id: 12ac56f50137398613e186abd49f82c8ab83532e
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index a9d5d4c..5ef887f 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -2088,6 +2088,14 @@
}
Tensor linalg_eigvalsh(const Tensor& input, std::string uplo) {
+ // if input requires grad we must compute the eigenvectors to make this function differentiable
+ // the eigenvectors are not exposed to the user
+ if (at::GradMode::is_enabled() && input.requires_grad()) {
+ Tensor values;
+ std::tie(values, std::ignore) = at::linalg_eigh(input, uplo);
+ return values;
+ }
+
squareCheckInputs(input);
checkUplo(uplo);
ScalarType real_dtype = toValueType(input.scalar_type());
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 38f8bda..5dfa4db 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9502,13 +9502,9 @@
- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
python_module: linalg
variants: function
- dispatch:
- CompositeExplicitAutograd: linalg_eigvalsh
- func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg
- dispatch:
- CompositeExplicitAutograd: linalg_eigvalsh_out
- func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
python_module: linalg
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index e933a5f..c19c685 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1154,9 +1154,6 @@
- name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
self: symeig_backward(grads, self, /*eigenvectors=*/true, /*upper=*/true, eigenvalues, eigenvectors)
-- name: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
- self: non_differentiable
-
- name: linalg_eigvals(Tensor self) -> Tensor
self: non_differentiable
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 1ee5cac..efdfac2 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -502,8 +502,8 @@
.. seealso::
- :func:`torch.linalg.eigvalsh` computes only the eigenvalues.
- However, that function is not differentiable.
+ :func:`torch.linalg.eigvalsh` computes only the eigenvalues,
+ but its gradients are always numerically stable.
:func:`torch.linalg.cholesky` for a different decomposition of a Hermitian matrix.
The Cholesky decomposition gives less information about the matrix but is much faster
@@ -586,8 +586,9 @@
.. note:: For CUDA inputs, this function synchronizes that device with the CPU.
-.. note:: This function is not differentiable. If you need differentiability use
- :func:`torch.linalg.eigh` instead, which also computes the eigenvectors.
+.. seealso::
+
+ :func:`torch.linalg.eigh` computes the full eigenvalue decomposition.
Args:
A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 82f58c8..bf3e467 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2106,10 +2106,15 @@
def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
"""
- This function generates input for torch.linalg.eigh with UPLO="U" or "L" keyword argument.
+ This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
"""
def out_fn(output):
- return output[0], abs(output[1])
+ if isinstance(output, tuple):
+ # eigh function
+ return output[0], abs(output[1])
+ else:
+ # eigvalsh function
+ return output
samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
for sample in samples:
@@ -4415,6 +4420,13 @@
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)
),
+ OpInfo('linalg.eigvalsh',
+ aten_name='linalg_eigvalsh',
+ dtypes=floating_and_complex_types(),
+ check_batched_gradgrad=False,
+ sample_inputs_func=sample_inputs_linalg_eigh,
+ gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+ decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],),
OpInfo('linalg.householder_product',
aten_name='linalg_householder_product',
op=torch.linalg.householder_product,