enable deterministic path for index_put with accumulate=False on CPU and CUDA (#57839)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57839

we reuse the `index_put_accum_kernel`, rename it to  `index_put_deterministic_kernel` and add a bool `accumulate` in `index_backward_kernel`

Test Plan:
buck test mode/opt //caffe2/test:torch -- test_index_put_non_accumulate_deterministic

    ✓ Pass: caffe2/test:torch - test_index_put_non_accumulate_deterministic_cpu (test_torch.TestTorchDeviceTypeCPU) (5.120)
Summary
  Pass: 1
  Skip: 1
    ↻ caffe2/test:torch - test_index_put_non_accumulate_deterministic_meta (test_torch.TestTorchDeviceTypeMETA)
  ListingSuccess: 1

buck test mode/opt //caffe2/test:torch_cuda -- test_index_put_non_accumulate_deterministic

    ✓ ListingSuccess: caffe2/test:torch_cuda - main (6.397)
    ✓ Pass: caffe2/test:torch_cuda - test_index_put_non_accumulate_deterministic_cuda (test_torch.TestTorchDeviceTypeCUDA) (26.030)
    ✓ Pass: caffe2/test:torch_cuda - main (26.030)
Summary
  Pass: 2
  ListingSuccess: 1

Reviewed By: ngimel

Differential Revision: D28290699

fbshipit-source-id: df8bbe7af2e72017566161b05b85737fda4ceb3f
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 7aba311..8eb2a74 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -79,7 +79,7 @@
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(index_put_stub);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-DEFINE_DISPATCH(index_put_accum_stub);
+DEFINE_DISPATCH(index_put_with_sort_stub);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(put_stub);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -87,7 +87,7 @@
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(masked_fill_stub);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
-REGISTER_NO_CPU_DISPATCH(index_put_accum_stub, index_put_accum_fn);
+REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub, index_put_with_sort_fn);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(masked_select_serial_stub);
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
@@ -402,10 +402,10 @@
     }
   }
 
-  if (accumulate && self.device().type() == DeviceType::CUDA) {
+  if (self.device().type() == DeviceType::CUDA && (accumulate || globalContext().deterministicAlgorithms())) {
       TORCH_CHECK(value.device() == self.device(), "expected device ", self.device(), " but got device ",
       value.device(), " for value tensor");
-      index_put_accum_stub(self.device().type(), self, indices, value, unsafe);
+      index_put_with_sort_stub(self.device().type(), self, indices, value, accumulate, unsafe);
       return self;
   }
 
@@ -456,11 +456,6 @@
 }
 
 Tensor & index_put_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
-  if (!accumulate) {
-    // See note [Writing Nondeterministic Operations]
-    // Nondeterministic when index contains duplicate entries
-    at::globalContext().alertNotDeterministic("index_put_ with accumulate=False");
-  }
   return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
 }
 
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.h b/aten/src/ATen/native/TensorAdvancedIndexing.h
index 2d20c86..cd2835a 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.h
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.h
@@ -17,7 +17,7 @@
 using index_fill_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride, const Scalar& source);
 using index_copy_fn = void(*)(TensorIterator & iter, int64_t dim, int64_t self_dim_size, int64_t self_dim_stride);
 using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
-using index_put_accum_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool unsafe);
+using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
 using masked_fill_fn = void(*)(TensorIterator &, const Scalar& scalar);
 using put_fn = void(*)(TensorIterator & iter, const Tensor& self, const bool accumulate);
 using take_fn = void(*)(TensorIterator & iter, const Tensor& input);
@@ -37,7 +37,7 @@
 DECLARE_DISPATCH(index_fill_fn, index_fill_stub);
 DECLARE_DISPATCH(index_copy_fn, index_copy_stub);
 DECLARE_DISPATCH(index_put_fn, index_put_stub);
-DECLARE_DISPATCH(index_put_accum_fn, index_put_accum_stub);
+DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
 DECLARE_DISPATCH(put_fn, put_stub);
 DECLARE_DISPATCH(take_fn, take_stub);
 DECLARE_DISPATCH(masked_fill_fn, masked_fill_stub);
diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp
index 38b0b98..f2be56b 100644
--- a/aten/src/ATen/native/cpu/IndexKernel.cpp
+++ b/aten/src/ATen/native/cpu/IndexKernel.cpp
@@ -231,11 +231,11 @@
   // NOTE: duplicate indices are only supported if accumulate is true.
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
     iter.dtype(), "index_put", [&] {
+    // See Note [Enabling Deterministic Operations]
+    // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
+    // must enable serial execution if deterministic algorithms are enabled.
+    const bool is_deterministic = at::globalContext().deterministicAlgorithms();
     if (accumulate) {
-      // See Note [Enabling Deterministic Operations]
-      // Parallel cpu_index_kernel with accumulation is nondeterministic, so we
-      // must enable serial execution if deterministic algorithms are enabled.
-      bool is_deterministic = at::globalContext().deterministicAlgorithms();
       bool use_parallel_for = (!is_deterministic) && (
         (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
       if (use_parallel_for && iter.dtype() == ScalarType::Float) {
@@ -252,7 +252,7 @@
     } else {
       cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
         *(scalar_t*)(dst + offset) = *(scalar_t*)src;
-      });
+      }, /*serial_execution=*/is_deterministic);
     }
   });
 }
diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu
index 69846e9..97ee47c 100644
--- a/aten/src/ATen/native/cuda/Indexing.cu
+++ b/aten/src/ATen/native/cuda/Indexing.cu
@@ -29,7 +29,7 @@
 template <typename scalar_t, int SZ>
 __global__ void indexing_backward_kernel(
   int64_t* sorted_indices, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
-  int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) {
+  int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
 //numel is total number of flattened indices, not expanded to dimensions that are not indexed.
 //stride is the cumulative size of the not-indexed last dimensions
 //stride_before is the stride of the dimension immediately preceding first indexed dimension
@@ -55,6 +55,11 @@
         && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
       do {
         int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
+        // if not accumulate, we only keep the last duplicate index so skip those before it
+        if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
+          idx++;
+          continue;
+        }
         const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
         const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
         const accscalar_t scale = (accscalar_t)1.0;
@@ -68,13 +73,19 @@
             int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
             if (feature_dim < stride) {
               gradient[ii] = static_cast<accscalar_t>(grad_output[grad_row + feature_dim]);
-              weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
+              if (accumulate) {
+                weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
+              }
             }
           }
 
           #pragma unroll
           for (int ii = 0; ii < SZ; ii++) {
-            weight[ii] += gradient[ii] * scale;
+            if (accumulate) {
+              weight[ii] += gradient[ii] * scale;
+            } else {
+              weight[ii] = gradient[ii] * scale;
+            }
           }
 
           #pragma unroll
@@ -183,7 +194,7 @@
 }
 
 
-void index_put_accum_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
+void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
 
 namespace {
 
@@ -195,7 +206,7 @@
   return result;
 }
 
-void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool unsafe) {
+void index_put_with_sort_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
   if (indices.size() > (size_t)self.dim()) {
     TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
   }
@@ -224,7 +235,7 @@
       // this bug is fixed in CUDA 11.3
 #if defined(CUDA_VERSION) && CUDA_VERSION < 11030
       if (num_indices < 50000) {
-        index_put_accum_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
+        index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
       } else
 #endif
       {
@@ -257,7 +268,8 @@
           num_indices,
           sliceSize,
           strideBefore,
-          nElemBefore);
+          nElemBefore,
+          accumulate);
         C10_CUDA_KERNEL_LAUNCH_CHECK();
       });
 
@@ -266,7 +278,7 @@
   }
 }
 
-REGISTER_CUDA_DISPATCH(index_put_accum_stub, &index_put_accum_kernel);
+REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
 } //anonymous
 
 
diff --git a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
index 1f9a0cc..d6dc83f 100644
--- a/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
+++ b/aten/src/ATen/native/cuda/LegacyThrustHelpers.cu
@@ -7,7 +7,7 @@
 
 namespace at { namespace native {
 
-void index_put_accum_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
+void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) {
   sorted_indices.copy_(linearIndex);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
diff --git a/test/test_torch.py b/test/test_torch.py
index 3cc6f52..79f1d55 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -3966,25 +3966,6 @@
         test_func(torch.Tensor.scatter_add)
         test_func(torch.scatter_add)
 
-    # Ensures that index_put throws nondeterministic alerts in the correct cases
-    @onlyOnCPUAndCUDA
-    def test_nondeterministic_alert_index_put(self, device):
-        def test_func(op_call):
-            a = torch.randn(10, device=device)
-            indices = (torch.tensor([0, 0], device=device), )
-            values = torch.tensor([0, 1], device=device)
-
-            @expectedAlertNondeterministic('index_put_ with accumulate=False')
-            def forward_func(slf, device):
-                op_call(a, indices, values, accumulate=False)
-
-            forward_func(self, device)
-
-        test_func(torch.index_put)
-        test_func(torch.Tensor.index_put)
-        test_func(torch.index_put_)
-        test_func(torch.Tensor.index_put_)
-
     @onlyOnCPUAndCUDA
     def test_nondeterministic_alert_put(self, device):
         def test_func(op_call):
@@ -5306,6 +5287,25 @@
                     y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
                     self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
 
+    @onlyOnCPUAndCUDA
+    def test_index_put_non_accumulate_deterministic(self, device) -> None:
+        with DeterministicGuard(True):
+            for i in range(3):
+                m = random.randint(10, 20)
+                elems = random.randint(20000, 30000)
+                values = torch.rand(elems, device=device)
+                indices = torch.randint(m, (elems,), device=device)
+                input = torch.rand(m, device=device)
+                output = input.index_put((indices,), values, accumulate=False)
+
+                input_list = input.tolist()
+                indices_list = indices.tolist()
+                values_list = values.tolist()
+                for i, v in zip(indices_list, values_list):
+                    input_list[i] = v
+
+                self.assertEqual(output, input_list)
+
     @dtypes(*torch.testing.get_all_dtypes())
     def test_index_fill(self, device, dtype):
         x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
diff --git a/torch/__init__.py b/torch/__init__.py
index 2efb2ee..1897191 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -375,6 +375,7 @@
         * :func:`torch.bmm` when called on sparse-dense CUDA tensors
         * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
           and the index is a list of tensors
+        * :func:`torch.Tensor.index_put` with ``accumulate=False``
         * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
           tensor
         * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
@@ -415,7 +416,6 @@
           ``mode='max'``
         * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor
         * :func:`torch.Tensor.index_copy` when called on a CUDA tensor
-        * :func:`torch.Tensor.index_put_` when ``accumulate=False``
         * :func:`torch.Tensor.put_` when ``accumulate=False``
         * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
         * :func:`torch.histc` when called on a CUDA tensor