Add a 'dim' argument to nuclear norm (#21022)

Summary:
Addresses #18275.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21022

Differential Revision: D15743515

Pulled By: ezyang

fbshipit-source-id: e4aaea0bd7f863a2abad45c4322d6a9fb02a88e3
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 70e4cf1..2b165bd 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -527,21 +527,89 @@
 Tensor nuclear_norm(const Tensor& self, bool keepdim) {
   TORCH_CHECK(
       self.dim() == 2,
-      "Expected a tensor with 2 dimensions, but got a ",
-      self.dim(),
-      " dimensions tensor instead.");
+      "Expected a tensor with 2 dimensions, but got a tensor with ",
+      self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
   return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
 }
 
 Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
   TORCH_CHECK(
       self.dim() == 2,
-      "Expected a tensor with 2 dimensions, but got a ",
-      self.dim(),
-      " dimensions tensor instead.");
+      "Expected a tensor with 2 dimensions, but got a tensor with ",
+      self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
   return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
 }
 
+// Non-optimized batched svd implementation. This can be merged with at::svd
+// once at::svd has been ported to ATen.
+static std::tuple<Tensor, Tensor, Tensor>
+_batch_svd(const Tensor& self, bool some, bool compute_uv)
+{
+  const int64_t ndim = self.ndimension();
+
+  TORCH_CHECK(
+      ndim >= 2,
+      "Expected a tensor with at least 2 dimensions, but got a tensor with ",
+      self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
+
+  if (ndim == 2) {
+    return at::svd(self, some, compute_uv);
+  }
+
+  const int64_t n = self.size(-2);
+  const int64_t m = self.size(-1);
+  const int64_t k = std::min<int64_t>(n, m);
+  const int64_t nn = (some && compute_uv) ? k : n;
+  const int64_t mm = (some && compute_uv) ? k : m;
+  const int64_t p = batchCount(self);
+
+  Tensor t = self.reshape({p, n, m});
+
+  Tensor s = at::empty({p, k}, self.options());
+  Tensor u, v;
+  if (compute_uv) {
+    u = at::empty({p, n, nn}, self.options());
+    v = at::empty({p, m, mm}, self.options());
+  }
+
+  for (int64_t i = 0; i < p; i++) {
+    auto tuple = at::svd(t[i], some, compute_uv);
+    s[i] = std::get<1>(tuple);
+    if (compute_uv) {
+      u[i] = std::get<0>(tuple);
+      v[i] = std::get<2>(tuple);
+    }
+  }
+
+  std::vector<int64_t> shape = self.sizes().slice(0, ndim-1).vec();
+  shape[ndim-2] = k;
+  s = s.reshape(shape);
+
+  shape[ndim-2] = n;
+  shape.push_back(nn);
+  u = compute_uv ? u.reshape(shape) : at::zeros(shape, self.options());
+
+  shape[ndim-2] = m;
+  shape[ndim-1] = mm;
+  v = compute_uv ? v.reshape(shape) : at::zeros(shape, self.options());
+
+  return std::tuple<Tensor, Tensor, Tensor>(u, s, v);
+}
+
+Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
+  TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+
+  Tensor p = _move_to_end(self, dim);
+  return at::sum(std::get<1>(_batch_svd(p, true, false)), -1, keepdim);
+}
+
+Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
+  TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+
+  Tensor p = _move_to_end(self, dim);
+  return at::sum_out(result, std::get<1>(_batch_svd(p, true, false)), -1, keepdim);
+}
+
 static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
   if (i == j)
     return matrices[i];
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index 35a25d6..1e32d18 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -188,6 +188,28 @@
   return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
 }
 
+// Return a permutation with the given axes moved to the end.
+static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
+  const std::vector<int64_t> a = axes.vec();
+  const int64_t ndim = self.ndimension();
+  std::vector<int64_t> perm;
+
+  for (int64_t i = 0; i < ndim; i++) {
+    auto it = std::find(a.begin(), a.end(), i);
+    if (it == a.end()) {
+       perm.push_back(i);
+    }
+  }
+  for (auto i : a) {
+    perm.push_back(i);
+  }
+
+  TORCH_CHECK(perm.size() == ndim,
+    "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
+
+  return self.permute(perm);
+}
+
 // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
 static inline std::tuple<std::vector<int64_t>,
                          std::vector<int64_t>,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index eb2d0f3..eca28c4 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2192,6 +2192,12 @@
 - func: nuclear_norm(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   variants: function
 
+- func: nuclear_norm(Tensor self, int[2] dim, bool keepdim=False) -> Tensor
+  variants: function
+
+- func: nuclear_norm(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+
 - func: clone(Tensor self) -> Tensor
   variants: function, method
   dispatch:
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 686233b..687a2dd 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -2674,6 +2674,15 @@
     def test_norm(self):
         _TestTorchMixin._test_norm(self, device='cuda')
 
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
+    def test_nuclear_norm_axes_small_brute_force(self):
+        _TestTorchMixin._test_nuclear_norm_axes(self, device='cuda')
+
+    @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
+    def test_nuclear_norm_exceptions(self):
+        _TestTorchMixin._test_nuclear_norm_exceptions(self, device='cuda')
+
     def test_dist(self):
         _TestTorchMixin._test_dist(self, device='cuda')
 
diff --git a/test/test_torch.py b/test/test_torch.py
index 57b582b..dbaa22c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -20,8 +20,9 @@
 from torch.utils.dlpack import from_dlpack, to_dlpack
 from torch._utils import _rebuild_tensor
 from torch._six import inf, nan, string_classes, istuple
-from itertools import product, combinations, combinations_with_replacement
+from itertools import product, combinations, combinations_with_replacement, permutations
 from functools import reduce
+from random import randrange
 from torch import multiprocessing as mp
 from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
     _compare_trilu_indices
@@ -979,6 +980,100 @@
         self._test_norm(self, device='cpu')
 
     @staticmethod
+    def _test_nuclear_norm_axes(self, device='cpu'):
+        def check_single_nuclear_norm(x, axes):
+            if x.is_cuda and randrange(100) < 95:
+                return  # too many cpu <==> gpu copies
+
+            a = np.array(x.cpu(), copy=False)
+            expected = np.linalg.norm(a, "nuc", axis=axes)
+
+            ans = torch.norm(x, "nuc", dim=axes)
+            self.assertTrue(ans.is_contiguous())
+            self.assertEqual(ans.shape, expected.shape)
+            self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03))
+
+            out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
+            ans = torch.norm(x, "nuc", dim=axes, out=out)
+            self.assertIs(ans, out)
+            self.assertTrue(ans.is_contiguous())
+            self.assertEqual(ans.shape, expected.shape)
+            self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03))
+
+        for n in range(1, 3):
+            for m in range(1, 3):
+                for axes in permutations([0, 1], 2):
+                    # 2d, inner dimensions C
+                    x = torch.randn(n, m, device=device)
+                    check_single_nuclear_norm(x, axes)
+
+                    # 2d, inner dimensions Fortran
+                    x = torch.randn(m, n, device=device).transpose(-1, -2)
+                    check_single_nuclear_norm(x, axes)
+
+                    # 2d, inner dimensions non-contiguous
+                    x = torch.randn(n, 2 * m, device=device)[:, ::2]
+                    check_single_nuclear_norm(x, axes)
+
+                    # 2d, all dimensions non-contiguous
+                    x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2]
+                    check_single_nuclear_norm(x, axes)
+
+                for o in range(1, 3):
+                    for axes in permutations([0, 1, 2], 2):
+                        # 3d, inner dimensions C
+                        x = torch.randn(o, n, m, device=device)
+                        check_single_nuclear_norm(x, axes)
+
+                        # 3d, inner dimensions Fortran
+                        y = torch.randn(o, n, m, device=device).transpose(-1, -2)
+                        check_single_nuclear_norm(x, axes)
+
+                        # 3d, inner dimensions non-contiguous
+                        x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2]
+                        check_single_nuclear_norm(x, axes)
+
+                        # 3d, all dimensions non-contiguous
+                        x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2]
+                        check_single_nuclear_norm(x, axes)
+
+                    for r in range(1, 3):
+                        for axes in permutations([0, 1, 2, 3], 2):
+                            # 4d, inner dimensions C
+                            x = torch.randn(r, o, n, m, device=device)
+                            check_single_nuclear_norm(x, axes)
+
+                            # 4d, inner dimensions Fortran
+                            x = torch.randn(r, o, n, m, device=device).transpose(-1, -2)
+                            check_single_nuclear_norm(x, axes)
+
+                            # 4d, inner dimensions non-contiguous
+                            x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2]
+                            check_single_nuclear_norm(x, axes)
+
+                            # 4d, all dimensions non-contiguous
+                            x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2]
+                            check_single_nuclear_norm(x, axes)
+
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    def test_nuclear_norm_axes_small_brute_force(self):
+        self._test_nuclear_norm_axes(self)
+
+    @staticmethod
+    def _test_nuclear_norm_exceptions(self, device='cpu'):
+        for lst in [], [1], [1, 2]:
+            for axes in (), (0,), (0, 1):
+                x = torch.tensor(lst, dtype=torch.double, device=device)
+                self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes)
+
+        x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device)
+        self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0))
+        self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 2))
+
+    def test_nuclear_norm_exceptions(self):
+        self._test_nuclear_norm_exceptions(self)
+
+    @staticmethod
     def _test_dist(self, device):
         def run_test(x, y):
             for p in [0, 1, 2, 3, 4, inf, -inf]:
diff --git a/torch/functional.py b/torch/functional.py
index 72e1de3..d1909b1 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -696,9 +696,11 @@
     elif p == "nuc":
         if dtype is not None:
             raise ValueError("dtype argument is not supported in nuclear norm")
-        if out is None:
-            torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
-        return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
+        if dim is None:
+            if out is None:
+                return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
+            return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
+        return torch._C._VariableFunctions.nuclear_norm(input, dim, keepdim=keepdim, out=out)
     else:
         if dim is None:
             dim = tuple(range(ndim))