Remove deprecated torch.symeig (#70988)
The time has come to remove deprecated linear algebra related functions. This PR removes `torch.symeig`.
- [x] XLA PR: https://github.com/pytorch/xla/pull/4498
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70988
Approved by: https://github.com/lezcano, https://github.com/kit1980, https://github.com/malfet
diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt
index 97cd3f6..fc88b90 100644
--- a/.github/ci_commit_pins/xla.txt
+++ b/.github/ci_commit_pins/xla.txt
@@ -1 +1 @@
-5714e03fdd9d86b9bd9ca684631e95ea2cf65c4f
+021a1cc2173138548481342c1863fcd3f177dca5
diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp
index ffce89f..9b80468 100644
--- a/aten/src/ATen/autocast_mode.cpp
+++ b/aten/src/ATen/autocast_mode.cpp
@@ -601,7 +601,6 @@
KERNEL_CPU(_lu_with_info, fp32)
KERNEL_CPU(qr, fp32)
KERNEL_CPU(svd, fp32)
- KERNEL_CPU(symeig, fp32)
KERNEL_CPU(triangular_solve, fp32)
KERNEL_CPU(fractional_max_pool2d, fp32)
KERNEL_CPU(fractional_max_pool3d, fp32)
diff --git a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
index cdc60ed..21836fc 100644
--- a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
+++ b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp
@@ -595,7 +595,6 @@
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(geqrf, geqrf);
LINALG_CHECK_MATRIX_UNARY_ONE_OUT(logdet, logdet);
-LINALG_CHECK_MATRIX_UNARY_TWO_OUT(symeig, symeig);
LINALG_CHECK_MATRIX_BINARY_TWO_OUT(triangular_solve, triangular_solve);
LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_det, linalg.det);
LINALG_CHECK_MATRIX_UNARY_TWO_OUT(_linalg_eigh, linalg.eigh);
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index afe1cf9..83613da 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -34,8 +34,6 @@
#include <ATen/ops/_linalg_svd_meta.h>
#include <ATen/ops/_linalg_svd_native.h>
#include <ATen/ops/_lu_with_info_native.h>
-#include <ATen/ops/_symeig_helper.h>
-#include <ATen/ops/_symeig_helper_native.h>
#include <ATen/ops/all.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/cat.h>
@@ -110,8 +108,6 @@
#include <ATen/ops/resize_as_native.h>
#include <ATen/ops/sum.h>
#include <ATen/ops/svd_native.h>
-#include <ATen/ops/symeig.h>
-#include <ATen/ops/symeig_native.h>
#include <ATen/ops/triangular_solve_meta.h>
#include <ATen/ops/triangular_solve_native.h>
#include <ATen/ops/tril.h>
@@ -289,12 +285,6 @@
extern "C" void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
extern "C" void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);
-// syev
-extern "C" void zheev_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *info);
-extern "C" void cheev_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *info);
-extern "C" void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
-extern "C" void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
-
// syevd
extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
@@ -910,24 +900,6 @@
sormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
}
-template<> void lapackSymeig<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int *info) {
- zheev_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, info);
-}
-
-template<> void lapackSymeig<c10::complex<float>, float>(char jobz, char uplo, int n, c10::complex<float> *a, int lda, float *w, c10::complex<float> *work, int lwork, float *rwork, int *info) {
- cheev_(&jobz, &uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, w, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, info);
-}
-
-template<> void lapackSymeig<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double* rwork, int *info) {
- (void)rwork; // unused
- dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
-}
-
-template<> void lapackSymeig<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float* rwork, int *info) {
- (void)rwork; // unused
- ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
-}
-
template<> void lapackSyevd<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
zheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
}
@@ -2815,134 +2787,6 @@
return L;
}
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-template <typename scalar_t>
-static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool upper, int* infos) {
-#if !AT_BUILD_WITH_LAPACK()
- AT_ERROR("symeig: LAPACK library not found in compilation");
-#else
- using value_t = typename c10::scalar_value_type<scalar_t>::type;
- auto self_data = self.data_ptr<scalar_t>();
- auto eigvals_data = eigvals.data_ptr<value_t>();
- auto self_matrix_stride = matrixStride(self);
- auto eigvals_stride = eigvals.size(-1);
- auto batch_size = batchCount(self);
- auto n = self.size(-1);
-
- char uplo = upper ? 'U' : 'L';
- char jobz = eigenvectors ? 'V' : 'N';
-
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int info;
- // Run once, first to get the optimum work size.
- // Since we deal with batches of matrices with the same dimensions, doing this outside
- // the loop saves (batch_size - 1) workspace queries which would provide the same result
- // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
- int lwork = -1;
- scalar_t wkopt;
-
- Tensor rwork;
- value_t* rwork_data = nullptr;
- if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
- int64_t lrwork = std::max(int64_t(1), 3 * n - 2);
- ScalarType dtype = toRealValueType(typeMetaToScalarType(self.dtype()));
- rwork = at::empty({lrwork}, self.options().dtype(dtype));
- rwork_data = rwork.data_ptr<value_t>();
- }
-
- lapackSymeig<scalar_t, value_t>(jobz, uplo, n, self_data, n, eigvals_data, &wkopt, lwork, rwork_data, &info);
- lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
- Tensor work = at::empty({lwork}, self.options());
-
- for (const auto i : c10::irange(batch_size)) {
- scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
- value_t* eigvals_working_ptr = &eigvals_data[i * eigvals_stride];
-
- // now compute the eigenvalues and the eigenvectors (optionally)
- lapackSymeig<scalar_t, value_t>(jobz, uplo, n, self_working_ptr, n, eigvals_working_ptr, work.data_ptr<scalar_t>(), lwork, rwork_data, &info);
- infos[i] = info;
- if (info != 0) {
- return;
- }
- }
-#endif
-}
-
-std::tuple<Tensor, Tensor> _symeig_helper_cpu(const Tensor& self, bool eigenvectors, bool upper) {
- auto infos = at::zeros({batchCount(self)}, self.options().dtype(kInt));
-
- auto self_sizes = self.sizes().vec();
- self_sizes.pop_back();
- ScalarType dtype = toRealValueType(typeMetaToScalarType(self.dtype()));
- auto eigvals = at::empty(self_sizes, self.options().dtype(dtype));
-
- if (self.numel() == 0) {
- return std::tuple<Tensor, Tensor>(eigvals, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
- }
-
- auto self_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "symeig_cpu", [&]{
- apply_symeig<scalar_t>(self_working_copy, eigvals, eigenvectors, upper, infos.data_ptr<int>());
- });
-
- at::_linalg_check_errors(infos, "symeig", self.dim() == 2);
- if (eigenvectors) {
- return std::tuple<Tensor, Tensor>(eigvals, self_working_copy);
- } else {
- return std::tuple<Tensor, Tensor>(eigvals, at::empty({0}, self.options()));
- }
-}
-
-std::tuple<Tensor, Tensor> symeig(const Tensor& self, bool eigenvectors, bool upper) {
- TORCH_WARN_ONCE(
- "torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future ",
- "PyTorch release.\n",
- "The default behavior has changed from using the upper triangular portion of the matrix by default ",
- "to using the lower triangular portion.\n",
- "L, _ = torch.symeig(A, upper=upper)\n",
- "should be replaced with\n",
- "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n",
- "and\n",
- "L, V = torch.symeig(A, eigenvectors=True)\n"
- "should be replaced with\n",
- "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
- );
- squareCheckInputs(self, "linalg.symeig");
- return at::_symeig_helper(self, eigenvectors, upper);
-}
-
-std::tuple<Tensor&, Tensor&> symeig_out(const Tensor& self, bool eigenvectors, bool upper, Tensor& vals, Tensor& vecs) {
- TORCH_WARN_ONCE(
- "torch.symeig is deprecated in favor of torch.linalg.eigh and will be removed in a future ",
- "PyTorch release.\n",
- "The default behavior has changed from using the upper triangular portion of the matrix by default ",
- "to using the lower triangular portion.\n",
- "L, _ = torch.symeig(A, upper=upper)\n",
- "should be replaced with\n",
- "L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')\n",
- "and\n",
- "L, V = torch.symeig(A, eigenvectors=True)\n"
- "should be replaced with\n",
- "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
- );
- checkSameDevice("symeig", vals, self, "eigenvalues");
- checkSameDevice("symeig", vecs, self, "eigenvectors");
- checkLinalgCompatibleDtype("symeig", vecs, self, "eigenvectors");
- // eigenvalues are always real-valued here
- ScalarType real_dtype = toRealValueType(self.scalar_type());
- checkLinalgCompatibleDtype("symeig", vals.scalar_type(), real_dtype, "eigenvalues");
-
- Tensor vals_tmp, vecs_tmp;
- std::tie(vals_tmp, vecs_tmp) = at::symeig(self, eigenvectors, upper);
-
- at::native::resize_output(vals, vals_tmp.sizes());
- at::native::resize_output(vecs, vecs_tmp.sizes());
- vals.copy_(vals_tmp);
- vecs.copy_(vecs_tmp);
- return std::tuple<Tensor&, Tensor&>(vals, vecs);
-}
-
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This function returns complex-valued eigenvectors that is obtained from LAPACK GEEV's real-valued output
diff --git a/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp b/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp
index b445e3a..045bfa8 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp
+++ b/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp
@@ -32,8 +32,7 @@
namespace at::native {
#if defined(BUILD_LAZY_CUDA_LINALG)
namespace {
-cuda::detail::LinalgDispatch disp = {_symeig_helper_cuda,
- _cholesky_solve_helper_cuda};
+cuda::detail::LinalgDispatch disp = {_cholesky_solve_helper_cuda};
at::DynamicLibrary& getTorchLinalgLibrary() {
static at::DynamicLibrary lib("libtorch_cuda_linalg.so", nullptr, true);
@@ -174,12 +173,6 @@
return disp.cholesky_solve_helper(self, A, upper);
}
-std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvectors, bool upper) {
- getTorchLinalgLibrary();
- TORCH_CHECK(disp.symeig_helper != _symeig_helper_cuda, "Can't find _symeig_helper_cuda");
- return disp.symeig_helper(self, eigenvectors, upper);
-}
-
#endif /*defined(BUILD_LAZY_CUDA_LINALG)*/
} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
index 7126299..8726019 100644
--- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp
@@ -24,7 +24,6 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_cholesky_solve_helper_native.h>
-#include <ATen/ops/_symeig_helper_native.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
@@ -1873,8 +1872,6 @@
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel);
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ symeig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
template <typename scalar_t>
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
#if !AT_MAGMA_ENABLED()
@@ -1949,39 +1946,6 @@
#endif
}
-std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvectors, bool upper) {
- Tensor infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt).device(at::kCPU));
-
- auto eigvals_shape = IntArrayRef(self.sizes().data(), self.dim()-1); // self.shape[:-1]
- ScalarType real_dtype = toRealValueType(self.scalar_type());
-
- // magmaSyevd uses a hybrid CPU-GPU algorithm to compute the eigenvalues and eigenvectors.
- // The driver routine magma_(d/s)syev_gpu accepts a tensor on the CPU for eigvalenvalues.
- // The data is later moved to the appropriate device.
- // In the case where self.numel() == 0, we just return an empty tensor of
- // dimensions on the CUDA (to avoid the unnecessary "to(at::kCUDA)")
- auto eigvals_working_copy = self.numel() == 0
- ? at::empty(eigvals_shape, self.options().dtype(real_dtype))
- : at::empty(eigvals_shape, self.options().dtype(real_dtype).device(at::kCPU));
-
- if (self.numel() == 0) {
- return std::tuple<Tensor, Tensor>(eigvals_working_copy, at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT));
- }
-
- auto self_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "symeig_cuda", [&]{
- apply_magma_eigh<scalar_t>(eigvals_working_copy, self_working_copy, infos, upper, eigenvectors);
- });
-
- at::_linalg_check_errors(infos, "symeig", self.dim() == 2);
-
- if (eigenvectors) {
- return std::tuple<Tensor, Tensor>(eigvals_working_copy.to(self.device()), self_working_copy);
- } else {
- return std::tuple<Tensor, Tensor>(eigvals_working_copy.to(self.device()), at::empty({0}, self.options()));
- }
-}
-
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// This is a type dispatch function for 'apply_magma_eigh'
@@ -2796,8 +2760,7 @@
#if defined(BUILD_LAZY_CUDA_LINALG)
struct DispatchInitializer {
DispatchInitializer() {
- cuda::detail::LinalgDispatch disp{ _symeig_helper_cuda,
- _cholesky_solve_helper_cuda};
+ cuda::detail::LinalgDispatch disp{_cholesky_solve_helper_cuda};
cuda::detail::registerLinalgDispatch(disp);
};
} initializer;
diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h
index 532919e..3fdf3eb 100644
--- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h
+++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h
@@ -84,7 +84,6 @@
// This is only used for an old-style dispatches
// Please do not add any new entires to it
struct LinalgDispatch {
- std::tuple<Tensor, Tensor> (*symeig_helper)(const Tensor& self, bool eigenvectors, bool upper);
Tensor (*cholesky_solve_helper)(const Tensor& self, const Tensor& A, bool upper);
};
C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 7a9382d..125423f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8699,22 +8699,6 @@
- func: linalg_vander(Tensor x, *, int? N=None) -> Tensor
python_module: linalg
-- func: symeig.e(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
- dispatch:
- CompositeExplicitAutograd: symeig_out
-
-- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
- variants: method, function
- dispatch:
- CompositeExplicitAutograd: symeig
-
-- func: _symeig_helper(Tensor self, bool eigenvectors, bool upper) -> (Tensor, Tensor)
- variants: function
- dispatch:
- CPU: _symeig_helper_cpu
- CUDA: _symeig_helper_cuda
- autogen: _symeig_helper.out
-
- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst
index 2700e61..4f6de6f 100644
--- a/docs/source/tensors.rst
+++ b/docs/source/tensors.rst
@@ -650,7 +650,6 @@
Tensor.svd
Tensor.swapaxes
Tensor.swapdims
- Tensor.symeig
Tensor.t
Tensor.t_
Tensor.tensor_split
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index bbec47f..a4f0a2c 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -589,7 +589,6 @@
svd
svd_lowrank
pca_lowrank
- symeig
lobpcg
trapz
trapezoid
diff --git a/test/cpp/lazy/test_lazy_ops.cpp b/test/cpp/lazy/test_lazy_ops.cpp
index 4f48cd8..a098e36 100644
--- a/test/cpp/lazy/test_lazy_ops.cpp
+++ b/test/cpp/lazy/test_lazy_ops.cpp
@@ -1028,39 +1028,6 @@
}
}
-TEST_F(LazyOpsTest, TestSymEig) {
- static const int dims[] = {4, 7};
- for (auto m : dims) {
- for (bool eigenvectors : {true, false}) {
- for (bool upper : {true, false}) {
- torch::Tensor a = torch::rand(
- {m, m},
- torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
- torch::Tensor sym_a = a.mm(a.t());
- auto b = torch::symeig(sym_a, eigenvectors, upper);
- ForEachDevice([&](const torch::Device& device) {
- torch::Tensor lazy_a = CopyToDevice(sym_a, device);
- auto lazy_b = torch::symeig(lazy_a, eigenvectors, upper);
- AllClose(
- std::get<0>(b),
- std::get<0>(lazy_b),
- /*rtol=*/3e-2,
- /*atol=*/1e-2);
- if (eigenvectors) {
- AllClose(
- std::get<1>(b).abs(),
- std::get<1>(lazy_b).abs(),
- /*rtol=*/3e-2,
- /*atol=*/1e-2);
- } else {
- EXPECT_EQ(std::get<1>(b).sizes(), std::get<1>(lazy_b).sizes());
- }
- });
- }
- }
- }
-}
-
TEST_F(LazyOpsTest, TestCholesky) {
static const int dims[] = {4, 7};
for (auto m : dims) {
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index 1569702..854c18b 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -476,7 +476,6 @@
xfail("stft"),
xfail("svd"),
xfail("svd_lowrank"),
- xfail("symeig"),
xfail("t"),
xfail("take_along_dim"),
xfail("take"),
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8b6b71c..9ff4d1d 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -472,8 +472,6 @@
aten::_standard_gamma.out
aten::_standard_gamma_grad
aten::_standard_gamma_grad.out
-aten::_symeig_helper
-aten::_symeig_helper.out
aten::_test_autograd_multiple_dispatch.fullcoverage
aten::_test_autograd_multiple_dispatch.fullcoverage_out
aten::_test_autograd_multiple_dispatch_view
@@ -1270,8 +1268,6 @@
aten::squeeze_copy.out
aten::sspaddmm.out
aten::std_mean.correction_out
-aten::symeig
-aten::symeig.e
aten::t_
aten::t_copy
aten::t_copy.out
diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
index 72c43e6..672b0ab 100644
--- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py
+++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
@@ -118,6 +118,9 @@
("aten::_nested_tensor", datetime.date(9999, 1, 1)),
("prepacked::unpack_prepacked_sizes_conv2d", datetime.date(9999, 1, 1)),
("prepacked::unpack_prepacked_sizes_linear", datetime.date(9999, 1, 1)),
+ ("aten::_symeig_helper", datetime.date(9999, 1, 1)),
+ ("aten::symeig", datetime.date(9999, 1, 1)),
+ ("aten::symeig.e", datetime.date(9999, 1, 1)),
("aten::linalg_solve", datetime.date(2022, 8, 31)),
("aten::linalg_solve.out", datetime.date(2022, 8, 31)),
("aten::quantile", datetime.date(2022, 9, 30)),
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 1bef054..261c886 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2395,7 +2395,6 @@
xfail('sum_to_size', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('svd', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('svd_lowrank', ''), # could not find kernel
- xfail('symeig', ''), # aten.symeig.default - couldn't find symbolic meta function/decomposition
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition
xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 5265490..d923ac8 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -1340,7 +1340,6 @@
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('ormqr', ''), # NYI: forward AD for ormqr
- xfail('symeig', ''), # NYI: forward AD for symeig
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
xfail('nn.functional.multilabel_soft_margin_loss', ''), # NYI: log_sigmoid_backward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward
@@ -1507,7 +1506,6 @@
xfail('segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
xfail('sparse.sampled_addmm'), # RuntimeError: Sparse CSR tensors do not have strides
xfail('svd_lowrank'), # calls random op
- xfail('symeig'), # Forward AD not implemented and no decomposition
xfail('take'), # vmap: inplace into regular tensor
xfail('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('to_sparse'), # Forward AD not implemented and no decomposition
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index 632b407..7b7996e 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -20,7 +20,7 @@
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
- skipCUDAIfNoMagma, OpDTypes
+ OpDTypes
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_utils import (
parametrize,
@@ -3261,16 +3261,6 @@
with self.assertRaisesRegex(RuntimeError, r"Attempted to vmap over aten::where"):
vmap(f)(x)
- @skipCUDAIfNoMagma
- @allowVmapFallbackUsage
- def test_symeig(self, device):
- def op(x):
- return torch.symeig(x, eigenvectors=True)[0]
-
- x = torch.randn(3, 3, device=device, requires_grad=True)
- self._batched_grad_test(op, (x,), {})
- self._batched_grad_grad_test(op, (x,), {})
-
def test_threshold(self, device):
x = torch.randn(2, 3, device=device, requires_grad=True)
self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
diff --git a/test/test_autograd.py b/test/test_autograd.py
index f420212..dbe045b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4482,14 +4482,6 @@
out.backward()
- # TODO: update these tests to use the linalg module and move to test_linalg.py
- @skipIfNoLapack
- def test_symeig_no_eigenvectors(self):
- A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
- w, v = torch.symeig(A, eigenvectors=False)
- with self.assertRaisesRegex(RuntimeError, 'is not differentiable'):
- torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)])
-
def test_no_grad_copy(self):
# create autograd function that saves grad pointer as class static
class MyFunc(Function):
diff --git a/test/test_legacy_vmap.py b/test/test_legacy_vmap.py
index 61edb1c..56d6e05 100644
--- a/test/test_legacy_vmap.py
+++ b/test/test_legacy_vmap.py
@@ -8,8 +8,7 @@
import functools
import itertools
import warnings
-from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
- skipCUDAIfNoMagma
+from torch.testing._internal.common_device_type import instantiate_device_type_tests
import types
@@ -2415,16 +2414,6 @@
x = torch.randn(2, 3, device=device, requires_grad=True)
self._batched_grad_test(Tensor.trace, (x,))
- @skipCUDAIfNoMagma
- @allowVmapFallbackUsage
- def test_symeig(self, device):
- def op(x):
- return torch.symeig(x, eigenvectors=True)[0]
-
- x = torch.randn(3, 3, device=device, requires_grad=True)
- self._batched_grad_test(op, (x,), {})
- self._batched_grad_grad_test(op, (x,), {})
-
def test_threshold(self, device):
x = torch.randn(2, 3, device=device, requires_grad=True)
self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
diff --git a/test/test_linalg.py b/test/test_linalg.py
index fe2f4c5..bb62e67 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -161,6 +161,13 @@
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
a.eig()
+ def test_symeig_removed_error(self, device):
+ a = make_tensor(5, 5, device=device, dtype=torch.float32)
+ with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
+ torch.symeig(a)
+ with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
+ a.symeig()
+
def test_lstsq_removed_error(self, device):
a = make_tensor(5, 5, device=device, dtype=torch.float32)
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
@@ -5095,7 +5102,7 @@
self.assertEqual(E.shape, batches + (k,))
self.assertEqual(V.shape, batches + (m, k))
self.assertEqual(matmul(A, V), mm(V, E.diag_embed()), atol=prec, rtol=0)
- e = torch.symeig(A)[0]
+ e = torch.linalg.eigvalsh(A)
e_smallest = e[..., :k]
self.assertEqual(E, e_smallest)
@@ -6972,98 +6979,6 @@
run_test((1, 1), (1, 1, 1025))
- @precisionOverride({torch.float32: 1e-5, torch.complex64: 1e-5})
- @skipCUDAIfNoMagma
- @skipCPUIfNoLapack
- @dtypes(*floating_and_complex_types())
- def test_symeig(self, device, dtype):
- from torch.testing._internal.common_utils import random_hermitian_matrix
-
- def run_test(dims, eigenvectors, upper):
- x = random_hermitian_matrix(*dims, dtype=dtype, device=device)
- if dtype.is_complex:
- real_dtype = torch.float32 if dtype is torch.complex64 else torch.float64
- else:
- real_dtype = dtype
- oute = torch.empty(dims[1:] + dims[:1], dtype=real_dtype, device=device)
- outv = torch.empty(dims[1:] + dims[:1] * 2, dtype=dtype, device=device)
- torch.symeig(x, eigenvectors=eigenvectors, upper=upper, out=(oute, outv))
-
- if eigenvectors:
- outv_ = outv.cpu().numpy()
- x_recon = np.matmul(np.matmul(outv_, torch.diag_embed(oute.to(dtype)).cpu().numpy()),
- outv_.swapaxes(-2, -1).conj())
- self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using V @ diag(e) @ V.T')
- else:
- eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper)
- self.assertEqual(eigvals, oute, msg='Eigenvalues mismatch')
- self.assertEqual(torch.empty(0, device=device, dtype=dtype), outv, msg='Eigenvector matrix not empty')
-
- rese, resv = x.symeig(eigenvectors=eigenvectors, upper=upper)
- self.assertEqual(rese, oute, msg="outputs of symeig and symeig with out don't match")
- self.assertEqual(resv, outv, msg="outputs of symeig and symeig with out don't match")
-
- # test non-contiguous
- x = random_hermitian_matrix(*dims, dtype=dtype, device=device)
- n_dim = len(dims) + 1
- # Reverse the batch dimensions and the matrix dimensions and then concat them
- x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2))
- assert not x.is_contiguous(), "x is intentionally non-contiguous"
- rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper)
- if eigenvectors:
- resv_ = resv.cpu().numpy()
- x_recon = np.matmul(np.matmul(resv_, torch.diag_embed(rese.to(dtype)).cpu().numpy()),
- resv_.swapaxes(-2, -1).conj())
- self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using V @ diag(e) @ V.T')
- else:
- eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper)
- self.assertEqual(eigvals, rese, msg='Eigenvalues mismatch')
- self.assertEqual(torch.empty(0, device=device, dtype=dtype), resv, msg='Eigenvector matrix not empty')
-
- batch_dims_set = [(), (3,), (3, 5), (5, 3, 5)]
- for batch_dims, eigenvectors, upper in itertools.product(batch_dims_set, (True, False), (True, False)):
- run_test((5,) + batch_dims, eigenvectors, upper)
-
- @skipCUDAIfNoMagma
- @skipCPUIfNoLapack
- @dtypes(*floating_and_complex_types())
- def test_symeig_out_errors_and_warnings(self, device, dtype):
- from torch.testing._internal.common_utils import random_hermitian_matrix
-
- # if non-empty out tensor with wrong shape is passed a warning is given
- a = random_hermitian_matrix(3, dtype=dtype, device=device)
- real_dtype = a.real.dtype if dtype.is_complex else dtype
- out_w = torch.empty(7, 7, dtype=real_dtype, device=device)
- out_v = torch.empty(7, 7, dtype=dtype, device=device)
- with warnings.catch_warnings(record=True) as w:
- # Trigger warning
- torch.symeig(a, out=(out_w, out_v))
- self.assertTrue("An output with one or more elements was resized" in str(w[-2].message))
- self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))
-
- # dtypes should be safely castable
- out_w = torch.empty(0, dtype=real_dtype, device=device)
- out_v = torch.empty(0, dtype=torch.int, device=device)
- with self.assertRaisesRegex(RuntimeError, "but got eigenvectors with dtype Int"):
- torch.symeig(a, out=(out_w, out_v))
-
- out_w = torch.empty(0, dtype=torch.int, device=device)
- out_v = torch.empty(0, dtype=dtype, device=device)
- with self.assertRaisesRegex(RuntimeError, "but got eigenvalues with dtype Int"):
- torch.symeig(a, out=(out_w, out_v))
-
- # device should match
- if torch.cuda.is_available():
- wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda'
- out_w = torch.empty(0, device=wrong_device, dtype=dtype)
- out_v = torch.empty(0, device=device, dtype=dtype)
- with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
- torch.symeig(a, out=(out_w, out_v))
- out_w = torch.empty(0, device=device, dtype=dtype)
- out_v = torch.empty(0, device=wrong_device, dtype=dtype)
- with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
- torch.symeig(a, out=(out_w, out_v))
-
@skipCUDAIfNoCusolver
@skipCPUIfNoLapack
def test_pca_lowrank(self, device):
diff --git a/test/test_meta.py b/test/test_meta.py
index 16a3886..583d452 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -632,7 +632,6 @@
torch.polar : {f64, f32},
torch.segment_reduce : {f64, f16, bf16, f32},
torch.searchsorted : {f64, i32, i64, f16, u8, i16, bf16, i8, f32},
- torch.symeig : {f64, f32, c128, c64},
torch.cholesky : {f64, f32, c128, c64},
torch.cholesky_inverse : {f64, f32, c128, c64},
torch.cholesky_solve : {f64, f32, c128, c64},
@@ -846,7 +845,6 @@
aten.ormqr.default : {c64, c128, f64, f32},
aten.ormqr.out : {c64, c128, f64, f32},
aten.polar.out : {f32, f64},
- aten.symeig.default : {c64, c128, f64, f32},
aten.take.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.take.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
aten.tensordot.out : {c64, i8, f64, c128, i64, bf16, f32, i32, i16, u8},
diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py
index 4878253..b0a209f 100644
--- a/test/test_namedtuple_return_api.py
+++ b/test/test_namedtuple_return_api.py
@@ -13,7 +13,7 @@
path = os.path.dirname(os.path.realpath(__file__))
aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml')
all_operators_with_namedtuple_return = {
- 'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig',
+ 'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd',
'qr', 'geqrf', 'slogdet', 'sort', 'topk', 'linalg_inv_ex',
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_linalg_eigh", "_unpack_dual", 'linalg_qr',
'linalg_svd', '_linalg_svd', 'linalg_slogdet', '_linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
@@ -77,7 +77,6 @@
op(operators=['_linalg_slogdet'], input=(), names=('sign', 'logabsdet', 'LU', 'pivots'), hasout=True),
op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True),
op(operators=['geqrf'], input=(), names=('a', 'tau'), hasout=True),
- op(operators=['symeig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True),
op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True),
op(operators=['linalg_eig'], input=(), names=('eigenvalues', 'eigenvectors'), hasout=True),
op(operators=['linalg_eigh'], input=("L",), names=('eigenvalues', 'eigenvectors'), hasout=True),
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 424cf5f..190e2b3 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1358,7 +1358,6 @@
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('svd_lowrank', ''), # aten.mm.default - couldn't find symbolic meta function/decomposition
- xfail('symeig', ''), # aten.symeig.default - couldn't find symbolic meta function/decomposition
xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition
xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 9ec2bb3..f5b4ab8 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1588,9 +1588,6 @@
full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)"
U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices)
-- name: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)
- self: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors_return, /*is_hermitian=*/true, /*symeig_eigenvector=*/eigenvectors)
-
- name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)
A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true)
eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true)
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 0361c27..06cb7f0 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -117,7 +117,6 @@
"_cholesky.*",
"_triangular_solve.*",
"_qr.*",
- "_symeig.*",
"_svd.*",
"slice",
"item",
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 4e1ca78..4fea5f7 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -305,7 +305,6 @@
"reflection_pad1d_backward",
"reflection_pad2d_backward",
"reflection_pad3d_backward",
- "symeig",
"_sparse_sparse_matmul",
"replication_pad1d",
"replication_pad2d",
diff --git a/torch/__init__.py b/torch/__init__.py
index ae0c6f3..08eab4b 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1307,6 +1307,7 @@
solve,
lstsq,
)
+from ._linalg_utils import _symeig as symeig # type: ignore[misc]
class _TorchCompileInductorWrapper:
compiler_name = "inductor"
diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py
index bdd22f3..3a81fc6 100644
--- a/torch/_linalg_utils.py
+++ b/torch/_linalg_utils.py
@@ -113,6 +113,14 @@
)
+def _symeig(
+ input, eigenvectors=False, upper=True, *, out=None
+) -> Tuple[Tensor, Tensor]:
+ raise RuntimeError(
+ "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.eigh` function instead.",
+ )
+
+
def eig(
self: Tensor, eigenvectors: bool = False, *, e=None, v=None
) -> Tuple[Tensor, Tensor]:
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 7a70653..64e3d06 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -662,6 +662,11 @@
return eig(self, eigenvectors=eigenvectors)
+ def symeig(self, eigenvectors=False):
+ from ._linalg_utils import _symeig
+
+ return _symeig(self, eigenvectors=eigenvectors)
+
def lu(self, pivot=True, get_infos=False):
r"""See :func:`torch.lu`"""
# If get_infos is True, then we don't need to check for errors and vice versa
diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py
index 427cd5b..7210acb 100644
--- a/torch/_tensor_docs.py
+++ b/torch/_tensor_docs.py
@@ -4917,15 +4917,6 @@
)
add_docstr_all(
- "symeig",
- r"""
-symeig(eigenvectors=False, upper=True) -> (Tensor, Tensor)
-
-See :func:`torch.symeig`
-""",
-)
-
-add_docstr_all(
"swapdims",
r"""
swapdims(dim0, dim1) -> Tensor
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 664b8b1..77404e2 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -11094,104 +11094,6 @@
""",
)
-add_docstr(
- torch.symeig,
- r"""
-symeig(input, eigenvectors=False, upper=True, *, out=None) -> (Tensor, Tensor)
-
-This function returns eigenvalues and eigenvectors
-of a real symmetric or complex Hermitian matrix :attr:`input` or a batch thereof,
-represented by a namedtuple (eigenvalues, eigenvectors).
-
-This function calculates all eigenvalues (and vectors) of :attr:`input`
-such that :math:`\text{input} = V \text{diag}(e) V^T`.
-
-The boolean argument :attr:`eigenvectors` defines computation of
-both eigenvectors and eigenvalues or eigenvalues only.
-
-If it is ``False``, only eigenvalues are computed. If it is ``True``,
-both eigenvalues and eigenvectors are computed.
-
-Since the input matrix :attr:`input` is supposed to be symmetric or Hermitian,
-only the upper triangular portion is used by default.
-
-If :attr:`upper` is ``False``, then lower triangular portion is used.
-
-.. warning::
-
- :func:`torch.symeig` is deprecated in favor of :func:`torch.linalg.eigh`
- and will be removed in a future PyTorch release. The default behavior has changed
- from using the upper triangular portion of the matrix by default to using the
- lower triangular portion.
-
- ``L, _ = torch.symeig(A, upper=upper)`` should be replaced with
-
- .. code :: python
-
- UPLO = "U" if upper else "L"
- L = torch.linalg.eigvalsh(A, UPLO=UPLO)
-
- ``L, V = torch.symeig(A, eigenvectors=True, upper=upper)`` should be replaced with
-
- .. code :: python
-
- UPLO = "U" if upper else "L"
- L, V = torch.linalg.eigh(A, UPLO=UPLO)
-
-.. note:: The eigenvalues are returned in ascending order. If :attr:`input` is a batch of matrices,
- then the eigenvalues of each matrix in the batch is returned in ascending order.
-
-.. note:: Irrespective of the original strides, the returned matrix `V` will
- be transposed, i.e. with strides `V.contiguous().mT.stride()`.
-
-.. warning:: Extra care needs to be taken when backward through outputs. Such
- operation is only stable when all eigenvalues are distinct and becomes
- less stable the smaller :math:`\min_{i \neq j} |\lambda_i - \lambda_j|` is.
-
-Args:
- input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more
- batch dimensions consisting of symmetric or Hermitian matrices.
- eigenvectors(bool, optional): controls whether eigenvectors have to be computed
- upper(bool, optional): controls whether to consider upper-triangular or lower-triangular region
-
-Keyword args:
- out (tuple, optional): the output tuple of (Tensor, Tensor)
-
-Returns:
- (Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing
-
- - **eigenvalues** (*Tensor*): Shape :math:`(*, m)`. The eigenvalues in ascending order.
- - **eigenvectors** (*Tensor*): Shape :math:`(*, m, m)`.
- If ``eigenvectors=False``, it's an empty tensor.
- Otherwise, this tensor contains the orthonormal eigenvectors of the ``input``.
-
-Examples::
-
-
- >>> a = torch.randn(5, 5)
- >>> a = a + a.t() # To make a symmetric
- >>> a
- tensor([[-5.7827, 4.4559, -0.2344, -1.7123, -1.8330],
- [ 4.4559, 1.4250, -2.8636, -3.2100, -0.1798],
- [-0.2344, -2.8636, 1.7112, -5.5785, 7.1988],
- [-1.7123, -3.2100, -5.5785, -2.6227, 3.1036],
- [-1.8330, -0.1798, 7.1988, 3.1036, -5.1453]])
- >>> e, v = torch.symeig(a, eigenvectors=True)
- >>> e
- tensor([-13.7012, -7.7497, -2.3163, 5.2477, 8.1050])
- >>> v
- tensor([[ 0.1643, 0.9034, -0.0291, 0.3508, 0.1817],
- [-0.2417, -0.3071, -0.5081, 0.6534, 0.4026],
- [-0.5176, 0.1223, -0.0220, 0.3295, -0.7798],
- [-0.4850, 0.2695, -0.5773, -0.5840, 0.1337],
- [ 0.6415, -0.0447, -0.6381, -0.0193, -0.4230]])
- >>> a_big = torch.randn(5, 2, 2)
- >>> a_big = a_big + a_big.mT # To make a_big symmetric
- >>> e, v = a_big.symeig(eigenvectors=True)
- >>> torch.allclose(torch.matmul(v, torch.matmul(e.diag_embed(), v.mT)), a_big)
- True
-""",
-)
add_docstr(
torch.t,
diff --git a/torch/overrides.py b/torch/overrides.py
index 469fdb8..2fcdb37 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -277,6 +277,7 @@
Tensor.new_full,
Tensor._make_subclass,
Tensor.solve,
+ Tensor.symeig,
Tensor.stride,
Tensor.unflatten,
Tensor.to_sparse_coo,
@@ -1009,7 +1010,6 @@
torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
torch.linalg.svdvals: lambda input, out=None: -1,
- torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1,
torch.swapaxes: lambda input, dim0, dim1: -1,
torch.swapdims: lambda input, axis0, axis1: -1,
torch.special.airy_ai: lambda input: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ba4d409..fbaaffa 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -5033,16 +5033,6 @@
other = make_input((*batch, *other_matrix_shape), requires_grad=requires_grad)
yield SampleInput(reflectors, tau, other, left=left, transpose=transpose)
-def sample_inputs_symeig(op_info, device, dtype, requires_grad=False, **kwargs):
- out = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)
-
- for o in out:
- o.kwargs = {"upper": bool(np.random.choice([True, False])),
- "eigenvectors": True}
- # A gauge-invariant function
- o.output_process_fn_grad = lambda output: (output[0], abs(output[1]))
- yield o
-
def sample_inputs_cholesky_solve(op_info, device, dtype, requires_grad=False, **kwargs):
cholesky_inverse_samples = sample_inputs_linalg_cholesky_inverse(
@@ -9546,21 +9536,6 @@
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float,)),
)),
- OpInfo('symeig',
- dtypes=floating_and_complex_types(),
- check_batched_grad=False,
- check_batched_gradgrad=False,
- sample_inputs_func=sample_inputs_symeig,
- gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
- skips=(
- DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out',
- device_type='mps', dtypes=[torch.float32]),
- DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager',
- device_type='mps', dtypes=[torch.float32]),
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit',
- device_type='mps', dtypes=[torch.float32]),
- ),
- decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off]),
OpInfo('clamp',
aliases=('clip',),
ref=_clamp_numpy,