enable deterministic path for index_put with accumulate=False on CPU and CUDA (#57839)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57839
we reuse the `index_put_accum_kernel`, rename it to `index_put_deterministic_kernel` and add a bool `accumulate` in `index_backward_kernel`
Test Plan:
buck test mode/opt //caffe2/test:torch -- test_index_put_non_accumulate_deterministic
✓ Pass: caffe2/test:torch - test_index_put_non_accumulate_deterministic_cpu (test_torch.TestTorchDeviceTypeCPU) (5.120)
Summary
Pass: 1
Skip: 1
↻ caffe2/test:torch - test_index_put_non_accumulate_deterministic_meta (test_torch.TestTorchDeviceTypeMETA)
ListingSuccess: 1
buck test mode/opt //caffe2/test:torch_cuda -- test_index_put_non_accumulate_deterministic
✓ ListingSuccess: caffe2/test:torch_cuda - main (6.397)
✓ Pass: caffe2/test:torch_cuda - test_index_put_non_accumulate_deterministic_cuda (test_torch.TestTorchDeviceTypeCUDA) (26.030)
✓ Pass: caffe2/test:torch_cuda - main (26.030)
Summary
Pass: 2
ListingSuccess: 1
Reviewed By: ngimel
Differential Revision: D28290699
fbshipit-source-id: df8bbe7af2e72017566161b05b85737fda4ceb3f
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 7aba311..8eb2a74 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -79,7 +79,7 @@
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(index_put_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-DEFINE_DISPATCH(index_put_accum_stub);
+DEFINE_DISPATCH(index_put_with_sort_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(put_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -87,7 +87,7 @@
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(masked_fill_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn);
+REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub, index_put_with_sort_fn);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(masked_select_serial_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -402,10 +402,10 @@
}
}
- if (accumulate && self.device().type() == DeviceType::CUDA) {
+ if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) {
TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ",
value.device(), " for value tensor");
- index_put_accum_stub(self.device().type(), self, indices, value, unsafe);
+ index_put_with_sort_stub(self.device().type(), self, indices, value, accumulate, unsafe);
return self;
}
@@ -456,11 +456,6 @@
}
Tensor & index_put_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
- if (!accumulate) {
- // See note [Writing Nondeterministic Operations]
- // Nondeterministic when index contains duplicate entries
- at::globalContext().alertNotDeterministic("index_put_ with accumulate=False");
- }
return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
}
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h
index 2d20c86..cd2835a 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.h
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.h
@@ -17,7 +17,7 @@
using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
-using index_put_accum_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool unsafe);
+using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
using put_fn = void(*)(TensorIterator & iter, const Tensor& self, const bool accumulate);
using take_fn = void(*)(TensorIterator & iter, const Tensor& input);
@@ -37,7 +37,7 @@
DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
DECLARE_DISPATCH(index_put_fn, index_put_stub);
-DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);
+DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
DECLARE_DISPATCH(put_fn, put_stub);
DECLARE_DISPATCH(take_fn, take_stub);
DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp
index 38b0b98..f2be56b 100644
--- a/aten/src/ATen/native/cpu/IndexKernel.cpp
+++ b/aten/src/ATen/native/cpu/IndexKernel.cpp
@@ -231,11 +231,11 @@
// NOTE: duplicate indices are only supported if accumulate is true.
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
iter.dtype(), "index_put", [&] {
+ // See Note [Enabling Deterministic Operations]
+ // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
+ // must enable serial execution if deterministic algorithms are enabled.
+ const bool is_deterministic = at::globalContext().deterministicAlgorithms();
if (accumulate) {
- // See Note [Enabling Deterministic Operations]
- // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
- // must enable serial execution if deterministic algorithms are enabled.
- bool is_deterministic = at::globalContext().deterministicAlgorithms();
bool use_parallel_for = (!is_deterministic) && (
(iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
if (use_parallel_for && iter.dtype() == ScalarType::Float) {
@@ -252,7 +252,7 @@
} else {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)(dst + offset) = *(scalar_t*)src;
- });
+ }, /*serial_execution=*/is_deterministic);
}
});
}
diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu
index 69846e9..97ee47c 100644
--- a/aten/src/ATen/native/cuda/Indexing.cu
+++ b/aten/src/ATen/native/cuda/Indexing.cu
@@ -29,7 +29,7 @@
template <typename scalar_t, int SZ>
__global__ void indexing_backward_kernel(
int64_t* sorted_indices, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
- int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) {
+ int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
//numel is total number of flattened indices, not expanded to dimensions that are not indexed.
//stride is the cumulative size of the not-indexed last dimensions
//stride_before is the stride of the dimension immediately preceding first indexed dimension
@@ -55,6 +55,11 @@
&& (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
do {
int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
+ // if not accumulate, we only keep the last duplicate index so skip those before it
+ if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
+ idx++;
+ continue;
+ }
const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
const accscalar_t scale = (accscalar_t)1.0;
@@ -68,13 +73,19 @@
int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
if (feature_dim < stride) {
gradient[ii] = static_cast<accscalar_t>(grad_output[grad_row + feature_dim]);
- weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
+ if (accumulate) {
+ weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
+ }
}
}
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
- weight[ii] += gradient[ii] * scale;
+ if (accumulate) {
+ weight[ii] += gradient[ii] * scale;
+ } else {
+ weight[ii] = gradient[ii] * scale;
+ }
}
#pragma unroll
@@ -183,7 +194,7 @@
}
-void index_put_accum_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
+void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
namespace {
@@ -195,7 +206,7 @@
return result;
}
-void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool unsafe) {
+void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
if (indices.size() > (size_t)self.dim()) {
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}
@@ -224,7 +235,7 @@
// this bug is fixed in CUDA 11.3
#if defined(CUDA_VERSION) && CUDA_VERSION < 11030
if (num_indices < 50000) {
- index_put_accum_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
+ index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
} else
#endif
{
@@ -257,7 +268,8 @@
num_indices,
sliceSize,
strideBefore,
- nElemBefore);
+ nElemBefore,
+ accumulate);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
@@ -266,7 +278,7 @@
}
}
-REGISTER_CUDA_DISPATCH(index_put_accum_stub, &index_put_accum_kernel);
+REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
} //anonymous
diff --git a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
index 1f9a0cc..d6dc83f 100644
--- a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
+++ b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
@@ -7,7 +7,7 @@
namespace at { namespace native {
-void index_put_accum_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
+void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
sorted_indices.copy_(linearIndex);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
diff --git a/test/test_torch.py b/test/test_torch.py
index 3cc6f52..79f1d55 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -3966,25 +3966,6 @@
test_func(torch.Tensor.scatter_add)
test_func(torch.scatter_add)
- # Ensures that index_put throws nondeterministic alerts in the correct cases
- @onlyOnCPUAndCUDA
- def test_nondeterministic_alert_index_put(self, device):
- def test_func(op_call):
- a = torch.randn(10, device=device)
- indices = (torch.tensor([0, 0], device=device), )
- values = torch.tensor([0, 1], device=device)
-
- @expectedAlertNondeterministic('index_put_ with accumulate=False')
- def forward_func(slf, device):
- op_call(a, indices, values, accumulate=False)
-
- forward_func(self, device)
-
- test_func(torch.index_put)
- test_func(torch.Tensor.index_put)
- test_func(torch.index_put_)
- test_func(torch.Tensor.index_put_)
-
@onlyOnCPUAndCUDA
def test_nondeterministic_alert_put(self, device):
def test_func(op_call):
@@ -5306,6 +5287,25 @@
y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
+ @onlyOnCPUAndCUDA
+ def test_index_put_non_accumulate_deterministic(self, device) -> None:
+ with DeterministicGuard(True):
+ for i in range(3):
+ m = random.randint(10, 20)
+ elems = random.randint(20000, 30000)
+ values = torch.rand(elems, device=device)
+ indices = torch.randint(m, (elems,), device=device)
+ input = torch.rand(m, device=device)
+ output = input.index_put((indices,), values, accumulate=False)
+
+ input_list = input.tolist()
+ indices_list = indices.tolist()
+ values_list = values.tolist()
+ for i, v in zip(indices_list, values_list):
+ input_list[i] = v
+
+ self.assertEqual(output, input_list)
+
@dtypes(*torch.testing.get_all_dtypes())
def test_index_fill(self, device, dtype):
x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
diff --git a/torch/__init__.py b/torch/__init__.py
index 2efb2ee..1897191 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -375,6 +375,7 @@
* :func:`torch.bmm` when called on sparse-dense CUDA tensors
* :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
and the index is a list of tensors
+ * :func:`torch.Tensor.index_put` with ``accumulate=False``
* :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
tensor
* :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
@@ -415,7 +416,6 @@
``mode='max'``
* :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
* :func:`torch.Tensor.index_copy` when called on a CUDA tensor
- * :func:`torch.Tensor.index_put_` when ``accumulate=False``
* :func:`torch.Tensor.put_` when ``accumulate=False``
* :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
* :func:`torch.histc` when called on a CUDA tensor