fix searchsorted output type (#42933)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/41389
Make sure searchsorted that returns integer type does not make them require gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42933
Reviewed By: gchanan
Differential Revision: D23109583
Pulled By: albanD
fbshipit-source-id: 5af300b2f7f3c140d39fd7f7d87799f7b93a79c1
diff --git a/test/test_autograd.py b/test/test_autograd.py
index ba3ec2f..9323332 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4434,6 +4434,20 @@
self.assertFalse(out.dtype.is_floating_point)
self.assertFalse(out.requires_grad)
+ out = inp.argmin()
+ self.assertFalse(out.dtype.is_floating_point)
+ self.assertFalse(out.requires_grad)
+
+ out = inp.argsort()
+ self.assertFalse(out.dtype.is_floating_point)
+ self.assertFalse(out.requires_grad)
+
+ val = torch.rand((), requires_grad=True)
+
+ out = torch.searchsorted(inp, val)
+ self.assertFalse(out.dtype.is_floating_point)
+ self.assertFalse(out.requires_grad)
+
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 1ab895b..c2b8688 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -138,7 +138,7 @@
# Quantize functions should not record gradients
'quantize_per_tensor', 'quantize_per_channel',
# Functions that return integers should not have output that require gradients
- 'argmax', 'argmin', 'argsort',
+ 'argmax', 'argmin', 'argsort', 'searchsorted'
}
# Some operators invalidate the grad_accumulator. Let's reset it.