Improve torch.linalg.qr (#50046)
Summary:
This is a follow up of PR https://github.com/pytorch/pytorch/issues/47764 to fix the remaining details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50046
Reviewed By: zou3519
Differential Revision: D25825557
Pulled By: mruberry
fbshipit-source-id: b8e335e02265e73484a99b0189e4cc042828e0a9
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index f0b36d0..146275e 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -1011,13 +1011,13 @@
std::tuple<Tensor,Tensor> linalg_qr(const Tensor& self, std::string mode) {
TORCH_CHECK(self.dim() >= 2,
- "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
+ "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
return at::_linalg_qr_helper(self, mode);
}
std::tuple<Tensor&,Tensor&> linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) {
TORCH_CHECK(self.dim() >= 2,
- "self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
+ "qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
Tensor Q_tmp, R_tmp;
std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode);
at::native::resize_output(Q, Q_tmp.sizes());
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index e97637d..4322c4c 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -206,7 +206,8 @@
compute_q = false;
reduced = true; // this is actually irrelevant in this mode
} else {
- TORCH_CHECK(false, "Unrecognized mode '", mode, "'");
+ TORCH_CHECK(false, "qr received unrecognized mode '", mode,
+ "' but expected one of 'reduced' (default), 'r', or 'complete'");
}
return std::make_tuple(compute_q, reduced);
}
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 9f59252..fb4394b 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4941,17 +4941,6 @@
return_counts=return_counts)
assert_only_first_requires_grad(res)
- def test_linalg_qr_r(self):
- # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
- # without 'q' you cannot compute the backward pass. Check that
- # linalg_qr_backward complains cleanly in that case.
- inp = torch.randn((5, 7), requires_grad=True)
- q, r = torch.linalg.qr(inp, mode='r')
- assert q.shape == (0,) # empty tensor
- b = torch.sum(r)
- with self.assertRaisesRegex(RuntimeError,
- "linalg_qr_backward: cannot compute backward"):
- b.backward()
def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple):
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 5f3a128..b39d64e 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -3059,14 +3059,36 @@
exp_r = np.linalg.qr(np_t, mode='r')
q, r = torch.linalg.qr(t, mode='r')
# check that q is empty
- assert q.shape == (0,)
- assert q.dtype == t.dtype
- assert q.device == t.device
+ self.assertEqual(q.shape, (0,))
+ self.assertEqual(q.dtype, t.dtype)
+ self.assertEqual(q.device, t.device)
# check r
self.assertEqual(r, exp_r)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
+ @dtypes(torch.float)
+ def test_linalg_qr_autograd_errors(self, device, dtype):
+ # torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
+ # without 'q' you cannot compute the backward pass. Check that
+ # linalg_qr_backward complains cleanly in that case.
+ inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
+ q, r = torch.linalg.qr(inp, mode='r')
+ self.assertEqual(q.shape, (0,)) # empty tensor
+ b = torch.sum(r)
+ with self.assertRaisesRegex(RuntimeError,
+ "The derivative of qr is not implemented when mode='r'"):
+ b.backward()
+ #
+ inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
+ q, r = torch.linalg.qr(inp, mode='complete')
+ b = torch.sum(r)
+ with self.assertRaisesRegex(RuntimeError,
+ "The derivative of qr is not implemented when mode='complete' and nrows > ncols"):
+ b.backward()
+
+ @skipCUDAIfNoMagma
+ @skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_qr_batched(self, device, dtype):
"""
@@ -3078,10 +3100,17 @@
all_q = []
all_r = []
for matrix in a:
- q, r = np.linalg.qr(matrix, mode=mode)
- all_q.append(q)
- all_r.append(r)
- return np.array(all_q), np.array(all_r)
+ result = np.linalg.qr(matrix, mode=mode)
+ if mode == 'r':
+ all_r.append(result)
+ else:
+ q, r = result
+ all_q.append(q)
+ all_r.append(r)
+ if mode == 'r':
+ return np.array(all_r)
+ else:
+ return np.array(all_q), np.array(all_r)
t = torch.randn((3, 7, 5), device=device, dtype=dtype)
np_t = t.cpu().numpy()
@@ -3090,6 +3119,15 @@
q, r = torch.linalg.qr(t, mode=mode)
self.assertEqual(q, exp_q)
self.assertEqual(r, exp_r)
+ # for mode='r' we need a special logic because numpy returns only r
+ exp_r = np_qr_batched(np_t, mode='r')
+ q, r = torch.linalg.qr(t, mode='r')
+ # check that q is empty
+ self.assertEqual(q.shape, (0,))
+ self.assertEqual(q.dtype, t.dtype)
+ self.assertEqual(q.device, t.device)
+ # check r
+ self.assertEqual(r, exp_r)
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@@ -3112,11 +3150,22 @@
out = (torch.empty((0), dtype=dtype, device=device),
torch.empty((0), dtype=dtype, device=device))
q2, r2 = torch.linalg.qr(t, mode=mode, out=out)
- assert q2 is out[0]
- assert r2 is out[1]
+ self.assertIs(q2, out[0])
+ self.assertIs(r2, out[1])
self.assertEqual(q2, q)
self.assertEqual(r2, r)
+ @skipCUDAIfNoMagma
+ @skipCPUIfNoLapack
+ @dtypes(torch.float)
+ def test_qr_error_cases(self, device, dtype):
+ t1 = torch.randn(5, device=device, dtype=dtype)
+ with self.assertRaisesRegex(RuntimeError, 'qr input should have at least 2 dimensions, but has 1 dimensions instead'):
+ torch.linalg.qr(t1)
+ t2 = torch.randn((5, 7), device=device, dtype=dtype)
+ with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
+ torch.linalg.qr(t2, mode='hello')
+
@dtypes(torch.double, torch.cdouble)
def test_einsum(self, device, dtype):
def check(equation, *operands):
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index d204afd..b08db84 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -6676,11 +6676,10 @@
If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization.
Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization.
-.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :meth:`~torch.linalg.qr`
- instead, which provides a better compatibility with
- ``numpy.linalg.qr``.
+.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :func:`~torch.linalg.qr`
+ instead.
- **Differences with** ``torch.linalg.`` :meth:`~torch.linalg.qr`:
+ **Differences with** ``torch.linalg.qr``:
* ``torch.linalg.qr`` takes a string parameter ``mode`` instead of ``some``:
@@ -6698,21 +6697,21 @@
.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs,
and may produce different (valid) decompositions on different device types
- and different platforms, depending on the precise version of the
- underlying library.
+ or different platforms.
Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of matrices of dimension :math:`m \times n`.
some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for
- complete QR decomposition.
+ complete QR decomposition. If `k = min(m, n)` then:
+
+ * ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default)
+
+ * ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n)
Keyword args:
- out (tuple, optional): tuple of `Q` and `R` tensors
- satisfying :code:`input = torch.matmul(Q, R)`.
- The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)`
- respectively, where :math:`k = \min(m, n)` if :attr:`some:` is ``True`` and
- :math:`k = m` otherwise.
+ out (tuple, optional): tuple of `Q` and `R` tensors.
+ The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above.
Example::
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 6558295..79d195d 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -2078,7 +2078,7 @@
std::string mode, const Tensor& q, const Tensor& r){
bool compute_q, reduced;
std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
- TORCH_CHECK(compute_q, "linalg_qr_backward: cannot compute backward if mode='r'. "
+ TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. "
"Please use torch.linalg.qr(..., mode='reduced')");
auto square_deep_case_backward = [](const Tensor& grad_Q,
@@ -2145,7 +2145,7 @@
TORCH_CHECK(
((m <= n && (!reduced)) || reduced),
- "The derivative is not implemented when nrows > ncols and complete QR. ");
+ "The derivative of qr is not implemented when mode='complete' and nrows > ncols.");
auto grad_Q = grads[0];
auto grad_R = grads[1];
diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py
index 4c724b0..de5fcb5 100644
--- a/torch/linalg/__init__.py
+++ b/torch/linalg/__init__.py
@@ -731,15 +731,15 @@
.. note::
Backpropagation is not supported for ``mode='r'``. Use ``mode='reduced'`` instead.
- If you plan to backpropagate through QR, note that the current backward implementation
- is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))`
- columns of :attr:`input` are linearly independent.
- This behavior may change in the future.
+ Backpropagation is also not supported if the first
+ :math:`\min(input.size(-1), input.size(-2))` columns of any matrix
+ in :attr:`input` are not linearly independent. While no error will
+ be thrown when this occurs the values of the "gradient" produced may
+ be anything. This behavior may change in the future.
.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs,
and may produce different (valid) decompositions on different device types
- and different platforms, depending on the precise version of the
- underlying library.
+ or different platforms.
Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
@@ -753,11 +753,8 @@
* ``'r'``: computes only `R`; returns `(Q, R)` where `Q` is empty and `R` has dimensions (k, n)
Keyword args:
- out (tuple, optional): tuple of `Q` and `R` tensors
- satisfying :code:`input = torch.matmul(Q, R)`.
- The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)`
- respectively, where :math:`k = \min(m, n)` if :attr:`mode` is `'reduced'` and
- :math:`k = m` if :attr:`mode` is `'complete'`.
+ out (tuple, optional): tuple of `Q` and `R` tensors.
+ The dimensions of `Q` and `R` are detailed in the description of :attr:`mode` above.
Example::
@@ -779,6 +776,11 @@
tensor([[ 1., 0., 0.],
[ 0., 1., -0.],
[ 0., -0., 1.]])
+ >>> q2, r2 = torch.linalg.qr(a, mode='r')
+ >>> q2
+ tensor([])
+ >>> torch.equal(r, r2)
+ True
>>> a = torch.randn(3, 4, 5)
>>> q, r = torch.linalg.qr(a, mode='complete')
>>> torch.allclose(torch.matmul(q, r), a)