Fix nuclear norm with requires_grad=True (#26303)
Summary:
Changelog:
- Selectively assign compute_uv in the at::svd used internally in the implementation of at::nuclear_norm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26303
Test Plan:
- Add tests in common_method_invocations.py
Refixes: https://github.com/pytorch/pytorch/issues/18275
Differential Revision: D17605357
Pulled By: ezyang
fbshipit-source-id: d87d60afe678e2546dca6992ea66f2daeb6b0346
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 06a88f56..93a8843 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -6,6 +6,7 @@
#include <ATen/TensorUtils.h>
#include <ATen/Parallel.h>
#include <ATen/LegacyTHFunctionsCPU.h>
+#include <ATen/core/grad_mode.h>
#include <functional>
#include <numeric>
#include <vector>
@@ -549,7 +550,11 @@
self.dim() == 2,
"Expected a tensor with 2 dimensions, but got a tensor with ",
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
- return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
+ // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
+ // would end up throwing an error as a result if U and V aren't computed.
+ // Due to this, we have to compute U and V conditionally.
+ return at::sum(std::get<1>(at::svd(self, /*some=*/true,
+ /*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), 0, keepdim);
}
Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
@@ -557,14 +562,19 @@
self.dim() == 2,
"Expected a tensor with 2 dimensions, but got a tensor with ",
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
- return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
+ return at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);
+
}
Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
Tensor p = _move_to_end(self, dim);
- return at::sum(std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
+ // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
+ // would end up throwing an error as a result if U and V aren't computed.
+ // Due to this, we have to compute U and V conditionally.
+ return at::sum(std::get<1>(at::svd(p, /*some=*/true,
+ /*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), -1, keepdim);
}
Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
@@ -572,6 +582,7 @@
Tensor p = _move_to_end(self, dim);
return at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
+
}
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py
index 93515aa..72cb533 100644
--- a/test/common_methods_invocations.py
+++ b/test/common_methods_invocations.py
@@ -534,6 +534,7 @@
('norm', (S, S), ('fro',), 'fro_default'),
('norm', (S, S), ('fro', [0, 1],), 'fro'),
('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
+ ('norm', (S, S, S), ('nuc', [1, 2]), 'nuc_batched', (), NO_ARGS, [skipIfNoLapack]),
('norm', (S, S), (-1,), 'neg_1'),
('norm', (S, S), (-2,), 'neg_2'),
('norm', (S, S), (-0.5,), 'neg_0_5'),
diff --git a/test/test_jit.py b/test/test_jit.py
index ca6ed59..795e61c 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -16334,6 +16334,7 @@
'test_norm_fro',
'test_norm_fro_default',
'test_norm_nuc',
+ 'test_norm_nuc_batched',
# aten op has additional cudnn argument
'test_nn_unfold',