Updated linalg.lstsq with NumPy compatible kwarg rcond (#54723)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54723

Renamed "cond" -> "rcond" to be NumPy compatible. The default value for
rcond was changed to match non-legacy NumPy behavior.

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27993741

Pulled By: mruberry

fbshipit-source-id: a4baf25aca6a8272f1af2f963600866bfda56fb3
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 81d391c..2f490f8 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -2881,7 +2881,10 @@
     }
     return *this;
   }
-  self_type& set_rcond(double cond) { this->rcond = static_cast<value_t>(cond); return *this; }
+  self_type& set_rcond(double rcond) {
+    this->rcond = static_cast<value_t>(rcond);
+    return *this;
+  }
   self_type& set_rank(Tensor& rank) {
     // only `?gels` is not rank-revealing
     if (LapackLstsqDriverType::Gels != driver_type) {
@@ -3000,7 +3003,7 @@
 #endif
 
 Tensor& _lstsq_helper_cpu(
-    Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double cond, std::string driver_name) {
+    Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double rcond, std::string driver_name) {
 #ifndef USE_LAPACK
   TORCH_CHECK(false, "torch.linalg.lstsq: LAPACK library not found in compilation");
 #else
@@ -3039,7 +3042,7 @@
       .set_b(b)
       .set_ldb(std::max<int64_t>(1, std::max(m, n)))
       .set_jpvt()
-      .set_rcond(cond)
+      .set_rcond(rcond)
       .set_rank(rank)
       .set_s(singular_values)
       .set_infos(infos)
@@ -3331,10 +3334,9 @@
   std::string driver_name = get_default_lstsq_driver(driver, input);
 
   // set default rcond value
-  // TODO: Change this to match non-legacy NumPy behaviour
-  double rcond_value = rcond.has_value() && (rcond.value() > 0)
+  double rcond_value = rcond.has_value()
     ? rcond.value()
-    : _get_epsilon(c10::toValueType(input.scalar_type()));
+    : _get_epsilon(c10::toValueType(input.scalar_type())) * std::max<int64_t>(input.size(-2), input.size(-1));
 
   auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
 
diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
index b2fb23f..75d11cd 100644
--- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
@@ -2735,7 +2735,7 @@
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 Tensor& _lstsq_helper_cuda(
-  Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double cond, std::string driver_name) {
+  Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, const Tensor& a, double rcond, std::string driver_name) {
 #ifndef USE_MAGMA
 TORCH_CHECK(false, "torch.linalg.lstsq: MAGMA library not found in "
     "compilation. Please rebuild with MAGMA.");
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 606109f..0771e3c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8658,19 +8658,19 @@
 - func: det(Tensor self) -> Tensor
   variants: function, method
 
-- func: linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+- func: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
   python_module: linalg
   variants: function
   dispatch:
     CompositeExplicitAutograd: linalg_lstsq
 
-- func: linalg_lstsq.out(Tensor self, Tensor b, float? cond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
+- func: linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
   python_module: linalg
   variants: function
   dispatch:
     CPU, CUDA: linalg_lstsq_out
 
-- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float cond, str driver_name) -> Tensor(a!)
+- func: _lstsq_helper_(Tensor(a!) self, Tensor(b!) rank, Tensor(c!) singular_values, Tensor(d!) infos, Tensor a, float rcond, str driver_name) -> Tensor(a!)
   variants: function
   dispatch:
     CPU: _lstsq_helper_cpu
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 0d2517d..b757480 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -41,6 +41,7 @@
     ("aten::irfft", datetime.date(2021, 1, 31)),
     ("aten::rfft", datetime.date(2021, 1, 31)),
     ("aten::_lstsq_helper", datetime.date(9999, 1, 1)),
+    ("aten::linalg_lstsq", datetime.date(2021, 5, 1)),
     ("aten::_svd_helper", datetime.date(2021, 1, 31)),
     ("aten::_syevd_helper", datetime.date(9999, 1, 1)),
     ("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 30eb4db..03bb899 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -125,7 +125,7 @@
         else:
             drivers = ('gels', None)
 
-        def check_correctness(a, b, sol):
+        def check_solution_correctness(a, b, sol):
             sol2 = a.pinverse() @ b
             self.assertEqual(sol, sol2, atol=1e-5, rtol=1e-5)
 
@@ -196,29 +196,22 @@
                     self.assertEqual(res.singular_values.shape, (0, ))
 
         def check_correctness_scipy(a, b, res, driver, cond):
-            if TEST_SCIPY and driver not in (None, 'gels'):
+            # SciPy provides 3 driver options: gelsd, gelss, gelsy
+            if TEST_SCIPY and driver in ('gelsd', 'gelss', 'gelsy'):
                 import scipy.linalg
 
                 def scipy_ref(a, b):
                     return scipy.linalg.lstsq(a, b, lapack_driver=driver, cond=cond)
                 check_correctness_ref(a, b, res, scipy_ref, driver=driver)
 
-        def check_correctness_numpy(a, b, res, driver, cond):
-            if driver in ('gelsd', 'gelss'):
-                import numpy.linalg
+        def check_correctness_numpy(a, b, res, driver, rcond):
+            # NumPy uses only gelsd routine
+            if driver == 'gelsd':
 
                 def numpy_ref(a, b):
-                    return numpy.linalg.lstsq(a, b, rcond=-1 if cond is None else cond)
+                    return np.linalg.lstsq(a, b, rcond=rcond)
                 check_correctness_ref(a, b, res, numpy_ref)
 
-        def check_ranks(a, ranks, cond=1e-7):
-            ranks2 = torch.matrix_rank(a, tol=cond)
-            self.assertEqual(ranks, ranks2)
-
-        def check_singular_values(a, sv):
-            sv2 = a.svd()[1]
-            self.assertEqual(sv, sv2)
-
         ms = [2 ** i for i in range(5)]
         m_ge_n_sizes = [(m, m // 2) for m in ms] + [(m, m) for m in ms]
         # cases m < n are only supported on CPU
@@ -229,32 +222,44 @@
         # that is why we use `cond=1.0`, the mean to cut roughly half of all
         # the singular values and compare whether torch.linalg.lstsq agrees with
         # SciPy and NumPy.
-        cond = (None, 1.0)
+        # if rcond is True then set value for it based on the used algorithm
+        # rcond == -1 or any other negative value forces LAPACK to use machine precision tolerance
+        rconds = (None, True, -1)
 
-        for batch, matrix_size, driver, cond in itertools.product(batches, matrix_sizes, drivers, cond):
+        for batch, matrix_size, driver, rcond in itertools.product(batches, matrix_sizes, drivers, rconds):
+            # keep the rcond value if it is None or -1, set the driver specific value if it is True
+            if rcond and rcond != -1:
+                if driver in ('gelss', 'gelsd'):
+                    # SVD based algorithm; set to zero roughly half of all the singular values
+                    rcond = 1.0
+                else:
+                    # driver == 'gelsy'
+                    # QR based algorithm; setting the value too high might lead to non-unique solutions and flaky tests
+                    rcond = 1e-4
+
+            # specifying rcond value has no effect for gels driver so no need to run the tests again
+            if driver == 'gels' and rcond is not None:
+                continue
+
             shape = batch + matrix_size
             a = random_well_conditioned_matrix(*shape, dtype=dtype, device=device)
             b = torch.rand(*shape, dtype=dtype, device=device)
 
-            cond = 1e-7
             m = a.size(-2)
             n = a.size(-1)
-            res = torch.linalg.lstsq(a, b, cond=cond, driver=driver)
-            sol = res.solution.narrow(-2, 0, n)
+            res = torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
+            sol = res.solution
 
-            check_correctness_scipy(a, b, res, driver, cond)
-            check_correctness_numpy(a, b, res, driver, cond)
+            # Only checks gelsd, gelss, gelsy drivers
+            check_correctness_scipy(a, b, res, driver, rcond)
 
-            check_correctness(a, b, sol)
-            if self.device_type == 'cpu' and driver != 'gels':
-                # rank-revealing drivers are only available for the CPU.
-                # `gels` is not rank-revealing and is only for full
-                # rank inputs.
-                check_ranks(a, res.rank, cond)
-            if self.device_type == 'cpu' and driver in ('gelsd', 'gelss'):
-                # SVD-based drivers are only available for the CPU.
-                # These are only `gelsd` and `gelss`.
-                check_singular_values(a, res.singular_values)
+            # Only checks gelsd driver
+            check_correctness_numpy(a, b, res, driver, rcond)
+
+            # gels driver is not checked by comparing to NumPy or SciPy implementation
+            # because NumPy and SciPy do not implement this driver
+            if driver == 'gels' and rcond is None:
+                check_solution_correctness(a, b, sol)
 
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index d09f455..3e45f72 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -704,7 +704,7 @@
   self: not_implemented("lstsq")
   A: not_implemented("lstsq")
 
-- name: linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
   self: not_implemented("linalg_lstsq")
   b: not_implemented("linalg_lstsq")
   output_differentiability: [True, True]
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 2f3d53b..72c928b 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -648,7 +648,7 @@
 """)
 
 lstsq = _add_docstr(_linalg.linalg_lstsq, r"""
-torch.linalg.lstsq(A, B, cond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor)
+torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor)
 
 Computes a solution to the least squares problem  of a system of linear equations.
 
@@ -714,16 +714,16 @@
     computations separately.
 
 .. warning::
-    The default value of :attr:`cond` may change in the future.
+    The default value of :attr:`rcond` may change in a future PyTorch release.
     It is therefore recommended to use a fixed value to avoid potential
     breaking changes.
 
 Args:
     A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions.
     B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions.
-    cond (float, optional): used to determine the effective rank of :attr:`A`.
-                            If :attr:`cond`\ `= None`, :attr:`cond` is set to the machine
-                            precision of the dtype of :attr:`A`. Default: `None`.
+    rcond (float, optional): used to determine the effective rank of :attr:`A`.
+                             If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine
+                             precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`.
 
 Keyword args:
     driver (str, optional): name of the LAPACK/MAGMA method to be used.