make ATen/native/cuda/SegmentReduce.cu data_ptr-correct (#99163)
make ATen/native/cuda/SegmentReduce.cu data_ptr-correct
Test Plan: Rely on CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99163
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu
index 1917666..74a05a8 100644
--- a/aten/src/ATen/native/cuda/SegmentReduce.cu
+++ b/aten/src/ATen/native/cuda/SegmentReduce.cu
@@ -70,8 +70,8 @@
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
- auto* lengths_data_ptr = lengths.data_ptr<index_t>();
- auto* offsets_data_ptr = offsets.data_ptr<index_t>();
+ auto* lengths_data_ptr = lengths.const_data_ptr<index_t>();
+ auto* offsets_data_ptr = offsets.mutable_data_ptr<index_t>();
at::cuda::cub::inclusive_sum(
lengths_data_ptr,
offsets_data_ptr + 1,
@@ -105,7 +105,7 @@
__global__ void segment_reduce_forward_kernel(
ReductionType reduction,
scalar_t* output_data,
- scalar_t* values_data,
+ const scalar_t* values_data,
const index_t* lengths_data,
const index_t* lengths_cumsum_data,
const int64_t segment_count,
@@ -175,8 +175,8 @@
__global__ void segment_reduce_backward_kernel(
ReductionType reduction,
scalar_t* grad_input_data,
- scalar_t* grad_data,
- scalar_t* output_data,
+ const scalar_t* grad_data,
+ const scalar_t* output_data,
const scalar_t* values_data,
const index_t* lengths_data,
const index_t* lengths_cumsum_data,
@@ -328,8 +328,8 @@
AT_DISPATCH_INDEX_TYPES(
lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
- const auto* lengths_data = lengths.data_ptr<index_t>();
- auto* offsets_data = offsets.data_ptr<index_t>();
+ const auto* lengths_data = lengths.const_data_ptr<index_t>();
+ auto* offsets_data = offsets.const_data_ptr<index_t>();
// TODO: Switch to TensorIterator for better maintainablility and
// readability
@@ -339,10 +339,10 @@
data_contig.scalar_type(),
"_segment_reduce_cpu",
([&]() {
- auto* output_data = output_contig.data_ptr<scalar_t>();
- auto* grad_data = grad_contig.data_ptr<scalar_t>();
- auto* grad_input_data = grad_input.data_ptr<scalar_t>();
- const auto* values_data = data_contig.data_ptr<scalar_t>();
+ auto* output_data = output_contig.const_data_ptr<scalar_t>();
+ auto* grad_data = grad_contig.const_data_ptr<scalar_t>();
+ auto* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
+ const auto* values_data = data_contig.const_data_ptr<scalar_t>();
scalar_t initial_prod_value;
if (initial.has_value()) {
@@ -458,16 +458,16 @@
AT_DISPATCH_INDEX_TYPES(
lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
- auto* offsets_data_ptr = offsets.data_ptr<index_t>();
- auto* lengths_data_ptr = lengths.data_ptr<index_t>();
+ auto* offsets_data_ptr = offsets.const_data_ptr<index_t>();
+ auto* lengths_data_ptr = lengths.const_data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
data.scalar_type(),
"segment_reduce_cuda",
[&]() {
- auto* data_data_ptr = data.data_ptr<scalar_t>();
- auto* output_data_ptr = output.data_ptr<scalar_t>();
+ auto* data_data_ptr = data.const_data_ptr<scalar_t>();
+ auto* output_data_ptr = output.mutable_data_ptr<scalar_t>();
// initialize starting value
scalar_t initial_value;