Add a 'dim' argument to nuclear norm (#21022)
Summary:
Addresses #18275.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21022
Differential Revision: D15743515
Pulled By: ezyang
fbshipit-source-id: e4aaea0bd7f863a2abad45c4322d6a9fb02a88e3
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 70e4cf1..2b165bd 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -527,21 +527,89 @@
Tensor nuclear_norm(const Tensor& self, bool keepdim) {
TORCH_CHECK(
self.dim() == 2,
- "Expected a tensor with 2 dimensions, but got a ",
- self.dim(),
- " dimensions tensor instead.");
+ "Expected a tensor with 2 dimensions, but got a tensor with ",
+ self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
}
Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
TORCH_CHECK(
self.dim() == 2,
- "Expected a tensor with 2 dimensions, but got a ",
- self.dim(),
- " dimensions tensor instead.");
+ "Expected a tensor with 2 dimensions, but got a tensor with ",
+ self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
}
+// Non-optimized batched svd implementation. This can be merged with at::svd
+// once at::svd has been ported to ATen.
+static std::tuple<Tensor, Tensor, Tensor>
+_batch_svd(const Tensor& self, bool some, bool compute_uv)
+{
+ const int64_t ndim = self.ndimension();
+
+ TORCH_CHECK(
+ ndim >= 2,
+ "Expected a tensor with at least 2 dimensions, but got a tensor with ",
+ self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
+
+ if (ndim == 2) {
+ return at::svd(self, some, compute_uv);
+ }
+
+ const int64_t n = self.size(-2);
+ const int64_t m = self.size(-1);
+ const int64_t k = std::min<int64_t>(n, m);
+ const int64_t nn = (some && compute_uv) ? k : n;
+ const int64_t mm = (some && compute_uv) ? k : m;
+ const int64_t p = batchCount(self);
+
+ Tensor t = self.reshape({p, n, m});
+
+ Tensor s = at::empty({p, k}, self.options());
+ Tensor u, v;
+ if (compute_uv) {
+ u = at::empty({p, n, nn}, self.options());
+ v = at::empty({p, m, mm}, self.options());
+ }
+
+ for (int64_t i = 0; i < p; i++) {
+ auto tuple = at::svd(t[i], some, compute_uv);
+ s[i] = std::get<1>(tuple);
+ if (compute_uv) {
+ u[i] = std::get<0>(tuple);
+ v[i] = std::get<2>(tuple);
+ }
+ }
+
+ std::vector<int64_t> shape = self.sizes().slice(0, ndim-1).vec();
+ shape[ndim-2] = k;
+ s = s.reshape(shape);
+
+ shape[ndim-2] = n;
+ shape.push_back(nn);
+ u = compute_uv ? u.reshape(shape) : at::zeros(shape, self.options());
+
+ shape[ndim-2] = m;
+ shape[ndim-1] = mm;
+ v = compute_uv ? v.reshape(shape) : at::zeros(shape, self.options());
+
+ return std::tuple<Tensor, Tensor, Tensor>(u, s, v);
+}
+
+Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
+ TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+
+ Tensor p = _move_to_end(self, dim);
+ return at::sum(std::get<1>(_batch_svd(p, true, false)), -1, keepdim);
+}
+
+Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
+ TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+
+ Tensor p = _move_to_end(self, dim);
+ return at::sum_out(result, std::get<1>(_batch_svd(p, true, false)), -1, keepdim);
+}
+
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
if (i == j)
return matrices[i];
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index 35a25d6..1e32d18 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -188,6 +188,28 @@
return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
}
+// Return a permutation with the given axes moved to the end.
+static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
+ const std::vector<int64_t> a = axes.vec();
+ const int64_t ndim = self.ndimension();
+ std::vector<int64_t> perm;
+
+ for (int64_t i = 0; i < ndim; i++) {
+ auto it = std::find(a.begin(), a.end(), i);
+ if (it == a.end()) {
+ perm.push_back(i);
+ }
+ }
+ for (auto i : a) {
+ perm.push_back(i);
+ }
+
+ TORCH_CHECK(perm.size() == ndim,
+ "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
+
+ return self.permute(perm);
+}
+
// Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
static inline std::tuple<std::vector<int64_t>,
std::vector<int64_t>,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index eb2d0f3..eca28c4 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2192,6 +2192,12 @@
- func: nuclear_norm(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
+- func: nuclear_norm(Tensor self, int[2] dim, bool keepdim=False) -> Tensor
+ variants: function
+
+- func: nuclear_norm(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+ variants: function
+
- func: clone(Tensor self) -> Tensor
variants: function, method
dispatch:
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 686233b..687a2dd 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -2674,6 +2674,15 @@
def test_norm(self):
_TestTorchMixin._test_norm(self, device='cuda')
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
+ def test_nuclear_norm_axes_small_brute_force(self):
+ _TestTorchMixin._test_nuclear_norm_axes(self, device='cuda')
+
+ @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
+ def test_nuclear_norm_exceptions(self):
+ _TestTorchMixin._test_nuclear_norm_exceptions(self, device='cuda')
+
def test_dist(self):
_TestTorchMixin._test_dist(self, device='cuda')
diff --git a/test/test_torch.py b/test/test_torch.py
index 57b582b..dbaa22c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -20,8 +20,9 @@
from torch.utils.dlpack import from_dlpack, to_dlpack
from torch._utils import _rebuild_tensor
from torch._six import inf, nan, string_classes, istuple
-from itertools import product, combinations, combinations_with_replacement
+from itertools import product, combinations, combinations_with_replacement, permutations
from functools import reduce
+from random import randrange
from torch import multiprocessing as mp
from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
_compare_trilu_indices
@@ -979,6 +980,100 @@
self._test_norm(self, device='cpu')
@staticmethod
+ def _test_nuclear_norm_axes(self, device='cpu'):
+ def check_single_nuclear_norm(x, axes):
+ if x.is_cuda and randrange(100) < 95:
+ return # too many cpu <==> gpu copies
+
+ a = np.array(x.cpu(), copy=False)
+ expected = np.linalg.norm(a, "nuc", axis=axes)
+
+ ans = torch.norm(x, "nuc", dim=axes)
+ self.assertTrue(ans.is_contiguous())
+ self.assertEqual(ans.shape, expected.shape)
+ self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03))
+
+ out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device)
+ ans = torch.norm(x, "nuc", dim=axes, out=out)
+ self.assertIs(ans, out)
+ self.assertTrue(ans.is_contiguous())
+ self.assertEqual(ans.shape, expected.shape)
+ self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03))
+
+ for n in range(1, 3):
+ for m in range(1, 3):
+ for axes in permutations([0, 1], 2):
+ # 2d, inner dimensions C
+ x = torch.randn(n, m, device=device)
+ check_single_nuclear_norm(x, axes)
+
+ # 2d, inner dimensions Fortran
+ x = torch.randn(m, n, device=device).transpose(-1, -2)
+ check_single_nuclear_norm(x, axes)
+
+ # 2d, inner dimensions non-contiguous
+ x = torch.randn(n, 2 * m, device=device)[:, ::2]
+ check_single_nuclear_norm(x, axes)
+
+ # 2d, all dimensions non-contiguous
+ x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2]
+ check_single_nuclear_norm(x, axes)
+
+ for o in range(1, 3):
+ for axes in permutations([0, 1, 2], 2):
+ # 3d, inner dimensions C
+ x = torch.randn(o, n, m, device=device)
+ check_single_nuclear_norm(x, axes)
+
+ # 3d, inner dimensions Fortran
+ y = torch.randn(o, n, m, device=device).transpose(-1, -2)
+ check_single_nuclear_norm(x, axes)
+
+ # 3d, inner dimensions non-contiguous
+ x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2]
+ check_single_nuclear_norm(x, axes)
+
+ # 3d, all dimensions non-contiguous
+ x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2]
+ check_single_nuclear_norm(x, axes)
+
+ for r in range(1, 3):
+ for axes in permutations([0, 1, 2, 3], 2):
+ # 4d, inner dimensions C
+ x = torch.randn(r, o, n, m, device=device)
+ check_single_nuclear_norm(x, axes)
+
+ # 4d, inner dimensions Fortran
+ x = torch.randn(r, o, n, m, device=device).transpose(-1, -2)
+ check_single_nuclear_norm(x, axes)
+
+ # 4d, inner dimensions non-contiguous
+ x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2]
+ check_single_nuclear_norm(x, axes)
+
+ # 4d, all dimensions non-contiguous
+ x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2]
+ check_single_nuclear_norm(x, axes)
+
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ def test_nuclear_norm_axes_small_brute_force(self):
+ self._test_nuclear_norm_axes(self)
+
+ @staticmethod
+ def _test_nuclear_norm_exceptions(self, device='cpu'):
+ for lst in [], [1], [1, 2]:
+ for axes in (), (0,), (0, 1):
+ x = torch.tensor(lst, dtype=torch.double, device=device)
+ self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes)
+
+ x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device)
+ self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0))
+ self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 2))
+
+ def test_nuclear_norm_exceptions(self):
+ self._test_nuclear_norm_exceptions(self)
+
+ @staticmethod
def _test_dist(self, device):
def run_test(x, y):
for p in [0, 1, 2, 3, 4, inf, -inf]:
diff --git a/torch/functional.py b/torch/functional.py
index 72e1de3..d1909b1 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -696,9 +696,11 @@
elif p == "nuc":
if dtype is not None:
raise ValueError("dtype argument is not supported in nuclear norm")
- if out is None:
- torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
- return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
+ if dim is None:
+ if out is None:
+ return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
+ return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
+ return torch._C._VariableFunctions.nuclear_norm(input, dim, keepdim=keepdim, out=out)
else:
if dim is None:
dim = tuple(range(ndim))