Fix torch.cdist backward CUDA error due to illegal gridDim setting (#51569)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/49928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51569

Reviewed By: mruberry

Differential Revision: D26215694

Pulled By: ngimel

fbshipit-source-id: 0710417e6a802424e2dcada325f27452c95d042f
diff --git a/test/test_torch.py b/test/test_torch.py
index 585180f..8170104 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -3831,6 +3831,29 @@
                             expected = self._brute_cdist(x, y, p=p)
                             self.assertEqual(expected, actual)
 
+    @onlyCUDA
+    def test_cdist_cuda_backward(self, device):
+        for l1 in [1, 511, 513]:
+            for l2 in [1, 511, 513]:
+                for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+                    x1 = torch.randn(4, l1, 32, device=device, requires_grad=True)
+                    x2 = x1.clone().detach_().requires_grad_()
+                    y1 = torch.randn(4, l2, 32, device=device, requires_grad=True)
+                    y2 = y1.clone().detach_().requires_grad_()
+                    if p == 2:
+                        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+                            z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean()
+                            z2 = self._brute_cdist(x2, y2, p=2).mean()
+                            z1.backward()
+                            z2.backward()
+                            self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
+                            self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
+                    else:
+                        z1 = torch.cdist(x1, y1, p=p).mean()
+                        z2 = self._brute_cdist(x2, y2, p=p).mean()
+                        self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
+                        self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
+
     @tf32_on_and_off(0.005)
     def test_cdist_large(self, device):
         for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: