Enable non-synchronizing cub scan for cum* operations (#42036)
Summary:
This uses cub for cum* operations, because, unlike thrust, cub is non-synchronizing.
Cub does not support more than `2**31` element tensors out of the box (in fact, due to cub bugs the cutoff point is even smaller)
so to support that I split the tensor into `2**30` element chunks, and modify the first value of the second and subsequent chunks to contain the cumsum result of the previous chunks. Since modification is done inplace on the source tensor, if something goes wrong and we error out before the source tensor is reverted back to its original state, source tensor will be corrupted, but in most cases errors will invalidate the full coda context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42036
Reviewed By: ajtulloch
Differential Revision: D22749945
Pulled By: ngimel
fbshipit-source-id: 9fc9b54d466df9c8885e79c4f4f8af81e3f224ef
diff --git a/aten/src/ATen/native/cuda/ScanKernels.cu b/aten/src/ATen/native/cuda/ScanKernels.cu
index 77f1b3a..a86987e 100644
--- a/aten/src/ATen/native/cuda/ScanKernels.cu
+++ b/aten/src/ATen/native/cuda/ScanKernels.cu
@@ -8,6 +8,7 @@
#include <thrust/execution_policy.h>
#include <thrust/device_ptr.h>
#include <thrust/scan.h>
+#include <cub/device/device_scan.cuh>
namespace at { namespace native {
@@ -446,6 +447,11 @@
AT_CUDA_CHECK(cudaGetLastError());
}
+template<typename scalar_t, class func_t>
+__global__ void transform_vals(scalar_t * a, scalar_t * b, scalar_t * out, func_t binary_op){
+ *out = binary_op(*a, *b);
+}
+
#ifdef __HIP_PLATFORM_HCC__
template<typename T>
struct ROCm_Bug {
@@ -454,9 +460,9 @@
#endif
template<typename scalar_t, typename BinaryFunction>
-void scan_thrust(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction binary_op) {
- auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
+void scan_thrust_or_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction binary_op) {
#ifdef __HIP_PLATFORM_HCC__
+ auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
using rocm_bug_t = ROCm_Bug<scalar_t>;
thrust::device_ptr<rocm_bug_t> src_data(reinterpret_cast<rocm_bug_t *>(self.data_ptr<scalar_t>()));
thrust::device_ptr<rocm_bug_t> dst_data(reinterpret_cast<rocm_bug_t *>(result.data_ptr<scalar_t>()));
@@ -471,14 +477,56 @@
src_data, src_data + size, dst_data,
rocm_bug_binary_op);
#else
- thrust::device_ptr<scalar_t> src_data(self.data_ptr<scalar_t>());
- thrust::device_ptr<scalar_t> dst_data(result.data_ptr<scalar_t>());
- ptrdiff_t size = self.numel();
- thrust::inclusive_scan(
- thrust::cuda::par(allocator).on(c10::cuda::getCurrentCUDAStream()),
- src_data, src_data + size, dst_data,
- binary_op);
- #endif
+ int64_t size = self.numel();
+ // non synchronizing cub call
+ // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
+ // so split at int_max/2
+ constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
+ for (int64_t i = 0; i < size; i += max_cub_size) {
+ int size_cub = std::min<int64_t>(size - i, max_cub_size);
+ Tensor first_elem; // need to save it for all iterations other than first
+ if (i > 0) {
+ // need to temporarily transform first element of the range we are
+ // operating on; self might be multi-d, but we need to index a single
+ // element
+ auto self_view = at::_unsafe_view(self, -1);
+ first_elem = self_view[i].clone();
+ transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
+ self.data_ptr<scalar_t>() + i,
+ result.data_ptr<scalar_t>() + i - 1,
+ self.data_ptr<scalar_t>() + i,
+ binary_op);
+ }
+ size_t temp_storage_bytes = 0;
+ AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan(
+ nullptr,
+ temp_storage_bytes,
+ self.data_ptr<scalar_t>() + i,
+ result.data_ptr<scalar_t>() + i,
+ binary_op,
+ size_cub,
+ at::cuda::getCurrentCUDAStream()));
+ auto temp_storage = at::native::empty_cuda(
+ {static_cast<int64_t>(temp_storage_bytes)},
+ self.options().dtype(kByte));
+ AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan(
+ temp_storage.data_ptr(),
+ temp_storage_bytes,
+ self.data_ptr<scalar_t>() + i,
+ result.data_ptr<scalar_t>() + i,
+ binary_op,
+ size_cub,
+ at::cuda::getCurrentCUDAStream()));
+ if (i > 0) {
+ if (self.data_ptr<scalar_t>() != result.data_ptr<scalar_t>()) {
+ // restore modified first element only if it's not an inplace operation
+ auto self_view = at::_unsafe_view(self, -1);
+ self_view[i].copy_(first_elem, /*non_blocking=*/true);
+ }
+ }
+ }
+
+#endif
}
template<typename scalar_t, typename BinaryFunction>
@@ -486,14 +534,18 @@
int64_t dim, scalar_t init, BinaryFunction binary_op) {
int ndim = self.dim();
Tensor self_ = self.contiguous();
- result = result.contiguous();
+ bool copy_result = !result.is_contiguous();
+ Tensor result_ = result.contiguous();
if (self.numel() == self.size(dim)) {
- scan_thrust<scalar_t>(self_, result, init, binary_op);
+ scan_thrust_or_cub<scalar_t>(self_, result_, init, binary_op);
} else if (dim == ndim - 1) {
- scan_innermost_dim<scalar_t>(self_, result, init, binary_op);
+ scan_innermost_dim<scalar_t>(self_, result_, init, binary_op);
} else {
- scan_outer_dim<scalar_t>(self_, result, dim, init, binary_op);
+ scan_outer_dim<scalar_t>(self_, result_, dim, init, binary_op);
+ }
+ if (copy_result) {
+ result.copy_(result_);
}
}
diff --git a/test/test_torch.py b/test/test_torch.py
index e74dd7f..aa01f05 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -11980,6 +11980,40 @@
'expected scalar_type Double but found Float'):
torch.logcumsumexp(b, axis, out=inplace_out)
+ def _test_large_cum_fn_helper(self, x, fn):
+ x_cpu = x.cpu().float()
+ expected = fn(x_cpu)
+ actual = fn(x).cpu().float()
+ self.assertEqual(expected, actual.cpu().float())
+
+ @onlyCUDA
+ @dtypesIfCUDA(torch.half) # only small dtype not to get oom
+ def test_large_cumsum(self, device, dtype):
+ # initialization to avoid overflow and half caveats
+ x = torch.empty(2**30 + 200, device=device, dtype=dtype)
+ x[::3] = -3
+ x[1::3] = 2
+ x[2::3] = 1
+ self._test_large_cum_fn_helper(x, lambda x: torch.cumsum(x, 0))
+
+ @onlyCUDA
+ @dtypesIfCUDA(torch.half) # only small dtype not to get oom
+ def test_large_cumprod(self, device, dtype):
+ # initialization to avoid overflow and half caveats
+ x = torch.empty(2**30 + 200, device=device, dtype=dtype)
+ x[::3] = 8
+ x[1::3] = .25
+ x[2::3] = .5
+ self._test_large_cum_fn_helper(x, lambda x: torch.cumprod(x, 0))
+
+ def test_discontiguous_out_cumsum(self, device):
+ x = torch.randn(4, 8, device=device)
+ y = torch.empty(4, 16, device=device)[:, ::2]
+ out = torch.cumsum(x, 0)
+ torch.cumsum(x, 0, out=y)
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(out, y, atol=0., rtol=0.)
+
def test_std_mean(self, device):
x = torch.rand(100, 50, 20, device=device)
for dim in range(x.dim()):