Fix torch.cdist backward CUDA error due to illegal gridDim setting (#51569)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/49928
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51569
Reviewed By: mruberry
Differential Revision: D26215694
Pulled By: ngimel
fbshipit-source-id: 0710417e6a802424e2dcada325f27452c95d042f
diff --git a/test/test_torch.py b/test/test_torch.py
index 585180f..8170104 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -3831,6 +3831,29 @@
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)
+ @onlyCUDA
+ def test_cdist_cuda_backward(self, device):
+ for l1 in [1, 511, 513]:
+ for l2 in [1, 511, 513]:
+ for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+ x1 = torch.randn(4, l1, 32, device=device, requires_grad=True)
+ x2 = x1.clone().detach_().requires_grad_()
+ y1 = torch.randn(4, l2, 32, device=device, requires_grad=True)
+ y2 = y1.clone().detach_().requires_grad_()
+ if p == 2:
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ z1 = torch.cdist(x1, y1, p=2, compute_mode=cm).mean()
+ z2 = self._brute_cdist(x2, y2, p=2).mean()
+ z1.backward()
+ z2.backward()
+ self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
+ self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
+ else:
+ z1 = torch.cdist(x1, y1, p=p).mean()
+ z2 = self._brute_cdist(x2, y2, p=p).mean()
+ self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001)
+ self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001)
+
@tf32_on_and_off(0.005)
def test_cdist_large(self, device):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: