make ATen/native/cuda/AdaptiveAveragePooling3d.cu data_ptr-correct (#99324)
make ATen/native/cuda/AdaptiveAveragePooling3d.cu data_ptr-correct
Test Plan: Rely on CI.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99324
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu
index 4bc45f3..0de39a7 100644
--- a/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveAveragePooling3d.cu
@@ -50,7 +50,7 @@
*/
template <typename scalar_t, typename accscalar_t>
__global__ void adaptiveaveragepool(
- scalar_t *input, scalar_t *output,
+ const scalar_t *input, scalar_t *output,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t istrideD,
@@ -78,7 +78,7 @@
int kT = iendT - istartT;
// input offset by slice/feature and earliest relevant frame/time
- scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
+ const scalar_t *input_dt = input + d*istrideD + istartT*istrideT;
// output offset by slice/feature and frame/time
scalar_t *output_dt = output + o_plane*osizeH*osizeW;
@@ -94,7 +94,7 @@
int kW = iendW - istartW;
// Compute the average pooling from corresponding input pixels
- scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
+ const scalar_t *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
scalar_t *ptr_output = output_dt + oh*osizeW + ow;
accscalar_t sum = static_cast<accscalar_t>(0);
@@ -117,7 +117,7 @@
template <typename scalar_t, typename accscalar_t>
void adaptiveaveragepool_loop(
- scalar_t *input_data, scalar_t *output_data,
+ const scalar_t *input_data, scalar_t *output_data,
int64_t totalZ,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
@@ -151,7 +151,7 @@
*/
template <typename scalar_t, typename accscalar_t>
__global__ void adaptiveaveragegradinput(
- scalar_t *gradInput, scalar_t *gradOutput,
+ scalar_t *gradInput, const scalar_t *gradOutput,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t offsetZ)
@@ -179,7 +179,7 @@
// gradInput offset by slice/feature and frame/time.
scalar_t *gradInput_dt = gradInput + i_plane*isizeH*isizeW;
// gradOutput offset by slice/feature and earliest relevant frame/time
- scalar_t *gradOutput_dt = gradOutput + (d*osizeT + ostartT)*osizeH*osizeW;
+ const scalar_t *gradOutput_dt = gradOutput + (d*osizeT + ostartT)*osizeH*osizeW;
// For all input pixels...
for (ih = istartH; ih < iendH; ih += istepH) {
@@ -192,7 +192,7 @@
// Compute the gradients from corresponding output pixels
scalar_t *ptr_gradInput = gradInput_dt + ih*isizeW + iw;
- scalar_t *ptr_gradOutput = gradOutput_dt;
+ const scalar_t *ptr_gradOutput = gradOutput_dt;
// for all relevant output pixels
int ot, oh, ow;
@@ -215,7 +215,7 @@
template <typename scalar_t, typename accscalar_t>
void adaptiveaveragegradinput_loop(
- scalar_t *gradInput_data, scalar_t *gradOutput_data,
+ scalar_t *gradInput_data, const scalar_t *gradOutput_data,
int64_t totalZ,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW) {
@@ -249,7 +249,7 @@
*/
template <typename scalar_t>
__global__ void atomicadaptiveaveragegradinput(
- scalar_t *gradInput, scalar_t *gradOutput,
+ scalar_t *gradInput, const scalar_t *gradOutput,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t offsetZ)
@@ -278,7 +278,7 @@
// gradInput offset by slice/feature and earliest relevant frame/time
scalar_t *gradInput_nt = gradInput + (d*isizeT + istartT)*isizeH*isizeW;
// gradOutput offset by slice/feature and frame/time
- scalar_t *gradOutput_nt = gradOutput + o_plane*osizeH*osizeW;
+ const scalar_t *gradOutput_nt = gradOutput + o_plane*osizeH*osizeW;
// For all output pixels...
for (oh = ostartH; oh < oendH; oh += ostepH) {
@@ -293,7 +293,7 @@
// Compute the gradients from corresponding input pixels
scalar_t *ptr_gradInput = gradInput_nt + istartH*isizeW + istartW;
- scalar_t *ptr_gradOutput = gradOutput_nt + oh*osizeW + ow;
+ const scalar_t *ptr_gradOutput = gradOutput_nt + oh*osizeW + ow;
scalar_t grad_delta = *ptr_gradOutput / kT / kH / kW;
int it, ih, iw;
@@ -311,7 +311,7 @@
template <typename scalar_t>
void atomicadaptiveaveragegradinput_loop(
- scalar_t* gradInput_data, scalar_t* gradOutput_data,
+ scalar_t* gradInput_data, const scalar_t* gradOutput_data,
int64_t totalZ,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW) {
@@ -407,8 +407,8 @@
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
- scalar_t* input_data = input.data_ptr<scalar_t>();
- scalar_t* output_data = output.data_ptr<scalar_t>();
+ const scalar_t* input_data = input.const_data_ptr<scalar_t>();
+ scalar_t* output_data = output.mutable_data_ptr<scalar_t>();
adaptiveaveragepool_loop<scalar_t, accscalar_t>(
input_data, output_data,
@@ -478,8 +478,8 @@
if (atomic) {
AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
- scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
- scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
+ scalar_t* gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
+ const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
atomicadaptiveaveragegradinput_loop(
gradInput_data, gradOutput_data,
@@ -492,8 +492,8 @@
input.scalar_type(), "adaptive_avg_pool3d_backward_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
- scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
- scalar_t* gradOutput_data = gradOutput.data_ptr<scalar_t>();
+ scalar_t* gradInput_data = gradInput.mutable_data_ptr<scalar_t>();
+ const scalar_t* gradOutput_data = gradOutput.const_data_ptr<scalar_t>();
adaptiveaveragegradinput_loop<scalar_t, accscalar_t>(
gradInput_data, gradOutput_data,