Evenly distribute output grad into all matching inputs for min/max/median (#43519)
Summary:
cc: ngimel mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43519
Reviewed By: albanD
Differential Revision: D23312235
Pulled By: ngimel
fbshipit-source-id: 678bda54996df7f29acf96add928bb7042fc2069
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 777f071..04ff222 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5846,12 +5846,13 @@
# Generic device type autograd tests.
class TestAutogradDeviceType(TestCase):
- def test_min_max_median_backprops_to_single_value(self, device):
+ def test_min_max_median_backprops_to_all_values(self, device):
for f in [torch.min, torch.max, torch.median]:
x = torch.tensor([1., 0., 1., 0., 1., 0.], device=device, requires_grad=True)
y = f(x)
y.backward()
self.assertEqual(x.grad.sum(), 1.)
+ self.assertEqual((x.grad == 1 / 3).sum(), 3)
# skip this test if running on rocm, because in cdist
# we use __shfl_down_sync on CUDA for fast reduction
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index e785e54..4dbb3b7 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -670,7 +670,7 @@
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: max(Tensor self) -> Tensor
- self: select_first_equal_backward(grad, self, result)
+ self: evenly_distribute_backward(grad, self, result)
- name: max.other(Tensor self, Tensor other) -> Tensor
self: grad.clone().masked_fill_(self <= other, 0)
@@ -683,7 +683,7 @@
self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.scalar_type()) / _safe_size(self.sizes(), dim)
- name: median(Tensor self) -> Tensor
- self: select_first_equal_backward(grad, self, result)
+ self: evenly_distribute_backward(grad, self, result)
# This is in theory incorrect in the following case:
# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
@@ -706,7 +706,7 @@
self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
- name: min(Tensor self) -> Tensor
- self: select_first_equal_backward(grad, self, result)
+ self: evenly_distribute_backward(grad, self, result)
- name: min.other(Tensor self, Tensor other) -> Tensor
self: grad.clone().masked_fill_(self >= other, 0)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 683ed76..69a9452 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -726,20 +726,15 @@
}
}
-Tensor select_first_equal_backward(Tensor grad, const Tensor & input, const Tensor & value) {
- auto grad_input = at::zeros_like(input);
-
- // find indices of the first element for which input[idx] == value
- auto first_value_idx = (input == value).nonzero().select(0, 0);
-
- if (grad_input.dim() == 0) {
- grad_input.copy_(grad);
+Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) {
+ auto mask = (input == value);
+ auto count = mask.sum();
+ auto grad_input = grad / count;
+ if (input.is_cuda()) {
+ return mask * grad_input;
+ } else {
+ return at::zeros_like(input).masked_fill_(mask, grad_input);
}
- else {
- grad_input.index_put_(at::chunk(first_value_idx, grad_input.dim()), grad);
- }
-
- return grad_input;
}
Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {