Fix nuclear norm with requires_grad=True (#26303)

Summary:
Changelog:
- Selectively assign compute_uv in the at::svd used internally in the implementation of at::nuclear_norm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26303

Test Plan:
- Add tests in common_method_invocations.py

Refixes: https://github.com/pytorch/pytorch/issues/18275

Differential Revision: D17605357

Pulled By: ezyang

fbshipit-source-id: d87d60afe678e2546dca6992ea66f2daeb6b0346
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 06a88f56..93a8843 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -6,6 +6,7 @@
 #include <ATen/TensorUtils.h>
 #include <ATen/Parallel.h>
 #include <ATen/LegacyTHFunctionsCPU.h>
+#include <ATen/core/grad_mode.h>
 #include <functional>
 #include <numeric>
 #include <vector>
@@ -549,7 +550,11 @@
       self.dim() == 2,
       "Expected a tensor with 2 dimensions, but got a tensor with ",
       self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
-  return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
+  // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
+  // would end up throwing an error as a result if U and V aren't computed.
+  // Due to this, we have to compute U and V conditionally.
+  return at::sum(std::get<1>(at::svd(self, /*some=*/true,
+                 /*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), 0, keepdim);
 }
 
 Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
@@ -557,14 +562,19 @@
       self.dim() == 2,
       "Expected a tensor with 2 dimensions, but got a tensor with ",
       self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
-  return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
+  return at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);
+
 }
 
 Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
   TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
 
   Tensor p = _move_to_end(self, dim);
-  return at::sum(std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
+  // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
+  // would end up throwing an error as a result if U and V aren't computed.
+  // Due to this, we have to compute U and V conditionally.
+  return at::sum(std::get<1>(at::svd(p, /*some=*/true,
+                 /*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), -1, keepdim);
 }
 
 Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
@@ -572,6 +582,7 @@
 
   Tensor p = _move_to_end(self, dim);
   return at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
+
 }
 
 static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py
index 93515aa..72cb533 100644
--- a/test/common_methods_invocations.py
+++ b/test/common_methods_invocations.py
@@ -534,6 +534,7 @@
         ('norm', (S, S), ('fro',), 'fro_default'),
         ('norm', (S, S), ('fro', [0, 1],), 'fro'),
         ('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
+        ('norm', (S, S, S), ('nuc', [1, 2]), 'nuc_batched', (), NO_ARGS, [skipIfNoLapack]),
         ('norm', (S, S), (-1,), 'neg_1'),
         ('norm', (S, S), (-2,), 'neg_2'),
         ('norm', (S, S), (-0.5,), 'neg_0_5'),
diff --git a/test/test_jit.py b/test/test_jit.py
index ca6ed59..795e61c 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -16334,6 +16334,7 @@
     'test_norm_fro',
     'test_norm_fro_default',
     'test_norm_nuc',
+    'test_norm_nuc_batched',
 
     # aten op has additional cudnn argument
     'test_nn_unfold',