[ROCm] Enable double __shfl_down (#34103)
Summary:
This allows us to enable some double-based pdist tests running into accrued error from casting down to float previously.
Addresses https://github.com/pytorch/pytorch/issues/33128
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34103
Differential Revision: D20343279
Pulled By: ezyang
fbshipit-source-id: a2da768259fab34ef326976283b7a15bebbbb979
diff --git a/aten/src/THC/THCDeviceUtils.cuh b/aten/src/THC/THCDeviceUtils.cuh
index 34bc74f..dea2e1c 100644
--- a/aten/src/THC/THCDeviceUtils.cuh
+++ b/aten/src/THC/THCDeviceUtils.cuh
@@ -91,12 +91,6 @@
}
#ifdef __HIP_PLATFORM_HCC__
-//To handle ambiguity, add a type double version.
-__device__ __forceinline__ double WARP_SHFL_DOWN(double value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
-{
- //(HIP doesn't support double)
- return (double) __shfl_down((float) value, delta, width);
-}
__device__ __forceinline__ int64_t WARP_SHFL_DOWN(int64_t value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
//(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
diff --git a/test/test_nn.py b/test/test_nn.py
index 3a8b061..aae8dc7 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -6312,7 +6312,6 @@
input2 = torch.randn(4, 4, requires_grad=True)
self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
- @skipIfRocm
def test_pdist(self):
for device, trans in itertools.product(device_(), [False, True]):
inp = torch.randn(4, 5, dtype=torch.double, device=device, requires_grad=True)
@@ -6343,7 +6342,6 @@
inp = torch.randn(4, 5, requires_grad=True)
gradgradcheck(F.pdist, (inp,))
- @skipIfRocm
@unittest.expectedFailure
def test_pdist_cuda_gradgrad_unimplemented(self):
inp = torch.randn(4, 5, device='cuda', requires_grad=True)
diff --git a/test/test_torch.py b/test/test_torch.py
index b8cb149..9af1f99 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -11287,7 +11287,6 @@
self._pdist_single((1000, 2), device, 2, dtype, trans=False, grad_check=False)
@slowTest
- @skipIfRocm
def test_pdist_norm_backward(self, device):
for shape in [(4, 5), (3, 2), (2, 1), (1500, 1)]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: