Set proper output differentiability for unique function (#47930)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/47851
Since the definitions of these functions in `native_functions.yaml` has special dispatch, we were already generating the proper `NotImplemented` behavior for these functions but we were wrongfully setting that gradient of all of the outputs.
Added entries in `derivatives.yaml` to allow us to specify which outpus are differentiable or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47930
Reviewed By: smessmer
Differential Revision: D24960667
Pulled By: albanD
fbshipit-source-id: 19e5bb3029cf0d020b31e2fa264b3a03dd86ec10
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 46a0835..d989e87 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4856,12 +4856,60 @@
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
- bins = torch.linspace(0, 1.0, requires_grad=True)
+ bins = torch.linspace(0, 1.0, steps=100, requires_grad=True)
vals = torch.rand(5, 5, requires_grad=True)
out = torch.bucketize(vals, bins)
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
+ def assert_only_first_requires_grad(res):
+ if not isinstance(res, tuple):
+ res = (res,)
+ self.assertTrue(res[0].requires_grad)
+ for out in res[1:]:
+ if out is not None:
+ self.assertFalse(out.requires_grad)
+
+ for sort in [True, False]:
+ for return_inverse in [True, False]:
+ for return_counts in [True, False]:
+ res = torch.unique(inp, sorted=sort, return_inverse=return_inverse,
+ return_counts=return_counts)
+ assert_only_first_requires_grad(res)
+
+ res = torch.unique(inp, sorted=sort, return_inverse=return_inverse,
+ return_counts=return_counts, dim=0)
+ assert_only_first_requires_grad(res)
+
+ res = torch.unique_consecutive(inp, return_inverse=return_inverse,
+ return_counts=return_counts)
+ assert_only_first_requires_grad(res)
+
+ res = torch.unique_consecutive(inp, return_inverse=return_inverse,
+ return_counts=return_counts, dim=0)
+ assert_only_first_requires_grad(res)
+
+ # Here we test the internal functions to make sure all of them are
+ # covered on top of the public API
+ res = torch._unique(inp, sorted=sort, return_inverse=return_inverse)
+ assert_only_first_requires_grad(res)
+
+ # This looks public but is actually manually deleted from the
+ # torch namespace in torch/functional.py
+ res = torch._VF.unique_dim(inp, dim=0, sorted=sort, return_inverse=return_inverse,
+ return_counts=return_counts)
+ assert_only_first_requires_grad(res)
+
+ # We don't test `unique_dim_consecutive` here.
+ # It looks public but the python binding is actually manually disabled in
+ # tools/autograd/gen_python_functions.py
+
+ res = torch._unique2(inp, sorted=sort, return_inverse=return_inverse,
+ return_counts=return_counts)
+ assert_only_first_requires_grad(res)
+
+
+
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index aeec025..476580b 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1109,8 +1109,25 @@
self: zeros_like(grad)
- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
+ output_differentiability: [True, False]
self: not_implemented("_unique")
+- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+ output_differentiability: [True, False, False]
+ self: not_implemented("unique_dim")
+
+- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
+ output_differentiability: [True, False, False]
+ self: not_implemented("unique_consecutive")
+
+- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+ output_differentiability: [True, False, False]
+ self: not_implemented("unique_dim_consecutive")
+
+- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+ output_differentiability: [True, False, False]
+ self: not_implemented("_unique2")
+
- name: _unsafe_view(Tensor self, int[] size) -> Tensor
self: grad.reshape(self.sizes())
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 4fde7d5..38e3ebe 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -829,7 +829,6 @@
# set_flags has to appear after version_counter, because rebase_history
# requires that the counter is incremented before it is called
body.append(emit_history())
- if requires_derivative:
body.append(emit_save_outputs())
body.extend(emit_check_if_in_complex_autograd_allowlist())
if base_name in RESET_GRAD_ACCUMULATOR: