Revert "Updates and simplifies nonzero as_tuple behavior"
This reverts commit 8b143771d0f0bcd93d925263adc8b0d6b235b398.
diff --git a/test/test_torch.py b/test/test_torch.py
index 9f1c16b..6c875e6 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -58,7 +58,7 @@
SIZE = 100
-AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
+AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
# Wrap base test class into a class to hide it from testing
# See https://stackoverflow.com/a/25695512
@@ -10782,6 +10782,15 @@
self.assertEqual(1, len(z))
self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
+ @onlyOnCPUAndCUDA
+ def test_nonzero_deprecated(self, device):
+ x = torch.randn((2, 3), device=device)
+ with self.maybeWarnsRegex(UserWarning, "This overload of nonzero is deprecated"):
+ x.nonzero()
+
+ with self.maybeWarnsRegex(UserWarning, "This overload of nonzero is deprecated"):
+ torch.nonzero(x)
+
# TODO: add torch.complex64, torch.complex128
@dtypes(torch.float, torch.double)
def test_normal(self, device, dtype):
@@ -13050,15 +13059,6 @@
self.assertEqual(tup1, np_result, atol=0, rtol=0)
self.assertEqual(tup2, np_result, atol=0, rtol=0)
- def test_nonzero_astuple_out(self, device):
- t = torch.randn((3, 3, 3), device=device)
- out = torch.empty_like(t, dtype=torch.long)
-
- with self.assertRaises(RuntimeError):
- torch.nonzero(t, as_tuple=True, out=out)
-
- self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
-
@onlyOnCPUAndCUDA
def test_nonzero_discontiguous(self, device):
shape = (4, 4)
@@ -19825,7 +19825,7 @@
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types,
True, [], 0, True),
('addmv', 'scalar', _medium_1d,
- lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
+ lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True,
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
('addmv', 'two_scalars', _medium_1d,
@@ -20065,7 +20065,7 @@
('sigmoid', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('logit', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes()),
('sqrt', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, torch.testing.get_all_fp_dtypes(), [torch.bfloat16]),
- ('tanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5,
+ ('tanh', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5,
torch.testing.get_all_fp_dtypes() + _complex_types, [torch.bfloat16]),
('asin', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('atan', '', _small_3d, lambda t, d: [], 1e-3, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp
index 85d78ae..673af99 100644
--- a/tools/autograd/templates/python_torch_functions.cpp
+++ b/tools/autograd/templates/python_torch_functions.cpp
@@ -583,28 +583,29 @@
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
- "nonzero(Tensor input, *, bool as_tuple=False, Tensor out=None)",
+ "nonzero(Tensor input, *, Tensor out=None)|deprecated",
+ "nonzero(Tensor input, *, bool as_tuple)",
});
- ParsedArgs<3> parsed_args;
+ ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
- if (r.has_torch_function()){
+ if(r.has_torch_function()){
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
- const auto as_tuple = r.toBool(1);
- const auto has_out = !r.isNone(2);
-
- if (as_tuple) {
- TORCH_CHECK(!has_out, "nonzero does not support the out kwarg when as_tuple is True");
- return wrap(dispatch_nonzero_numpy(r.tensor(0)));
+ if (r.idx == 0) {
+ if (r.isNone(1)) {
+ return wrap(dispatch_nonzero(r.tensor(0)));
+ } else {
+ return wrap(dispatch_nonzero(r.tensor(0), r.tensor(1)));
+ }
+ } else {
+ if (r.toBool(1)) {
+ return wrap(dispatch_nonzero_numpy(r.tensor(0)));
+ } else {
+ return wrap(dispatch_nonzero(r.tensor(0)));
+ }
}
-
- if (has_out) {
- return wrap(dispatch_nonzero(r.tensor(0), r.tensor(2)));
- }
-
- return wrap(dispatch_nonzero(r.tensor(0)));
END_HANDLE_TH_ERRORS
}