Better and more consistent error messages in torch.linalg (#62734)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62734

Following https://github.com/pytorch/pytorch/pull/62715#discussion_r682610788
- squareCheckInputs takes a string with the name of the function
- We reuse more functions when checking the inputs

The state of the errors in torch.linalg is far from great though. We
leave a more comprehensive clean-up for the future.

cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D31823230

Pulled By: mruberry

fbshipit-source-id: eccd531f10d590eb5f9d04a957b7cdcb31c72ea4
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index b68aa1c..9e08bac 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -801,15 +801,14 @@
 // Solves a system of linear equations matmul(input, x) = other in-place
 // LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here
 static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor& input, const Tensor& other) {
-  checkSameDevice("linalg_solve", result, input);
-  checkSameDevice("linalg_solve", other, input, "other");
-  checkLinalgCompatibleDtype("linalg_solve", result, input);
+  checkSameDevice("linalg.solve", result, input);
+  checkSameDevice("linalg.solve", other, input, "other");
+  checkLinalgCompatibleDtype("linalg.solve", result, input);
 
   TORCH_CHECK(input.scalar_type() == other.scalar_type(),
     "input dtype ", input.scalar_type(), " does not match other dtype ", other.scalar_type());
 
-  TORCH_CHECK(input.dim() >= 2,
-           "input should have at least 2 dimensions, but has ", input.dim(), " dimensions instead");
+  squareCheckInputs(input, "linalg.solve");
   TORCH_CHECK(other.dim() >= 1,
            "other should have at least 1 dimension, but has ", other.dim(), " dimensions instead");
 
@@ -856,7 +855,7 @@
   // _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
   // it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
   Tensor other_broadcasted, input_broadcasted;
-  std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg_solve");
+  std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
 
   auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
   auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
@@ -928,9 +927,9 @@
   // batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
   bool vector_case = linalg_solve_is_vector_rhs(input, other);
   if (vector_case ? result.dim() > 1 : result.dim() > 2) {
-    batchCheckErrors(infos, "linalg_solve");
+    batchCheckErrors(infos, "linalg.solve");
   } else {
-    singleCheckErrors(infos.item().toInt(), "linalg_solve");
+    singleCheckErrors(infos.item().toInt(), "linalg.solve");
   }
 
   return result;
@@ -1015,7 +1014,7 @@
   if (self.numel() == 0) {
     return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
   }
-  squareCheckInputs(self);
+  squareCheckInputs(self, "inverse");
   return at::_inverse_helper(self);
 }
 
@@ -1042,9 +1041,9 @@
 // Computes the inverse matrix of 'input', it is is saved to 'result' in-place
 // LAPACK/MAGMA/cuSOLVER error codes are saved in 'infos' tensors, they are not checked here
 static Tensor& linalg_inv_out_info(Tensor& result, Tensor& infos_lu, Tensor& infos_getri, const Tensor& input) {
-  squareCheckInputs(input);
-  checkSameDevice("linalg_inv", result, input);
-  checkLinalgCompatibleDtype("linalg_inv", result, input);
+  squareCheckInputs(input, "linalg.inv");
+  checkSameDevice("linalg.inv", result, input);
+  checkLinalgCompatibleDtype("linalg.inv", result, input);
 
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos_lu.scalar_type() == kInt);
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos_getri.scalar_type() == kInt);
@@ -1137,11 +1136,11 @@
 
   // Now check LAPACK/MAGMA/cuSOLVER error codes
   if (result.dim() > 2) {
-    batchCheckErrors(infos_lu, "linalg_inv_lu");
-    batchCheckErrors(infos_getri, "linalg_inv_getri");
+    batchCheckErrors(infos_lu, "linalg.inv");
+    batchCheckErrors(infos_getri, "linalg.inv");
   } else {
-    singleCheckErrors(infos_lu.item().toInt(), "linalg_inv_lu");
-    singleCheckErrors(infos_getri.item().toInt(), "linalg_inv_getri");
+    singleCheckErrors(infos_lu.item().toInt(), "linalg.inv");
+    singleCheckErrors(infos_getri.item().toInt(), "linalg.inv");
   }
 
   return result;
@@ -1164,7 +1163,7 @@
 }
 
 std::tuple<Tensor&, Tensor&> linalg_inv_ex_out(const Tensor& input, bool check_errors, Tensor& inverse, Tensor& info) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.inv_ex");
   ScalarType info_output_type = ScalarType::Int;
   TORCH_CHECK(
       info.scalar_type() == info_output_type,
@@ -1191,7 +1190,7 @@
 }
 
 std::tuple<Tensor, Tensor> linalg_inv_ex(const Tensor& input, bool check_errors) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.inv_ex");
   Tensor inverse = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous);
   inverse.transpose_(-2, -1); // make `inverse` tensor with batched column major format
   auto info_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2); // input.shape[:-2]
@@ -1287,7 +1286,7 @@
   if (self.numel() == 0) {
     return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
   }
-  squareCheckInputs(self);
+  squareCheckInputs(self, "cholesky");
 
   auto raw_cholesky_output = cloneBatchedColumnMajor(self);
   auto info_shape = IntArrayRef(
@@ -1375,7 +1374,7 @@
 }
 
 std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool upper, bool check_errors, Tensor& L, Tensor& info) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.cholesky_ex");
   checkSameDevice("torch.linalg.cholesky_ex", L, input, "L");
   checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", L, input, "L");
   checkSameDevice("torch.linalg.cholesky_ex", info, input, "info");
@@ -1510,7 +1509,7 @@
 }
 
 Tensor& cholesky_inverse_out(const Tensor &input, bool upper, Tensor &result) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "cholesky_inverse");
   checkSameDevice("cholesky_inverse", result, input);
   checkLinalgCompatibleDtype("cholesky_inverse", result, input);
 
@@ -2032,14 +2031,7 @@
       tau.scalar_type(),
       " does not match input dtype ",
       input.scalar_type());
-  TORCH_CHECK(
-      input.device() == tau.device(),
-      "torch.linalg.householder_product: Expected input and tau to be on the same device, but found input on ",
-      input.device(),
-      " and tau on ",
-      tau.device(),
-      " instead.");
-
+  checkSameDevice("torch.linalg.householder_product", tau, input, "tau");
   checkSameDevice("torch.linalg.householder_product", result, input);
   checkLinalgCompatibleDtype("torch.linalg.householder_product", result, input);
 
@@ -2312,7 +2304,7 @@
 }
 
 std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& input, c10::string_view uplo) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.eigh");
   checkUplo(uplo);
   ScalarType real_dtype = toValueType(input.scalar_type());
   Tensor values = at::empty({0}, input.options().dtype(real_dtype));
@@ -2369,7 +2361,7 @@
   ScalarType real_dtype = toValueType(input.scalar_type());
   checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype);
 
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.eigvalsh");
   checkUplo(uplo);
 
   auto expected_result_shape = IntArrayRef(input.sizes().data(), input.dim()-1);  // input.shape[:-1]
@@ -2498,7 +2490,7 @@
     "should be replaced with\n",
     "L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
   );
-  squareCheckInputs(self);
+  squareCheckInputs(self, "linalg.symeig");
   return at::_symeig_helper(self, eigenvectors, upper);
 }
 
@@ -2708,7 +2700,7 @@
 }
 
 std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values, Tensor& vectors) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.eig");
 
   // unlike NumPy for real-valued inputs the output is always complex-valued
   checkLinalgCompatibleDtype("torch.linalg.eig", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
@@ -2805,7 +2797,7 @@
 }
 
 Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
-  squareCheckInputs(input);
+  squareCheckInputs(input, "linalg.eigvals");
 
   // unlike NumPy for real-valued inputs the output is always complex-valued
   checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
@@ -3094,7 +3086,6 @@
 
     Tensor Vh = V.mH();
     return std::make_tuple(U, S, Vh);
-
 }
 
 static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) {
@@ -3107,11 +3098,11 @@
   checkSameDevice("svd", U, self, "U");
   checkSameDevice("svd", S, self, "S");
   checkSameDevice("svd", Vh, self, "Vh");
-  checkLinalgCompatibleDtype("linalg_svd", U, self, "U");
-  checkLinalgCompatibleDtype("linalg_svd", Vh, self, "Vh");
+  checkLinalgCompatibleDtype("linalg.svd", U, self, "U");
+  checkLinalgCompatibleDtype("linalg.svd", Vh, self, "Vh");
   // singular values are always real-valued here
   ScalarType real_dtype = toValueType(self.scalar_type());
-  checkLinalgCompatibleDtype("linalg_svd", S.scalar_type(), real_dtype, "S");
+  checkLinalgCompatibleDtype("linalg.svd", S.scalar_type(), real_dtype, "S");
   Tensor U_tmp, S_tmp, Vh_tmp;
   std::tie(U_tmp, S_tmp, Vh_tmp) = at::native::linalg_svd(self, full_matrices);
   svd_resize_and_copy("U", U_tmp, U);
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index ce7d176..03e14b7 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -24,6 +24,7 @@
 #include <functional>
 #include <limits>
 #include <numeric>
+#include <string>
 
 namespace at {
 namespace meta {
@@ -153,33 +154,28 @@
   return at::linalg_det(self);
 }
 
-Tensor& linalg_det_out(const Tensor& self, Tensor& out) {
-  checkSameDevice("torch.linalg.det", out, self, "out");
-  checkLinalgCompatibleDtype("torch.linalg.det", out, self, "out");
-  squareCheckInputs(self);
-  TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
-              "Expected a floating point or complex tensor as input");
-
-  IntArrayRef out_sizes(self.sizes().data(), self.dim() - 2);
-  at::native::resize_output(out, out_sizes);
-
-  auto det = std::get<0>(at::native::_det_lu_based_helper(self));
-  out.copy_(det);
-  return out;
-}
-
 Tensor linalg_det(const Tensor& self) {
-  squareCheckInputs(self);
-  TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
-              "Expected a floating point or complex tensor as input");
+  squareCheckInputs(self, "linalg.det");
+  checkFloatingOrComplex(self, "linalg.det");
 
   return std::get<0>(at::_det_lu_based_helper(self));
 }
 
+Tensor& linalg_det_out(const Tensor& self, Tensor& out) {
+  checkSameDevice("torch.linalg.det", out, self, "out");
+  checkLinalgCompatibleDtype("torch.linalg.det", out, self, "out");
+
+  IntArrayRef out_sizes(self.sizes().data(), self.dim() - 2);
+  at::native::resize_output(out, out_sizes);
+
+  auto det = at::native::linalg_det(self);
+  out.copy_(det);
+  return out;
+}
+
 Tensor logdet(const Tensor& self) {
-  squareCheckInputs(self);
-  TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
-              "Expected a floating point tensor as input");
+  squareCheckInputs(self, "logdet");
+  checkFloatingOrComplex(self, "logdet");
 
   c10::ExclusivelyOwned<Tensor> det_P, diag_U;
   std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
@@ -200,10 +196,10 @@
 }
 
 std::tuple<Tensor, Tensor> linalg_slogdet(const Tensor& self) {
-  squareCheckInputs(self);
+  squareCheckInputs(self, "linalg.slogdet");
   ScalarType t = self.scalar_type();
   TORCH_CHECK(t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble,
-              "linalg_slogdet: expected a tensor of float, double, cfloat or cdouble types but got ", t);
+              "linalg.slogdet: expected a tensor of float, double, cfloat or cdouble types but got ", t);
 
   c10::ExclusivelyOwned<Tensor> det_P, diag_U;
   std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
@@ -218,12 +214,12 @@
 
 // TODO: implement _out variant avoiding copy and using already allocated storage directly
 std::tuple<Tensor&, Tensor&> linalg_slogdet_out(const Tensor& input, Tensor& sign, Tensor& logabsdet) {
-  checkSameDevice("linalg_slogdet", sign, input, "sign");
-  checkSameDevice("linalg_slogdet", logabsdet, input, "logabsdet");
-  checkLinalgCompatibleDtype("linalg_slogdet", sign, input, "sign");
+  checkSameDevice("linalg.slogdet", sign, input, "sign");
+  checkSameDevice("linalg.slogdet", logabsdet, input, "logabsdet");
+  checkLinalgCompatibleDtype("linalg.slogdet", sign, input, "sign");
   ScalarType real_dtype = toValueType(input.scalar_type());
   // logabsdet is always real-valued here
-  checkLinalgCompatibleDtype("linalg_slogdet", logabsdet.scalar_type(), real_dtype, "logabsdet");
+  checkLinalgCompatibleDtype("linalg.slogdet", logabsdet.scalar_type(), real_dtype, "logabsdet");
 
   Tensor sign_tmp, logabsdet_tmp;
   std::tie(sign_tmp, logabsdet_tmp) = at::linalg_slogdet(input);
@@ -301,7 +297,7 @@
   ScalarType t = input.scalar_type();
   TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
               && input.dim() >= 2,
-              "linalg_pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
+              "linalg.pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
               "of float, double, cfloat or cdouble types");
 
   Tensor atol, rtol;
@@ -366,8 +362,8 @@
     const optional<Tensor>& rtol,
     bool hermitian,
     Tensor& result) {
-  checkSameDevice("linalg_pinv", result, input);
-  checkLinalgCompatibleDtype("linalg_pinv", result, input);
+  checkSameDevice("linalg.pinv", result, input);
+  checkLinalgCompatibleDtype("linalg.pinv", result, input);
   Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
   at::native::resize_output(result, result_tmp.sizes());
   result.copy_(result_tmp);
@@ -380,8 +376,8 @@
     optional<double> rtol,
     bool hermitian,
     Tensor& result) {
-  checkSameDevice("linalg_pinv", result, input);
-  checkLinalgCompatibleDtype("linalg_pinv", result, input);
+  checkSameDevice("linalg.pinv", result, input);
+  checkLinalgCompatibleDtype("linalg.pinv", result, input);
   Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
   at::native::resize_output(result, result_tmp.sizes());
   result.copy_(result_tmp);
@@ -389,8 +385,8 @@
 }
 
 Tensor& linalg_pinv_out(const Tensor& input, const Tensor& rcond, bool hermitian, Tensor& result) {
-  checkSameDevice("linalg_pinv", result, input);
-  checkLinalgCompatibleDtype("linalg_pinv", result, input);
+  checkSameDevice("linalg.pinv", result, input);
+  checkLinalgCompatibleDtype("linalg.pinv", result, input);
 
   Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian);
   at::native::resize_output(result, result_tmp.sizes());
@@ -427,7 +423,7 @@
     c10::optional<Tensor> _out) {
   auto out = _out.value_or(Tensor());
 
-  squareCheckInputs(self);
+  squareCheckInputs(self, "linalg.matrix_power");
   if (_out.has_value()) {
     checkSameDevice("matrix_power", out, self);
     checkLinalgCompatibleDtype("matrix_power", out, self);
@@ -2251,9 +2247,8 @@
 // Mathematics 2019, 7, 1174.
 //
 Tensor linalg_matrix_exp(const Tensor& a) {
-  squareCheckInputs(a);
-  TORCH_CHECK((at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())),
-              "Expected a floating point or complex tensor as input. Got: ", a.scalar_type());
+  squareCheckInputs(a, "linalg.matrix_exp");
+  checkFloatingOrComplex(a, "matrix_exp");
 
   NoTF32Guard disable_tf32;
 
@@ -2524,11 +2519,8 @@
       "but got ", opt_dtype.value());
   }
 
+  checkFloatingOrComplex(self, "linalg.vector_norm");
   ScalarType in_dtype = opt_dtype.value_or(self.scalar_type());
-  TORCH_CHECK(
-      at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
-      "linalg.vector_norm only supports floating point and complex dtypes, but got: ",
-      toString(in_dtype));
 
   IntArrayRef dim = opt_dim.value_or(IntArrayRef{});
 
@@ -2714,21 +2706,20 @@
     Scalar* ord = c10::get_if<Scalar>(&ord_variant);
     double abs_ord = std::abs(ord->toDouble());
     TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY,
-      "linalg_cond got an invalid norm type: ", ord->toDouble());
+      "linalg.cond got an invalid norm type: ", ord->toDouble());
   } else if (ord_variant.index() == 1) {
     c10::string_view* ord = c10::get_if<c10::string_view>(&ord_variant);
     TORCH_CHECK(*ord == "fro" || *ord == "nuc",
-      "linalg_cond got an invalid norm type: ", *ord);
+      "linalg.cond got an invalid norm type: ", *ord);
   } else {
     TORCH_CHECK(false,
-      "linalg_cond: something went wrong while checking the norm type");
+      "linalg.cond: something went wrong while checking the norm type");
   }
 }
 
 // Numerical or None norms
 Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) {
-  TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ",
-    self.dim(), " dimensions.");
+  TORCH_CHECK(self.dim() >= 2, "linalg.cond: The input tensor must have at least 2 dimensions.");
 
   // The default case is using 2-norm
   Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
@@ -2759,19 +2750,18 @@
   }
 
   // ord == ±1 ord == ±inf
-  // since at::inverse is used in the implementation, self has to be a tensor consisting of square matrices
-  // the same check as squareCheckInputs(self) but with a slightly more informative error message
-  TORCH_CHECK(self.size(-1) == self.size(-2),
-              "linalg_cond with ±1 or ±inf norm types only supports square matrices or batches of square matrices "
-              "but got ", self.size(-1), " by ", self.size(-2), " matrices");
-
+  if (ord.isFloatingPoint()) { // ord == ±1
+    squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<double>()) + ")").c_str());
+  } else { // ord == ±inf
+    squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<int64_t>()) + ")").c_str());
+  }
   return _linalg_cond_helper(self, ord_variant);
 }
 
 Tensor& linalg_cond_out(const Tensor& self, const optional<Scalar>& opt_ord, Tensor& result) {
-  checkSameDevice("linalg_cond", result, self);
+  checkSameDevice("linalg.cond", result, self);
   ScalarType real_dtype = toValueType(self.scalar_type());
-  checkLinalgCompatibleDtype("linalg_cond", result.scalar_type(), real_dtype);
+  checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
 
   Tensor result_tmp = at::linalg_cond(self, opt_ord);
   at::native::resize_output(result, result_tmp.sizes());
@@ -2781,13 +2771,7 @@
 
 // Frobenius or nuclear norms
 Tensor linalg_cond(const Tensor& self, c10::string_view ord) {
-  // the same checks as squareCheckInputs(self) but with a slightly more informative error message
-  TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ",
-    self.dim(), " dimensions.");
-  TORCH_CHECK(self.size(-1) == self.size(-2),
-              "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices "
-              "but got ", self.size(-1), " by ", self.size(-2), " matrices");
-
+  squareCheckInputs(self, ("linalg.cond(ord=" + std::string(ord) + ")").c_str());
   c10::variant<Scalar, c10::string_view> ord_variant = ord;
   _linalg_cond_check_ord(ord_variant);
 
@@ -2809,9 +2793,9 @@
 
 // TODO: implement _out variant avoiding copy and using already allocated storage directly
 Tensor& linalg_cond_out(const Tensor& self, c10::string_view ord, Tensor& result) {
-  checkSameDevice("linalg_cond", result, self);
+  checkSameDevice("linalg.cond", result, self);
   ScalarType real_dtype = toValueType(self.scalar_type());
-  checkLinalgCompatibleDtype("linalg_cond", result.scalar_type(), real_dtype);
+  checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
 
   Tensor result_tmp = at::linalg_cond(self, ord);
   at::native::resize_output(result, result_tmp.sizes());
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index c495fc8..e5a7b92 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -225,13 +225,19 @@
 }
 
 // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
-static inline void squareCheckInputs(const Tensor& self) {
-  TORCH_CHECK(self.dim() >= 2, "Tensor of matrices must have at least 2 dimensions. ");
+static inline void squareCheckInputs(const Tensor& self, const char* const f_name) {
+  TORCH_CHECK(self.dim() >= 2, f_name, ": The input tensor must have at least 2 dimensions.");
   TORCH_CHECK(self.size(-1) == self.size(-2),
-              "A must be batches of square matrices, "
+              f_name,
+              ": A must be batches of square matrices, "
               "but they are ", self.size(-1), " by ", self.size(-2), " matrices");
 }
 
+static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name) {
+  TORCH_CHECK((at::isFloatingType(t.scalar_type()) || at::isComplexType(t.scalar_type())),
+              f_name, ": Expected a floating point or complex tensor as input. Got ", toString(t.scalar_type()));
+}
+
 /*
  * Given a info int, obtained after a single operation, this function check if the computation
  * has been successful (info = 0) or not, and report in case of the latter.
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 826b2cf..44b1822 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -1661,14 +1661,14 @@
         # cond expects the input to be at least 2-dimensional
         a = torch.ones(3, dtype=dtype, device=device)
         for p in norm_types:
-            with self.assertRaisesRegex(RuntimeError, r'supports matrices or batches of matrices'):
+            with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'):
                 torch.linalg.cond(a, p)
 
         # for some norm types cond expects the input to be square
         a = torch.ones(3, 2, dtype=dtype, device=device)
         norm_types = [1, -1, inf, -inf, 'fro', 'nuc']
         for p in norm_types:
-            with self.assertRaisesRegex(RuntimeError, r'supports square matrices or batches of square matrices'):
+            with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
                 torch.linalg.cond(a, p)
 
         # if non-empty out tensor with wrong shape is passed a warning is given
@@ -1714,7 +1714,7 @@
         # check invalid norm type
         a = torch.ones(3, 3, dtype=dtype, device=device)
         for p in ['wrong_norm', 5]:
-            with self.assertRaisesRegex(RuntimeError, f"linalg_cond got an invalid norm type: {p}"):
+            with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"):
                 torch.linalg.cond(a, p)
 
     # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments