Avoid COW materialize in index, reduce, compare, unique, and copy ops (#119504)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119504
Approved by: https://github.com/ezyang
ghstack dependencies: #119501, #119502, #119503
diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp
index bdd07ea..ed794d1 100644
--- a/aten/src/ATen/native/Copy.cpp
+++ b/aten/src/ATen/native/Copy.cpp
@@ -81,7 +81,7 @@
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
 
   _AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
-    scalar_t* sp = src.data_ptr<scalar_t>();
+    const scalar_t* sp = src.const_data_ptr<scalar_t>();
     scalar_t* rp = self.data_ptr<scalar_t>();
     scalar_t* bp = buf.data_ptr<scalar_t>();
 
@@ -89,7 +89,7 @@
     int64_t NC = src.size(1);
     for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
       for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
-        scalar_t* spo = sp + R + C * NR;
+        const scalar_t* spo = sp + R + C * NR;
         scalar_t* rpo = rp + C + R * NC;
 
         int nr = std::min(NR - R, BLOCK_SZ);
@@ -156,7 +156,7 @@
         auto* output_ptr =
             reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
         if (self.numel() < at::internal::GRAIN_SIZE) {
-          fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel());
+          fbgemm::FloatToFloat16_simd(src.const_data_ptr<float>(), output_ptr, self.numel());
         } else {
           at::parallel_for(
               0,
@@ -164,14 +164,14 @@
               at::internal::GRAIN_SIZE,
               [&](int64_t begin, int64_t end) {
                 fbgemm::FloatToFloat16_simd(
-                    src.data_ptr<float>() + begin,
+                    src.const_data_ptr<float>() + begin,
                     output_ptr + begin,
                   end - begin);
               });
         }
       } else {
-        auto in_data = reinterpret_cast<fbgemm::float16*>(
-            src.data_ptr<at::Half>());
+        auto in_data = reinterpret_cast<const fbgemm::float16*>(
+            src.const_data_ptr<at::Half>());
         auto* output_ptr = self.data_ptr<float>();
         if (self.numel() < at::internal::GRAIN_SIZE) {
           fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
@@ -265,7 +265,7 @@
 
   auto iter = TensorIteratorConfig()
     .add_output(self)
-    .add_input(src)
+    .add_const_input(src)
     .resize_outputs(false)
     .check_all_same_dtype(false)
     .check_all_same_device(false)
@@ -335,7 +335,7 @@
   // FIXME: really, overlapping writes should be illegal/an error in Torch
   auto iter = TensorIteratorConfig()
       .add_output(dst)
-      .add_input(src)
+      .add_const_input(src)
       .resize_outputs(false)
       .set_check_mem_overlap(false)
       .check_all_same_dtype(true)
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 5fa3e2b..1c02858 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1250,7 +1250,7 @@
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
     using accscalar_t = at::acc_type<scalar_t, false>;
     accscalar_t sum = 0;
-    const auto* t_data = self.data_ptr<scalar_t>();
+    const auto* t_data = self.const_data_ptr<scalar_t>();
 
     int64_t t_stride_0, t_stride_1, t_diag_size;
 
@@ -1726,7 +1726,7 @@
 
   auto mean = self.mean().item<double>();
   auto iter = TensorIteratorConfig()
-      .add_input(self)
+      .add_const_input(self)
       .build();
 
   auto reduction = [&](int64_t begin, int64_t end, double thread_sum) {
@@ -2197,7 +2197,7 @@
       return true;
     }
     std::atomic<bool> result{true};
-    auto iter = TensorIteratorConfig().add_input(self).build();
+    auto iter = TensorIteratorConfig().add_const_input(self).build();
     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "equal_notnan_cpu", [&] {
       iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
         if (!result) {
@@ -2218,8 +2218,8 @@
 
   std::atomic<bool> result{true};
   auto iter = TensorIteratorConfig()
-    .add_input(self)
-    .add_input(other)
+    .add_const_input(self)
+    .add_const_input(other)
     .allow_cpu_scalars(true)
     .promote_inputs_to_common_dtype(true)
     .build();
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index ba00cbf..d1660e1 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -408,9 +408,9 @@
   config.set_check_mem_overlap(false)
       .check_all_same_dtype(false)
       .add_output(result)
-      .add_owned_input(info.src);
+      .add_owned_const_input(info.src);
   for (auto& index : info.indices) {
-    config.add_owned_input(index);
+    config.add_owned_const_input(index);
   }
   if (!result.defined()) {
     config.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device());
@@ -614,9 +614,9 @@
   config.resize_outputs(false);
   config.check_all_same_dtype(false);
   config.add_output(info.src);
-  config.add_input(value);
+  config.add_const_input(value);
   for (auto& index : info.indices) {
-    config.add_input(index);
+    config.add_const_input(index);
   }
   return config.build();
 }
@@ -689,8 +689,8 @@
   auto iter = TensorIteratorConfig()
     .set_check_mem_overlap(false)
     .check_all_same_dtype(false)
-    .add_input(source)
-    .add_input(index_reshaped)
+    .add_const_input(source)
+    .add_const_input(index_reshaped)
     .build();
 
   put_stub(iter.device_type(), iter, self, accumulate);
@@ -769,7 +769,7 @@
     .set_check_mem_overlap(false)
     .check_all_same_dtype(false)
     .add_output(out)
-    .add_input(index)
+    .add_const_input(index)
     .build();
 
   // Early return after out has been resized
@@ -848,8 +848,8 @@
       .check_all_same_dtype(false)
       .resize_outputs(false)
       .add_output(result_restrided)
-      .add_input(index_restrided)
-      .add_input(source_nonzero)
+      .add_const_input(index_restrided)
+      .add_const_input(source_nonzero)
       .build();
 
     auto result_dim_size = result_nonzero.size(dim);
@@ -943,15 +943,15 @@
     auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
 
     AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () {
-      auto index_data = index_contig.data_ptr<index_t>();
+      auto index_data = index_contig.const_data_ptr<index_t>();
       for (const auto i : c10::irange(numel)) {
           auto self_i = index_data[i];
           TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
           auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
-          auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
+          auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
           iter.unsafe_replace_operand(0, self_data);
           iter.unsafe_replace_operand(1, self_data);
-          iter.unsafe_replace_operand(2, source_data);
+          iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
           add_stub(iter.device_type(), iter, alpha);
       }
     });
@@ -967,10 +967,10 @@
       auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
       // TODO: Maybe TensorAccessor can be used here?
       auto* result_ptr = result.data_ptr<scalar_t>();
-      auto* source_ptr = source.data_ptr<scalar_t>();
+      auto* source_ptr = source.const_data_ptr<scalar_t>();
       AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_",
         [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &alpha_value] {
-        auto index_data = index_contig.data_ptr<index_t>();
+        auto index_data = index_contig.const_data_ptr<index_t>();
         for (const auto i : c10::irange(numel)) {
             auto self_i = index_data[i];
             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
@@ -1040,15 +1040,15 @@
     auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
 
     AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_func_cpu_", [&] () {
-      auto index_data = index_contig.data_ptr<index_t>();
+      auto index_data = index_contig.const_data_ptr<index_t>();
       for (const auto i : c10::irange(numel)) {
         auto self_i = index_data[i];
         TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
         auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
-        auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
+        auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
         iter.unsafe_replace_operand(0, self_data);
         iter.unsafe_replace_operand(1, self_data);
-        iter.unsafe_replace_operand(2, source_data);
+        iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
 
         switch (op) {
           case ReductionType::PROD :
@@ -1090,11 +1090,11 @@
       auto counts_stride = counts.dim() == 0 ? 1 : counts.stride(dim);
       // TODO: Maybe TensorAccessor can be used here?
       auto* result_ptr = result.data_ptr<scalar_t>();
-      auto* source_ptr = source.data_ptr<scalar_t>();
+      auto* source_ptr = source.const_data_ptr<scalar_t>();
       auto counts_ptr = counts.data_ptr<scalar_t>();
       AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_func_cpu_",
         [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &op, &counts_ptr, &counts_stride] {
-        auto index_data = index_contig.data_ptr<index_t>();
+        auto index_data = index_contig.const_data_ptr<index_t>();
         for (const auto i : c10::irange(numel)) {
             auto self_i = index_data[i];
             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
@@ -1175,7 +1175,7 @@
 
   auto out = static_cast<char*>(result_contig.data_ptr());
 
-  auto src_base = static_cast<const char*>(self_contig.data_ptr());
+  auto src_base = static_cast<const char*>(self_contig.const_data_ptr());
 
   auto self_sizes = self_contig.sizes();
   auto outer_dims_product = c10::size_to_dim_(1, self_sizes);
@@ -1191,7 +1191,7 @@
   AT_DISPATCH_INDEX_TYPES(
     index_contig.scalar_type(), "batch_index_select_compute", [&]() {
 
-      const auto* idxs = index_contig.data_ptr<index_t>();
+      const auto* idxs = index_contig.const_data_ptr<index_t>();
       check_indexarray_range<index_t>(idxs, N, src_indexing_axis_dim);
 
       // Special-case single-float copy for efficiency
@@ -1256,7 +1256,7 @@
                   "index_select(): self indexing axis dim should be positive");
       AT_DISPATCH_INDEX_TYPES(
       index_contig.scalar_type(), "index_select_empty_self_bound_check", [&]() {
-        const auto* idxs = index_contig.data_ptr<index_t>();
+        const auto* idxs = index_contig.const_data_ptr<index_t>();
         check_indexarray_range<index_t>(idxs, numel, src_indexing_axis_dim);
       });
       return result;
@@ -1280,7 +1280,7 @@
       .check_all_same_dtype(false)
       .resize_outputs(false)
       .add_output(resultSlice)
-      .add_input(selfSlice)
+      .add_const_input(selfSlice)
       .build();
 
     auto grain_size = at::internal::GRAIN_SIZE;
@@ -1293,7 +1293,7 @@
       AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
         [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes,
           &resultSlice_data, &result_stride_bytes] () {
-        auto index_data = index_contig.data_ptr<index_t>();
+        auto index_data = index_contig.const_data_ptr<index_t>();
         for (const auto i : c10::irange(start, end)) {
           auto self_i = index_data[i];
           TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
@@ -1322,7 +1322,7 @@
           AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
             [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
               &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () {
-            auto index_data = index_contig.data_ptr<index_t>();
+            auto index_data = index_contig.const_data_ptr<index_t>();
             for (const auto i : c10::irange(start, end)) {
               auto self_i = index_data[i];
               TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
@@ -1344,16 +1344,16 @@
       AT_DISPATCH_QINT_TYPES(self.scalar_type(), "index_select_quant", [&index_contig, &self, &result, &dim, &numel] {
         auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
         auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
-        auto self_data_ptr = self.data_ptr<scalar_t>();
+        auto self_data_ptr = self.const_data_ptr<scalar_t>();
         auto result_data_ptr = result.data_ptr<scalar_t>();
         auto self_numel = self.numel();
         AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_quant_",
           [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
-          auto index_data = index_contig.data_ptr<index_t>();
+          auto index_data = index_contig.const_data_ptr<index_t>();
           for (const auto i : c10::irange(numel)) {
             auto self_i = index_data[i];
             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
-            scalar_t *self_ip = self_data_ptr + self_i * self_stride;
+            const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
             *(result_data_ptr + i * result_stride) = *self_ip;
           }
         });
@@ -1364,16 +1364,16 @@
         auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
         auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
 
-        auto self_data_ptr = self.data_ptr<scalar_t>();
+        auto self_data_ptr = self.const_data_ptr<scalar_t>();
         auto result_data_ptr = result.data_ptr<scalar_t>();
         auto self_numel = self.numel();
         AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
           [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
-          auto index_data = index_contig.data_ptr<index_t>();
+          auto index_data = index_contig.const_data_ptr<index_t>();
           for (const auto i : c10::irange(numel)) {
             auto self_i = index_data[i];
             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
-            scalar_t *self_ip = self_data_ptr + self_i * self_stride;
+            const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
             *(result_data_ptr + i * result_stride) = *self_ip;
           }
         });
@@ -1462,7 +1462,7 @@
     .check_all_same_dtype(false)
     .resize_outputs(false)
     .add_output(self_restrided)
-    .add_input(index_restrided)
+    .add_const_input(index_restrided)
     .build();
 
   auto self_dim_size = (self_nonzero_dim.sizes())[dim];
@@ -1924,7 +1924,7 @@
     .check_all_same_dtype(false)
     .resize_outputs(false)
     .add_output(self)
-    .add_input(mask)
+    .add_const_input(mask)
     .build();
 
   masked_fill_stub(iter.device_type(), iter, value);
@@ -2017,8 +2017,8 @@
       .check_all_same_dtype(false)
       .resize_outputs(false)
       .add_output(result_strided)
-      .add_input(*_self)
-      .add_input(*_mask)
+      .add_const_input(*_self)
+      .add_const_input(*_mask)
       .build();
 
     masked_select_serial_stub(iter.device_type(), iter, orig_stride);
@@ -2041,9 +2041,9 @@
     .check_all_same_dtype(false)
     .resize_outputs(false)
     .add_output(result_strided)
-    .add_input(*_self)
-    .add_input(*_mask)
-    .add_input(mask_prefix_sum)
+    .add_const_input(*_self)
+    .add_const_input(*_mask)
+    .add_const_input(mask_prefix_sum)
     .build();
 
   masked_select_stub(iter.device_type(), iter, orig_stride);
@@ -2228,7 +2228,7 @@
 
   // Optimized all-reduce
   auto iter = TensorIteratorConfig()
-      .add_input(self)
+      .add_const_input(self)
       .build();
 
   const auto num_threads = at::get_num_threads();
@@ -2267,7 +2267,7 @@
   at::assert_no_overlap(result, self);
 
   auto iter = TensorIteratorConfig()
-    .add_input(self)
+    .add_const_input(self)
     .enforce_linear_iteration()
     .build();
 
@@ -2495,7 +2495,7 @@
       // order of indexing matters
       .enforce_linear_iteration()
       .add_output(self)
-      .add_input(*b_mask)
+      .add_const_input(*b_mask)
       .build();
 
   masked_scatter_stub(iter.device_type(), iter, src_cont);
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index f7a2d0f..4151257 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -128,17 +128,17 @@
     TensorIteratorConfig()                  \
       .set_check_mem_overlap(true)          \
       .add_output(maybe_get_output())       \
-      .add_input(self)                      \
+      .add_const_input(self)                \
       .promote_inputs_to_common_dtype(true) \
       .cast_common_dtype_to_outputs(true)   \
       .enforce_safe_casting_to_output(true)
 
   if (min && max) {
-    build(CLAMP_CONFIG().add_input(*min).add_input(*max));
+    build(CLAMP_CONFIG().add_const_input(*min).add_const_input(*max));
   } else if (min) {
-    build(CLAMP_CONFIG().add_input(*min));
+    build(CLAMP_CONFIG().add_const_input(*min));
   } else if (max) {
-    build(CLAMP_CONFIG().add_input(*max));
+    build(CLAMP_CONFIG().add_const_input(*max));
   }
 }
 
@@ -535,9 +535,9 @@
   auto iter = at::TensorIteratorConfig()
     .check_all_same_dtype(false)
     .add_output(out)
-    .add_input(condition_)
-    .add_input(self_)
-    .add_input(other_)
+    .add_const_input(condition_)
+    .add_const_input(self_)
+    .add_const_input(other_)
     .build();
   where_kernel(iter.device_type(), iter);
   return out;
diff --git a/aten/src/ATen/native/TensorDimApply.h b/aten/src/ATen/native/TensorDimApply.h
index 65d90f6..4d52446 100644
--- a/aten/src/ATen/native/TensorDimApply.h
+++ b/aten/src/ATen/native/TensorDimApply.h
@@ -10,7 +10,7 @@
   int ndims = self.dim();
   int tensor_dim_apply_has_finished = 0;
   std::vector<int64_t> counter(ndims, 0);
-  T1* self_data = self.data_ptr<T1>();
+  const T1* self_data = self.const_data_ptr<T1>();
   T1* values_data = values.data_ptr<T1>();
   T2* indices_data = indices.data_ptr<T2>();
   int64_t self_stride = self.stride(dim);
diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp
index be220fc..12993e9 100644
--- a/aten/src/ATen/native/Unique.cpp
+++ b/aten/src/ATen/native/Unique.cpp
@@ -37,7 +37,7 @@
     const bool return_inverse,
     const bool return_counts) {
   const Tensor& input = self.contiguous();
-  bool* input_data = input.data_ptr<bool>();
+  const bool* input_data = input.const_data_ptr<bool>();
 
   int64_t numel = input.numel();
   Tensor output = at::empty({0}, self.options());
@@ -270,7 +270,7 @@
     const bool return_inverse,
     const bool return_counts) {
   const Tensor& input = self.contiguous();
-  const scalar_t* input_data = input.data_ptr<scalar_t>();
+  const scalar_t* input_data = input.const_data_ptr<scalar_t>();
   int64_t numel = input.numel();
   Tensor output = at::empty({numel}, input.options());
   Tensor inverse_indices = at::empty({0}, self.options().dtype(kLong));
@@ -390,7 +390,7 @@
   std::vector<int64_t> indices(input_flat.size(0));
   std::iota(indices.begin(), indices.end(), 0);
   int64_t numel = input_flat.size(1);
-  scalar_t* input_flat_ptr = ((scalar_t*)input_flat.data_ptr());
+  const scalar_t* input_flat_ptr = ((const scalar_t*)input_flat.const_data_ptr());
 
   // sort indices using data
   if (!consecutive) {
diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp
index a815896..906fa89 100644
--- a/aten/src/ATen/native/cpu/CopyKernel.cpp
+++ b/aten/src/ATen/native/cpu/CopyKernel.cpp
@@ -71,7 +71,7 @@
       using Vecs = Vectorized<scalar_t>;
       c10::SmallBuffer<char*, 2> ptrs(2);
       dest_t* output_data = iter.tensor_base(0).data_ptr<dest_t>();
-      scalar_t* input_data = iter.tensor_base(1).data_ptr<scalar_t>();
+      scalar_t* input_data = const_cast<scalar_t*>(iter.tensor_base(1).const_data_ptr<scalar_t>());
       ptrs[0] = reinterpret_cast<char*>(output_data);
       ptrs[1] = reinterpret_cast<char*>(input_data);
 
@@ -139,7 +139,7 @@
       using Vecs = Vectorized<source_t>;
       c10::SmallBuffer<char*, 2> ptrs(2);
       dest_t* output_data = iter.tensor_base(0).data_ptr<dest_t>();
-      source_t* input_data = iter.tensor_base(1).data_ptr<source_t>();
+      source_t* input_data = const_cast<source_t*>(iter.tensor_base(1).const_data_ptr<source_t>());
 
       ptrs[0] = reinterpret_cast<char*>(output_data);
       ptrs[1] = reinterpret_cast<char*>(input_data);
diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp
index 416f199..bff7edf 100644
--- a/aten/src/ATen/native/cpu/IndexKernel.cpp
+++ b/aten/src/ATen/native/cpu/IndexKernel.cpp
@@ -54,6 +54,7 @@
 void cpu_take_put_kernel(
     TensorIterator& iter,
     const TensorBase& indexed,
+    bool is_indexed_data_mutated,
     const func_t& f,
     bool serial_execution=false) {
   // This kernel follows the same strategy as `cpu_index_kernel`
@@ -70,7 +71,9 @@
   const auto numel = indexed.numel();
   const auto offset_indexed = IndexToOffset(indexed);
 
-  auto* indexed_data = indexed.data_ptr<scalar_t>();
+  auto* indexed_data = is_indexed_data_mutated ?
+   indexed.data_ptr<scalar_t>()
+   : const_cast<scalar_t*>(indexed.const_data_ptr<scalar_t>());
   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
     auto* iterated_data_bytes = data[0];
     auto* index_data_bytes = data[1];
@@ -115,21 +118,21 @@
       bool use_parallel_for = (!is_deterministic) && (
         (iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1));
       if (use_parallel_for && iter.dtype() == ScalarType::Float) {
-        cpu_take_put_kernel<float>(iter, self,
+        cpu_take_put_kernel<float>(iter, self, true,
             [](float& iterated, float* indexed, const int64_t idx) {
                 cpu_atomic_add_float(indexed+idx, iterated);
               });
       } else {
         // TODO: investigate parallelization of the accumulate kernel.
         // Unlike the non-accumulate case, this needs to be thread-safe.
-        cpu_take_put_kernel<scalar_t>(iter, self,
+        cpu_take_put_kernel<scalar_t>(iter, self, true,
             [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
                 indexed[idx] += iterated;
               },
             /*serial_execution=*/true);
       }
     } else {
-      cpu_take_put_kernel<scalar_t>(iter, self,
+      cpu_take_put_kernel<scalar_t>(iter, self, true,
           [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
               indexed[idx] = iterated;
             });
@@ -142,8 +145,8 @@
   const TensorBase & input) {
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
     iter.dtype(), "take_cpu", [&] {
-      cpu_take_put_kernel<scalar_t>(iter, input,
-          [](scalar_t& iterated, scalar_t* indexed, const int64_t idx) {
+      cpu_take_put_kernel<scalar_t>(iter, input, false,
+          [](scalar_t& iterated, const scalar_t* indexed, const int64_t idx) {
               iterated = indexed[idx];
             });
     });
@@ -332,7 +335,7 @@
 template <typename scalar_t>
 void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
   std::ptrdiff_t source_cntr = 0;
-  scalar_t* source_ptr = source.data_ptr<scalar_t>();
+  const scalar_t* source_ptr = source.const_data_ptr<scalar_t>();
   auto numel = source.numel();
 
   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp
index 125f3ce..04fc88d 100644
--- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp
@@ -29,7 +29,7 @@
     vec_func_t vop) {
   using Vec = Vectorized<opmath_type<scalar_t>>;
   const int64_t input_numel = input.numel();
-  auto input_data = input.data_ptr<scalar_t>();
+  auto input_data = input.const_data_ptr<scalar_t>();
   // NOTE: parallel_reduce not support bool type
   scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
     [&](int64_t start, int64_t end, const scalar_t /*ident*/) -> scalar_t {
@@ -50,7 +50,7 @@
     const scalar_t ident_v,
     func_t op) {
   const int64_t input_numel = input.numel();
-  auto input_data = input.data_ptr<scalar_t>();
+  auto input_data = input.const_data_ptr<scalar_t>();
   scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
     [&](int64_t start, int64_t end, const scalar_t ident) -> scalar_t {
       scalar_t partial_out = ident;
@@ -123,7 +123,7 @@
     func_t2 reduce_acc_func) {
   using scalar_t_pair = std::pair<scalar_t, scalar_t>;
   const int64_t input_numel = input.numel();
-  auto input_data = input.data_ptr<scalar_t>();
+  auto input_data = input.const_data_ptr<scalar_t>();
   scalar_t_pair result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
     [&](int64_t start, int64_t end, const scalar_t_pair& ident) -> scalar_t_pair {
       scalar_t_pair partial_out(ident);
@@ -150,7 +150,7 @@
   using Vec = Vectorized<opmath_type<scalar_t>>;
   using scalar_t_pair = std::pair<scalar_t, scalar_t>;
   const int64_t input_numel = input.numel();
-  auto input_data = input.data_ptr<scalar_t>();
+  auto input_data = input.const_data_ptr<scalar_t>();
   // NOTE: parallel_reduce not support bool type
   std::pair<scalar_t, scalar_t> result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
     [&](int64_t start, int64_t end, const scalar_t_pair& /* ident */) -> scalar_t_pair {
diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
index 24598ee..620e8f1 100644
--- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
@@ -53,7 +53,7 @@
     // NOLINTNEXTLINE(bugprone-argument-comment)
     .declare_static_shape(self.sizes(), /*squash_dim=*/dim)
     .add_output(result)
-    .add_input(self)
+    .add_const_input(self)
     .build();
 
   auto result_dim_stride = ensure_nonempty_stride(result, dim);
diff --git a/aten/src/ATen/native/cpu/ReduceUtils.h b/aten/src/ATen/native/cpu/ReduceUtils.h
index c54dc49..1113b7e 100644
--- a/aten/src/ATen/native/cpu/ReduceUtils.h
+++ b/aten/src/ATen/native/cpu/ReduceUtils.h
@@ -199,7 +199,7 @@
 }
 
 template <typename scalar_t, ReductionType reduce>
-inline void update(scalar_t* out, scalar_t* data, int64_t K) {
+inline void update(scalar_t* out, const scalar_t* data, int64_t K) {
   using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
   map2<scalar_t>(
       [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
@@ -211,7 +211,7 @@
 
 template <typename scalar_t, ReductionType reduce,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline void update(at::opmath_type<scalar_t>* out, scalar_t* data, int64_t K) {
+inline void update(at::opmath_type<scalar_t>* out, const scalar_t* data, int64_t K) {
   using opmath_t = at::opmath_type<scalar_t>;
   using Vec = vec::Vectorized<opmath_t>;
   map_acc<scalar_t, opmath_t>(
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index f014c34..984e600 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -59,7 +59,7 @@
     .declare_static_shape(self.sizes(), /*squash_dims=*/dim)
     .add_output(result1)
     .add_output(result2)
-    .add_input(self)
+    .add_const_input(self)
     .build();
 
   iter.for_each(loop, /* grain_size */ 1);
@@ -320,13 +320,13 @@
 
   auto iter = TensorIteratorConfig()
     .add_output(out)
-    .add_input(promoted_elements)
+    .add_const_input(promoted_elements)
     .check_all_same_dtype(false)
     .build();
   // Dispatch based on promoted type.
   AT_DISPATCH_ALL_TYPES(iter.dtype(1), "isin_default_cpu", [&]() {
     cpu_kernel(iter, [&](scalar_t element_val) -> bool {
-      const auto* test_element_data = test_elements_flat.data_ptr<scalar_t>();
+      const auto* test_element_data = test_elements_flat.const_data_ptr<scalar_t>();
       for (const auto j : c10::irange(test_elements_flat.numel())) {
         if (element_val == *(test_element_data + test_elements_stride * j)) {
           return !invert;
diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu
index 657c0c7..5682ba2 100644
--- a/aten/src/ATen/native/cuda/IndexKernel.cu
+++ b/aten/src/ATen/native/cuda/IndexKernel.cu
@@ -333,7 +333,7 @@
     // Cannot use `OpaqueType`, as Tensor::data_ptr<OpaqueType<N>> is not implemented
     AT_DISPATCH_INDEX_TYPES(cuda::detail::canUse32BitIndexMath(input) ? ScalarType::Int : ScalarType::Long,
       "take_cuda_index", [&] {
-         const auto* __restrict__ indexed_ptr = input.template data_ptr<scalar_t>();
+         const auto* __restrict__ indexed_ptr = input.template const_data_ptr<scalar_t>();
          cuda_take_put_kernel<scalar_t, index_t>(iter, input,
             [indexed_ptr] __device__(scalar_t& iterated, const index_t offset) {
                iterated = indexed_ptr[offset];
@@ -385,7 +385,7 @@
       .resize_outputs(false)
       .add_output(self)
       .add_input(self)
-      .add_input(mask_cont)
+      .add_const_input(mask_cont)
       .add_input(maskPrefixSum)
       .build();
 
diff --git a/aten/src/ATen/native/cuda/UniqueCub.cu b/aten/src/ATen/native/cuda/UniqueCub.cu
index feb8d21..f4bb514 100644
--- a/aten/src/ATen/native/cuda/UniqueCub.cu
+++ b/aten/src/ATen/native/cuda/UniqueCub.cu
@@ -158,12 +158,14 @@
     } else {
       sorted = at::empty(self.sizes(), self.options());
     }
-    scalar_t* sorted_data = sorted.mutable_data_ptr<scalar_t>();
 
     Tensor sorted_indices;
     if (!return_inverse) {
       if (!consecutive) {
-        cuda::cub::radix_sort_keys(self.const_data_ptr<scalar_t>(), sorted_data, num_inp);
+        cuda::cub::radix_sort_keys(
+          self.const_data_ptr<scalar_t>(),
+          sorted.mutable_data_ptr<scalar_t>(),
+          num_inp);
       }
     } else {
       if (!consecutive) {
@@ -172,7 +174,7 @@
         sorted_indices = at::empty({num_inp}, options);
         cuda::cub::radix_sort_pairs(
             self.const_data_ptr<scalar_t>(),
-            sorted_data,
+            sorted.mutable_data_ptr<scalar_t>(),
             range.const_data_ptr<int64_t>(),
             sorted_indices.mutable_data_ptr<int64_t>(),
             num_inp);