Follow up to adaptive_max_pool3d() port (#19748)

Summary:
This is a follow up PR for #19547.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19748

Differential Revision: D15103230

Pulled By: ezyang

fbshipit-source-id: e7ce925faeadea502f77ed42d52e247c8c6571d8
diff --git a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
index 46e3a9e..2586123 100644
--- a/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
+++ b/aten/src/ATen/native/AdaptiveMaxPooling3d.cpp
@@ -22,7 +22,7 @@
 // 5d tensor B x D x T x H x W
 
 template <typename scalar_t>
-static void adaptive_max_pool3d_out_frame(
+static void adaptive_max_pool3d_single_out_frame(
           scalar_t *input_p,
           scalar_t *output_p,
           int64_t *ind_p,
@@ -99,6 +99,39 @@
   }
 }
 
+template <typename scalar_t>
+static void adaptive_max_pool3d_out_frame(
+          scalar_t *input_data,
+          scalar_t *output_data,
+          int64_t *indices_data,
+          int64_t sizeB,
+          int64_t sizeD,
+          int64_t isizeT,
+          int64_t isizeH,
+          int64_t isizeW,
+          int64_t osizeT,
+          int64_t osizeH,
+          int64_t osizeW,
+          int64_t istrideB,
+          int64_t istrideD,
+          int64_t istrideT,
+          int64_t istrideH,
+          int64_t istrideW)
+{
+  int64_t b;
+#pragma omp parallel for private(b)
+  for (b = 0; b < sizeB; b++)
+  {
+    adaptive_max_pool3d_single_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
+                                                   indices_data+b*sizeD*osizeT*osizeH*osizeW,
+                                                   sizeD,
+                                                   isizeT, isizeH, isizeW,
+                                                   osizeT, osizeH, osizeW,
+                                                   istrideD, istrideT,
+                                                   istrideH, istrideW);
+  }
+}
+
 void adaptive_max_pool3d_out_cpu_template(
           Tensor& output,
           Tensor& indices,
@@ -172,47 +205,43 @@
       auto output_data = output.data<scalar_t>();
       auto indices_data = indices.data<int64_t>();
 
-      adaptive_max_pool3d_out_frame<scalar_t>(input_data, output_data,
-                                              indices_data,
-                                              sizeD,
-                                              isizeT, isizeH, isizeW,
-                                              osizeT, osizeH, osizeW,
-                                              istrideD, istrideT,
-                                              istrideH, istrideW);
+      adaptive_max_pool3d_single_out_frame<scalar_t>(input_data, output_data,
+                                                     indices_data,
+                                                     sizeD,
+                                                     isizeT, isizeH, isizeW,
+                                                     osizeT, osizeH, osizeW,
+                                                     istrideD, istrideT,
+                                                     istrideH, istrideW);
       }
     );
   }
   else
   {
-    int64_t b;
-
     output.resize_({sizeB, sizeD, osizeT, osizeH, osizeW});
     /* indices will contain max input locations for each output point */
     indices.resize_({sizeB, sizeD, osizeT, osizeH, osizeW});
 
-#pragma omp parallel for private(b)
-    for (b = 0; b < sizeB; b++)
-    {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
-        auto input_data = input.data<scalar_t>();
-        auto output_data = output.data<scalar_t>();
-        auto indices_data = indices.data<int64_t>();
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "adaptive_max_pool3d_cpu", [&] {
+      auto input_data = input.data<scalar_t>();
+      auto output_data = output.data<scalar_t>();
+      auto indices_data = indices.data<int64_t>();
 
-        adaptive_max_pool3d_out_frame<scalar_t>(input_data+b*istrideB, output_data+b*sizeD*osizeT*osizeH*osizeW,
-                                                indices_data+b*sizeD*osizeT*osizeH*osizeW,
-                                                sizeD,
-                                                isizeT, isizeH, isizeW,
-                                                osizeT, osizeH, osizeW,
-                                                istrideD, istrideT,
-                                                istrideH, istrideW);
-        }
-      );
-    }
+      adaptive_max_pool3d_out_frame<scalar_t>(input_data, output_data,
+                                              indices_data,
+                                              sizeB,
+                                              sizeD,
+                                              isizeT, isizeH, isizeW,
+                                              osizeT, osizeH, osizeW,
+                                              istrideB,
+                                              istrideD, istrideT,
+                                              istrideH, istrideW);
+      }
+    );
   }
 }
 
 template <typename scalar_t>
-static void adaptive_max_pool3d_backward_out_frame(
+static void adaptive_max_pool3d_backward_single_out_frame(
           scalar_t *gradInput_p,
           scalar_t *gradOutput_p,
           int64_t *ind_p,
@@ -251,6 +280,32 @@
   }
 }
 
+template <typename scalar_t>
+static void adaptive_max_pool3d_backward_out_frame(
+          scalar_t *gradInput_data,
+          scalar_t *gradOutput_data,
+          int64_t *indices_data,
+          int64_t sizeB,
+          int64_t sizeD,
+          int64_t isizeT,
+          int64_t isizeH,
+          int64_t isizeW,
+          int64_t osizeT,
+          int64_t osizeH,
+          int64_t osizeW)
+{
+  int64_t b;
+#pragma omp parallel for private(b)
+  for (b = 0; b < sizeB; b++)
+  {
+    adaptive_max_pool3d_backward_single_out_frame<scalar_t>(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
+                                                            indices_data+b*sizeD*osizeT*osizeH*osizeW,
+                                                            sizeD,
+                                                            isizeT, isizeH, isizeW,
+                                                            osizeT, osizeH, osizeW);
+  }
+}
+
 Tensor& adaptive_max_pool3d_backward_out_cpu_template(
           Tensor& gradInput,
           const Tensor& gradOutput_,
@@ -305,36 +360,32 @@
         scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
         int64_t *indices_data = indices.data<int64_t>();
 
-        adaptive_max_pool3d_backward_out_frame<scalar_t>(gradInput_data, gradOutput_data,
-                                                         indices_data,
-                                                         sizeD,
-                                                         isizeT, isizeH, isizeW,
-                                                         osizeT, osizeH, osizeW);
+        adaptive_max_pool3d_backward_single_out_frame<scalar_t>(gradInput_data, gradOutput_data,
+                                                                indices_data,
+                                                                sizeD,
+                                                                isizeT, isizeH, isizeW,
+                                                                osizeT, osizeH, osizeW);
       }
     );
   }
   else
   {
-    int64_t b;
-#pragma omp parallel for private(b)
-    for (b = 0; b < sizeB; b++)
-    {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
-        "adaptive_max_pool3d_backward",
-        [&] {
-          /* get raw pointers */
-          scalar_t *gradInput_data = gradInput.data<scalar_t>();
-          scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
-          int64_t *indices_data = indices.data<int64_t>();
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+      "adaptive_max_pool3d_backward",
+      [&] {
+        /* get raw pointers */
+        scalar_t *gradInput_data = gradInput.data<scalar_t>();
+        scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+        int64_t *indices_data = indices.data<int64_t>();
 
-          adaptive_max_pool3d_backward_out_frame<scalar_t>(gradInput_data+b*sizeD*isizeT*isizeH*isizeW, gradOutput_data+b*sizeD*osizeT*osizeH*osizeW,
-                                                           indices_data+b*sizeD*osizeT*osizeH*osizeW,
-                                                           sizeD,
-                                                           isizeT, isizeH, isizeW,
-                                                           osizeT, osizeH, osizeW);
-        }
-      );
-    }
+        adaptive_max_pool3d_backward_out_frame<scalar_t>(gradInput_data, gradOutput_data,
+                                                         indices_data,
+                                                         sizeB,
+                                                         sizeD,
+                                                         isizeT, isizeH, isizeW,
+                                                         osizeT, osizeH, osizeW);
+      }
+    );
   }
 
   return gradInput;
diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
index 49f6151..fcdf68d 100644
--- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
+++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling3d.cu
@@ -1,30 +1,31 @@
-#include "ATen/ATen.h"
-#include "ATen/cuda/CUDAApplyUtils.cuh"
-#include "ATen/cuda/CUDAContext.h"
-#include "ATen/NativeFunctions.h"
-#include "ATen/TensorUtils.h"
-#include "ATen/Utils.h"
-#include "c10/util/Exception.h"
+#include <ATen/ATen.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/Utils.h>
+#include <c10/util/Exception.h>
 #include <THC/THCGeneral.h>
-#include "THC/THCNumerics.cuh"
+#include <THC/THCNumerics.cuh>
 
 #include <algorithm>
 #include <cfloat>
 #include <cmath>
 
-#define CUDA_MAX_THREADS 1024   // this is safe, in reality 256 is our limit
-
-#define START_IND(a,b,c) (int)std::floor((float)(a * c) / b)
-#define END_IND(a,b,c) (int)std::ceil((float)((a + 1) * c) / b)
-// #define START_IND(a,b,c) a * c / b
-// #define END_IND(a,b,c)  (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
-
 
 namespace at {
 namespace native {
 
 namespace {
 
+__device__ inline int start_index(int a, int b, int c) {
+  return (int)std::floor((float)(a * c) / b);
+}
+
+__device__ inline int end_index(int a, int b, int c) {
+  return (int)std::ceil((float)((a + 1) * c) / b);
+}
+
 // 5d tensor B x D x T x H x W
 
 /*
@@ -58,8 +59,8 @@
   int d = o_plane / osizeT;  // slice/feature
 
   // input frame/time ramge is fixed.
-  int istartT = START_IND(ot, osizeT, isizeT);
-  int iendT = END_IND(ot, osizeT, isizeT);
+  int istartT = start_index(ot, osizeT, isizeT);
+  int iendT = end_index(ot, osizeT, isizeT);
   int kT = iendT - istartT;
 
   // input offset by slice/feature and earliest relevant frame/time
@@ -72,14 +73,14 @@
   // For all output pixels...
   for(oh = ostartH; oh < oendH; oh += ostepH) {
 
-    int istartH = START_IND(oh, osizeH, isizeH);
-    int iendH   = END_IND(oh, osizeH, isizeH);
+    int istartH = start_index(oh, osizeH, isizeH);
+    int iendH   = end_index(oh, osizeH, isizeH);
     int kH = iendH - istartH;
 
     for(ow = ostartW; ow < oendW; ow += ostepW) {
 
-      int istartW = START_IND(ow, osizeW, isizeW);
-      int iendW   = END_IND(ow, osizeW, isizeW);
+      int istartW = start_index(ow, osizeW, isizeW);
+      int iendW   = end_index(ow, osizeW, isizeW);
       int kW = iendW - istartW;
 
       // Compute the average pooling from corresponding input pixels
@@ -109,6 +110,33 @@
   }
 }
 
+template <typename scalar_t>
+void adaptivemaxpool_loop(
+                        scalar_t *input_data,
+                        scalar_t *output_data,
+                        int64_t *indices_data,
+                        int64_t totalZ,
+                        int isizeT, int isizeH, int isizeW,
+                        int osizeT, int osizeH, int osizeW,
+                        int64_t istrideD,
+                        int64_t istrideT, int64_t istrideH, int64_t istrideW)
+{
+  int64_t offsetZ = 0;
+  dim3 threads(32, 8);
+  // each H*W plane is processed by blocksH thread blocks
+  int blocksH = std::max((int)(16L / totalZ), 1);
+  while (totalZ > 0) {
+    dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+    adaptivemaxpool<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
+      input_data, output_data, indices_data, isizeT, isizeH, isizeW,
+      osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW, offsetZ);
+
+    totalZ -= 65535;
+    offsetZ += 65535;
+    THCudaCheck(cudaGetLastError());
+  }
+}
+
 /*
  * Description:
  *    This function computes the gradInput from gradOutput.
@@ -162,6 +190,30 @@
   }
 }
 
+template <typename scalar_t>
+void adaptivemaxgradinput_loop(
+  scalar_t *gradInput_data,
+  scalar_t *gradOutput_data,
+  int64_t *indices_data,
+  int64_t totalZ,
+  int isizeT, int isizeH, int isizeW,
+  int osizeT, int osizeH, int osizeW)
+{
+  int64_t offsetZ = 0;
+  dim3 threads(32, 8);
+  // each H*W plane is processed by blocksH thread blocks
+  int blocksH = std::max((int)(16L / totalZ), 1);
+  while (totalZ > 0) {
+    dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+    adaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
+      gradInput_data, gradOutput_data, indices_data,
+      isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ);
+
+    totalZ -= 65535;
+    offsetZ += 65535;
+    THCudaCheck(cudaGetLastError());
+  }
+}
 
 /*
  * Description:
@@ -215,6 +267,31 @@
   }
 }
 
+template <typename scalar_t>
+void atomicadaptivemaxgradinput_loop(
+  scalar_t *gradInput_data,
+  scalar_t *gradOutput_data,
+  int64_t *indices_data,
+  int64_t totalZ,
+  int isizeT, int isizeH, int isizeW,
+  int osizeT, int osizeH, int osizeW)
+{
+  int64_t offsetZ = 0;
+  dim3 threads(32, 8);
+  // each H*W plane is processed by blocksH thread blocks
+  int blocksH = std::max((int)(16L / totalZ), 1);
+  while (totalZ > 0) {
+    dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+    atomicadaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
+      gradInput_data, gradOutput_data, indices_data,
+      isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ);
+
+    totalZ -= 65535;
+    offsetZ += 65535;
+    THCudaCheck(cudaGetLastError());
+  }
+}
+
 // 5d tensor B x D x T x H x W
 
 void adaptive_max_pool3d_out_cuda_template(
@@ -286,30 +363,18 @@
     totalZ = sizeB * sizeD * osizeT;
   }
 
-  int64_t offsetZ = 0;
-  dim3 threads(32, 8);
-  // each H*W plane is processed by blocksH thread blocks
-  int blocksH = std::max((int)(16L / totalZ), 1);
-  while (totalZ > 0) {
-    dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
-    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
-      "adaptive_max_pool3d_cuda",
-      [&] {
-        scalar_t *input_data = input.data<scalar_t>();
-        scalar_t *output_data = output.data<scalar_t>();
-        int64_t *indices_data = indices.data<int64_t>();
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+    "adaptive_max_pool3d_cuda",
+    [&] {
+      scalar_t *input_data = input.data<scalar_t>();
+      scalar_t *output_data = output.data<scalar_t>();
+      int64_t *indices_data = indices.data<int64_t>();
 
-        adaptivemaxpool<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
-          input_data, output_data, indices_data, isizeT, isizeH, isizeW,
-          osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW, offsetZ
-      );
-
-      totalZ -= 65535;
-      offsetZ += 65535;
-      }
-    );
-    THCudaCheck(cudaGetLastError());
-  }
+      adaptivemaxpool_loop(
+        input_data, output_data, indices_data, totalZ, isizeT, isizeH, isizeW,
+        osizeT, osizeH, osizeW, istrideD, istrideT, istrideH, istrideW);
+    }
+  );
 }
 
 void adaptive_max_pool3d_backward_out_cuda_template(
@@ -364,47 +429,34 @@
     totalZ = sizeB * sizeD * osizeT;
   }
 
-  int64_t offsetZ = 0;
-  dim3 threads(32, 8);
-  // each H*W plane is processed by blocksH thread blocks
-  int blocksH = std::max((int)(16L / totalZ), 1);
-  while (totalZ > 0) {
-    dim3 blocks(totalZ > 65535 ? 65535 : totalZ, blocksH);
+  if (atomic) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+      "adaptive_max_pool3d_backward_cuda",
+      [&] {
+        scalar_t *gradInput_data = gradInput.data<scalar_t>();
+        scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+        int64_t *indices_data = indices.data<int64_t>();
 
-    if (atomic)
-    {
-      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
-        "adaptive_max_pool3d_backward_cuda",
-        [&] {
-          scalar_t *gradInput_data = gradInput.data<scalar_t>();
-          scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
-          int64_t *indices_data = indices.data<int64_t>();
+        atomicadaptivemaxgradinput_loop(
+          gradInput_data, gradOutput_data, indices_data,
+          totalZ,
+          isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
+      }
+    );
+  } else {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+      "adaptive_max_pool3d_backward_cuda",
+      [&] {
+        scalar_t *gradInput_data = gradInput.data<scalar_t>();
+        scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+        int64_t *indices_data = indices.data<int64_t>();
 
-          atomicadaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
-            gradInput_data, gradOutput_data, indices_data,
-            isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ
-          );
-        }
-      );
-    } else {
-      AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
-        "adaptive_max_pool3d_backward_cuda",
-        [&] {
-          scalar_t *gradInput_data = gradInput.data<scalar_t>();
-          scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
-          int64_t *indices_data = indices.data<int64_t>();
-
-          adaptivemaxgradinput<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
-            gradInput_data, gradOutput_data, indices_data,
-            isizeT, isizeH, isizeW, osizeT, osizeH, osizeW, offsetZ
-          );
-        }
-      );
-    }
-
-    totalZ -= 65535;
-    offsetZ += 65535;
-    THCudaCheck(cudaGetLastError());
+        adaptivemaxgradinput_loop(
+          gradInput_data, gradOutput_data, indices_data,
+          totalZ,
+          isizeT, isizeH, isizeW, osizeT, osizeH, osizeW);
+      }
+    );
   }
 }
 
@@ -468,7 +520,3 @@
 
 } // at::native
 } // at
-
-#undef CUDA_MAX_THREADS
-#undef START_IND
-#undef END_IND