Improve torch.linalg.qr (#50046)

Summary:
This is a follow up of PR https://github.com/pytorch/pytorch/issues/47764 to fix the remaining details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50046

Reviewed By: zou3519

Differential Revision: D25825557

Pulled By: mruberry

fbshipit-source-id: b8e335e02265e73484a99b0189e4cc042828e0a9
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index f0b36d0..146275e 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -1011,13 +1011,13 @@
 
 std::tuple<Tensor,Tensor> linalg_qr(const Tensor& self, std::string mode) {
   TORCH_CHECK(self.dim() >= 2,
-              "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
+              "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
   return at::_linalg_qr_helper(self, mode);
 }
 
 std::tuple<Tensor&,Tensor&> linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) {
   TORCH_CHECK(self.dim() >= 2,
-              "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
+              "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
   Tensor Q_tmp, R_tmp;
   std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode);
   at::native::resize_output(Q, Q_tmp.sizes());
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index e97637d..4322c4c 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -206,7 +206,8 @@
     compute_q = false;
     reduced = true; // this is actually irrelevant in this mode
   } else {
-    TORCH_CHECK(false, "Unrecognized mode '", mode, "'");
+      TORCH_CHECK(false, "qr received unrecognized mode '", mode,
+                  "' but expected one of 'reduced' (default), 'r', or 'complete'");
   }
   return std::make_tuple(compute_q, reduced);
 }
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 9f59252..fb4394b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4941,17 +4941,6 @@
                                          return_counts=return_counts)
                     assert_only_first_requires_grad(res)
 
-    def test_linalg_qr_r(self):
-        # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
-        # without 'q' you cannot compute the backward pass. Check that
-        # linalg_qr_backward complains cleanly in that case.
-        inp = torch.randn((5, 7), requires_grad=True)
-        q, r = torch.linalg.qr(inp, mode='r')
-        assert q.shape == (0,)  # empty tensor
-        b = torch.sum(r)
-        with self.assertRaisesRegex(RuntimeError,
-                                    "linalg_qr_backward: cannot compute backward"):
-            b.backward()
 
 def index_perm_variable(shape, max_indices):
     if not isinstance(shape, tuple):
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 5f3a128..b39d64e 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -3059,14 +3059,36 @@
             exp_r = np.linalg.qr(np_t, mode='r')
             q, r = torch.linalg.qr(t, mode='r')
             # check that q is empty
-            assert q.shape == (0,)
-            assert q.dtype == t.dtype
-            assert q.device == t.device
+            self.assertEqual(q.shape, (0,))
+            self.assertEqual(q.dtype, t.dtype)
+            self.assertEqual(q.device, t.device)
             # check r
             self.assertEqual(r, exp_r)
 
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
+    @dtypes(torch.float)
+    def test_linalg_qr_autograd_errors(self, device, dtype):
+        # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
+        # without 'q' you cannot compute the backward pass. Check that
+        # linalg_qr_backward complains cleanly in that case.
+        inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
+        q, r = torch.linalg.qr(inp, mode='r')
+        self.assertEqual(q.shape, (0,))  # empty tensor
+        b = torch.sum(r)
+        with self.assertRaisesRegex(RuntimeError,
+                                    "The derivative of qr is not implemented when mode='r'"):
+            b.backward()
+        #
+        inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
+        q, r = torch.linalg.qr(inp, mode='complete')
+        b = torch.sum(r)
+        with self.assertRaisesRegex(RuntimeError,
+                                    "The derivative of qr is not implemented when mode='complete' and nrows > ncols"):
+            b.backward()
+
+    @skipCUDAIfNoMagma
+    @skipCPUIfNoLapack
     @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
     def test_qr_batched(self, device, dtype):
         """
@@ -3078,10 +3100,17 @@
             all_q = []
             all_r = []
             for matrix in a:
-                q, r = np.linalg.qr(matrix, mode=mode)
-                all_q.append(q)
-                all_r.append(r)
-            return np.array(all_q), np.array(all_r)
+                result = np.linalg.qr(matrix, mode=mode)
+                if mode == 'r':
+                    all_r.append(result)
+                else:
+                    q, r = result
+                    all_q.append(q)
+                    all_r.append(r)
+            if mode == 'r':
+                return np.array(all_r)
+            else:
+                return np.array(all_q), np.array(all_r)
 
         t = torch.randn((3, 7, 5), device=device, dtype=dtype)
         np_t = t.cpu().numpy()
@@ -3090,6 +3119,15 @@
             q, r = torch.linalg.qr(t, mode=mode)
             self.assertEqual(q, exp_q)
             self.assertEqual(r, exp_r)
+        # for mode='r' we need a special logic because numpy returns only r
+        exp_r = np_qr_batched(np_t, mode='r')
+        q, r = torch.linalg.qr(t, mode='r')
+        # check that q is empty
+        self.assertEqual(q.shape, (0,))
+        self.assertEqual(q.dtype, t.dtype)
+        self.assertEqual(q.device, t.device)
+        # check r
+        self.assertEqual(r, exp_r)
 
     @skipCUDAIfNoMagma
     @skipCPUIfNoLapack
@@ -3112,11 +3150,22 @@
                 out = (torch.empty((0), dtype=dtype, device=device),
                        torch.empty((0), dtype=dtype, device=device))
                 q2, r2 = torch.linalg.qr(t, mode=mode, out=out)
-                assert q2 is out[0]
-                assert r2 is out[1]
+                self.assertIs(q2, out[0])
+                self.assertIs(r2, out[1])
                 self.assertEqual(q2, q)
                 self.assertEqual(r2, r)
 
+    @skipCUDAIfNoMagma
+    @skipCPUIfNoLapack
+    @dtypes(torch.float)
+    def test_qr_error_cases(self, device, dtype):
+        t1 = torch.randn(5, device=device, dtype=dtype)
+        with self.assertRaisesRegex(RuntimeError, 'qr input should have at least 2 dimensions, but has 1 dimensions instead'):
+            torch.linalg.qr(t1)
+        t2 = torch.randn((5, 7), device=device, dtype=dtype)
+        with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
+            torch.linalg.qr(t2, mode='hello')
+
     @dtypes(torch.double, torch.cdouble)
     def test_einsum(self, device, dtype):
         def check(equation, *operands):
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index d204afd..b08db84 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -6676,11 +6676,10 @@
 If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization.
 Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization.
 
-.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :meth:`~torch.linalg.qr`
-             instead, which provides a better compatibility with
-             ``numpy.linalg.qr``.
+.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :func:`~torch.linalg.qr`
+             instead.
 
-             **Differences with** ``torch.linalg.`` :meth:`~torch.linalg.qr`:
+             **Differences with** ``torch.linalg.qr``:
 
              * ``torch.linalg.qr`` takes a string parameter ``mode`` instead of ``some``:
 
@@ -6698,21 +6697,21 @@
 
 .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs,
           and may produce different (valid) decompositions on different device types
-          and different platforms, depending on the precise version of the
-          underlying library.
+          or different platforms.
 
 Args:
     input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
                 batch dimensions consisting of matrices of dimension :math:`m \times n`.
     some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for
-                complete QR decomposition.
+                complete QR decomposition. If `k = min(m, n)` then:
+
+                  * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default)
+
+                  * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n)
 
 Keyword args:
-    out (tuple, optional): tuple of `Q` and `R` tensors
-                satisfying :code:`input = torch.matmul(Q, R)`.
-                The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)`
-                respectively, where :math:`k = \min(m, n)` if :attr:`some:` is ``True`` and
-                :math:`k = m` otherwise.
+    out (tuple, optional): tuple of `Q` and `R` tensors.
+                The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above.
 
 Example::
 
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 6558295..79d195d 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -2078,7 +2078,7 @@
                           std::string mode, const Tensor& q, const Tensor& r){
   bool compute_q, reduced;
   std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
-  TORCH_CHECK(compute_q, "linalg_qr_backward: cannot compute backward if mode='r'. "
+  TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. "
                          "Please use torch.linalg.qr(..., mode='reduced')");
 
   auto square_deep_case_backward = [](const Tensor& grad_Q,
@@ -2145,7 +2145,7 @@
 
   TORCH_CHECK(
       ((m <= n && (!reduced)) || reduced),
-      "The derivative is not implemented when nrows > ncols and complete QR. ");
+      "The derivative of qr is not implemented when mode='complete' and nrows > ncols.");
 
   auto grad_Q = grads[0];
   auto grad_R = grads[1];
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 4c724b0..de5fcb5 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -731,15 +731,15 @@
 .. note::
           Backpropagation is not supported for ``mode='r'``. Use ``mode='reduced'`` instead.
 
-          If you plan to backpropagate through QR, note that the current backward implementation
-          is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))`
-          columns of :attr:`input` are linearly independent.
-          This behavior may change in the future.
+          Backpropagation is also not supported if the first
+          :math:`\min(input.size(-1), input.size(-2))` columns of any matrix
+          in :attr:`input` are not linearly independent. While no error will
+          be thrown when this occurs the values of the "gradient" produced may
+          be anything. This behavior may change in the future.
 
 .. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs,
           and may produce different (valid) decompositions on different device types
-          and different platforms, depending on the precise version of the
-          underlying library.
+          or different platforms.
 
 Args:
     input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
@@ -753,11 +753,8 @@
           * ``'r'``: computes only `R`; returns `(Q, R)` where `Q` is empty and `R` has dimensions (k, n)
 
 Keyword args:
-    out (tuple, optional): tuple of `Q` and `R` tensors
-                satisfying :code:`input = torch.matmul(Q, R)`.
-                The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)`
-                respectively, where :math:`k = \min(m, n)` if :attr:`mode` is `'reduced'` and
-                :math:`k = m` if :attr:`mode` is `'complete'`.
+    out (tuple, optional): tuple of `Q` and `R` tensors.
+                The dimensions of `Q` and `R` are detailed in the description of :attr:`mode` above.
 
 Example::
 
@@ -779,6 +776,11 @@
     tensor([[ 1.,  0.,  0.],
             [ 0.,  1., -0.],
             [ 0., -0.,  1.]])
+    >>> q2, r2 = torch.linalg.qr(a, mode='r')
+    >>> q2
+    tensor([])
+    >>> torch.equal(r, r2)
+    True
     >>> a = torch.randn(3, 4, 5)
     >>> q, r = torch.linalg.qr(a, mode='complete')
     >>> torch.allclose(torch.matmul(q, r), a)