[Array API] Add linalg.diagonal (#70599)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70599
This PR adds `linalg.diagonal` following the Array API:
https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html#linalg-diagonal-x-axis1-0-axis2-1-offset-0
Fixes https://github.com/pytorch/pytorch/issues/62813
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33760506
Pulled By: mruberry
fbshipit-source-id: e32c3490321d8c3f31b3bb538bc1f72b39bd2854
(cherry picked from commit 44f41f8e3922892ca2f86c9c05249336de40e9ee)
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 5115f47..2dc014e 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1737,6 +1737,11 @@
return at::native::matmul_out(tensor1, tensor2, result);
}
+// torch.linalg.diagonal, alias for torch.diagonal with dim1=-2, dim2=-1 as defaults
+Tensor linalg_diagonal(const Tensor& A, int64_t offset, int64_t dim1, int64_t dim2) {
+ return A.diagonal(offset, dim1, dim2);
+}
+
// helper methods for matrix_exp
namespace {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index d94388a..0a493c1 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1610,6 +1610,10 @@
dispatch:
CompositeExplicitAutograd: diagonal
+- func: linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)
+ python_module: linalg
+ variants: function
+
- func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)
variants: function, method
diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst
index 2e2a31a..f7b2324 100644
--- a/docs/source/linalg.rst
+++ b/docs/source/linalg.rst
@@ -19,6 +19,7 @@
norm
vector_norm
matrix_norm
+ diagonal
det
slogdet
cond
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index fa8d933..e70c86e 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -1353,6 +1353,12 @@
Alias for :func:`torch.matmul`
""")
+diagonal = _add_docstr(_linalg.linalg_diagonal, r"""
+linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor
+
+Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`.
+""")
+
multi_dot = _add_docstr(_linalg.linalg_multi_dot, r"""
linalg.multi_dot(tensors, *, out=None)
diff --git a/torch/overrides.py b/torch/overrides.py
index b816a44..9ea6932 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -435,6 +435,7 @@
torch.diagflat: lambda input, offset=0: -1,
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
+ torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
torch.digamma: lambda input, out=None: -1,
torch.dist: lambda input, other, p=2: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 72acef7..e726c9f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -6022,11 +6022,13 @@
# Shapes for 3D Tensors
shapes_3d = ((M, M, M),)
- args_2d = ((), (2,), (-2,), (1,))
- args_3d = ((1, 1, 2), (2, 0, 1), (-2, 0, 1))
+ kwargs_2d = (dict(), dict(offset=2), dict(offset=2), dict(offset=1))
+ kwargs_3d = (dict(offset=1, dim1=1, dim2=2),
+ dict(offset=2, dim1=0, dim2=1),
+ dict(offset=-2, dim1=0, dim2=1))
- for shape, arg in chain(product(shapes_2d, args_2d), product(shapes_3d, args_3d)):
- yield SampleInput(make_arg(shape), args=arg)
+ for shape, kwarg in chain(product(shapes_2d, kwargs_2d), product(shapes_3d, kwargs_3d)):
+ yield SampleInput(make_arg(shape), kwargs=kwarg)
def sample_inputs_diagonal_scatter(op_info, device, dtype, requires_grad, **kwargs):
@@ -9309,6 +9311,9 @@
supports_fwgrad_bwgrad=True,
sample_inputs_func=sample_inputs_diagonal_diag_embed),
OpInfo('diagonal',
+ # They are not strictly aliases as they have diverging defaults, but we can see them as aliases for testing purposes
+ # If we add tests that test the function against the alias, make linalg.diagonal into its own OpInfo
+ aliases=('linalg.diagonal',),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
supports_forward_ad=True,