Evenly distribute output grad into all matching inputs for min/max/median (#43519)

Summary:
cc: ngimel mruberry

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43519

Reviewed By: albanD

Differential Revision: D23312235

Pulled By: ngimel

fbshipit-source-id: 678bda54996df7f29acf96add928bb7042fc2069
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 777f071..04ff222 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -5846,12 +5846,13 @@
 # Generic device type autograd tests.
 class TestAutogradDeviceType(TestCase):
 
-    def test_min_max_median_backprops_to_single_value(self, device):
+    def test_min_max_median_backprops_to_all_values(self, device):
         for f in [torch.min, torch.max, torch.median]:
             x = torch.tensor([1., 0., 1., 0., 1., 0.], device=device, requires_grad=True)
             y = f(x)
             y.backward()
             self.assertEqual(x.grad.sum(), 1.)
+            self.assertEqual((x.grad == 1 / 3).sum(), 3)
 
     # skip this test if running on rocm, because in cdist
     # we use __shfl_down_sync on CUDA for fast reduction
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index e785e54..4dbb3b7 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -670,7 +670,7 @@
   self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
 
 - name: max(Tensor self) -> Tensor
-  self: select_first_equal_backward(grad, self, result)
+  self: evenly_distribute_backward(grad, self, result)
 
 - name: max.other(Tensor self, Tensor other) -> Tensor
   self: grad.clone().masked_fill_(self <= other, 0)
@@ -683,7 +683,7 @@
   self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.scalar_type()) / _safe_size(self.sizes(), dim)
 
 - name: median(Tensor self) -> Tensor
-  self: select_first_equal_backward(grad, self, result)
+  self: evenly_distribute_backward(grad, self, result)
 
 # This is in theory incorrect in the following case:
 #   sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
@@ -706,7 +706,7 @@
   self: index_select_backward(grad, dim, indices, self.sizes(), keepdim)
 
 - name: min(Tensor self) -> Tensor
-  self: select_first_equal_backward(grad, self, result)
+  self: evenly_distribute_backward(grad, self, result)
 
 - name: min.other(Tensor self, Tensor other) -> Tensor
   self: grad.clone().masked_fill_(self >= other, 0)
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 683ed76..69a9452 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -726,20 +726,15 @@
   }
 }
 
-Tensor select_first_equal_backward(Tensor grad, const Tensor & input, const Tensor & value) {
-  auto grad_input = at::zeros_like(input);
-
-  // find indices of the first element for which input[idx] == value
-  auto first_value_idx = (input == value).nonzero().select(0, 0);
-
-  if (grad_input.dim() == 0) {
-    grad_input.copy_(grad);
+Tensor evenly_distribute_backward(Tensor grad, const Tensor & input, const Tensor & value) {
+  auto mask = (input == value);
+  auto count = mask.sum();
+  auto grad_input = grad / count;
+  if (input.is_cuda()) {
+    return mask * grad_input;
+  } else {
+    return at::zeros_like(input).masked_fill_(mask, grad_input);
   }
-  else {
-    grad_input.index_put_(at::chunk(first_value_idx, grad_input.dim()), grad);
-  }
-
-  return grad_input;
 }
 
 Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) {