[Array API] Add linalg.diagonal (#70599)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70599

This PR adds `linalg.diagonal` following the Array API:
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-diagonal-x-axis1-0-axis2-1-offset-0

Fixes https://github.com/pytorch/pytorch/issues/62813

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33760506

Pulled By: mruberry

fbshipit-source-id: e32c3490321d8c3f31b3bb538bc1f72b39bd2854
(cherry picked from commit 44f41f8e3922892ca2f86c9c05249336de40e9ee)
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 5115f47..2dc014e 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1737,6 +1737,11 @@
   return at::native::matmul_out(tensor1, tensor2, result);
 }
 
+// torch.linalg.diagonal, alias for torch.diagonal with dim1=-2, dim2=-1 as defaults
+Tensor linalg_diagonal(const Tensor& A, int64_t offset, int64_t dim1, int64_t dim2) {
+  return A.diagonal(offset, dim1, dim2);
+}
+
 // helper methods for matrix_exp
 namespace {
 
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index d94388a..0a493c1 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1610,6 +1610,10 @@
   dispatch:
     CompositeExplicitAutograd: diagonal
 
+- func: linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)
+  python_module: linalg
+  variants: function
+
 - func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)
   variants: function, method
 
diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst
index 2e2a31a..f7b2324 100644
--- a/docs/source/linalg.rst
+++ b/docs/source/linalg.rst
@@ -19,6 +19,7 @@
     norm
     vector_norm
     matrix_norm
+    diagonal
     det
     slogdet
     cond
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index fa8d933..e70c86e 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -1353,6 +1353,12 @@
 Alias for :func:`torch.matmul`
 """)
 
+diagonal = _add_docstr(_linalg.linalg_diagonal, r"""
+linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor
+
+Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`.
+""")
+
 multi_dot = _add_docstr(_linalg.linalg_multi_dot, r"""
 linalg.multi_dot(tensors, *, out=None)
 
diff --git a/torch/overrides.py b/torch/overrides.py
index b816a44..9ea6932 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -435,6 +435,7 @@
         torch.diagflat: lambda input, offset=0: -1,
         torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
         torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
+        torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
         torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
         torch.digamma: lambda input, out=None: -1,
         torch.dist: lambda input, other, p=2: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 72acef7..e726c9f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6022,11 +6022,13 @@
     # Shapes for 3D Tensors
     shapes_3d = ((M, M, M),)
 
-    args_2d = ((), (2,), (-2,), (1,))
-    args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
+    kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1))
+    kwargs_3d = (dict(offset=1, dim1=1, dim2=2),
+                 dict(offset=2, dim1=0, dim2=1),
+                 dict(offset=-2, dim1=0, dim2=1))
 
-    for shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
-        yield SampleInput(make_arg(shape), args=arg)
+    for shape, kwarg in chain(product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)):
+        yield SampleInput(make_arg(shape), kwargs=kwarg)
 
 
 def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
@@ -9309,6 +9311,9 @@
            supports_fwgrad_bwgrad=True,
            sample_inputs_func=sample_inputs_diagonal_diag_embed),
     OpInfo('diagonal',
+           # They are not strictly aliases as they have diverging defaults, but we can see them as aliases for testing purposes
+           # If we add tests that test the function against the alias, make linalg.diagonal into its own OpInfo
+           aliases=('linalg.diagonal',),
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
            supports_out=False,
            supports_forward_ad=True,