Revert "Dispatch the auxiliary frobenius_norm and nuclear_norm to better implementations and deprecate them (#81763)"

This reverts commit 122245985a544d9d74d7b5037493541f5e525498.

Reverted https://github.com/pytorch/pytorch/pull/81763 on behalf of https://github.com/mehtanirav due to Internal breakages
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index bb10c74..8cbda26 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -2797,84 +2797,94 @@
 
 ////////////////////////////////////////////////////////////////////////////////
 //                              Frobenius Norm                                //
+//             Just used in torch..norm. It should not be removed.            //
 ////////////////////////////////////////////////////////////////////////////////
 
 Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
-  auto device = self.device();
-  if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
-    TORCH_WARN_ONCE(
-      "at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
-      "It will be removed in a future PyTorch release. Please use ",
-      "`linalg.vector_norm(A, 2., dim, keepdim)` instead"
-    );
+  TORCH_CHECK(
+      dim.size() <= 2,
+      "Expected at most 2 dimensions, but got ",
+      dim.size(),
+      " dimensions instead.");
+  Tensor result;
+  if (dim.size() == 1 || dim.size() == 0) {
+    result = at::norm(self, 2, dim, keepdim);
+  } else {
+    auto dim_ = dim.vec();
+    maybe_wrap_dims(dim_, self.dim());
+    TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead");
+    if (self.is_complex()) {
+      result = at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim));
+    } else {
+      result = at::sqrt(at::sum((self * self), dim_, keepdim));
+    }
   }
-  // This frobenius norm is just wrong, but well
-  TORCH_CHECK(dim.size() <= 2,
-              "Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
-  // Dispatch to at::norm as it is implemented for Sparse and MPS backends
-  // TODO Make the backends implement vector_norm and matrix_norm
-  return at::norm(self, 2., dim, keepdim);
+  TORCH_INTERNAL_ASSERT(result.scalar_type() == toRealValueType(self.scalar_type()));
+  TORCH_INTERNAL_ASSERT(result.layout() == c10::Layout::Strided);
+  return result;
 }
 
 Tensor &frobenius_norm_out(const Tensor& self,
     IntArrayRef dim,
     bool keepdim,
     Tensor& result) {
-  auto device = self.device();
-  if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
-    TORCH_WARN_ONCE(
-      "at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
-      "It will be removed in a future PyTorch release. Please use ",
-      "`linalg.vector_norm(A, 2., dim, keepdim)` instead"
-    );
-  }
-  TORCH_CHECK(dim.size() <= 2,
-              "Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
-  return at::norm_out(result, self, 2., dim, keepdim);
+  auto result_ = at::native::frobenius_norm(self, dim, keepdim);
+  // NOTE: It would be better to avoid resize and copy by using norm_out and sqrt_out in frobenius_norm.
+  //    However, norm_out and sqrt_out do not support automatic differentiation.
+  //    More details here: https://github.com/pytorch/pytorch/pull/44095#discussion_r486673947
+  at::native::resize_output(result, result_.sizes());
+  result.copy_(result_);
+  return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
 //                                Nuclear Norm                                //
+//              Just used in torch.norm. It should not be removed.            //
 ////////////////////////////////////////////////////////////////////////////////
 
 Tensor nuclear_norm(const Tensor& self, bool keepdim) {
-  return at::native::nuclear_norm(self, IntArrayRef({-2, -1}), keepdim);
+  TORCH_CHECK(
+      self.dim() == 2,
+      "Expected a tensor with 2 dimensions, but got a tensor with ",
+      self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
+  return at::native::nuclear_norm(self, IntArrayRef({0, 1}), keepdim);
 }
 
 Tensor &nuclear_norm_out(const Tensor& self, bool keepdim, Tensor& result) {
-  auto device = self.device();
-  if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
-    TORCH_WARN_ONCE(
-      "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
-      "It will be removed in a future PyTorch release. Please use ",
-      "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
-    );
-  }
-  return at::linalg_matrix_norm_out(result, self, "nuc", IntArrayRef({-2, -1}), keepdim);
+  TORCH_CHECK(
+      self.dim() == 2,
+      "Expected a tensor with 2 dimensions, but got a tensor with ",
+      self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
+  return at::native::nuclear_norm_out(self, IntArrayRef({0, 1}), keepdim, result);
 }
 
-Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
-  auto device = self.device();
-  if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
-    TORCH_WARN_ONCE(
-      "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
-      "It will be removed in a future PyTorch release. Please use ",
-      "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
-    );
+namespace {
+Tensor nuclear_norm_impl(const Tensor& self, IntArrayRef dim, bool keepdim) {
+  TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
+  auto dim_ = dim.vec();
+  maybe_wrap_dims(dim_, self.dim());
+
+  auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim());
+  Tensor p = self.permute(permutation);
+  Tensor result_ = at::sum(at::linalg_svdvals(p), -1, keepdim);
+  if (keepdim) {
+    result_.unsqueeze_(-1);
+    auto permutation_reverse = create_reverse_permutation(std::move(permutation));
+    result_ = result_.permute(permutation_reverse);
   }
-  return at::linalg_matrix_norm(self, "nuc", dim, keepdim);
+  return result_;
+}
+} // anonymous namespace
+
+Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
+  return nuclear_norm_impl(self, dim, keepdim).to(toRealValueType(self.scalar_type()));
 }
 
 Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
-  auto device = self.device();
-  if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
-    TORCH_WARN_ONCE(
-      "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
-      "It will be removed in a future PyTorch release. Please use ",
-      "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
-    );
-  }
-  return at::linalg_matrix_norm_out(result, self, "nuc", dim, keepdim);
+  auto result_ = nuclear_norm_impl(self, dim, keepdim);
+  at::native::resize_output(result, result_.sizes());
+  result.copy_(result_);
+  return result;
 }
 
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index d920cec..05a90e3 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -66,7 +66,6 @@
 #include <ATen/ops/gradient_native.h>
 #include <ATen/ops/imag.h>
 #include <ATen/ops/isnan_native.h>
-#include <ATen/ops/linalg_vector_norm.h>
 #include <ATen/ops/logcumsumexp.h>
 #include <ATen/ops/logcumsumexp_native.h>
 #include <ATen/ops/logical_xor.h>
@@ -1463,10 +1462,34 @@
     bool keepdim,
     optional<ScalarType> opt_dtype,
     const Tensor& result) {
-  // Left this implementation without deprecating it as it is called in a number of places
-  // in the codebase. We should swap those by linalg_vector_norm
   auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
-  at::linalg_vector_norm_out(const_cast<Tensor&>(result), self, p, dim, keepdim, opt_dtype);
+  auto in_dtype = opt_dtype.value_or(self.scalar_type());
+  auto out_dtype = result.scalar_type();
+
+  // See the note [Reductions do not use vectorized ops]
+  Tensor self_;
+  if (self.is_cpu() && self.is_complex() && std::abs(p.toDouble()) == INFINITY) {
+    if (opt_dtype.has_value()) {
+      self_ = self.to(*opt_dtype).abs();
+    } else {
+      self_ = self.abs();
+    }
+  } else {
+    self_ = self;
+  }
+
+
+  // omit in_dtype in the following call, to avoid make_reduction explicitly
+  // casting input to out_dtype
+  auto iter = isComplexType(self_.scalar_type())
+      ? meta::make_reduction(self_, result, dim, keepdim, in_dtype)
+      : meta::make_reduction_from_out_ty(self_, result, dim, keepdim, out_dtype);
+
+  if (iter.numel() == 0) {
+    result.zero_();
+  } else {
+    norm_stub(iter.device_type(), iter, p);
+  }
 }
 
 TORCH_IMPL_FUNC(norm_out)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 543e24e..e3a9bcf 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -16072,13 +16072,6 @@
            check_batched_forward_grad=False,
            supports_fwgrad_bwgrad=True,
            skips=(
-               # MPS has some mild accuracy issues for float16. We divide the tolerances by 10
-               DecorateInfo(
-                   toleranceOverride({torch.float16: tol(atol=1e-4, rtol=0.01)}),
-                   'TestConsistency',
-                   'test_output_match',
-
-               ),
                # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479
                DecorateInfo(
                    unittest.skip("Skipped!"),