add qr_backward functionality for wide case (#42216)

Summary:
Unblocks implementation of https://github.com/pytorch/pytorch/issues/27036. Note that this PR ***does not*** fix #{27036}.
Currently QR decomposition only has support for square and tall (a.k.a. skinny) case.
This PR adds functionality for wide A matrix/tensors, includes 3 unit tests for the new case
and restructures the `qr_backward` method to use the same Walther method as a helper.

cc albanD t-vi

I don't have a gpu machine so haven't tested on cuda but everything passes on my local machine in cpu.

The basic idea of the PR is noted in the comments in the `Functions.cpp` file but I'll note here too for clarity:

let <img src="https://render.githubusercontent.com/render/math?math=A_{m,n}"> be a matrix and <img src="https://render.githubusercontent.com/render/math?math=m < n"> then partition <img src="https://render.githubusercontent.com/render/math?math=A_{m, n}"> as  <img src="https://render.githubusercontent.com/render/math?math=A_{m,n} = [ X_{m,m} |\ Y_{m, n-m} ]">
and take QR of <img src="https://render.githubusercontent.com/render/math?math=X"> and call that one
<img src="https://render.githubusercontent.com/render/math?math=X=QU"> the <img src="https://render.githubusercontent.com/render/math?math=Q"> here from <img src="https://render.githubusercontent.com/render/math?math=X"> is the same as the <img src="https://render.githubusercontent.com/render/math?math=Q"> from <img src="https://render.githubusercontent.com/render/math?math=QR"> on entire <img src="https://render.githubusercontent.com/render/math?math=A"> matrix. Then transform <img src="https://render.githubusercontent.com/render/math?math=Y"> with the <img src="https://render.githubusercontent.com/render/math?math=Q"> rotation got from <img src="https://render.githubusercontent.com/render/math?math=X"> to get <img src="https://render.githubusercontent.com/render/math?math=V=Q^{T}Y"> now <img src="https://render.githubusercontent.com/render/math?math=R= [U |\ V] "> and similarly for the grads of each piece, e.g. if <img src="https://render.githubusercontent.com/render/math?math=\bar{A}"> is  `grad_A` then
<img src="https://render.githubusercontent.com/render/math?math=\bar{A} = [ \bar{X} |\ \bar{Y}]"> and <img src="https://render.githubusercontent.com/render/math?math=\bar{R} = [ \bar{U} |\ \bar{V}]"> and then
<img src="https://render.githubusercontent.com/render/math?math=\bar{Y} =  Q\bar{V}"> and
<img src="https://render.githubusercontent.com/render/math?math=\bar{V}"> is the `narrow()` of `grad_R`.
<img src="https://render.githubusercontent.com/render/math?math=\bar{X}"> is calculated very similar to the original Walther formula (exactly the same in the tall and square cases) but is slightly modified here for wide case matrices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42216

Reviewed By: glaringlee

Differential Revision: D23373118

Pulled By: albanD

fbshipit-source-id: 3702ba7e7e23923868c02cdb7e10a96036052344
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 988bb88..4109d1e 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -2018,68 +2018,127 @@
   return result.add(result.transpose(-2, -1)).mul_(0.5);
 }
 
-// We refer Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
-// Algebra Functions with Application in Optimum Experimental Design (Extended Version)
-// The derivative for the QR decomposition is adapted from Eq. 42 of the
-// above reference.
+
+
 Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
-                   bool some, const Tensor& Q, const Tensor& R) {
-  auto grad_Q = grads[0];
-  auto grad_R = grads[1];
-  TORCH_CHECK(R.size(-2) == R.size(-1),
-              "The derivative when R is non-square is not implemented. ");
+                   bool some, const Tensor& q, const Tensor& r){
+  auto square_deep_case_backward = [](const Tensor& grad_Q,
+                                      const Tensor& grad_R,
+                                      const Tensor& A,
+                                      const Tensor& Q,
+                                      const Tensor& R) -> Tensor {
+    // For square and deep (tall) case we refer
+    // Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
+    // Algebra Functions with Application in Optimum Experimental Design
+    // (Extended Version) The derivative for the QR decomposition is adapted
+    // from Eq. 42 of the above reference.
 
-  // Compute R (R')^{T}
-  Tensor R_term;
-  if (grad_R.defined()) {
-    R_term = at::matmul(R, grad_R.transpose(-2, -1));
-  } else {
-    // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
-    R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  }
-
-  // Compute Q^{T} Q'
-  Tensor Q_term;
-  if (grad_Q.defined()) {
-    Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
-  } else {
-    // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
-    Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  }
-
-  // We want to compute: (rhs_solve_1 . R^{-T})
-  // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
-  // Since R is upper triangular, we can do this using
-  // triangular_solve(rhs_solve_1^{T}, R)^{T}
-  auto rhs_solve_1 = R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
-  rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
-  Tensor solve_soln_1;
-  std::tie(solve_soln_1, std::ignore) = at::triangular_solve(rhs_solve_1.transpose(-2, -1), R,
-                                                             /*upper=*/true, /*transpose=*/false,
-                                                             /*unitriangular=*/false);
-  Tensor grad_A;
-  if (grad_R.defined()) {
-    grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
-  } else {
-    grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
-  }
-
-  // Successive computations involve computation of QQ^{T} which is identity when A is square
-  if (self.size(-1) != self.size(-2)) {
-    Tensor rhs_solve_2;
-    // We use the same trick from above for this computation
-    if (grad_Q.defined()) {
-      rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
+    // Compute R (R')^{T}
+    Tensor R_term;
+    if (grad_R.defined()) {
+      R_term = at::matmul(R, grad_R.transpose(-2, -1));
     } else {
-      rhs_solve_2 = -at::matmul(Q, Q_term);
+      // R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
+      R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
     }
-    Tensor solve_soln_2;
-    std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
+
+    // Compute Q^{T} Q'
+    Tensor Q_term;
+    if (grad_Q.defined()) {
+      Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
+    } else {
+      // Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
+      Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+    }
+
+    // We want to compute: (rhs_solve_1 . R^{-T})
+    // Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
+    // Since R is upper triangular, we can do this using
+    // triangular_solve(rhs_solve_1^{T}, R)^{T}
+    auto rhs_solve_1 =
+        R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
+    rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
+    Tensor solve_soln_1;
+    std::tie(solve_soln_1, std::ignore) = at::triangular_solve(
+        rhs_solve_1.transpose(-2, -1),
+        R,
+        /*upper=*/true,
+        /*transpose=*/false,
+        /*unitriangular=*/false);
+    Tensor grad_A;
+    if (grad_R.defined()) {
+      grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
+    } else {
+      grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
+    }
+
+    // Successive computations involve computation of QQ^{T} which is identity when A is square
+    if (A.size(-1) != A.size(-2)) {
+      Tensor rhs_solve_2;
+      // We use the same trick from above for this computation
+      if (grad_Q.defined()) {
+        rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
+      } else {
+        rhs_solve_2 = -at::matmul(Q, Q_term);
+      }
+      Tensor solve_soln_2;
+      std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
                                                                /*upper=*/true, /*transpose=*/false,
                                                                /*unitriangular=*/false);
-    grad_A.add_(solve_soln_2.transpose(-2, -1));
+      grad_A.add_(solve_soln_2.transpose(-2, -1));
+    }
+    return grad_A;
+  };
+
+  auto m = self.size(-2);
+  auto n = self.size(-1);
+
+  TORCH_CHECK(
+      ((m <= n && (!some)) || some),
+      "The derivative is not implemented when nrows > ncols and complete QR. ");
+
+  auto grad_Q = grads[0];
+  auto grad_R = grads[1];
+
+ if (m >= n) {
+    return square_deep_case_backward(grad_Q, grad_R, self, q, r);
+  } else {
+    // For wide (m < n) input matrices A,  partition A = [X|Y] and R = [U|V]
+    // X and U are square full rank matrices. We will partition grads,
+    // grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
+    // To obtain grad_X we reuse the gradient formula from the square case.
+    // Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
+    // where grad_Q_prime = grad_Q + Y @ grad_V.T
+    // and grad_Y = Q @ grad_V.
+    // Then concatenate grads to get grad_A = [grad_X | grad_Y].
+
+    auto Y = self.narrow(-1, m, n - m);
+    auto U = r.narrow(-1, 0, m);
+    Tensor grad_Y, grad_X, grad_V, grad_Q_prime;
+
+    if (grad_R.defined()) {
+      grad_V = grad_R.narrow(-1, m, n - m);
+      // reuse grad_R to store grad_U
+      grad_R = grad_R.narrow(-1, 0, m);
+      // grad_Q_prime starts with the value of Y @ grad_V.T
+      grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1));
+    } else {
+      // when grad_R is not defined then grad_V and grad_Q_prime
+      // get initialized with zeros
+      grad_V = at::zeros_like(Y, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+      grad_Q_prime = at::zeros_like(q, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+    }
+
+    if (grad_Q.defined()) {
+      // add the grad_Q term into grad_Q_prime when defined o/w is 0
+      grad_Q_prime = grad_Q_prime + grad_Q;
+    }
+    // Calculate grad_X using the helper. Grad_R contains the grad_U value
+    grad_X = square_deep_case_backward(grad_Q_prime, grad_R, self, q, U);
+    grad_Y = at::matmul(q, grad_V);
+    // Concatenate grad_X and grad_Y to get grad_A.
+    return at::cat({grad_X, grad_Y}, -1);
   }
-  return grad_A;
 }
 
 // Invertible case is derived from Jacobi's formula, and also can be found at:
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 7820c33..40f611b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1014,10 +1014,13 @@
          lambda usv: (usv[0][..., :, :(S - 2)], usv[1], usv[2])),
         ('qr', (S, S), (False,), 'square_single', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('qr', (S, S - 2), (True,), 'tall_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
+        ('qr', (S - 2, S), (False,), 'wide_single' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('qr', (3, S, S), (False,), 'square_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('qr', (3, S, S - 2), (True,), 'tall_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
+        ('qr', (3, S - 2, S), (True,), 'wide_batched' , (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('qr', (3, 2, S, S), (False,), 'square_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('qr', (3, 2, S, S - 2), (True,), 'tall_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
+        ('qr', (3, 2, S - 2, S), (True,), 'wide_many_batched', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
             S, silent=True),), '', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]),
         ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),