Follow up to adaptive_max_pool3d() port (#19748)
Summary:
This is a follow up PR for #19547.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19748
Differential Revision: D15103230
Pulled By: ezyang
fbshipit-source-id: e7ce925faeadea502f77ed42d52e247c8c6571d8
diff --git a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
index 46e3a9e..2586123 100644
--- a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
+++ b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
@@ -22,7 +22,7 @@
// 5d tensor B x D x T x H x W
template <typename scalar_t>
-static void adaptive_max_pool3d_out_frame(
+static void adaptive_max_pool3d_single_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t *ind_p,
@@ -99,6 +99,39 @@
}
}
+template <typename scalar_t>
+static void adaptive_max_pool3d_out_frame(
+ scalar_t *input_data,
+ scalar_t *output_data,
+ int64_t *indices_data,
+ int64_t sizeB,
+ int64_t sizeD,
+ int64_t isizeT,
+ int64_t isizeH,
+ int64_t isizeW,
+ int64_t osizeT,
+ int64_t osizeH,
+ int64_t osizeW,
+ int64_t istrideB,
+ int64_t istrideD,
+ int64_t istrideT,
+ int64_t istrideH,
+ int64_t istrideW)
+{
+ int64_t b;
+#pragma omp parallel for private(b)
+ for (b = 0; b < sizeB; b++)
+ {
+ adaptive_max_pool3d_single_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
+ indices_data+b*sizeD*osizeT*osizeH*osizeW,
+ sizeD,
+ isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW,
+ istrideD, istrideT,
+ istrideH, istrideW);
+ }
+}
+
void adaptive_max_pool3d_out_cpu_template(
Tensor& output,
Tensor& indices,
@@ -172,47 +205,43 @@
auto output_data = output.data<scalar_t>();
auto indices_data = indices.data<int64_t>();
- adaptive_max_pool3d_out_frame<scalar_t>(input_data, output_data,
- indices_data,
- sizeD,
- isizeT, isizeH, isizeW,
- osizeT, osizeH, osizeW,
- istrideD, istrideT,
- istrideH, istrideW);
+ adaptive_max_pool3d_single_out_frame<scalar_t>(input_data, output_data,
+ indices_data,
+ sizeD,
+ isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW,
+ istrideD, istrideT,
+ istrideH, istrideW);
}
);
}
else
{
- int64_t b;
-
output.resize_({sizeB, sizeD, osizeT, osizeH, osizeW});
/* indices will contain max input locations for each output point */
indices.resize_({sizeB, sizeD, osizeT, osizeH, osizeW});
-#pragma omp parallel for private(b)
- for (b = 0; b < sizeB; b++)
- {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
- auto input_data = input.data<scalar_t>();
- auto output_data = output.data<scalar_t>();
- auto indices_data = indices.data<int64_t>();
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
+ auto input_data = input.data<scalar_t>();
+ auto output_data = output.data<scalar_t>();
+ auto indices_data = indices.data<int64_t>();
- adaptive_max_pool3d_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
- indices_data+b*sizeD*osizeT*osizeH*osizeW,
- sizeD,
- isizeT, isizeH, isizeW,
- osizeT, osizeH, osizeW,
- istrideD, istrideT,
- istrideH, istrideW);
- }
- );
- }
+ adaptive_max_pool3d_out_frame<scalar_t>(input_data, output_data,
+ indices_data,
+ sizeB,
+ sizeD,
+ isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW,
+ istrideB,
+ istrideD, istrideT,
+ istrideH, istrideW);
+ }
+ );
}
}
template <typename scalar_t>
-static void adaptive_max_pool3d_backward_out_frame(
+static void adaptive_max_pool3d_backward_single_out_frame(
scalar_t *gradInput_p,
scalar_t *gradOutput_p,
int64_t *ind_p,
@@ -251,6 +280,32 @@
}
}
+template <typename scalar_t>
+static void adaptive_max_pool3d_backward_out_frame(
+ scalar_t *gradInput_data,
+ scalar_t *gradOutput_data,
+ int64_t *indices_data,
+ int64_t sizeB,
+ int64_t sizeD,
+ int64_t isizeT,
+ int64_t isizeH,
+ int64_t isizeW,
+ int64_t osizeT,
+ int64_t osizeH,
+ int64_t osizeW)
+{
+ int64_t b;
+#pragma omp parallel for private(b)
+ for (b = 0; b < sizeB; b++)
+ {
+ adaptive_max_pool3d_backward_single_out_frame<scalar_t>(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
+ indices_data+b*sizeD*osizeT*osizeH*osizeW,
+ sizeD,
+ isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW);
+ }
+}
+
Tensor& adaptive_max_pool3d_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput_,
@@ -305,36 +360,32 @@
scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
int64_t *indices_data = indices.data<int64_t>();
- adaptive_max_pool3d_backward_out_frame<scalar_t>(gradInput_data, gradOutput_data,
- indices_data,
- sizeD,
- isizeT, isizeH, isizeW,
- osizeT, osizeH, osizeW);
+ adaptive_max_pool3d_backward_single_out_frame<scalar_t>(gradInput_data, gradOutput_data,
+ indices_data,
+ sizeD,
+ isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW);
}
);
}
else
{
- int64_t b;
-#pragma omp parallel for private(b)
- for (b = 0; b < sizeB; b++)
- {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
- "adaptive_max_pool3d_backward",
- [&] {
- /* get raw pointers */
- scalar_t *gradInput_data = gradInput.data<scalar_t>();
- scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
- int64_t *indices_data = indices.data<int64_t>();
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+ "adaptive_max_pool3d_backward",
+ [&] {
+ /* get raw pointers */
+ scalar_t *gradInput_data = gradInput.data<scalar_t>();
+ scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+ int64_t *indices_data = indices.data<int64_t>();
- adaptive_max_pool3d_backward_out_frame<scalar_t>(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
- indices_data+b*sizeD*osizeT*osizeH*osizeW,
- sizeD,
- isizeT, isizeH, isizeW,
- osizeT, osizeH, osizeW);
- }
- );
- }
+ adaptive_max_pool3d_backward_out_frame<scalar_t>(gradInput_data, gradOutput_data,
+ indices_data,
+ sizeB,
+ sizeD,
+ isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW);
+ }
+ );
}
return gradInput;
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
index 49f6151..fcdf68d 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
@@ -1,30 +1,31 @@
-#include "ATen/ATen.h"
-#include "ATen/cuda/CUDAApplyUtils.cuh"
-#include "ATen/cuda/CUDAContext.h"
-#include "ATen/NativeFunctions.h"
-#include "ATen/TensorUtils.h"
-#include "ATen/Utils.h"
-#include "c10/util/Exception.h"
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <c10/util/Exception.h>
#include <THC/THCGeneral.h>
-#include "THC/THCNumerics.cuh"
+#include <THC/THCNumerics.cuh>
#include <algorithm>
#include <cfloat>
#include <cmath>
-#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
-
-#define START_IND(a,b,c) (int)std::floor((float)(a * c) / b)
-#define END_IND(a,b,c) (int)std::ceil((float)((a + 1) * c) / b)
-// #define START_IND(a,b,c) a * c / b
-// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
-
namespace at {
namespace native {
namespace {
+__device__ inline int start_index(int a, int b, int c) {
+ return (int)std::floor((float)(a * c) / b);
+}
+
+__device__ inline int end_index(int a, int b, int c) {
+ return (int)std::ceil((float)((a + 1) * c) / b);
+}
+
// 5d tensor B x D x T x H x W
/*
@@ -58,8 +59,8 @@
int d = o_plane / osizeT; // slice/feature
// input frame/time ramge is fixed.
- int istartT = START_IND(ot, osizeT, isizeT);
- int iendT = END_IND(ot, osizeT, isizeT);
+ int istartT = start_index(ot, osizeT, isizeT);
+ int iendT = end_index(ot, osizeT, isizeT);
int kT = iendT - istartT;
// input offset by slice/feature and earliest relevant frame/time
@@ -72,14 +73,14 @@
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
- int istartH = START_IND(oh, osizeH, isizeH);
- int iendH = END_IND(oh, osizeH, isizeH);
+ int istartH = start_index(oh, osizeH, isizeH);
+ int iendH = end_index(oh, osizeH, isizeH);
int kH = iendH - istartH;
for(ow = ostartW; ow < oendW; ow += ostepW) {
- int istartW = START_IND(ow, osizeW, isizeW);
- int iendW = END_IND(ow, osizeW, isizeW);
+ int istartW = start_index(ow, osizeW, isizeW);
+ int iendW = end_index(ow, osizeW, isizeW);
int kW = iendW - istartW;
// Compute the average pooling from corresponding input pixels
@@ -109,6 +110,33 @@
}
}
+template <typename scalar_t>
+void adaptivemaxpool_loop(
+ scalar_t *input_data,
+ scalar_t *output_data,
+ int64_t *indices_data,
+ int64_t totalZ,
+ int isizeT, int isizeH, int isizeW,
+ int osizeT, int osizeH, int osizeW,
+ int64_t istrideD,
+ int64_t istrideT, int64_t istrideH, int64_t istrideW)
+{
+ int64_t offsetZ = 0;
+ dim3 threads(32, 8);
+ // each H*W plane is processed by blocksH thread blocks
+ int blocksH = std::max((int)(16L / totalZ), 1);
+ while (totalZ > 0) {
+ dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+ adaptivemaxpool<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
+ input_data, output_data, indices_data, isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW, offsetZ);
+
+ totalZ -= 65535;
+ offsetZ += 65535;
+ THCudaCheck(cudaGetLastError());
+ }
+}
+
/*
* Description:
* This function computes the gradInput from gradOutput.
@@ -162,6 +190,30 @@
}
}
+template <typename scalar_t>
+void adaptivemaxgradinput_loop(
+ scalar_t *gradInput_data,
+ scalar_t *gradOutput_data,
+ int64_t *indices_data,
+ int64_t totalZ,
+ int isizeT, int isizeH, int isizeW,
+ int osizeT, int osizeH, int osizeW)
+{
+ int64_t offsetZ = 0;
+ dim3 threads(32, 8);
+ // each H*W plane is processed by blocksH thread blocks
+ int blocksH = std::max((int)(16L / totalZ), 1);
+ while (totalZ > 0) {
+ dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+ adaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
+ gradInput_data, gradOutput_data, indices_data,
+ isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ);
+
+ totalZ -= 65535;
+ offsetZ += 65535;
+ THCudaCheck(cudaGetLastError());
+ }
+}
/*
* Description:
@@ -215,6 +267,31 @@
}
}
+template <typename scalar_t>
+void atomicadaptivemaxgradinput_loop(
+ scalar_t *gradInput_data,
+ scalar_t *gradOutput_data,
+ int64_t *indices_data,
+ int64_t totalZ,
+ int isizeT, int isizeH, int isizeW,
+ int osizeT, int osizeH, int osizeW)
+{
+ int64_t offsetZ = 0;
+ dim3 threads(32, 8);
+ // each H*W plane is processed by blocksH thread blocks
+ int blocksH = std::max((int)(16L / totalZ), 1);
+ while (totalZ > 0) {
+ dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+ atomicadaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
+ gradInput_data, gradOutput_data, indices_data,
+ isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ);
+
+ totalZ -= 65535;
+ offsetZ += 65535;
+ THCudaCheck(cudaGetLastError());
+ }
+}
+
// 5d tensor B x D x T x H x W
void adaptive_max_pool3d_out_cuda_template(
@@ -286,30 +363,18 @@
totalZ = sizeB * sizeD * osizeT;
}
- int64_t offsetZ = 0;
- dim3 threads(32, 8);
- // each H*W plane is processed by blocksH thread blocks
- int blocksH = std::max((int)(16L / totalZ), 1);
- while (totalZ > 0) {
- dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
- "adaptive_max_pool3d_cuda",
- [&] {
- scalar_t *input_data = input.data<scalar_t>();
- scalar_t *output_data = output.data<scalar_t>();
- int64_t *indices_data = indices.data<int64_t>();
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+ "adaptive_max_pool3d_cuda",
+ [&] {
+ scalar_t *input_data = input.data<scalar_t>();
+ scalar_t *output_data = output.data<scalar_t>();
+ int64_t *indices_data = indices.data<int64_t>();
- adaptivemaxpool<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
- input_data, output_data, indices_data, isizeT, isizeH, isizeW,
- osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW, offsetZ
- );
-
- totalZ -= 65535;
- offsetZ += 65535;
- }
- );
- THCudaCheck(cudaGetLastError());
- }
+ adaptivemaxpool_loop(
+ input_data, output_data, indices_data, totalZ, isizeT, isizeH, isizeW,
+ osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW);
+ }
+ );
}
void adaptive_max_pool3d_backward_out_cuda_template(
@@ -364,47 +429,34 @@
totalZ = sizeB * sizeD * osizeT;
}
- int64_t offsetZ = 0;
- dim3 threads(32, 8);
- // each H*W plane is processed by blocksH thread blocks
- int blocksH = std::max((int)(16L / totalZ), 1);
- while (totalZ > 0) {
- dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+ if (atomic) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+ "adaptive_max_pool3d_backward_cuda",
+ [&] {
+ scalar_t *gradInput_data = gradInput.data<scalar_t>();
+ scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+ int64_t *indices_data = indices.data<int64_t>();
- if (atomic)
- {
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
- "adaptive_max_pool3d_backward_cuda",
- [&] {
- scalar_t *gradInput_data = gradInput.data<scalar_t>();
- scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
- int64_t *indices_data = indices.data<int64_t>();
+ atomicadaptivemaxgradinput_loop(
+ gradInput_data, gradOutput_data, indices_data,
+ totalZ,
+ isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
+ }
+ );
+ } else {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+ "adaptive_max_pool3d_backward_cuda",
+ [&] {
+ scalar_t *gradInput_data = gradInput.data<scalar_t>();
+ scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+ int64_t *indices_data = indices.data<int64_t>();
- atomicadaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
- gradInput_data, gradOutput_data, indices_data,
- isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ
- );
- }
- );
- } else {
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
- "adaptive_max_pool3d_backward_cuda",
- [&] {
- scalar_t *gradInput_data = gradInput.data<scalar_t>();
- scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
- int64_t *indices_data = indices.data<int64_t>();
-
- adaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
- gradInput_data, gradOutput_data, indices_data,
- isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ
- );
- }
- );
- }
-
- totalZ -= 65535;
- offsetZ += 65535;
- THCudaCheck(cudaGetLastError());
+ adaptivemaxgradinput_loop(
+ gradInput_data, gradOutput_data, indices_data,
+ totalZ,
+ isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
+ }
+ );
}
}
@@ -468,7 +520,3 @@
} // at::native
} // at
-
-#undef CUDA_MAX_THREADS
-#undef START_IND
-#undef END_IND