[ROCm] Enable double __shfl_down (#34103)

Summary:
This allows us to enable some double-based pdist tests running into accrued error from casting down to float previously.

Addresses https://github.com/pytorch/pytorch/issues/33128
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34103

Differential Revision: D20343279

Pulled By: ezyang

fbshipit-source-id: a2da768259fab34ef326976283b7a15bebbbb979
diff --git a/aten/src/THC/THCDeviceUtils.cuh b/aten/src/THC/THCDeviceUtils.cuh
index 34bc74f..dea2e1c 100644
--- a/aten/src/THC/THCDeviceUtils.cuh
+++ b/aten/src/THC/THCDeviceUtils.cuh
@@ -91,12 +91,6 @@
 }
 
 #ifdef __HIP_PLATFORM_HCC__
-//To handle ambiguity, add a type double version.
-__device__ __forceinline__ double WARP_SHFL_DOWN(double value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
-{
-  //(HIP doesn't support double)
-  return (double) __shfl_down((float) value, delta, width);
-}
 __device__ __forceinline__ int64_t WARP_SHFL_DOWN(int64_t value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
 {
   //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
diff --git a/test/test_nn.py b/test/test_nn.py
index 3a8b061..aae8dc7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -6312,7 +6312,6 @@
         input2 = torch.randn(4, 4, requires_grad=True)
         self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
 
-    @skipIfRocm
     def test_pdist(self):
         for device, trans in itertools.product(device_(), [False, True]):
             inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
@@ -6343,7 +6342,6 @@
         inp = torch.randn(4, 5, requires_grad=True)
         gradgradcheck(F.pdist, (inp,))
 
-    @skipIfRocm
     @unittest.expectedFailure
     def test_pdist_cuda_gradgrad_unimplemented(self):
         inp = torch.randn(4, 5, device='cuda', requires_grad=True)
diff --git a/test/test_torch.py b/test/test_torch.py
index b8cb149..9af1f99 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -11287,7 +11287,6 @@
             self._pdist_single((1000, 2), device, 2, dtype, trans=False, grad_check=False)
 
     @slowTest
-    @skipIfRocm
     def test_pdist_norm_backward(self, device):
         for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]:
             for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: