Enable non-synchronizing cub scan for cum* operations (#42036)
Summary:
This uses cub for cum* operations, because, unlike thrust, cub is non-synchronizing.
Cub does not support more than `2**31` element tensors out of the box (in fact, due to cub bugs the cutoff point is even smaller)
so to support that I split the tensor into `2**30` element chunks, and modify the first value of the second and subsequent chunks to contain the cumsum result of the previous chunks. Since modification is done inplace on the source tensor, if something goes wrong and we error out before the source tensor is reverted back to its original state, source tensor will be corrupted, but in most cases errors will invalidate the full coda context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42036
Reviewed By: ajtulloch
Differential Revision: D22749945
Pulled By: ngimel
fbshipit-source-id: 9fc9b54d466df9c8885e79c4f4f8af81e3f224ef
diff --git a/test/test_torch.py b/test/test_torch.py
index e74dd7f..aa01f05 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -11980,6 +11980,40 @@
'expected scalar_type Double but found Float'):
torch.logcumsumexp(b, axis, out=inplace_out)
+ def _test_large_cum_fn_helper(self, x, fn):
+ x_cpu = x.cpu().float()
+ expected = fn(x_cpu)
+ actual = fn(x).cpu().float()
+ self.assertEqual(expected, actual.cpu().float())
+
+ @onlyCUDA
+ @dtypesIfCUDA(torch.half) # only small dtype not to get oom
+ def test_large_cumsum(self, device, dtype):
+ # initialization to avoid overflow and half caveats
+ x = torch.empty(2**30 + 200, device=device, dtype=dtype)
+ x[::3] = -3
+ x[1::3] = 2
+ x[2::3] = 1
+ self._test_large_cum_fn_helper(x, lambda x: torch.cumsum(x, 0))
+
+ @onlyCUDA
+ @dtypesIfCUDA(torch.half) # only small dtype not to get oom
+ def test_large_cumprod(self, device, dtype):
+ # initialization to avoid overflow and half caveats
+ x = torch.empty(2**30 + 200, device=device, dtype=dtype)
+ x[::3] = 8
+ x[1::3] = .25
+ x[2::3] = .5
+ self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0))
+
+ def test_discontiguous_out_cumsum(self, device):
+ x = torch.randn(4, 8, device=device)
+ y = torch.empty(4, 16, device=device)[:, ::2]
+ out = torch.cumsum(x, 0)
+ torch.cumsum(x, 0, out=y)
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(out, y, atol=0., rtol=0.)
+
def test_std_mean(self, device):
x = torch.rand(100, 50, 20, device=device)
for dim in range(x.dim()):