add qr_backward functionality for wide case (#42216)
Summary:
Unblocks implementation of https://github.com/pytorch/pytorch/issues/27036. Note that this PR ***does not*** fix #{27036}.
Currently QR decomposition only has support for square and tall (a.k.a. skinny) case.
This PR adds functionality for wide A matrix/tensors, includes 3 unit tests for the new case
and restructures the `qr_backward` method to use the same Walther method as a helper.
cc albanD t-vi
I don't have a gpu machine so haven't tested on cuda but everything passes on my local machine in cpu.
The basic idea of the PR is noted in the comments in the `Functions.cpp` file but I'll note here too for clarity:
let <img src="https://render.githubusercontent.com/render/math?math=A_{m,n}"> be a matrix and <img src="https://render.githubusercontent.com/render/math?math=m < n"> then partition <img src="https://render.githubusercontent.com/render/math?math=A_{m, n}"> as <img src="https://render.githubusercontent.com/render/math?math=A_{m,n} = [ X_{m,m} |\ Y_{m, n-m} ]">
and take QR of <img src="https://render.githubusercontent.com/render/math?math=X"> and call that one
<img src="https://render.githubusercontent.com/render/math?math=X=QU"> the <img src="https://render.githubusercontent.com/render/math?math=Q"> here from <img src="https://render.githubusercontent.com/render/math?math=X"> is the same as the <img src="https://render.githubusercontent.com/render/math?math=Q"> from <img src="https://render.githubusercontent.com/render/math?math=QR"> on entire <img src="https://render.githubusercontent.com/render/math?math=A"> matrix. Then transform <img src="https://render.githubusercontent.com/render/math?math=Y"> with the <img src="https://render.githubusercontent.com/render/math?math=Q"> rotation got from <img src="https://render.githubusercontent.com/render/math?math=X"> to get <img src="https://render.githubusercontent.com/render/math?math=V=Q^{T}Y"> now <img src="https://render.githubusercontent.com/render/math?math=R= [U |\ V] "> and similarly for the grads of each piece, e.g. if <img src="https://render.githubusercontent.com/render/math?math=\bar{A}"> is `grad_A` then
<img src="https://render.githubusercontent.com/render/math?math=\bar{A} = [ \bar{X} |\ \bar{Y}]"> and <img src="https://render.githubusercontent.com/render/math?math=\bar{R} = [ \bar{U} |\ \bar{V}]"> and then
<img src="https://render.githubusercontent.com/render/math?math=\bar{Y} = Q\bar{V}"> and
<img src="https://render.githubusercontent.com/render/math?math=\bar{V}"> is the `narrow()` of `grad_R`.
<img src="https://render.githubusercontent.com/render/math?math=\bar{X}"> is calculated very similar to the original Walther formula (exactly the same in the tall and square cases) but is slightly modified here for wide case matrices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42216
Reviewed By: glaringlee
Differential Revision: D23373118
Pulled By: albanD
fbshipit-source-id: 3702ba7e7e23923868c02cdb7e10a96036052344
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 988bb88..4109d1e 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -2018,68 +2018,127 @@
return result.add(result.transpose(-2, -1)).mul_(0.5);
}
-// We refer Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
-// Algebra Functions with Application in Optimum Experimental Design (Extended Version)
-// The derivative for the QR decomposition is adapted from Eq. 42 of the
-// above reference.
+
+
Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
- bool some, const Tensor& Q, const Tensor& R) {
- auto grad_Q = grads[0];
- auto grad_R = grads[1];
- TORCH_CHECK(R.size(-2) == R.size(-1),
- "The derivative when R is non-square is not implemented. ");
+ bool some, const Tensor& q, const Tensor& r){
+ auto square_deep_case_backward = [](const Tensor& grad_Q,
+ const Tensor& grad_R,
+ const Tensor& A,
+ const Tensor& Q,
+ const Tensor& R) -> Tensor {
+ // For square and deep (tall) case we refer
+ // Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
+ // Algebra Functions with Application in Optimum Experimental Design
+ // (Extended Version) The derivative for the QR decomposition is adapted
+ // from Eq. 42 of the above reference.
- // Compute R (R')^{T}
- Tensor R_term;
- if (grad_R.defined()) {
- R_term = at::matmul(R, grad_R.transpose(-2, -1));
- } else {
- // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
- R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- }
-
- // Compute Q^{T} Q'
- Tensor Q_term;
- if (grad_Q.defined()) {
- Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
- } else {
- // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
- Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- }
-
- // We want to compute: (rhs_solve_1 . R^{-T})
- // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
- // Since R is upper triangular, we can do this using
- // triangular_solve(rhs_solve_1^{T}, R)^{T}
- auto rhs_solve_1 = R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
- rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
- Tensor solve_soln_1;
- std::tie(solve_soln_1, std::ignore) = at::triangular_solve(rhs_solve_1.transpose(-2, -1), R,
- /*upper=*/true, /*transpose=*/false,
- /*unitriangular=*/false);
- Tensor grad_A;
- if (grad_R.defined()) {
- grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
- } else {
- grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
- }
-
- // Successive computations involve computation of QQ^{T} which is identity when A is square
- if (self.size(-1) != self.size(-2)) {
- Tensor rhs_solve_2;
- // We use the same trick from above for this computation
- if (grad_Q.defined()) {
- rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
+ // Compute R (R')^{T}
+ Tensor R_term;
+ if (grad_R.defined()) {
+ R_term = at::matmul(R, grad_R.transpose(-2, -1));
} else {
- rhs_solve_2 = -at::matmul(Q, Q_term);
+ // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
+ R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
- Tensor solve_soln_2;
- std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
+
+ // Compute Q^{T} Q'
+ Tensor Q_term;
+ if (grad_Q.defined()) {
+ Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
+ } else {
+ // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
+ Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ }
+
+ // We want to compute: (rhs_solve_1 . R^{-T})
+ // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
+ // Since R is upper triangular, we can do this using
+ // triangular_solve(rhs_solve_1^{T}, R)^{T}
+ auto rhs_solve_1 =
+ R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
+ rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
+ Tensor solve_soln_1;
+ std::tie(solve_soln_1, std::ignore) = at::triangular_solve(
+ rhs_solve_1.transpose(-2, -1),
+ R,
+ /*upper=*/true,
+ /*transpose=*/false,
+ /*unitriangular=*/false);
+ Tensor grad_A;
+ if (grad_R.defined()) {
+ grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
+ } else {
+ grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
+ }
+
+ // Successive computations involve computation of QQ^{T} which is identity when A is square
+ if (A.size(-1) != A.size(-2)) {
+ Tensor rhs_solve_2;
+ // We use the same trick from above for this computation
+ if (grad_Q.defined()) {
+ rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
+ } else {
+ rhs_solve_2 = -at::matmul(Q, Q_term);
+ }
+ Tensor solve_soln_2;
+ std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
- grad_A.add_(solve_soln_2.transpose(-2, -1));
+ grad_A.add_(solve_soln_2.transpose(-2, -1));
+ }
+ return grad_A;
+ };
+
+ auto m = self.size(-2);
+ auto n = self.size(-1);
+
+ TORCH_CHECK(
+ ((m <= n && (!some)) || some),
+ "The derivative is not implemented when nrows > ncols and complete QR. ");
+
+ auto grad_Q = grads[0];
+ auto grad_R = grads[1];
+
+ if (m >= n) {
+ return square_deep_case_backward(grad_Q, grad_R, self, q, r);
+ } else {
+ // For wide (m < n) input matrices A, partition A = [X|Y] and R = [U|V]
+ // X and U are square full rank matrices. We will partition grads,
+ // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
+ // To obtain grad_X we reuse the gradient formula from the square case.
+ // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
+ // where grad_Q_prime = grad_Q + Y @ grad_V.T
+ // and grad_Y = Q @ grad_V.
+ // Then concatenate grads to get grad_A = [grad_X | grad_Y].
+
+ auto Y = self.narrow(-1, m, n - m);
+ auto U = r.narrow(-1, 0, m);
+ Tensor grad_Y, grad_X, grad_V, grad_Q_prime;
+
+ if (grad_R.defined()) {
+ grad_V = grad_R.narrow(-1, m, n - m);
+ // reuse grad_R to store grad_U
+ grad_R = grad_R.narrow(-1, 0, m);
+ // grad_Q_prime starts with the value of Y @ grad_V.T
+ grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1));
+ } else {
+ // when grad_R is not defined then grad_V and grad_Q_prime
+ // get initialized with zeros
+ grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ }
+
+ if (grad_Q.defined()) {
+ // add the grad_Q term into grad_Q_prime when defined o/w is 0
+ grad_Q_prime = grad_Q_prime + grad_Q;
+ }
+ // Calculate grad_X using the helper. Grad_R contains the grad_U value
+ grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U);
+ grad_Y = at::matmul(q, grad_V);
+ // Concatenate grad_X and grad_Y to get grad_A.
+ return at::cat({grad_X, grad_Y}, -1);
}
- return grad_A;
}
// Invertible case is derived from Jacobi's formula, and also can be found at:
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 7820c33..40f611b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1014,10 +1014,13 @@
lambda usv: (usv[0][..., :, :(S - 2)], usv[1], usv[2])),
('qr', (S, S), (False,), 'square_single', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('qr', (S, S - 2), (True,), 'tall_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
+ ('qr', (S - 2, S), (False,), 'wide_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('qr', (3, S, S), (False,), 'square_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('qr', (3, S, S - 2), (True,), 'tall_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
+ ('qr', (3, S - 2, S), (True,), 'wide_batched' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('qr', (3, 2, S, S), (False,), 'square_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('qr', (3, 2, S, S - 2), (True,), 'tall_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
+ ('qr', (3, 2, S - 2, S), (True,), 'wide_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
S, silent=True),), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),