Better and more consistent error messages in torch.linalg (#62734)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62734
Following https://github.com/pytorch/pytorch/pull/62715#discussion_r682610788
- squareCheckInputs takes a string with the name of the function
- We reuse more functions when checking the inputs
The state of the errors in torch.linalg is far from great though. We
leave a more comprehensive clean-up for the future.
cc jianyuh nikitaved pearu mruberry walterddr IvanYashchuk xwang233 Lezcano
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D31823230
Pulled By: mruberry
fbshipit-source-id: eccd531f10d590eb5f9d04a957b7cdcb31c72ea4
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index b68aa1c..9e08bac 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -801,15 +801,14 @@
// Solves a system of linear equations matmul(input, x) = other in-place
// LAPACK/MAGMA error codes are saved in 'infos' tensor, they are not checked here
static Tensor& linalg_solve_out_info(Tensor& result, Tensor& infos, const Tensor& input, const Tensor& other) {
- checkSameDevice("linalg_solve", result, input);
- checkSameDevice("linalg_solve", other, input, "other");
- checkLinalgCompatibleDtype("linalg_solve", result, input);
+ checkSameDevice("linalg.solve", result, input);
+ checkSameDevice("linalg.solve", other, input, "other");
+ checkLinalgCompatibleDtype("linalg.solve", result, input);
TORCH_CHECK(input.scalar_type() == other.scalar_type(),
"input dtype ", input.scalar_type(), " does not match other dtype ", other.scalar_type());
- TORCH_CHECK(input.dim() >= 2,
- "input should have at least 2 dimensions, but has ", input.dim(), " dimensions instead");
+ squareCheckInputs(input, "linalg.solve");
TORCH_CHECK(other.dim() >= 1,
"other should have at least 1 dimension, but has ", other.dim(), " dimensions instead");
@@ -856,7 +855,7 @@
// _linalg_broadcast_batch_dims also includes linearSolveCheckInputs
// it checks for squareness of 'input' and 'shape' compatibility of 'other' and 'input'
Tensor other_broadcasted, input_broadcasted;
- std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg_solve");
+ std::tie(other_broadcasted, input_broadcasted) = _linalg_broadcast_batch_dims(other_, input, "linalg.solve");
auto squeezed_other_broadcasted = at::squeeze(other_broadcasted, -1);
auto squeezed_result_shape = squeezed_other_broadcasted.sizes();
@@ -928,9 +927,9 @@
// batchCheckErrors(Tensor, char*) calls 'infos = infos.to(kCPU)'
bool vector_case = linalg_solve_is_vector_rhs(input, other);
if (vector_case ? result.dim() > 1 : result.dim() > 2) {
- batchCheckErrors(infos, "linalg_solve");
+ batchCheckErrors(infos, "linalg.solve");
} else {
- singleCheckErrors(infos.item().toInt(), "linalg_solve");
+ singleCheckErrors(infos.item().toInt(), "linalg.solve");
}
return result;
@@ -1015,7 +1014,7 @@
if (self.numel() == 0) {
return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
- squareCheckInputs(self);
+ squareCheckInputs(self, "inverse");
return at::_inverse_helper(self);
}
@@ -1042,9 +1041,9 @@
// Computes the inverse matrix of 'input', it is is saved to 'result' in-place
// LAPACK/MAGMA/cuSOLVER error codes are saved in 'infos' tensors, they are not checked here
static Tensor& linalg_inv_out_info(Tensor& result, Tensor& infos_lu, Tensor& infos_getri, const Tensor& input) {
- squareCheckInputs(input);
- checkSameDevice("linalg_inv", result, input);
- checkLinalgCompatibleDtype("linalg_inv", result, input);
+ squareCheckInputs(input, "linalg.inv");
+ checkSameDevice("linalg.inv", result, input);
+ checkLinalgCompatibleDtype("linalg.inv", result, input);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos_lu.scalar_type() == kInt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos_getri.scalar_type() == kInt);
@@ -1137,11 +1136,11 @@
// Now check LAPACK/MAGMA/cuSOLVER error codes
if (result.dim() > 2) {
- batchCheckErrors(infos_lu, "linalg_inv_lu");
- batchCheckErrors(infos_getri, "linalg_inv_getri");
+ batchCheckErrors(infos_lu, "linalg.inv");
+ batchCheckErrors(infos_getri, "linalg.inv");
} else {
- singleCheckErrors(infos_lu.item().toInt(), "linalg_inv_lu");
- singleCheckErrors(infos_getri.item().toInt(), "linalg_inv_getri");
+ singleCheckErrors(infos_lu.item().toInt(), "linalg.inv");
+ singleCheckErrors(infos_getri.item().toInt(), "linalg.inv");
}
return result;
@@ -1164,7 +1163,7 @@
}
std::tuple<Tensor&, Tensor&> linalg_inv_ex_out(const Tensor& input, bool check_errors, Tensor& inverse, Tensor& info) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.inv_ex");
ScalarType info_output_type = ScalarType::Int;
TORCH_CHECK(
info.scalar_type() == info_output_type,
@@ -1191,7 +1190,7 @@
}
std::tuple<Tensor, Tensor> linalg_inv_ex(const Tensor& input, bool check_errors) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.inv_ex");
Tensor inverse = at::empty(input.sizes(), input.options(), MemoryFormat::Contiguous);
inverse.transpose_(-2, -1); // make `inverse` tensor with batched column major format
auto info_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2); // input.shape[:-2]
@@ -1287,7 +1286,7 @@
if (self.numel() == 0) {
return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
- squareCheckInputs(self);
+ squareCheckInputs(self, "cholesky");
auto raw_cholesky_output = cloneBatchedColumnMajor(self);
auto info_shape = IntArrayRef(
@@ -1375,7 +1374,7 @@
}
std::tuple<Tensor&, Tensor&> linalg_cholesky_ex_out(const Tensor& input, bool upper, bool check_errors, Tensor& L, Tensor& info) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.cholesky_ex");
checkSameDevice("torch.linalg.cholesky_ex", L, input, "L");
checkLinalgCompatibleDtype("torch.linalg.cholesky_ex", L, input, "L");
checkSameDevice("torch.linalg.cholesky_ex", info, input, "info");
@@ -1510,7 +1509,7 @@
}
Tensor& cholesky_inverse_out(const Tensor &input, bool upper, Tensor &result) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "cholesky_inverse");
checkSameDevice("cholesky_inverse", result, input);
checkLinalgCompatibleDtype("cholesky_inverse", result, input);
@@ -2032,14 +2031,7 @@
tau.scalar_type(),
" does not match input dtype ",
input.scalar_type());
- TORCH_CHECK(
- input.device() == tau.device(),
- "torch.linalg.householder_product: Expected input and tau to be on the same device, but found input on ",
- input.device(),
- " and tau on ",
- tau.device(),
- " instead.");
-
+ checkSameDevice("torch.linalg.householder_product", tau, input, "tau");
checkSameDevice("torch.linalg.householder_product", result, input);
checkLinalgCompatibleDtype("torch.linalg.householder_product", result, input);
@@ -2312,7 +2304,7 @@
}
std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& input, c10::string_view uplo) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.eigh");
checkUplo(uplo);
ScalarType real_dtype = toValueType(input.scalar_type());
Tensor values = at::empty({0}, input.options().dtype(real_dtype));
@@ -2369,7 +2361,7 @@
ScalarType real_dtype = toValueType(input.scalar_type());
checkLinalgCompatibleDtype("torch.linalg.eigvalsh", result.scalar_type(), real_dtype);
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.eigvalsh");
checkUplo(uplo);
auto expected_result_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
@@ -2498,7 +2490,7 @@
"should be replaced with\n",
"L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L')"
);
- squareCheckInputs(self);
+ squareCheckInputs(self, "linalg.symeig");
return at::_symeig_helper(self, eigenvectors, upper);
}
@@ -2708,7 +2700,7 @@
}
std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values, Tensor& vectors) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.eig");
// unlike NumPy for real-valued inputs the output is always complex-valued
checkLinalgCompatibleDtype("torch.linalg.eig", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
@@ -2805,7 +2797,7 @@
}
Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
- squareCheckInputs(input);
+ squareCheckInputs(input, "linalg.eigvals");
// unlike NumPy for real-valued inputs the output is always complex-valued
checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
@@ -3094,7 +3086,6 @@
Tensor Vh = V.mH();
return std::make_tuple(U, S, Vh);
-
}
static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst) {
@@ -3107,11 +3098,11 @@
checkSameDevice("svd", U, self, "U");
checkSameDevice("svd", S, self, "S");
checkSameDevice("svd", Vh, self, "Vh");
- checkLinalgCompatibleDtype("linalg_svd", U, self, "U");
- checkLinalgCompatibleDtype("linalg_svd", Vh, self, "Vh");
+ checkLinalgCompatibleDtype("linalg.svd", U, self, "U");
+ checkLinalgCompatibleDtype("linalg.svd", Vh, self, "Vh");
// singular values are always real-valued here
ScalarType real_dtype = toValueType(self.scalar_type());
- checkLinalgCompatibleDtype("linalg_svd", S.scalar_type(), real_dtype, "S");
+ checkLinalgCompatibleDtype("linalg.svd", S.scalar_type(), real_dtype, "S");
Tensor U_tmp, S_tmp, Vh_tmp;
std::tie(U_tmp, S_tmp, Vh_tmp) = at::native::linalg_svd(self, full_matrices);
svd_resize_and_copy("U", U_tmp, U);
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index ce7d176..03e14b7 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -24,6 +24,7 @@
#include <functional>
#include <limits>
#include <numeric>
+#include <string>
namespace at {
namespace meta {
@@ -153,33 +154,28 @@
return at::linalg_det(self);
}
-Tensor& linalg_det_out(const Tensor& self, Tensor& out) {
- checkSameDevice("torch.linalg.det", out, self, "out");
- checkLinalgCompatibleDtype("torch.linalg.det", out, self, "out");
- squareCheckInputs(self);
- TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
- "Expected a floating point or complex tensor as input");
-
- IntArrayRef out_sizes(self.sizes().data(), self.dim() - 2);
- at::native::resize_output(out, out_sizes);
-
- auto det = std::get<0>(at::native::_det_lu_based_helper(self));
- out.copy_(det);
- return out;
-}
-
Tensor linalg_det(const Tensor& self) {
- squareCheckInputs(self);
- TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
- "Expected a floating point or complex tensor as input");
+ squareCheckInputs(self, "linalg.det");
+ checkFloatingOrComplex(self, "linalg.det");
return std::get<0>(at::_det_lu_based_helper(self));
}
+Tensor& linalg_det_out(const Tensor& self, Tensor& out) {
+ checkSameDevice("torch.linalg.det", out, self, "out");
+ checkLinalgCompatibleDtype("torch.linalg.det", out, self, "out");
+
+ IntArrayRef out_sizes(self.sizes().data(), self.dim() - 2);
+ at::native::resize_output(out, out_sizes);
+
+ auto det = at::native::linalg_det(self);
+ out.copy_(det);
+ return out;
+}
+
Tensor logdet(const Tensor& self) {
- squareCheckInputs(self);
- TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())),
- "Expected a floating point tensor as input");
+ squareCheckInputs(self, "logdet");
+ checkFloatingOrComplex(self, "logdet");
c10::ExclusivelyOwned<Tensor> det_P, diag_U;
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
@@ -200,10 +196,10 @@
}
std::tuple<Tensor, Tensor> linalg_slogdet(const Tensor& self) {
- squareCheckInputs(self);
+ squareCheckInputs(self, "linalg.slogdet");
ScalarType t = self.scalar_type();
TORCH_CHECK(t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble,
- "linalg_slogdet: expected a tensor of float, double, cfloat or cdouble types but got ", t);
+ "linalg.slogdet: expected a tensor of float, double, cfloat or cdouble types but got ", t);
c10::ExclusivelyOwned<Tensor> det_P, diag_U;
std::tie(det_P, diag_U) = _lu_det_P_diag_U(self);
@@ -218,12 +214,12 @@
// TODO: implement _out variant avoiding copy and using already allocated storage directly
std::tuple<Tensor&, Tensor&> linalg_slogdet_out(const Tensor& input, Tensor& sign, Tensor& logabsdet) {
- checkSameDevice("linalg_slogdet", sign, input, "sign");
- checkSameDevice("linalg_slogdet", logabsdet, input, "logabsdet");
- checkLinalgCompatibleDtype("linalg_slogdet", sign, input, "sign");
+ checkSameDevice("linalg.slogdet", sign, input, "sign");
+ checkSameDevice("linalg.slogdet", logabsdet, input, "logabsdet");
+ checkLinalgCompatibleDtype("linalg.slogdet", sign, input, "sign");
ScalarType real_dtype = toValueType(input.scalar_type());
// logabsdet is always real-valued here
- checkLinalgCompatibleDtype("linalg_slogdet", logabsdet.scalar_type(), real_dtype, "logabsdet");
+ checkLinalgCompatibleDtype("linalg.slogdet", logabsdet.scalar_type(), real_dtype, "logabsdet");
Tensor sign_tmp, logabsdet_tmp;
std::tie(sign_tmp, logabsdet_tmp) = at::linalg_slogdet(input);
@@ -301,7 +297,7 @@
ScalarType t = input.scalar_type();
TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
&& input.dim() >= 2,
- "linalg_pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
+ "linalg.pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
"of float, double, cfloat or cdouble types");
Tensor atol, rtol;
@@ -366,8 +362,8 @@
const optional<Tensor>& rtol,
bool hermitian,
Tensor& result) {
- checkSameDevice("linalg_pinv", result, input);
- checkLinalgCompatibleDtype("linalg_pinv", result, input);
+ checkSameDevice("linalg.pinv", result, input);
+ checkLinalgCompatibleDtype("linalg.pinv", result, input);
Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
@@ -380,8 +376,8 @@
optional<double> rtol,
bool hermitian,
Tensor& result) {
- checkSameDevice("linalg_pinv", result, input);
- checkLinalgCompatibleDtype("linalg_pinv", result, input);
+ checkSameDevice("linalg.pinv", result, input);
+ checkLinalgCompatibleDtype("linalg.pinv", result, input);
Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
@@ -389,8 +385,8 @@
}
Tensor& linalg_pinv_out(const Tensor& input, const Tensor& rcond, bool hermitian, Tensor& result) {
- checkSameDevice("linalg_pinv", result, input);
- checkLinalgCompatibleDtype("linalg_pinv", result, input);
+ checkSameDevice("linalg.pinv", result, input);
+ checkLinalgCompatibleDtype("linalg.pinv", result, input);
Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian);
at::native::resize_output(result, result_tmp.sizes());
@@ -427,7 +423,7 @@
c10::optional<Tensor> _out) {
auto out = _out.value_or(Tensor());
- squareCheckInputs(self);
+ squareCheckInputs(self, "linalg.matrix_power");
if (_out.has_value()) {
checkSameDevice("matrix_power", out, self);
checkLinalgCompatibleDtype("matrix_power", out, self);
@@ -2251,9 +2247,8 @@
// Mathematics 2019, 7, 1174.
//
Tensor linalg_matrix_exp(const Tensor& a) {
- squareCheckInputs(a);
- TORCH_CHECK((at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())),
- "Expected a floating point or complex tensor as input. Got: ", a.scalar_type());
+ squareCheckInputs(a, "linalg.matrix_exp");
+ checkFloatingOrComplex(a, "matrix_exp");
NoTF32Guard disable_tf32;
@@ -2524,11 +2519,8 @@
"but got ", opt_dtype.value());
}
+ checkFloatingOrComplex(self, "linalg.vector_norm");
ScalarType in_dtype = opt_dtype.value_or(self.scalar_type());
- TORCH_CHECK(
- at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
- "linalg.vector_norm only supports floating point and complex dtypes, but got: ",
- toString(in_dtype));
IntArrayRef dim = opt_dim.value_or(IntArrayRef{});
@@ -2714,21 +2706,20 @@
Scalar* ord = c10::get_if<Scalar>(&ord_variant);
double abs_ord = std::abs(ord->toDouble());
TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY,
- "linalg_cond got an invalid norm type: ", ord->toDouble());
+ "linalg.cond got an invalid norm type: ", ord->toDouble());
} else if (ord_variant.index() == 1) {
c10::string_view* ord = c10::get_if<c10::string_view>(&ord_variant);
TORCH_CHECK(*ord == "fro" || *ord == "nuc",
- "linalg_cond got an invalid norm type: ", *ord);
+ "linalg.cond got an invalid norm type: ", *ord);
} else {
TORCH_CHECK(false,
- "linalg_cond: something went wrong while checking the norm type");
+ "linalg.cond: something went wrong while checking the norm type");
}
}
// Numerical or None norms
Tensor linalg_cond(const Tensor& self, const optional<Scalar>& opt_ord) {
- TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ",
- self.dim(), " dimensions.");
+ TORCH_CHECK(self.dim() >= 2, "linalg.cond: The input tensor must have at least 2 dimensions.");
// The default case is using 2-norm
Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
@@ -2759,19 +2750,18 @@
}
// ord == ±1 ord == ±inf
- // since at::inverse is used in the implementation, self has to be a tensor consisting of square matrices
- // the same check as squareCheckInputs(self) but with a slightly more informative error message
- TORCH_CHECK(self.size(-1) == self.size(-2),
- "linalg_cond with ±1 or ±inf norm types only supports square matrices or batches of square matrices "
- "but got ", self.size(-1), " by ", self.size(-2), " matrices");
-
+ if (ord.isFloatingPoint()) { // ord == ±1
+ squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<double>()) + ")").c_str());
+ } else { // ord == ±inf
+ squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<int64_t>()) + ")").c_str());
+ }
return _linalg_cond_helper(self, ord_variant);
}
Tensor& linalg_cond_out(const Tensor& self, const optional<Scalar>& opt_ord, Tensor& result) {
- checkSameDevice("linalg_cond", result, self);
+ checkSameDevice("linalg.cond", result, self);
ScalarType real_dtype = toValueType(self.scalar_type());
- checkLinalgCompatibleDtype("linalg_cond", result.scalar_type(), real_dtype);
+ checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
Tensor result_tmp = at::linalg_cond(self, opt_ord);
at::native::resize_output(result, result_tmp.sizes());
@@ -2781,13 +2771,7 @@
// Frobenius or nuclear norms
Tensor linalg_cond(const Tensor& self, c10::string_view ord) {
- // the same checks as squareCheckInputs(self) but with a slightly more informative error message
- TORCH_CHECK(self.dim() >= 2, "linalg_cond only supports matrices or batches of matrices, but got a tensor with ",
- self.dim(), " dimensions.");
- TORCH_CHECK(self.size(-1) == self.size(-2),
- "linalg_cond with frobenius or nuclear norm types only supports square matrices or batches of square matrices "
- "but got ", self.size(-1), " by ", self.size(-2), " matrices");
-
+ squareCheckInputs(self, ("linalg.cond(ord=" + std::string(ord) + ")").c_str());
c10::variant<Scalar, c10::string_view> ord_variant = ord;
_linalg_cond_check_ord(ord_variant);
@@ -2809,9 +2793,9 @@
// TODO: implement _out variant avoiding copy and using already allocated storage directly
Tensor& linalg_cond_out(const Tensor& self, c10::string_view ord, Tensor& result) {
- checkSameDevice("linalg_cond", result, self);
+ checkSameDevice("linalg.cond", result, self);
ScalarType real_dtype = toValueType(self.scalar_type());
- checkLinalgCompatibleDtype("linalg_cond", result.scalar_type(), real_dtype);
+ checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
Tensor result_tmp = at::linalg_cond(self, ord);
at::native::resize_output(result, result_tmp.sizes());
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index c495fc8..e5a7b92 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -225,13 +225,19 @@
}
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
-static inline void squareCheckInputs(const Tensor& self) {
- TORCH_CHECK(self.dim() >= 2, "Tensor of matrices must have at least 2 dimensions. ");
+static inline void squareCheckInputs(const Tensor& self, const char* const f_name) {
+ TORCH_CHECK(self.dim() >= 2, f_name, ": The input tensor must have at least 2 dimensions.");
TORCH_CHECK(self.size(-1) == self.size(-2),
- "A must be batches of square matrices, "
+ f_name,
+ ": A must be batches of square matrices, "
"but they are ", self.size(-1), " by ", self.size(-2), " matrices");
}
+static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name) {
+ TORCH_CHECK((at::isFloatingType(t.scalar_type()) || at::isComplexType(t.scalar_type())),
+ f_name, ": Expected a floating point or complex tensor as input. Got ", toString(t.scalar_type()));
+}
+
/*
* Given a info int, obtained after a single operation, this function check if the computation
* has been successful (info = 0) or not, and report in case of the latter.
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 826b2cf..44b1822 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -1661,14 +1661,14 @@
# cond expects the input to be at least 2-dimensional
a = torch.ones(3, dtype=dtype, device=device)
for p in norm_types:
- with self.assertRaisesRegex(RuntimeError, r'supports matrices or batches of matrices'):
+ with self.assertRaisesRegex(RuntimeError, r'at least 2 dimensions'):
torch.linalg.cond(a, p)
# for some norm types cond expects the input to be square
a = torch.ones(3, 2, dtype=dtype, device=device)
norm_types = [1, -1, inf, -inf, 'fro', 'nuc']
for p in norm_types:
- with self.assertRaisesRegex(RuntimeError, r'supports square matrices or batches of square matrices'):
+ with self.assertRaisesRegex(RuntimeError, r'must be batches of square matrices'):
torch.linalg.cond(a, p)
# if non-empty out tensor with wrong shape is passed a warning is given
@@ -1714,7 +1714,7 @@
# check invalid norm type
a = torch.ones(3, 3, dtype=dtype, device=device)
for p in ['wrong_norm', 5]:
- with self.assertRaisesRegex(RuntimeError, f"linalg_cond got an invalid norm type: {p}"):
+ with self.assertRaisesRegex(RuntimeError, f"linalg.cond got an invalid norm type: {p}"):
torch.linalg.cond(a, p)
# This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments