Improve `torch.norm` functionality, errors, and tests (#41956)

Summary:
**BC-Breaking Note:**
BC breaking changes in the case where keepdim=True. Before this change, when calling `torch.norm` with keepdim=True and p='fro' or p=number, leaving all other optional arguments as their default values, the keepdim argument would be ignored. Also, any time `torch.norm` was called with p='nuc', the result would have one fewer dimension than the input, and the dimensions could be out of order depending on which dimensions were being reduced. After the change, for each of these cases, the result has the same number and order of dimensions as the input.

**PR Summary:**

* Fix keepdim behavior
* Throw descriptive errors for unsupported sparse norm args
* Increase unit test coverage for these cases and for complex inputs

These changes were taken from part of PR https://github.com/pytorch/pytorch/issues/40924. That PR is not going to be merged because it overrides `torch.norm`'s interface, which we want to avoid. But these improvements are still useful.

Issue https://github.com/pytorch/pytorch/issues/24802

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41956

Reviewed By: albanD

Differential Revision: D22837455

Pulled By: mruberry

fbshipit-source-id: 509ecabfa63b93737996f48a58c7188b005b7217
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 1c6cfef..bc795da 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -724,17 +724,19 @@
 }
 
 Tensor frobenius_norm(const Tensor& self) {
+  TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors");
   return at::norm(self);
 }
 
 Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
+  TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors");
   TORCH_CHECK(
       dim.size() <= 2,
       "Expected at most 2 dimensions, but got ",
       dim.size(),
       " dimensions instead.");
-  if (dim.size() == 1) {
-    return at::norm(self, 2, dim, keepdim, self.scalar_type());
+  if (dim.size() == 1 || dim.size() == 0) {
+    return at::norm(self, 2, dim, keepdim);
   }
   if (self.is_complex()){
     return at::sqrt(at::sum(at::real(self.conj() * self), dim, keepdim));
@@ -748,12 +750,13 @@
     const Tensor& self,
     IntArrayRef dim,
     bool keepdim) {
+  TORCH_CHECK(!self.is_complex(), "frobenius norm not supported for complex tensors");
   TORCH_CHECK(
       dim.size() <= 2,
       "Expected at most 2 dimensions, but got ",
       dim.size(),
       " dimensions instead.");
-  if (dim.size() == 1) {
+  if (dim.size() == 1 || dim.size() == 0) {
     return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type());
   }
   if (self.is_complex()){
@@ -771,8 +774,12 @@
   // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
   // would end up throwing an error as a result if U and V aren't computed.
   // Due to this, we have to compute U and V conditionally.
-  return at::sum(std::get<1>(at::svd(self, /*some=*/true,
+  Tensor result = at::sum(std::get<1>(at::svd(self, /*some=*/true,
                  /*compute_uv=*/at::GradMode::is_enabled() && self.requires_grad())), 0, keepdim);
+  if (keepdim) {
+    result.unsqueeze_(0);
+  }
+  return result;
 }
 
 Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
@@ -780,27 +787,44 @@
       self.dim() == 2,
       "Expected a tensor with 2 dimensions, but got a tensor with ",
       self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
-  return at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);
-
+  at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);
+  if (keepdim) {
+    result.unsqueeze_(0);
+  }
+  return result;
 }
 
 Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
   TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
 
-  Tensor p = _move_to_end(self, dim);
+  auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
+  auto permutation_reverse = create_reverse_permutation(permutation);
+  Tensor p = self.permute(permutation);
   // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
   // would end up throwing an error as a result if U and V aren't computed.
   // Due to this, we have to compute U and V conditionally.
-  return at::sum(std::get<1>(at::svd(p, /*some=*/true,
+  Tensor result = at::sum(std::get<1>(at::svd(p, /*some=*/true,
                  /*compute_uv=*/at::GradMode::is_enabled() && self.requires_grad())), -1, keepdim);
+  if (keepdim) {
+    result.unsqueeze_(-1);
+    result = result.permute(permutation_reverse);
+  }
+  return result;
 }
 
 Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
   TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
 
-  Tensor p = _move_to_end(self, dim);
-  return at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
+  auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim());
+  auto permutation_reverse = create_reverse_permutation(permutation);
 
+  Tensor p = self.permute(permutation);
+  at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
+  if (keepdim) {
+    result.unsqueeze_(-1);
+    result = result.permute(permutation_reverse);
+  }
+  return result;
 }
 
 static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h
index 90a9e1a..0e89eab 100644
--- a/aten/src/ATen/native/LinearAlgebraUtils.h
+++ b/aten/src/ATen/native/LinearAlgebraUtils.h
@@ -278,4 +278,41 @@
   return strided_to;
 }
 
+// Creates a dimension permutation array that can be given to `at::permute()`, which will shift
+// the two specified dimensions to the end of a tensor, without changing the order of
+// the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
+// placed just to the left of it.
+//
+// For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
+// calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
+// be `vec(0, 2, 1, 3)`.
+static inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
+  TORCH_CHECK(
+    (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
+    "duplicate or invalid dimensions");
+  std::vector<int64_t> permutation(ndim);
+  int64_t cur_permuted_dim = 0;
+  for (int64_t dim_ind = 0; dim_ind < ndim; dim_ind++) {
+    if ((dim_ind != dim0) && (dim_ind != dim1)) {
+      permutation[cur_permuted_dim++] = dim_ind;
+    }
+  }
+  permutation[cur_permuted_dim++] = dim0;
+  permutation[cur_permuted_dim] = dim1;
+  return permutation;
+}
+
+// Creates a dimension permutation array that can be given to `at::permute()`, which
+// will reverse a given permutation.
+// The reverse permutation array is created by swapping the indices and their
+// associated values from the given permutation array.
+static inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
+  int64_t ndim = permutation.size();
+  std::vector<int64_t> reverse_permutation(ndim);
+  for (int64_t dim_ind = 0; dim_ind < ndim; dim_ind++) {
+    reverse_permutation[permutation[dim_ind]] = dim_ind;
+  }
+  return reverse_permutation;
+}
+
 }}  // namespace at::native
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index b1a9e58..0de7933 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -432,6 +432,7 @@
 static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
                                IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
   auto p = opt_p.value_or(2.0);
+  TORCH_CHECK(!(p.toDouble() == 2 && self.is_complex()), "norm with p=2 not supported for complex tensors");
   TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
               "norm only supports CPU AND CUDA device type, got: ", self.device().type());
   TORCH_CHECK(self.layout() == Layout::Strided,
@@ -456,6 +457,8 @@
 
 static inline Tensor _norm(const Tensor &self, Scalar p) {
   if (self.is_sparse()) {
+    // Sparse tensors need a different implementation because their values
+    // are accessed with a different API than strided tensors
     return at::native_norm(self, p);
   } else {
     TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA,
@@ -480,8 +483,14 @@
 
 static Tensor norm(const Tensor& self, optional<Scalar> p, IntArrayRef dim, bool keepdim,
             optional<ScalarType> opt_dtype) {
-  Tensor result;
-  return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype);
+  if (self.is_sparse()) {
+    // Sparse tensors need a different implementation because their values
+    // are accessed with a different API than strided tensors
+    return at::native_norm(self, p, dim, keepdim, opt_dtype);
+  } else {
+    Tensor result;
+    return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype);
+  }
 }
 
 Tensor norm(const Tensor& self, optional<Scalar> p, IntArrayRef dim, bool keepdim, ScalarType dtype) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f46c9f0..d2804d5 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3217,6 +3217,11 @@
   dispatch:
     SparseCPU, SparseCUDA: norm_sparse
 
+- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
+  use_c10_dispatcher: full
+  dispatch:
+    SparseCPU, SparseCUDA: norm_sparse
+
 # TODO: reduce signatures down to one when optional args is available
 - func: _sparse_sum(Tensor self) -> Tensor
   use_c10_dispatcher: full
diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
index 5ac706a..4cedc8a 100644
--- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
@@ -346,10 +346,37 @@
 // --------------------------------------------------------------------
 
 // Only supports floating point, FYI
-Tensor norm_sparse(const SparseTensor& self, Scalar value) {
+Tensor norm_sparse(const SparseTensor& self, Scalar p) {
   AT_ASSERT(self.is_sparse());
+  return norm_sparse(self, p, IntArrayRef{}, false, c10::nullopt);
+}
 
-  return self.coalesce()._values().norm(value);
+Tensor norm_sparse(const SparseTensor& self, optional<Scalar> p, IntArrayRef dim, bool keepdim, optional<ScalarType> dtype) {
+  AT_ASSERT(self.is_sparse());
+  if (dim.size() > 0) {
+    // Only full reductions are supported, so check if that is the case
+    int64_t ndim = self.dim();
+    bool passed_full_reduction_check = ndim == dim.size();
+    if (passed_full_reduction_check) {
+      auto dim_ = dim.vec();
+      maybe_wrap_dims(dim_, ndim);
+      std::vector<bool> dims_check(ndim, false);
+      // Need to check for duplicates, and fail if any are found
+      for (auto dim_ind : dim_) {
+        if (dims_check[dim_ind]) {
+          passed_full_reduction_check = false;
+          break;
+        }
+        dims_check[dim_ind] = true;
+      }
+    }
+    TORCH_CHECK(passed_full_reduction_check,
+      "norm_sparse currently only supports full reductions, so 'dim' must either be empty or contain all dimensions of the input");
+  }
+  TORCH_CHECK(keepdim == false, "norm_sparse currently does not support keepdim=True");
+  TORCH_CHECK(!dtype.has_value(), "norm_sparse currently does not support 'dtype' argument");
+  auto p_ = p.value_or(2.0);
+  return self.coalesce()._values().norm(p_);
 }
 
 // --------------------------------------------------------------------
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 6939dfb..eaa2748 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -1248,6 +1248,23 @@
         test_shape(4, 10, [100, 100, 100, 5, 5, 5, 0])
         test_shape(4, 0, [0, 0, 100, 5, 5, 5, 0])
 
+        # Unsupported arguments should error
+        kwarg_error_pairs = [
+            ({'keepdim': True},
+             RuntimeError, r'norm_sparse currently does not support keepdim=True'),
+            ({'dim': 0},
+             RuntimeError, r'norm_sparse currently only supports full reductions'),
+            ({'dtype': torch.double, 'p': 'fro'},
+             ValueError, r'dtype argument is not supported in frobenius norm'),
+            ({'dtype': torch.double, 'p': 0},
+             RuntimeError, r"norm_sparse currently does not support 'dtype' argument") 
+        ]
+        x = self._gen_sparse(3, 10, 100)[0]
+        for kwargs, err, msg in kwarg_error_pairs:
+            with self.assertRaisesRegex(err, msg):
+                x.norm(**kwargs)
+
+
     @skipIfRocm
     def test_sparse_sum(self):
 
diff --git a/test/test_torch.py b/test/test_torch.py
index 0889a28..709e564 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -10091,32 +10091,106 @@
     @skipCPUIfNoLapack
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_norm(self, device):
-        # full reduction
-        x = torch.randn(25, device=device)
-        xn = x.cpu().numpy()
-        for p in [0, 1, 2, 3, 4, inf, -inf]:
-            res = x.norm(p).item()
-            expected = np.linalg.norm(xn, p)
-            self.assertEqual(res, expected, atol=1e-5, rtol=0, msg="full reduction failed for {}-norm".format(p))
+        def gen_error_message(input_size, p, keepdim, dim=None):
+            return "norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % (
+                input_size, p, keepdim, dim)
 
-        # one dimension
-        x = torch.randn(25, 25, device=device)
-        xn = x.cpu().numpy()
-        for p in [0, 1, 2, 3, 4, inf, -inf]:
-            res = x.norm(p, 1).cpu()
-            expected = np.linalg.norm(xn, p, 1)
-            self.assertEqual(res.shape, expected.shape)
-            self.assertEqual(res, expected, msg="dim reduction failed for {}-norm".format(p))
+        for keepdim in [False, True]:
+            # full reduction
+            x = torch.randn(25, device=device)
+            xn = x.cpu().numpy()
+            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3, 1.5]:
+                res = x.norm(p, keepdim=keepdim).cpu()
+                expected = np.linalg.norm(xn, p, keepdims=keepdim)
+                self.assertEqual(res, expected, atol=1e-5, rtol=0, msg=gen_error_message(x.size(), p, keepdim))
 
-        # matrix norm
-        for p in ['fro', 'nuc']:
-            res = x.norm(p).cpu()
-            expected = np.linalg.norm(xn, p)
-            self.assertEqual(res.shape, expected.shape)
-            self.assertEqual(res, expected, msg="dim reduction failed for {}-norm".format(p))
+            # one dimension
+            x = torch.randn(25, 25, device=device)
+            xn = x.cpu().numpy()
+            for p in [0, 1, 2, 3, 4, inf, -inf, -1, -2, -3]:
+                dim = 1
+                res = x.norm(p, dim, keepdim=keepdim).cpu()
+                expected = np.linalg.norm(xn, p, dim, keepdims=keepdim)
+                msg = gen_error_message(x.size(), p, keepdim, dim)
+                self.assertEqual(res.shape, expected.shape, msg=msg)
+                self.assertEqual(res, expected, msg=msg)
 
-        # larger tensor sanity check
-        self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000)))
+            # matrix norm
+            for p in ['fro', 'nuc']:
+                res = x.norm(p, keepdim=keepdim).cpu()
+                expected = np.linalg.norm(xn, p, keepdims=keepdim)
+                msg = gen_error_message(x.size(), p, keepdim)
+                self.assertEqual(res.shape, expected.shape, msg=msg)
+                self.assertEqual(res, expected, msg=msg)
+
+            # zero dimensions
+            x = torch.randn((), device=device)
+            xn = x.cpu().numpy()
+            res = x.norm(keepdim=keepdim).cpu()
+            expected = np.linalg.norm(xn, keepdims=keepdim)
+            msg = gen_error_message(x.size(), None, keepdim)
+            self.assertEqual(res.shape, expected.shape, msg=msg)
+            self.assertEqual(res, expected, msg=msg)
+
+            # larger tensor sanity check
+            self.assertEqual(
+                2 * torch.norm(torch.ones(10000), keepdim=keepdim),
+                torch.norm(torch.ones(40000), keepdim=keepdim))
+
+            # matrix norm with non-square >2-D tensors, all combinations of reduction dims
+            x = torch.randn(5, 6, 7, 8, device=device)
+            xn = x.cpu().numpy()
+            for p in ['fro', 'nuc']:
+                for dim in product(*[list(range(4))] * 2):
+                    if dim[0] == dim[1]:
+                        continue
+                    res = x.norm(p=p, dim=dim, keepdim=keepdim).cpu()
+                    expected = np.linalg.norm(xn, ord=p, axis=dim, keepdims=keepdim)
+                    msg = gen_error_message(x.size(), p, keepdim, dim)
+                    self.assertEqual(res.shape, expected.shape, msg=msg)
+                    self.assertEqual(res, expected, msg=msg)
+
+    @skipCUDAIfNoMagma
+    @skipCPUIfNoLapack
+    @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+    def test_norm_complex(self, device):
+        def gen_error_message(input_size, p, keepdim, dim=None):
+            return "complex norm failed for input size %s, p=%s, keepdim=%s, dim=%s" % (
+                input_size, p, keepdim, dim)
+
+        if device == 'cpu':
+            for keepdim in [False, True]:
+                # vector norm
+                x = torch.randn(25, device=device) + 1j * torch.randn(25, device=device)
+                xn = x.cpu().numpy()
+                for p in [0, 1, 3, inf, -1, -2, -3, -inf]:
+                    res = x.norm(p, keepdim=keepdim).cpu()
+                    expected = np.linalg.norm(xn, p, keepdims=keepdim)
+                    msg = gen_error_message(x.size(), p, keepdim)
+                    self.assertEqual(res.shape, expected.shape, msg=msg)
+                    self.assertEqual(res, expected, msg=msg)
+
+                # matrix norm
+                x = torch.randn(25, 25, device=device) + 1j * torch.randn(25, 25, device=device)
+                xn = x.cpu().numpy()
+                for p in ['nuc']:
+                    res = x.norm(p, keepdim=keepdim).cpu()
+                    expected = np.linalg.norm(xn, p, keepdims=keepdim)
+                    msg = gen_error_message(x.size(), p, keepdim)
+                    self.assertEqual(res.shape, expected.shape, msg=msg)
+                    self.assertEqual(res, expected, msg=msg)
+
+            # TODO: remove error test and add functionality test above when 2-norm support is added
+            with self.assertRaisesRegex(RuntimeError, r'norm with p=2 not supported for complex tensors'):
+                x = torch.randn(2, device=device, dtype=torch.complex64).norm(p=2)
+
+            # TODO: remove error test and add functionality test above when frobenius support is added
+            with self.assertRaisesRegex(RuntimeError, r'frobenius norm not supported for complex tensors'):
+                x = torch.randn(2, 2, device=device, dtype=torch.complex64).norm(p='fro')
+
+        elif device == 'cuda':
+            with self.assertRaisesRegex(RuntimeError, r'"norm_cuda" not implemented for \'ComplexFloat\''):
+                (1j * torch.randn(25)).norm()
 
     @skipCUDAIfNoMagma
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
diff --git a/torch/_overrides.py b/torch/_overrides.py
index 11f80bb..026a7f2 100644
--- a/torch/_overrides.py
+++ b/torch/_overrides.py
@@ -433,6 +433,8 @@
         torch.native_layer_norm: lambda input, weight, bias, M, N, eps: -1,
         torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
         torch.native_norm: lambda input, p=2: -1,
+        torch.native_norm: lambda input, p=2: -1,
+        torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
         torch.ne: lambda input, other, out=None: -1,
         torch.neg: lambda input, out=None: -1,
         torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
diff --git a/torch/functional.py b/torch/functional.py
index 9401c33..2ec28e6 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1190,9 +1190,10 @@
     if dim is None and out is None and dtype is None and p is not None:
         if isinstance(p, str):
             if p == "fro":
-                return _VF.frobenius_norm(input)
+                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
         if not isinstance(p, str):
-            return _VF.norm(input, p)
+            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))
+            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)
 
     # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
     # remove the overloads where dim is an int and replace with BraodcastingList1