Make torch.linalg.eigvalsh differentiable (#57189)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57189

`torch.linalg.eigvalsh` now supports autograd. This is achieved by
computing the eigenvectors internally if input requires grad,
otherwise the eigenvectors are not computed and the operation is faster.

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D28199708

Pulled By: albanD

fbshipit-source-id: 12ac56f50137398613e186abd49f82c8ab83532e
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index a9d5d4c..5ef887f 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -2088,6 +2088,14 @@
 }
 
 Tensor linalg_eigvalsh(const Tensor& input, std::string uplo) {
+  // if input requires grad we must compute the eigenvectors to make this function differentiable
+  // the eigenvectors are not exposed to the user
+  if (at::GradMode::is_enabled() && input.requires_grad()) {
+    Tensor values;
+    std::tie(values, std::ignore) = at::linalg_eigh(input, uplo);
+    return values;
+  }
+
   squareCheckInputs(input);
   checkUplo(uplo);
   ScalarType real_dtype = toValueType(input.scalar_type());
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 38f8bda..5dfa4db 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9502,13 +9502,9 @@
 - func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
   python_module: linalg
   variants: function
-  dispatch:
-    CompositeExplicitAutograd: linalg_eigvalsh
 
 - func: linalg_eigvalsh.out(Tensor self, str UPLO='L', *, Tensor(a!) out) -> Tensor(a!)
   python_module: linalg
-  dispatch:
-    CompositeExplicitAutograd: linalg_eigvalsh_out
 
 - func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
   python_module: linalg
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index e933a5f..c19c685 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1154,9 +1154,6 @@
 - name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
   self: symeig_backward(grads, self, /*eigenvectors=*/true, /*upper=*/true, eigenvalues, eigenvectors)
 
-- name: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
-  self: non_differentiable
-
 - name: linalg_eigvals(Tensor self) -> Tensor
   self: non_differentiable
 
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 1ee5cac..efdfac2 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -502,8 +502,8 @@
 
 .. seealso::
 
-        :func:`torch.linalg.eigvalsh` computes only the eigenvalues.
-        However, that function is not differentiable.
+        :func:`torch.linalg.eigvalsh` computes only the eigenvalues,
+        but its gradients are always numerically stable.
 
         :func:`torch.linalg.cholesky` for a different decomposition of a Hermitian matrix.
         The Cholesky decomposition gives less information about the matrix but is much faster
@@ -586,8 +586,9 @@
 
 .. note:: For CUDA inputs, this function synchronizes that device with the CPU.
 
-.. note:: This function is not differentiable. If you need differentiability use
-          :func:`torch.linalg.eigh` instead, which also computes the eigenvectors.
+.. seealso::
+
+        :func:`torch.linalg.eigh` computes the full eigenvalue decomposition.
 
 Args:
     A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 82f58c8..bf3e467 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2106,10 +2106,15 @@
 
 def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
     """
-    This function generates input for torch.linalg.eigh with UPLO="U" or "L" keyword argument.
+    This function generates input for torch.linalg.eigh/eigvalsh with UPLO="U" or "L" keyword argument.
     """
     def out_fn(output):
-        return output[0], abs(output[1])
+        if isinstance(output, tuple):
+            # eigh function
+            return output[0], abs(output[1])
+        else:
+            # eigvalsh function
+            return output
 
     samples = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
     for sample in samples:
@@ -4415,6 +4420,13 @@
                # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
                SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)
            ),
+    OpInfo('linalg.eigvalsh',
+           aten_name='linalg_eigvalsh',
+           dtypes=floating_and_complex_types(),
+           check_batched_gradgrad=False,
+           sample_inputs_func=sample_inputs_linalg_eigh,
+           gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
+           decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],),
     OpInfo('linalg.householder_product',
            aten_name='linalg_householder_product',
            op=torch.linalg.householder_product,