Updated linalg.lstsq with NumPy compatible kwarg rcond (#54723)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54723
Renamed "cond" -> "rcond" to be NumPy compatible. The default value for
rcond was changed to match non-legacy NumPy behavior.
Test Plan: Imported from OSS
Reviewed By: H-Huang
Differential Revision: D27993741
Pulled By: mruberry
fbshipit-source-id: a4baf25aca6a8272f1af2f963600866bfda56fb3
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 81d391c..2f490f8 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -2881,7 +2881,10 @@
}
return *this;
}
- self_type& set_rcond(double cond) { this->rcond = static_cast<value_t>(cond); return *this; }
+ self_type& set_rcond(double rcond) {
+ this->rcond = static_cast<value_t>(rcond);
+ return *this;
+ }
self_type& set_rank(Tensor& rank) {
// only `?gels` is not rank-revealing
if (LapackLstsqDriverType::Gels != driver_type) {
@@ -3000,7 +3003,7 @@
#endif
Tensor& _lstsq_helper_cpu(
- Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double cond, std::string driver_name) {
+ Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double rcond, std::string driver_name) {
#ifndef USE_LAPACK
TORCH_CHECK(false, "torch.linalg.lstsq: LAPACK library not found in compilation");
#else
@@ -3039,7 +3042,7 @@
.set_b(b)
.set_ldb(std::max<int64_t>(1, std::max(m, n)))
.set_jpvt()
- .set_rcond(cond)
+ .set_rcond(rcond)
.set_rank(rank)
.set_s(singular_values)
.set_infos(infos)
@@ -3331,10 +3334,9 @@
std::string driver_name = get_default_lstsq_driver(driver, input);
// set default rcond value
- // TODO: Change this to match non-legacy NumPy behaviour
- double rcond_value = rcond.has_value() && (rcond.value() > 0)
+ double rcond_value = rcond.has_value()
? rcond.value()
- : _get_epsilon(c10::toValueType(input.scalar_type()));
+ : _get_epsilon(c10::toValueType(input.scalar_type())) * std::max<int64_t>(input.size(-2), input.size(-1));
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
index b2fb23f..75d11cd 100644
--- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
@@ -2735,7 +2735,7 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor& _lstsq_helper_cuda(
- Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double cond, std::string driver_name) {
+ Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double rcond, std::string driver_name) {
#ifndef USE_MAGMA
TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 606109f..0771e3c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8658,19 +8658,19 @@
- func: det(Tensor self) -> Tensor
variants: function, method
-- func: linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+- func: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
python_module: linalg
variants: function
dispatch:
CompositeExplicitAutograd: linalg_lstsq
-- func: linalg_lstsq.out(Tensor self, Tensor b, float? cond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
+- func: linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
python_module: linalg
variants: function
dispatch:
CPU, CUDA: linalg_lstsq_out
-- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float cond, str driver_name) -> Tensor(a!)
+- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float rcond, str driver_name) -> Tensor(a!)
variants: function
dispatch:
CPU: _lstsq_helper_cpu
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 0d2517d..b757480 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -41,6 +41,7 @@
("aten::irfft", datetime.date(2021, 1, 31)),
("aten::rfft", datetime.date(2021, 1, 31)),
("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
+ ("aten::linalg_lstsq", datetime.date(2021, 5, 1)),
("aten::_svd_helper", datetime.date(2021, 1, 31)),
("aten::_syevd_helper", datetime.date(9999, 1, 1)),
("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 30eb4db..03bb899 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -125,7 +125,7 @@
else:
drivers = ('gels', None)
- def check_correctness(a, b, sol):
+ def check_solution_correctness(a, b, sol):
sol2 = a.pinverse() @ b
self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
@@ -196,29 +196,22 @@
self.assertEqual(res.singular_values.shape, (0, ))
def check_correctness_scipy(a, b, res, driver, cond):
- if TEST_SCIPY and driver not in (None, 'gels'):
+ # SciPy provides 3 driver options: gelsd, gelss, gelsy
+ if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
import scipy.linalg
def scipy_ref(a, b):
return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
check_correctness_ref(a, b, res, scipy_ref, driver=driver)
- def check_correctness_numpy(a, b, res, driver, cond):
- if driver in ('gelsd', 'gelss'):
- import numpy.linalg
+ def check_correctness_numpy(a, b, res, driver, rcond):
+ # NumPy uses only gelsd routine
+ if driver == 'gelsd':
def numpy_ref(a, b):
- return numpy.linalg.lstsq(a, b, rcond=-1 if cond is None else cond)
+ return np.linalg.lstsq(a, b, rcond=rcond)
check_correctness_ref(a, b, res, numpy_ref)
- def check_ranks(a, ranks, cond=1e-7):
- ranks2 = torch.matrix_rank(a, tol=cond)
- self.assertEqual(ranks, ranks2)
-
- def check_singular_values(a, sv):
- sv2 = a.svd()[1]
- self.assertEqual(sv, sv2)
-
ms = [2 ** i for i in range(5)]
m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
# cases m < n are only supported on CPU
@@ -229,32 +222,44 @@
# that is why we use `cond=1.0`, the mean to cut roughly half of all
# the singular values and compare whether torch.linalg.lstsq agrees with
# SciPy and NumPy.
- cond = (None, 1.0)
+ # if rcond is True then set value for it based on the used algorithm
+ # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance
+ rconds = (None, True, -1)
- for batch, matrix_size, driver, cond in itertools.product(batches, matrix_sizes, drivers, cond):
+ for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
+ # keep the rcond value if it is None or -1, set the driver specific value if it is True
+ if rcond and rcond != -1:
+ if driver in ('gelss', 'gelsd'):
+ # SVD based algorithm; set to zero roughly half of all the singular values
+ rcond = 1.0
+ else:
+ # driver == 'gelsy'
+ # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests
+ rcond = 1e-4
+
+ # specifying rcond value has no effect for gels driver so no need to run the tests again
+ if driver == 'gels' and rcond is not None:
+ continue
+
shape = batch + matrix_size
a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
b = torch.rand(*shape, dtype=dtype, device=device)
- cond = 1e-7
m = a.size(-2)
n = a.size(-1)
- res = torch.linalg.lstsq(a, b, cond=cond, driver=driver)
- sol = res.solution.narrow(-2, 0, n)
+ res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
+ sol = res.solution
- check_correctness_scipy(a, b, res, driver, cond)
- check_correctness_numpy(a, b, res, driver, cond)
+ # Only checks gelsd, gelss, gelsy drivers
+ check_correctness_scipy(a, b, res, driver, rcond)
- check_correctness(a, b, sol)
- if self.device_type == 'cpu' and driver != 'gels':
- # rank-revealing drivers are only available for the CPU.
- # `gels` is not rank-revealing and is only for full
- # rank inputs.
- check_ranks(a, res.rank, cond)
- if self.device_type == 'cpu' and driver in ('gelsd', 'gelss'):
- # SVD-based drivers are only available for the CPU.
- # These are only `gelsd` and `gelss`.
- check_singular_values(a, res.singular_values)
+ # Only checks gelsd driver
+ check_correctness_numpy(a, b, res, driver, rcond)
+
+ # gels driver is not checked by comparing to NumPy or SciPy implementation
+ # because NumPy and SciPy do not implement this driver
+ if driver == 'gels' and rcond is None:
+ check_solution_correctness(a, b, sol)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index d09f455..3e45f72 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -704,7 +704,7 @@
self: not_implemented("lstsq")
A: not_implemented("lstsq")
-- name: linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
self: not_implemented("linalg_lstsq")
b: not_implemented("linalg_lstsq")
output_differentiability: [True, True]
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 2f3d53b..72c928b 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -648,7 +648,7 @@
""")
lstsq = _add_docstr(_linalg.linalg_lstsq, r"""
-torch.linalg.lstsq(A, B, cond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor)
+torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor)
Computes a solution to the least squares problem of a system of linear equations.
@@ -714,16 +714,16 @@
computations separately.
.. warning::
- The default value of :attr:`cond` may change in the future.
+ The default value of :attr:`rcond` may change in a future PyTorch release.
It is therefore recommended to use a fixed value to avoid potential
breaking changes.
Args:
A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions.
B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions.
- cond (float, optional): used to determine the effective rank of :attr:`A`.
- If :attr:`cond`\ `= None`, :attr:`cond` is set to the machine
- precision of the dtype of :attr:`A`. Default: `None`.
+ rcond (float, optional): used to determine the effective rank of :attr:`A`.
+ If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine
+ precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`.
Keyword args:
driver (str, optional): name of the LAPACK/MAGMA method to be used.