Port avg_pool3d() to ATen (#21732)

Summary:
This will need a conflict resolution once avg_pool2d() has been merged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21732

Differential Revision: D15824923

Pulled By: ezyang

fbshipit-source-id: 83341e0209b660aecf788272079d8135d78b6ff1
diff --git a/aten/src/ATen/native/AveragePool3d.cpp b/aten/src/ATen/native/AveragePool3d.cpp
new file mode 100644
index 0000000..ff9fdf6
--- /dev/null
+++ b/aten/src/ATen/native/AveragePool3d.cpp
@@ -0,0 +1,488 @@
+#include <ATen/ATen.h>
+#include <ATen/Parallel.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/native/Pool.h>
+#include <tuple>
+
+
+namespace at {
+namespace native {
+
+namespace {
+
+template <typename scalar_t>
+static void avg_pool3d_out_frame(
+          scalar_t *input_p,
+          scalar_t *output_p,
+          int64_t nslices,
+          int64_t itime,
+          int64_t iwidth,
+          int64_t iheight,
+          int64_t otime,
+          int64_t owidth,
+          int64_t oheight,
+          int kT,
+          int kW,
+          int kH,
+          int dT,
+          int dW,
+          int dH,
+          int padT,
+          int padW,
+          int padH,
+          bool count_include_pad)
+{
+  at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
+    for (auto k = start; k < end; k++)
+    {
+      int64_t i, j, ti;
+
+      /* local pointers. */
+      scalar_t *ip = input_p + k * itime * iwidth * iheight;
+      scalar_t *op = output_p + k * otime * owidth * oheight;
+      for (i = 0; i < otime * oheight * owidth; ++i)
+        *(op + i) = 0;
+
+      /* loop over output */
+      for (ti = 0; ti < otime; ti++)
+      {
+        for (i = 0; i < oheight; i++)
+        {
+          for (j = 0; j < owidth; j++)
+          {
+            /* compute pool range. */
+            int64_t tstart = ti * dT - padT;
+            int64_t hstart = i  * dH - padH;
+            int64_t wstart = j  * dW - padW;
+            int64_t tend = std::min(tstart + kT, itime + padT);
+            int64_t hend = std::min(hstart + kH, iheight + padH);
+            int64_t wend = std::min(wstart + kW, iwidth + padW);
+            int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+            tstart = std::max(tstart, (int64_t) 0);
+            hstart = std::max(hstart, (int64_t) 0);
+            wstart = std::max(wstart, (int64_t) 0);
+            tend = std::min(tend, itime);
+            hend = std::min(hend, iheight);
+            wend = std::min(wend, iwidth);
+
+            int divide_factor;
+            if (count_include_pad)
+              divide_factor = pool_size;
+            else
+              divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
+
+            /* compute local sum: */
+            scalar_t sum = 0.0;
+            int64_t x, y, z;
+
+            for (z = tstart; z < tend; z++)
+            {
+              for (y = hstart; y < hend; y++)
+              {
+                for (x = wstart; x < wend; x++)
+                {
+                  sum +=  *(ip + z * iwidth * iheight + y * iwidth + x);
+                }
+              }
+            }
+
+            /* set output to local max */
+            *op++ += sum / divide_factor;
+          }
+        }
+      }
+    }
+  });
+}
+
+void avg_pool3d_out_cpu_template(
+  Tensor& output,
+  const Tensor& input_,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  // #20866 [JIT] stride.empty() is passed through
+  // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+  TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+                        (stride.empty() || stride.size() == 3) &&
+                        (padding.size() == 1 || padding.size() == 3),
+    "avg_pool3d: all IntArrayRef sizes must be 3");
+
+  TORCH_CHECK((input_.ndimension() == 4 || input_.ndimension() == 5),
+    "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+  const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+  const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+  const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+  const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+  const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+  const int padT = safe_downcast<int, int64_t>(padding[0]);
+  const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+  const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+  const int64_t nslices = input_.size(-4);
+  const int64_t itime = input_.size(-3);
+  const int64_t iheight = input_.size(-2);
+  const int64_t iwidth = input_.size(-1);
+
+  const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+  const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+  const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+  pool3d_shape_check(
+    input_,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    padT, padH, padW,
+    1, 1, 1,
+    itime, iheight, iwidth,
+    otime, oheight, owidth,
+    /*check_input_size=*/ true);
+
+  /* get contiguous input */
+  Tensor input = input_.contiguous();
+
+  if (input.ndimension() == 4) /* non-batch mode */
+  {
+    /* resize output */
+    output.resize_({nslices, otime, oheight, owidth});
+
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+      "avg_pool3d_out_frame",
+      [&] {
+        scalar_t *input_data = input.data<scalar_t>();
+        scalar_t *output_data = output.data<scalar_t>();
+
+        avg_pool3d_out_frame(
+          input_data, output_data, nslices,
+          itime, iwidth, iheight,
+          otime, owidth, oheight,
+          kT, kW, kH,
+          dT, dW, dH,
+          padT, padW, padH,
+          count_include_pad);
+    });
+  }
+  else  /* batch mode */
+  {
+    const int64_t nbatch = input.size(0);
+    const int64_t istride = nslices * itime * iwidth * iheight;
+    const int64_t ostride = nslices * otime * owidth * oheight;
+
+    /* resize output */
+    output.resize_({nbatch, nslices, otime, oheight, owidth});
+
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+      "avg_pool3d_out_frame",
+      [&] {
+        scalar_t *input_data = input.data<scalar_t>();
+        scalar_t *output_data = output.data<scalar_t>();
+
+        at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
+          for (auto p = start; p < end; p++) {
+            avg_pool3d_out_frame(
+              input_data + p * istride, output_data + p * ostride, nslices,
+              itime, iwidth, iheight,
+              otime, owidth, oheight,
+              kT, kW, kH,
+              dT, dW, dH,
+              padT, padW, padH,
+              count_include_pad
+            );
+          }
+        });
+    });
+  }
+}
+
+template <typename scalar_t>
+static void avg_pool3d_backward_out_frame(
+          scalar_t *gradInput_p,
+          scalar_t *gradOutput_p,
+          int64_t nslices,
+          int64_t itime,
+          int64_t iwidth,
+          int64_t iheight,
+          int64_t otime,
+          int64_t owidth,
+          int64_t oheight,
+          int kT,
+          int kW,
+          int kH,
+          int dT,
+          int dW,
+          int dH,
+          int padT,
+          int padW,
+          int padH,
+          bool count_include_pad)
+{
+  at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
+    for (auto k = start; k < end; k++)
+    {
+      int64_t i, j, ti;
+
+      /* local pointers */
+      scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
+      scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
+      for (i = 0; i < itime*iwidth*iheight; i++)
+        *(ip + i) = 0;
+
+      /* loop over output */
+      for (ti = 0; ti < otime; ti++)
+      {
+        for (i = 0; i < oheight; i++)
+        {
+          for (j = 0; j < owidth; j++)
+          {
+            int64_t tstart = ti * dT - padT;
+            int64_t hstart = i  * dH - padH;
+            int64_t wstart = j  * dW - padW;
+            int64_t tend = std::min(tstart + kT, itime + padT);
+            int64_t hend = std::min(hstart + kH, iheight + padH);
+            int64_t wend = std::min(wstart + kW, iwidth + padW);
+            int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
+            tstart = std::max(tstart, (int64_t) 0);
+            hstart = std::max(hstart, (int64_t) 0);
+            wstart = std::max(wstart, (int64_t) 0);
+            tend = std::min(tend, itime);
+            hend = std::min(hend, iheight);
+            wend = std::min(wend, iwidth);
+
+            int64_t divide_factor;
+            if (count_include_pad)
+              divide_factor = pool_size;
+            else
+              divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
+
+            /* scatter gradients out to footprint: */
+            scalar_t val  = *op++;
+
+            int64_t x,y,z;
+            for (z = tstart; z < tend; z++)
+            {
+              for (y = hstart; y < hend; y++)
+              {
+                for (x = wstart; x < wend; x++)
+                {
+                  *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  });
+}
+
+Tensor& avg_pool3d_backward_out_cpu_template(
+  Tensor& gradInput,
+  const Tensor& gradOutput_,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  // #20866 [JIT] stride.empty() is passed through
+  // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+  TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+                        (stride.empty() || stride.size() == 3) &&
+                        (padding.size() == 1 || padding.size() == 3),
+    "avg_pool3d: all IntArrayRef sizes must be 3");
+
+  TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
+    "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+  const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+  const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+  const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+  const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+  const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+  const int padT = safe_downcast<int, int64_t>(padding[0]);
+  const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+  const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+  const int64_t nslices = input.size(-4);
+  const int64_t itime = input.size(-3);
+  const int64_t iheight = input.size(-2);
+  const int64_t iwidth = input.size(-1);
+
+  /* get contiguous gradOutput */
+  Tensor gradOutput = gradOutput_.contiguous();
+
+  const int64_t otime = gradOutput.size(-3);
+  const int64_t oheight = gradOutput.size(-2);
+  const int64_t owidth = gradOutput.size(-1);
+
+  /* XXX shape check behavior from TH */
+  const int64_t otime_for_shape_check = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+  const int64_t oheight_for_shape_check = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+  const int64_t owidth_for_shape_check = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+  avg_pool3d_backward_shape_check(
+    input,
+    gradOutput_,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    padT, padH, padW,
+    itime, iheight, iwidth,
+    otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check);
+
+  /* resize */
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  /* backprop */
+  if (input.ndimension() == 4) /* non-batch mode*/
+  {
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+      "avg_pool3d_backward_out_frame",
+      [&] {
+       scalar_t *gradInput_data = gradInput.data<scalar_t>();
+       scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+
+       avg_pool3d_backward_out_frame(
+         gradInput_data, gradOutput_data,
+         nslices,
+         itime, iwidth, iheight,
+         otime, owidth, oheight,
+         kT, kW, kH,
+         dT, dW, dH,
+         padT, padW, padH,
+         count_include_pad);
+    });
+  }
+  else /* batch mode */
+  {
+    const int64_t nbatch = input.size(0);
+    const int64_t istride = nslices * itime * iwidth * iheight;
+    const int64_t ostride = nslices * otime * owidth * oheight;
+
+    AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+      "avg_pool3d_backward_out_frame",
+      [&] {
+        scalar_t *gradInput_data = gradInput.data<scalar_t>();
+        scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+
+        at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
+          for (auto p = start; p < end; p++)
+          {
+            avg_pool3d_backward_out_frame(
+              gradInput_data  + p * istride, gradOutput_data + p * ostride, nslices,
+              itime, iwidth, iheight,
+              otime, owidth, oheight,
+              kT, kW, kH,
+              dT, dW, dH,
+              padT, padW, padH,
+              count_include_pad
+            );
+          }
+        });
+    });
+  }
+
+  return gradInput;
+}
+
+} // namespace
+
+Tensor& avg_pool3d_out_cpu(
+  Tensor& output,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  avg_pool3d_out_cpu_template(
+    output,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return output;
+}
+
+Tensor avg_pool3d_cpu(
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  Tensor output = at::empty({0}, input.options());
+  avg_pool3d_out_cpu_template(
+    output,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return output;
+}
+
+Tensor& avg_pool3d_backward_out_cpu(
+  Tensor& gradInput,
+  const Tensor& gradOutput_,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  avg_pool3d_backward_out_cpu_template(
+    gradInput,
+    gradOutput_,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return gradInput;
+}
+
+Tensor avg_pool3d_backward_cpu(
+  const Tensor& gradOutput_,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  auto gradInput = at::zeros_like(input);
+  avg_pool3d_backward_out_cpu_template(
+    gradInput,
+    gradOutput_,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return gradInput;
+}
+
+} // at::native
+} // at
diff --git a/aten/src/ATen/native/DilatedMaxPool3d.cpp b/aten/src/ATen/native/DilatedMaxPool3d.cpp
index 5d4ffb3..c3b3290 100644
--- a/aten/src/ATen/native/DilatedMaxPool3d.cpp
+++ b/aten/src/ATen/native/DilatedMaxPool3d.cpp
@@ -184,7 +184,7 @@
   const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
   const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
 
-  max_pool3d_with_indices_shape_check(
+  pool3d_shape_check(
     input_,
     nslices,
     kT, kH, kW,
@@ -396,7 +396,7 @@
   const int64_t oheight = gradOutput.size(-2);
   const int64_t owidth = gradOutput.size(-1);
 
-  max_pool3d_with_indices_shape_check(
+  max_pool3d_backward_shape_check(
     input,
     gradOutput,
     indices,
diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h
index c204722..d9169ef 100644
--- a/aten/src/ATen/native/Pool.h
+++ b/aten/src/ATen/native/Pool.h
@@ -132,8 +132,9 @@
   check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
 }
 
+// AveragePool3d/DilatedMaxPool3d (forward)
 static inline void
-max_pool3d_with_indices_shape_check(
+pool3d_shape_check(
   const Tensor& input,
   int64_t nslices,
   int kT, int kH, int kW,
@@ -141,7 +142,8 @@
   int pT, int pH, int pW,
   int dilationT, int dilationH, int dilationW,
   int64_t itime, int64_t iheight, int64_t iwidth,
-  int64_t otime, int64_t oheight, int64_t owidth)
+  int64_t otime, int64_t oheight, int64_t owidth,
+  bool check_input_size=false)
 {
   const int64_t ndim = input.ndimension();
 
@@ -158,6 +160,12 @@
   TORCH_CHECK(input.numel() > 0 && (ndim == 4 || ndim == 5),
               "non-empty 4D or 5D (batch mode) tensor expected for input, but got ndim: ", ndim);
 
+  if (check_input_size) { // AveragePool3d
+    TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
+                "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
+                "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
+  }
+
   TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
               "pad should be smaller than half of kernel size, but got "
               "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
@@ -171,7 +179,7 @@
 }
 
 static inline void
-max_pool3d_with_indices_shape_check(
+max_pool3d_backward_shape_check(
   const Tensor& input,
   const Tensor& gradOutput,
   const Tensor& indices,
@@ -185,7 +193,7 @@
 {
   const int64_t ndim = input.ndimension();
 
-  max_pool3d_with_indices_shape_check(
+  pool3d_shape_check(
     input,
     nslices,
     kT, kH, kW,
@@ -206,6 +214,36 @@
   check_dim_size(indices, ndim, ndim-1, owidth);
 }
 
+static inline void
+avg_pool3d_backward_shape_check(
+  const Tensor& input,
+  const Tensor& gradOutput,
+  int64_t nslices,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int pT, int pH, int pW,
+  int64_t itime, int64_t iheight, int64_t iwidth,
+  int64_t otime, int64_t oheight, int64_t owidth)
+{
+  const int64_t ndim = input.ndimension();
+
+  pool3d_shape_check(
+    input,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    pT, pH, pW,
+    1, 1, 1,
+    itime, iheight, iwidth,
+    otime, oheight, owidth,
+    true);
+
+  check_dim_size(gradOutput, ndim, ndim-4, nslices);
+  check_dim_size(gradOutput, ndim, ndim-3, otime);
+  check_dim_size(gradOutput, ndim, ndim-2, oheight);
+  check_dim_size(gradOutput, ndim, ndim-1, owidth);
+}
+
 } // namespace
 
 } // at::native
diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu
new file mode 100644
index 0000000..439d0ac
--- /dev/null
+++ b/aten/src/ATen/native/cuda/AveragePool3d.cu
@@ -0,0 +1,675 @@
+#include <ATen/AccumulateType.h>
+#include <ATen/native/Pool.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/detail/TensorInfo.cuh>
+#include <ATen/cuda/detail/IndexUtils.cuh>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <THC/THCNumerics.cuh>
+#include <c10/macros/Macros.h>
+
+
+namespace at {
+namespace native {
+namespace {
+
+__device__ inline int min(int a, int b) {
+  return a <= b ? a : b;
+}
+
+__device__ inline int max(int a, int b) {
+  return a >= b ? a : b;
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_output(
+  PackedTensorAccessor<scalar_t, 4> input,
+  PackedTensorAccessor<scalar_t, 4> output,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int padT, int padH, int padW,
+  bool count_include_pad,
+  int offsetZ)
+{
+  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
+  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
+  int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time
+  int slice  = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature
+
+  if (oRow < output.size(2) && oCol < output.size(3))
+  {
+    accscalar_t sum = 0.0;
+
+    int tstart = oFrame * dT - padT;
+    int hstart = oRow   * dH - padH;
+    int wstart = oCol   * dW - padW;
+    int tend = min(tstart + kT, input.size(1) + padT);
+    int hend = min(hstart + kH, input.size(2) + padH);
+    int wend = min(wstart + kW, input.size(3) + padW);
+    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+    tstart = max(tstart, 0);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    tend = min(tend, input.size(1));
+    hend = min(hend, input.size(2));
+    wend = min(wend, input.size(3));
+
+    accscalar_t divide_factor;
+    if (count_include_pad)
+      divide_factor = static_cast<accscalar_t>(pool_size);
+    else
+      divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+    int ti, hi, wi;
+    for (ti = tstart; ti < tend; ++ti)
+    {
+      for (hi = hstart; hi < hend; ++hi)
+      {
+        for (wi = wstart; wi < wend; ++wi)
+        {
+          scalar_t val = input[slice][ti][hi][wi];
+          sum += val;
+        }
+      }
+    }
+
+    output[slice][oFrame][oRow][oCol] = ScalarConvert<accscalar_t, scalar_t>::to(sum / divide_factor);
+  }
+}
+
+// Inner-most loop size (kW) passed as template parameter for
+// performance reasons.
+//
+template<int KERNEL_WIDTH, typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_output(
+  PackedTensorAccessor<scalar_t, 4> input,
+  PackedTensorAccessor<scalar_t, 4> output,
+  int kT, int kH,
+  int dT, int dH, int dW,
+  int padT, int padH, int padW,
+  bool count_include_pad,
+  int offsetZ)
+{
+  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
+  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
+  int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time
+  int slice  = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature
+
+  if (oRow < output.size(2) && oCol < output.size(3))
+  {
+    accscalar_t sum = 0.0;
+
+    int tstart = oFrame * dT - padT;
+    int hstart = oRow   * dH - padH;
+    int wstart = oCol   * dW - padW;
+    int tend = min(tstart + kT, input.size(1) + padT);
+    int hend = min(hstart + kH, input.size(2) + padH);
+    int wend = min(wstart + KERNEL_WIDTH, input.size(3) + padW);
+    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+    tstart = max(tstart, 0);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    tend = min(tend, input.size(1));
+    hend = min(hend, input.size(2));
+    wend = min(wend, input.size(3));
+
+    accscalar_t divide_factor;
+    if (count_include_pad)
+      divide_factor = static_cast<accscalar_t>(pool_size);
+    else
+      divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+    int ti, hi, wi;
+    for (ti = tstart; ti < tend; ++ti)
+    {
+      for (hi = hstart; hi < hend; ++hi)
+      {
+        for (wi = wstart; wi < wend; ++wi)
+        {
+          scalar_t val = input[slice][ti][hi][wi];
+          sum += val;
+        }
+      }
+    }
+
+    output[slice][oFrame][oRow][oCol] = ScalarConvert<accscalar_t, scalar_t>::to(sum / divide_factor);
+  }
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_single_backward_out_frame_stride1(
+  PackedTensorAccessor<scalar_t, 4> gradOutput,
+  PackedTensorAccessor<scalar_t, 4> gradInput,
+  int kT, int kH, int kW,
+  accscalar_t normFactor,
+  int offsetZ)
+{
+  int iCol   = blockIdx.x * blockDim.x + threadIdx.x;
+  int iRow   = blockIdx.y * blockDim.y + threadIdx.y;
+  int iFrame = (blockIdx.z + offsetZ) % gradInput.size(1); // input frame/time
+  int slice  = (blockIdx.z + offsetZ) / gradInput.size(1); // input slice/feature
+
+  // guard against over-tiled threads
+  if (iRow < gradInput.size(2) && iCol < gradInput.size(3))
+  {
+    accscalar_t sum = 0.0;
+    scalar_t *gOut = &gradOutput[slice][max(0, iFrame - kT + 1)]
+      [max(0, iRow - kH + 1)][max(0, iCol - kW + 1)];
+    int frameOffset = 0;
+    for (int oFrame  = max(0, iFrame - kT + 1);
+         oFrame < min(iFrame + 1, gradOutput.size(1));
+         ++oFrame)
+    {
+      int rowOffset = frameOffset;
+      for (int oRow = max(0, iRow - kH + 1);
+           oRow < min(iRow + 1, gradOutput.size(2));
+           ++oRow)
+      {
+        int colOffset = rowOffset;
+        for (int oCol = max(0, iCol - kW + 1);
+             oCol < min(iCol + 1, gradOutput.size(3));
+             ++oCol)
+        {
+          sum += gOut[colOffset];
+          ++colOffset;
+        }
+        rowOffset += gradOutput.size(3);
+      }
+      frameOffset += gradOutput.size(2) * gradOutput.size(3);
+    }
+    gradInput[slice][iFrame][iRow][iCol] = ScalarConvert<accscalar_t, scalar_t>::to(sum * normFactor);
+  }
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_grad_input_atomic(
+  PackedTensorAccessor<scalar_t, 4> gradOutput,
+  PackedTensorAccessor<scalar_t, 4> gradInput,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int padT, int padH, int padW,
+  bool count_include_pad,
+  int offsetZ)
+{
+  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
+  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
+  int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // gradOutput frame/time
+  int slice  = (blockIdx.z + offsetZ) / gradOutput.size(1); // gradOutput slice/feature
+
+  // guard against over-tiled threads
+  if (oRow < gradOutput.size(2) && oCol < gradOutput.size(3))
+  {
+    int tstart = oFrame * dT - padT;
+    int hstart = oRow   * dH - padH;
+    int wstart = oCol   * dW - padW;
+    int tend = min(tstart + kT, gradInput.size(1) + padT);
+    int hend = min(hstart + kH, gradInput.size(2) + padH);
+    int wend = min(wstart + kW, gradInput.size(3) + padW);
+    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+    tstart = max(tstart, 0);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    tend = min(tend, gradInput.size(1));
+    hend = min(hend, gradInput.size(2));
+    wend = min(wend, gradInput.size(3));
+
+    accscalar_t divide_factor;
+    if (count_include_pad)
+      divide_factor = static_cast<accscalar_t>(pool_size);
+    else
+      divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+    scalar_t val = ScalarConvert<accscalar_t, scalar_t>::to(
+      ScalarConvert<scalar_t, accscalar_t>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
+    for (int iFrame = tstart; iFrame < tend; ++iFrame)
+    {
+      for (int iRow = hstart; iRow < hend; ++iRow)
+      {
+        for (int iCol = wstart; iCol < wend; ++iCol)
+        {
+          atomicAdd(&gradInput[slice][iFrame][iRow][iCol], val);
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_grad_input(
+  PackedTensorAccessor<scalar_t, 4> gradOutput,
+  PackedTensorAccessor<scalar_t, 4> gradInput,
+  int kT, int kH, int kW,
+  int dT, int dH, int dW,
+  int padT, int padH, int padW,
+  bool count_include_pad, int offsetZ)
+{
+  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
+  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
+  int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // gradOutput frame/time
+  int slice  = (blockIdx.z + offsetZ) / gradOutput.size(1); // gradOutput slice/feature
+
+  // guard against over-tiled threads
+  if (oRow < gradOutput.size(2) && oCol < gradOutput.size(3))
+  {
+    int tstart = oFrame * dT - padT;
+    int hstart = oRow   * dH - padH;
+    int wstart = oCol   * dW - padW;
+    int tend = min(tstart + kT, gradInput.size(1) + padT);
+    int hend = min(hstart + kH, gradInput.size(2) + padH);
+    int wend = min(wstart + kW, gradInput.size(3) + padW);
+    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+    tstart = max(tstart, 0);
+    hstart = max(hstart, 0);
+    wstart = max(wstart, 0);
+    tend = min(tend, gradInput.size(1));
+    hend = min(hend, gradInput.size(2));
+    wend = min(wend, gradInput.size(3));
+
+    accscalar_t divide_factor;
+    if (count_include_pad)
+      divide_factor = static_cast<accscalar_t>(pool_size);
+    else
+      divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+    scalar_t val = ScalarConvert<accscalar_t, scalar_t>::to(
+      ScalarConvert<scalar_t, accscalar_t>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
+    for (int iFrame = tstart; iFrame < tend; ++iFrame)
+    {
+      for (int iRow = hstart; iRow < hend; ++iRow)
+      {
+        for (int iCol = wstart; iCol < wend; ++iCol)
+        {
+          gradInput[slice][iFrame][iRow][iCol] = val;
+        }
+      }
+    }
+  }
+}
+
+#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
+  avg_pool3d_cuda_update_output<KW, scalar_t, accscalar_t>  \
+    <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( \
+       work_input.packed_accessor<scalar_t, 4>(),           \
+       work_output.packed_accessor<scalar_t, 4>(),          \
+       kT, kH,                                              \
+       dT, dH, dW,                                          \
+       padT, padH, padW,                                    \
+       count_include_pad,                                   \
+       offsetZ);                                            \
+  break
+
+void avg_pool3d_out_cuda_template(
+  Tensor& output,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  TensorArg output_arg{ output, "output", 1 };
+  TensorArg input_arg{ input, "input", 2 };
+
+  checkAllSameGPU("avg_pool3d_out_cuda", {output_arg, input_arg});
+
+  // #20866 [JIT] stride.empty() is passed through
+  // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+  TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+                        (stride.empty() || stride.size() == 3) &&
+                        (padding.size() == 1 || padding.size() == 3),
+    "avg_pool3d: all IntArrayRef sizes must be 3");
+
+  TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
+    "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+  const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+  const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+  const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+  const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+  const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+  const int padT = safe_downcast<int, int64_t>(padding[0]);
+  const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+  const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+  const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
+  const int64_t nslices = input.size(-4);
+  const int64_t itime = input.size(-3);
+  const int64_t iheight = input.size(-2);
+  const int64_t iwidth = input.size(-1);
+
+  const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+  const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+  const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+  pool3d_shape_check(
+    input,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    padT, padH, padW,
+    1, 1, 1,
+    itime, iheight, iwidth,
+    otime, oheight, owidth,
+    /*check_input_size=*/ true);
+
+  if (input.ndimension() == 4) {
+    output.resize_({ nslices, otime, oheight, owidth});
+  }
+  else {
+    output.resize_({nbatch, nslices, otime, oheight, owidth});
+  }
+
+  Tensor work_input = input.contiguous();
+  Tensor work_output = output;
+  if (input.ndimension() == 5) {
+    // Collapse batch and feature dimensions.
+    work_input = work_input.reshape({nbatch * nslices, itime, iheight, iwidth});
+    work_output = work_output.reshape({nbatch * nslices, otime, oheight, owidth});
+  }
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+    input.scalar_type(),
+    "avg_pool3d_out_cuda",
+    [&] {
+      using accscalar_t = acc_type<scalar_t, true>;
+      int64_t totalZ = otime * nslices * nbatch;
+      int64_t offsetZ = 0;
+      dim3 block(32, 8);
+
+      while (totalZ > 0) {
+        dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
+                  cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
+                  totalZ > 65535 ? 65535 : totalZ);
+
+        switch (kW) {
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
+          LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
+        default:
+          avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
+            <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+               work_input.packed_accessor<scalar_t, 4>(),
+               work_output.packed_accessor<scalar_t, 4>(),
+               kT, kH, kW,
+               dT, dH, dW,
+               padT, padH, padW,
+               count_include_pad,
+               offsetZ);
+            break;
+        }
+
+        TORCH_CHECK(cudaGetLastError() == cudaSuccess,
+          "avg_pool3d_out_cuda failed with error code ",
+          cudaGetLastError());
+
+        totalZ -= 65535;
+        offsetZ += 65535;
+      }
+    }
+  );
+}
+
+#undef LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH
+
+void avg_pool3d_backward_out_cuda_template(
+  Tensor& gradInput,
+  const Tensor& gradOutput,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
+  TensorArg gradOutput_arg{ gradOutput, "gradOutput", 2 };
+  TensorArg input_arg{ input, "input", 3 };
+
+  checkAllSameGPU("avg_pool3d_backward_out_cuda",
+                  {gradInput_arg, gradOutput_arg, input_arg});
+
+  // #20866 [JIT] stride.empty() is passed through
+  // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+  TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+                        (stride.empty() || stride.size() == 3) &&
+                        (padding.size() == 1 || padding.size() == 3),
+    "avg_pool3d: all IntArrayRef sizes must be 3");
+
+  TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
+    "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+  TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5),
+    "non-empty 4D or 5D (batch mode) tensor expected for gradOutput");
+
+  // Resize and initialize result tensor.
+  gradInput.resize_as_(input);
+  gradInput.zero_();
+
+  const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+  const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+  const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+  const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+  const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+  const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+  const int padT = safe_downcast<int, int64_t>(padding[0]);
+  const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+  const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+  const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
+  const int64_t nslices = input.size(-4);
+  const int64_t itime = input.size(-3);
+  const int64_t iheight = input.size(-2);
+  const int64_t iwidth = input.size(-1);
+
+  const int64_t otime = gradOutput.size(-3);
+  const int64_t oheight = gradOutput.size(-2);
+  const int64_t owidth = gradOutput.size(-1);
+
+  /* XXX shape check behavior from TH */
+  const int64_t otime_for_shape_check = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+  const int64_t oheight_for_shape_check = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+  const int64_t owidth_for_chape_check = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+  const bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW);
+
+  avg_pool3d_backward_shape_check(
+    input,
+    gradOutput,
+    nslices,
+    kT, kH, kW,
+    dT, dH, dW,
+    padT, padH, padW,
+    itime, iheight, iwidth,
+    otime, oheight, owidth);
+
+  Tensor work_grad_input = gradInput;
+  Tensor work_grad_output = gradOutput.contiguous();
+
+  if (input.ndimension() == 5) {
+    // Collapse batch and feature dimensions.
+    work_grad_input = work_grad_input.reshape({nbatch * nslices, itime, iheight, iwidth});
+    work_grad_output = work_grad_output.reshape({nbatch * nslices, otime, oheight, owidth});
+  }
+
+
+  // Optimizing for stride 1 is probably only of limited value, but this
+  // specialization yields 3x speedup over the atomicAdd implementation.
+  // Padding must be 0, otherwise, pool size may change.
+  if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0) {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+      "avg_pool3d_backward_out_frame_stride1",
+      [&] {
+        using accscalar_t = acc_type<scalar_t, true>;
+        int64_t totalZ = itime * nslices * nbatch;
+        int64_t offsetZ = 0;
+        dim3 block(32, 8);
+
+        while (totalZ > 0) {
+          dim3 grid(cuda::ATenCeilDiv(iwidth, static_cast<int64_t>(block.x)),
+                    cuda::ATenCeilDiv(iheight, static_cast<int64_t>(block.y)),
+                    totalZ > 65535 ? 65535 : totalZ);
+
+          avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
+            <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+              work_grad_output.packed_accessor<scalar_t, 4>(),
+              work_grad_input.packed_accessor<scalar_t, 4>(),
+              kT, kH, kW,
+              1.0f/(kT * kH * kW),
+              offsetZ);
+
+          TORCH_CHECK(cudaGetLastError() == cudaSuccess,
+            "avg_pool3d_backward_out_frame failed with error code ",
+            cudaGetLastError());
+
+          totalZ -= 65535;
+          offsetZ += 65535;
+        }
+      }
+    );
+  }
+  else {
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+      "avg_pool3d_backward_out_frame",
+      [&] {
+        using accscalar_t = acc_type<scalar_t, true>;
+        int64_t totalZ = otime * nslices * nbatch;
+        int64_t offsetZ = 0;
+        dim3 block(32, 8);
+
+        while (totalZ > 0) {
+          dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
+                    cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
+                    totalZ > 65535 ? 65535 : totalZ);
+
+          if (kernelsOverlap) {
+            avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
+              <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+                 work_grad_output.packed_accessor<scalar_t, 4>(),
+                 work_grad_input.packed_accessor<scalar_t, 4>(),
+                 kT, kH, kW,
+                 dT, dH, dW,
+                 padT, padH, padW,
+                 count_include_pad,
+                 offsetZ);
+          }
+          else {
+            avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
+              <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+                 work_grad_output.packed_accessor<scalar_t, 4>(),
+                 work_grad_input.packed_accessor<scalar_t, 4>(),
+                 kT, kH, kW,
+                 dT, dH, dW,
+                 padT, padH, padW,
+                 count_include_pad,
+                 offsetZ);
+          }
+
+          TORCH_CHECK(cudaGetLastError() == cudaSuccess,
+            "avg_pool3d_backward_out_frame failed with error code ",
+            cudaGetLastError());
+
+          totalZ -= 65535;
+          offsetZ += 65535;
+        }
+      }
+    );
+  }
+}
+
+} // namespace
+
+Tensor& avg_pool3d_out_cuda(
+  Tensor& output,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  avg_pool3d_out_cuda_template(
+    output,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return output;
+}
+
+Tensor avg_pool3d_cuda(
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  Tensor output = at::empty({0}, input.options());
+  avg_pool3d_out_cuda_template(
+    output,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return output;
+}
+
+Tensor& avg_pool3d_backward_out_cuda(
+  Tensor& gradInput,
+  const Tensor& gradOutput_,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  avg_pool3d_backward_out_cuda_template(
+    gradInput,
+    gradOutput_,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return gradInput;
+}
+
+Tensor avg_pool3d_backward_cuda(
+  const Tensor& gradOutput_,
+  const Tensor& input,
+  IntArrayRef kernel_size,
+  IntArrayRef stride,
+  IntArrayRef padding,
+  bool ceil_mode,
+  bool count_include_pad)
+{
+  auto gradInput = at::zeros_like(input);
+  avg_pool3d_backward_out_cuda_template(
+    gradInput,
+    gradOutput_,
+    input,
+    kernel_size,
+    stride,
+    padding,
+    ceil_mode,
+    count_include_pad);
+  return gradInput;
+}
+
+} // at::native
+} // at
diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
index 9305966..e6b39b0 100644
--- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
+++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
@@ -327,7 +327,7 @@
   const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
   const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
 
-  max_pool3d_with_indices_shape_check(
+  pool3d_shape_check(
     input,
     nslices,
     kT, kH, kW,
@@ -440,7 +440,7 @@
   const int64_t iheight = gradInput.size(-2);
   const int64_t iwidth = gradInput.size(-1);
 
-  max_pool3d_with_indices_shape_check(
+  max_pool3d_backward_shape_check(
     input,
     gradOutput,
     indices,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 1e44157..6cbc84c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4737,26 +4737,26 @@
 - func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!)
   python_module: nn
   dispatch:
-    CPU: legacy::cpu::_thnn_avg_pool3d_forward_out
-    CUDA: legacy::cuda::_thnn_avg_pool3d_forward_out
+    CPU: avg_pool3d_out_cpu
+    CUDA: avg_pool3d_out_cuda
 
 - func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor
   python_module: nn
   dispatch:
-    CPU: legacy::cpu::_thnn_avg_pool3d_forward
-    CUDA: legacy::cuda::_thnn_avg_pool3d_forward
+    CPU: avg_pool3d_cpu
+    CUDA: avg_pool3d_cuda
 
 - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, *, Tensor(a!) grad_input) -> Tensor(a!)
   python_module: nn
   dispatch:
-    CPU: legacy::cpu::_thnn_avg_pool3d_backward_out
-    CUDA: legacy::cuda::_thnn_avg_pool3d_backward_out
+    CPU: avg_pool3d_backward_out_cpu
+    CUDA: avg_pool3d_backward_out_cuda
 
 - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad) -> Tensor
   python_module: nn
   dispatch:
-    CPU: legacy::cpu::_thnn_avg_pool3d_backward
-    CUDA: legacy::cuda::_thnn_avg_pool3d_backward
+    CPU: avg_pool3d_backward_cpu
+    CUDA: avg_pool3d_backward_cuda
 
 # Return: (Tensor output, Tensor indices)
 - func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index f99ca2b..e925b04 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -116,16 +116,6 @@
     output: 'false'
     grad_input: 'false'
 
-# Pooling
-
-- name: _thnn_avg_pool3d(Tensor self, IntArrayRef[3] kernel_size, IntArrayRef[3] stride={}, IntArrayRef[3] padding=0, bool ceil_mode=false, bool count_include_pad=true)
-  cname: VolumetricAveragePooling
-  default_init:
-    stride: kernel_size
-  scalar_check:
-    output: 'false'
-    grad_input: 'false'
-
 # Private functions. These also exist in TH, but we want the backwards functions
 # to implement derivatives.
 
diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt
index ecb5ee4..2915ef4 100644
--- a/aten/src/THCUNN/CMakeLists.txt
+++ b/aten/src/THCUNN/CMakeLists.txt
@@ -39,7 +39,6 @@
 ${CMAKE_CURRENT_SOURCE_DIR}/TemporalConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/TemporalMaxPooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/TemporalRowConvolution.cu
-${CMAKE_CURRENT_SOURCE_DIR}/VolumetricAveragePooling.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricDilatedConvolution.cu
 ${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFullConvolution.cu
diff --git a/aten/src/THCUNN/VolumetricAveragePooling.cu b/aten/src/THCUNN/VolumetricAveragePooling.cu
deleted file mode 100644
index 56e1d69..0000000
--- a/aten/src/THCUNN/VolumetricAveragePooling.cu
+++ /dev/null
@@ -1,279 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <THC/THCTensor.hpp>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateOutput(
-  THCDeviceTensor<Dtype, 4> input,
-  THCDeviceTensor<Dtype, 4> output,
-  int kT, int kH, int kW,
-  int dT, int dH, int dW,
-  int padT, int padH, int padW,
-  bool count_include_pad, int offsetZ)
-{
-  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
-  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
-  int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
-  int slice  = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
-
-  if (oRow < output.getSize(2) && oCol < output.getSize(3))
-  {
-    Acctype sum = 0.0;
-
-    int tstart = oFrame * dT - padT;
-    int hstart = oRow   * dH - padH;
-    int wstart = oCol   * dW - padW;
-    int tend = min(tstart + kT, input.getSize(1) + padT);
-    int hend = min(hstart + kH, input.getSize(2) + padH);
-    int wend = min(wstart + kW, input.getSize(3) + padW);
-    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
-    tstart = max(tstart, 0);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    tend = min(tend, input.getSize(1));
-    hend = min(hend, input.getSize(2));
-    wend = min(wend, input.getSize(3));
-
-    Acctype divide_factor;
-    if (count_include_pad)
-      divide_factor = static_cast<Acctype>(pool_size);
-    else
-      divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
-    int ti, hi, wi;
-    for (ti = tstart; ti < tend; ++ti)
-    {
-      for (hi = hstart; hi < hend; ++hi)
-      {
-        for (wi = wstart; wi < wend; ++wi)
-        {
-          Dtype val = input[slice][ti][hi][wi];
-          sum += val;
-        }
-      }
-    }
-
-    output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor);
-  }
-}
-
-// Inner-most loop size (kW) passed as template parameter for
-// performance reasons.
-//
-template<int KERNEL_WIDTH, typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateOutput_fixedKW(
-  THCDeviceTensor<Dtype, 4> input,
-  THCDeviceTensor<Dtype, 4> output,
-  int kT, int kH,
-  int dT, int dH, int dW,
-  int padT, int padH, int padW,
-  bool count_include_pad, int offsetZ)
-{
-  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
-  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
-  int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
-  int slice  = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
-
-  if (oRow < output.getSize(2) && oCol < output.getSize(3))
-  {
-    Acctype sum = 0.0;
-
-    int tstart = oFrame * dT - padT;
-    int hstart = oRow   * dH - padH;
-    int wstart = oCol   * dW - padW;
-    int tend = min(tstart + kT, input.getSize(1) + padT);
-    int hend = min(hstart + kH, input.getSize(2) + padH);
-    int wend = min(wstart + KERNEL_WIDTH, input.getSize(3) + padW);
-    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
-    tstart = max(tstart, 0);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    tend = min(tend, input.getSize(1));
-    hend = min(hend, input.getSize(2));
-    wend = min(wend, input.getSize(3));
-
-    Acctype divide_factor;
-    if (count_include_pad)
-      divide_factor = static_cast<Acctype>(pool_size);
-    else
-      divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
-    int ti, hi, wi;
-    for (ti = tstart; ti < tend; ++ti)
-    {
-      for (hi = hstart; hi < hend; ++hi)
-      {
-        for (wi = wstart; wi < wend; ++wi)
-        {
-          Dtype val = input[slice][ti][hi][wi];
-          sum += val;
-        }
-      }
-    }
-
-    output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor);
-  }
-}
-
-#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
-  cuda_VolumetricAveragePooling_updateOutput_fixedKW<KW, scalar_t, accreal> \
-    <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
-      cudaInput, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, count_include_pad, offsetZ); \
-  break
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateGradInput_Stride1(
-  THCDeviceTensor<Dtype, 4> gradOutput,
-  THCDeviceTensor<Dtype, 4> gradInput,
-  int kT, int kH, int kW,
-  Acctype normFactor, int offsetZ)
-{
-  int iCol   = blockIdx.x * blockDim.x + threadIdx.x;
-  int iRow   = blockIdx.y * blockDim.y + threadIdx.y;
-  int iFrame = (blockIdx.z + offsetZ) % gradInput.getSize(1); // input frame/time
-  int slice  = (blockIdx.z + offsetZ) / gradInput.getSize(1); // input slice/feature
-
-  // guard against over-tiled threads
-  if (iRow < gradInput.getSize(2) && iCol < gradInput.getSize(3))
-  {
-    Acctype sum = 0.0;
-    Dtype *gOut = &gradOutput[slice][max(0, iFrame - kT + 1)]
-      [max(0, iRow - kH + 1)][max(0, iCol - kW + 1)];
-    int frameOffset = 0;
-    for (int oFrame  = max(0, iFrame - kT + 1);
-         oFrame < min(iFrame + 1, gradOutput.getSize(1));
-         ++oFrame)
-    {
-      int rowOffset = frameOffset;
-      for (int oRow = max(0, iRow - kH + 1);
-           oRow < min(iRow + 1, gradOutput.getSize(2));
-           ++oRow)
-      {
-        int colOffset = rowOffset;
-        for (int oCol = max(0, iCol - kW + 1);
-             oCol < min(iCol + 1, gradOutput.getSize(3));
-             ++oCol)
-        {
-          sum += gOut[colOffset];
-          ++colOffset;
-        }
-        rowOffset += gradOutput.getSize(3);
-      }
-      frameOffset += gradOutput.getSize(2) * gradOutput.getSize(3);
-    }
-    gradInput[slice][iFrame][iRow][iCol] = ScalarConvert<Acctype, Dtype>::to(sum * normFactor);
-  }
-}
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateGradInput_atomicAdd(
-  THCDeviceTensor<Dtype, 4> gradOutput,
-  THCDeviceTensor<Dtype, 4> gradInput,
-  int kT, int kH, int kW,
-  int dT, int dH, int dW,
-  int padT, int padH, int padW,
-  bool count_include_pad, int offsetZ)
-{
-  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
-  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
-  int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
-  int slice  = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
-
-  // guard against over-tiled threads
-  if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
-  {
-    int tstart = oFrame * dT - padT;
-    int hstart = oRow   * dH - padH;
-    int wstart = oCol   * dW - padW;
-    int tend = min(tstart + kT, gradInput.getSize(1) + padT);
-    int hend = min(hstart + kH, gradInput.getSize(2) + padH);
-    int wend = min(wstart + kW, gradInput.getSize(3) + padW);
-    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
-    tstart = max(tstart, 0);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    tend = min(tend, gradInput.getSize(1));
-    hend = min(hend, gradInput.getSize(2));
-    wend = min(wend, gradInput.getSize(3));
-
-    Acctype divide_factor;
-    if (count_include_pad)
-      divide_factor = static_cast<Acctype>(pool_size);
-    else
-      divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
-    Dtype val = ScalarConvert<Acctype, Dtype>::to(
-      ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
-    for (int iFrame = tstart; iFrame < tend; ++iFrame)
-    {
-      for (int iRow = hstart; iRow < hend; ++iRow)
-      {
-        for (int iCol = wstart; iCol < wend; ++iCol)
-        {
-          atomicAdd(&gradInput[slice][iFrame][iRow][iCol], val);
-        }
-      }
-    }
-  }
-}
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateGradInput(
-  THCDeviceTensor<Dtype, 4> gradOutput,
-  THCDeviceTensor<Dtype, 4> gradInput,
-  int kT, int kH, int kW,
-  int dT, int dH, int dW,
-  int padT, int padH, int padW,
-  bool count_include_pad, int offsetZ)
-{
-  int oCol   = blockIdx.x * blockDim.x + threadIdx.x;
-  int oRow   = blockIdx.y * blockDim.y + threadIdx.y;
-  int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
-  int slice  = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
-
-  // guard against over-tiled threads
-  if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
-  {
-    int tstart = oFrame * dT - padT;
-    int hstart = oRow   * dH - padH;
-    int wstart = oCol   * dW - padW;
-    int tend = min(tstart + kT, gradInput.getSize(1) + padT);
-    int hend = min(hstart + kH, gradInput.getSize(2) + padH);
-    int wend = min(wstart + kW, gradInput.getSize(3) + padW);
-    int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
-    tstart = max(tstart, 0);
-    hstart = max(hstart, 0);
-    wstart = max(wstart, 0);
-    tend = min(tend, gradInput.getSize(1));
-    hend = min(hend, gradInput.getSize(2));
-    wend = min(wend, gradInput.getSize(3));
-
-    Acctype divide_factor;
-    if (count_include_pad)
-      divide_factor = static_cast<Acctype>(pool_size);
-    else
-      divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
-    Dtype val = ScalarConvert<Acctype, Dtype>::to(
-      ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
-    for (int iFrame = tstart; iFrame < tend; ++iFrame)
-    {
-      for (int iRow = hstart; iRow < hend; ++iRow)
-      {
-        for (int iCol = wstart; iCol < wend; ++iCol)
-        {
-          gradInput[slice][iFrame][iRow][iCol] = val;
-        }
-      }
-    }
-  }
-}
-
-#include <THCUNN/generic/VolumetricAveragePooling.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index 7eed4f9..a685994 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -867,27 +867,6 @@
                   bool featFirst,
                   accreal scale);
 
-THC_API void THNN_(VolumetricAveragePooling_updateOutput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *output,
-                  int kT, int kW, int kH,
-                  int dT, int dW, int dH,
-                  int padT, int padW, int padH,
-                  bool ceil_mode,
-                  bool count_include_pad);
-
-THC_API void THNN_(VolumetricAveragePooling_updateGradInput)(
-                  THCState *state,
-                  THCTensor *input,
-                  THCTensor *gradOutput,
-                  THCTensor *gradInput,
-                  int kT, int kW, int kH,
-                  int dT, int dW, int dH,
-                  int padT, int padW, int padH,
-                  bool ceil_mode,
-                  bool count_include_pad);
-
 // VolumetricConvolution is legacy and purposefully not bound by ATen
 THC_API void THNN_(VolumetricConvolution_updateOutput)(
                   THCState *state,
diff --git a/aten/src/THCUNN/generic/VolumetricAveragePooling.cu b/aten/src/THCUNN/generic/VolumetricAveragePooling.cu
deleted file mode 100644
index 47939ef..0000000
--- a/aten/src/THCUNN/generic/VolumetricAveragePooling.cu
+++ /dev/null
@@ -1,337 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/VolumetricAveragePooling.cu"
-#else
-
-#include <THCUNN/generic/pooling_shape.h>
-
-static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
-                         THCState *state,
-                         THCTensor *input,
-                         THCTensor *gradOutput,
-                         int kT, int kW, int kH,
-                         int dT, int dW, int dH,
-                         int padT, int padW, int padH,
-                         bool ceil_mode)
-{
-  int inputSlices;
-  int inputTime;
-  int inputHeight;
-  int inputWidth;
-
-  int ndim = input->dim();
-  int dimN = 0;
-  int dimt = 1;
-  int dimh = 2;
-  int dimw = 3;
-
-  if (input->dim() == 5)
-  {
-    dimN++;
-    dimt++;
-    dimh++;
-    dimw++;
-  }
-
-  if (!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
-  {
-    THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
-               && input->size(dimt) >= kT, 2,
-               "input image (T: %d H: %d W: %d) smaller than "
-               "kernel size (kT: %d kH: %d kW: %d)",
-               input->size(dimt), input->size(dimh), input->size(dimw),
-               kT, kH, kW);
-
-    /* sizes */
-    inputSlices = THCTensor_(size)(state, input, 0);
-    inputTime   = THCTensor_(size)(state, input, 1);
-    inputHeight = THCTensor_(size)(state, input, 2);
-    inputWidth  = THCTensor_(size)(state, input, 3);
-  }
-  else if (!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5)
-  {
-    THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
-               && input->size(dimt) >= kT, 2,
-               "input image (T: %d H: %d W: %d) smaller than "
-               "kernel size (kT: %d kH: %d kW: %d)",
-               input->size(dimt), input->size(dimh), input->size(dimw),
-               kT, kH, kW);
-
-    /* sizes */
-    inputSlices = THCTensor_(size)(state, input, 1);
-    inputTime   = THCTensor_(size)(state, input, 2);
-    inputHeight = THCTensor_(size)(state, input, 3);
-    inputWidth  = THCTensor_(size)(state, input, 4);
-  }
-  else
-  {
-    AT_ERROR("non-empty 4D or 5D tensor expected, but got size: ", input->sizes());
-  }
-
-  // The second argument is the index of padH.
-  THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11,
-             "pad should not be greater than half of kernel size, but got "
-             "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d",
-             padT, padW, padH, kT, kW, kH);
-
-  int outputTime = pooling_output_shape<int>(inputTime, kT, padT, dT, 1, ceil_mode);
-  int outputHeight = pooling_output_shape<int>(inputHeight, kH, padH, dH, 1, ceil_mode);
-  int outputWidth = pooling_output_shape<int>(inputWidth, kW, padW, dW, 1, ceil_mode);
-
-  if (gradOutput != NULL)
-  {
-     THCUNN_check_dim_size(state, gradOutput, ndim, dimN, inputSlices);
-     THCUNN_check_dim_size(state, gradOutput, ndim, dimt, outputTime);
-     THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight);
-     THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth);
-  }
-}
-
-void THNN_(VolumetricAveragePooling_updateOutput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *output,
-           int kT, int kW, int kH,
-           int dT, int dW, int dH,
-           int padT, int padW, int padH,
-           bool ceil_mode,
-           bool count_include_pad)
-{
-  int batchSize;
-  int inputSlices;
-  int inputTime;
-  int inputHeight;
-  int inputWidth;
-
-  int dimt = 1;
-  int dimh = 2;
-  int dimw = 3;
-
-  int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
-  if (fiveDimensionalInput)
-  {
-    dimt++;
-    dimh++;
-    dimw++;
-  }
-
-  THNN_(VolumetricAveragePooling_shapeCheck)
-       (state, input, NULL, kT, kW, kH, dT, dW, dH,
-        padT, padW, padH, ceil_mode);
-
-  if (!fiveDimensionalInput) /* 4D */
-  {
-    /* sizes */
-    batchSize   = 1;
-    inputSlices = THCTensor_(size)(state, input, 0);
-    inputTime   = THCTensor_(size)(state, input, 1);
-    inputHeight = THCTensor_(size)(state, input, 2);
-    inputWidth  = THCTensor_(size)(state, input, 3);
-  }
-  else /* 5D */
-  {
-    /* sizes */
-    batchSize   = THCTensor_(size)(state, input, 0);
-    inputSlices = THCTensor_(size)(state, input, 1);
-    inputTime   = THCTensor_(size)(state, input, 2);
-    inputHeight = THCTensor_(size)(state, input, 3);
-    inputWidth  = THCTensor_(size)(state, input, 4);
-  }
-
-  int outputTime = pooling_output_shape<int>(inputTime, kT, padT, dT, 1, ceil_mode);
-  int outputHeight = pooling_output_shape<int>(inputHeight, kH, padH, dH, 1, ceil_mode);
-  int outputWidth = pooling_output_shape<int>(inputWidth, kW, padW, dW, 1, ceil_mode);
-
-  if (!fiveDimensionalInput) /* 4D */
-  {
-    /* resize output */
-    THCTensor_(resize4d)(state, output, inputSlices,
-                         outputTime, outputHeight, outputWidth);
-  }
-  else /* 5D */
-  {
-    THCTensor_(resize5d)(state, output, batchSize, inputSlices,
-                         outputTime, outputHeight, outputWidth);
-  }
-
-  input = THCTensor_(newContiguous)(state, input);
-  if (fiveDimensionalInput) {
-    // Collapse batch and feature dimensions
-    output = THCTensor_(newFoldBatchDim)(state, output);
-
-    THCTensor *old_input = input;
-    input = THCTensor_(newFoldBatchDim)(state, input);
-    THCTensor_(free)(state, old_input);
-  } else {
-    THCTensor_(retain)(state, output);
-  }
-
-  THCDeviceTensor<scalar_t, 4> cudaInput;
-  THCDeviceTensor<scalar_t, 4> cudaOutput;
-  cudaInput  = toDeviceTensor<scalar_t, 4>(state, input);
-  cudaOutput = toDeviceTensor<scalar_t, 4>(state, output);
-
-  int totalZ = outputTime * inputSlices * batchSize;
-  int offsetZ = 0;
-  dim3 block(32, 8);
-  while (totalZ > 0) {
-    dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
-              THCCeilDiv(outputHeight, static_cast<int>(block.y)),
-              totalZ > 65535 ? 65535 : totalZ);
-
-    switch (kW)
-      {
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
-        LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
-      default:
-        cuda_VolumetricAveragePooling_updateOutput<scalar_t, accreal>
-          <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
-            cudaInput,
-            cudaOutput,
-            kT, kH, kW,
-            dT, dH, dW,
-            padT, padH, padW,
-            count_include_pad,
-            offsetZ);
-        break;
-      }
-    totalZ -= 65535;
-    offsetZ += 65535;
-    THCudaCheck(cudaGetLastError());
-  }
-
-  THCTensor_(free)(state, input);
-  THCTensor_(free)(state, output);
-}
-
-void THNN_(VolumetricAveragePooling_updateGradInput)(
-           THCState *state,
-           THCTensor *input,
-           THCTensor *gradOutput,
-           THCTensor *gradInput,
-           int kT, int kW, int kH,
-           int dT, int dW, int dH,
-           int padT, int padW, int padH,
-           bool ceil_mode,
-           bool count_include_pad)
-{
-  THNN_(VolumetricAveragePooling_shapeCheck)
-       (state, input, gradOutput, kT, kW, kH, dT, dW, dH,
-        padT, padW, padH, ceil_mode);
-  bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW);
-
-  // Resize and initialize result tensor.
-  THCTensor_(resizeAs)(state, gradInput, input);
-  THCTensor_(zero)(state, gradInput);
-
-  int batchSize;
-  int inputSlices;
-  int inputTime;
-  int inputHeight;
-  int inputWidth;
-
-  int outputTime;
-  int outputHeight;
-  int outputWidth;
-
-  int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
-  if (!fiveDimensionalInput) /* 4D */
-  {
-    batchSize = 1;
-    inputSlices  = THCTensor_(size)(state, input, 0);
-    inputTime    = THCTensor_(size)(state, input, 1);
-    inputHeight  = THCTensor_(size)(state, input, 2);
-    inputWidth   = THCTensor_(size)(state, input, 3);
-
-    outputTime   = THCTensor_(size)(state, gradOutput, 1);
-    outputHeight = THCTensor_(size)(state, gradOutput, 2);
-    outputWidth  = THCTensor_(size)(state, gradOutput, 3);
-  }
-  else
-  {
-    batchSize    = THCTensor_(size)(state, input, 0);
-    inputSlices  = THCTensor_(size)(state, input, 1);
-    inputTime    = THCTensor_(size)(state, input, 2);
-    inputHeight  = THCTensor_(size)(state, input, 3);
-    inputWidth   = THCTensor_(size)(state, input, 4);
-
-    outputTime   = THCTensor_(size)(state, gradOutput, 2);
-    outputHeight = THCTensor_(size)(state, gradOutput, 3);
-    outputWidth  = THCTensor_(size)(state, gradOutput, 4);
-  }
-
-  gradOutput = THCTensor_(newContiguous)(state, gradOutput);
-  if (fiveDimensionalInput) {
-    // Collapse batch and feature dimensions
-    gradInput = THCTensor_(newFoldBatchDim)(state, gradInput);
-
-    THCTensor *old_gradOutput = gradOutput;
-    gradOutput = THCTensor_(newFoldBatchDim)(state, gradOutput);
-    THCTensor_(free)(state, old_gradOutput);
-  } else {
-    THCTensor_(retain)(state, gradInput);
-  }
-
-  THCDeviceTensor<scalar_t, 4> cudaGradInput;
-  THCDeviceTensor<scalar_t, 4> cudaGradOutput;
-  cudaGradInput  = toDeviceTensor<scalar_t, 4>(state, gradInput);
-  cudaGradOutput = toDeviceTensor<scalar_t, 4>(state, gradOutput);
-
-  dim3 block(32, 8);
-
-  // Optimizing for stride 1 is probably only of limited value, but this
-  // specialization yields 3x speedup over the atomicAdd implementation.
-  // Padding must be 0, otherwise, pool size may change.
-  if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0)
-  {
-    int totalZ = inputTime * inputSlices * batchSize;
-    int offsetZ = 0;
-    while (totalZ > 0) {
-      dim3 grid(THCCeilDiv(inputWidth, static_cast<int>(block.x)),
-                THCCeilDiv(inputHeight, static_cast<int>(block.y)),
-                totalZ > 65535 ? 65535 : totalZ);
-      cuda_VolumetricAveragePooling_updateGradInput_Stride1<scalar_t, accreal>
-        <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
-          cudaGradOutput, cudaGradInput, kT, kH, kW, 1.0f/(kT * kH * kW), offsetZ);
-      THCudaCheck(cudaGetLastError());
-      totalZ -= 65535;
-      offsetZ += 65535;
-    }
-  }
-  else
-  {
-    int totalZ = outputTime * inputSlices * batchSize;
-    int offsetZ = 0;
-    while (totalZ > 0) {
-      dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
-                THCCeilDiv(outputHeight, static_cast<int>(block.y)),
-                totalZ > 65535 ? 65535 : totalZ);
-      if (kernelsOverlap)
-      {
-        cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<scalar_t, accreal>
-          <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
-            cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW,
-            padT, padH, padW, count_include_pad, offsetZ);
-      }
-      else
-      {
-        cuda_VolumetricAveragePooling_updateGradInput<scalar_t, accreal>
-          <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
-            cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW,
-            padT, padH, padW, count_include_pad, offsetZ);
-      }
-      THCudaCheck(cudaGetLastError());
-      totalZ -= 65535;
-      offsetZ += 65535;
-    }
-  }
-
-  THCTensor_(free)(state, gradInput);
-  THCTensor_(free)(state, gradOutput);
-}
-
-#endif
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index 6cde3bc..ff76b71 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -489,24 +489,6 @@
           int inputWidth, int inputHeight,
           int outputWidth, int outputHeight);
 
-TH_API void THNN_(VolumetricAveragePooling_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          int kT, int kW, int kH,
-          int dT, int dW, int dH,
-          int padT, int padW, int padH,
-          bool ceil_mode, bool count_include_pad);
-TH_API void THNN_(VolumetricAveragePooling_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          int kT, int kW, int kH,
-          int dT, int dW, int dH,
-          int padT, int padW, int padH,
-          bool ceil_mode, bool count_include_pad);
-
 TH_API void THNN_(VolumetricDilatedConvolution_updateOutput)(
           THNNState *state,
           THTensor *input,
diff --git a/aten/src/THNN/generic/VolumetricAveragePooling.c b/aten/src/THNN/generic/VolumetricAveragePooling.c
deleted file mode 100644
index 3671413..0000000
--- a/aten/src/THNN/generic/VolumetricAveragePooling.c
+++ /dev/null
@@ -1,465 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/VolumetricAveragePooling.c"
-#else
-
-#include <THNN/generic/pooling_shape.h>
-#include <algorithm>
-
-#include <ATen/Parallel.h>
-
-static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
-                         THNNState *state,
-                         THTensor *input,
-                         THTensor *gradOutput,
-                         int kT,
-                         int kW,
-                         int kH,
-                         int dT,
-                         int dW,
-                         int dH,
-                         int padT,
-                         int padW,
-                         int padH,
-                         bool ceil_mode)
-{
-  int64_t nslices;
-  int64_t itime;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t otime;
-  int64_t oheight;
-  int64_t owidth;
-  int ndim = input->dim();
-  int dimN = 0;
-  int dimt = 1;
-  int dimh = 2;
-  int dimw = 3;
-
-  if (input->dim() == 5)
-  {
-    dimN++;
-    dimt++;
-    dimh++;
-    dimw++;
-  }
-
-  THArgCheck(kT > 0 && kW > 0 && kH > 0, 5,
-             "kernel size should be greater than zero, but got kT: %d kH: %d kW: %d",
-             kT, kH, kW);
-  THArgCheck(dT > 0 && dW > 0 && dH > 0, 8,
-             "stride should be greater than zero, but got dT: %d dH: %d dW: %d",
-             dT, dH, dW);
-  THNN_ARGCHECK(!input->is_empty() && (input->dim() == 4 || input->dim() == 5), 2, input,
-                "non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
-
-  THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
-             && input->size(dimt) >= kT, 2,
-             "input image (T: %d H: %d W: %d) smaller than "
-             "kernel size (kT: %d kH: %d kW: %d)",
-             input->size(dimt), input->size(dimh), input->size(dimw),
-             kT, kH, kW);
-
-  // The second argument is argNumber... here is the index of padH.
-  THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11,
-            "pad should not be greater than half of kernel size, but got "
-            "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d",
-            padT, padW, padH, kT, kW, kH);
-
-  /* sizes */
-  nslices = input->size(dimN);
-  itime   = input->size(dimt);
-  iheight = input->size(dimh);
-  iwidth  = input->size(dimw);
-
-  otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
-  oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
-  owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
-
-  if (otime < 1 || owidth < 1 || oheight < 1)
-    THError("Given input size: (%dx%dx%dx%d). "
-            "Calculated output size: (%dx%dx%dx%d). Output size is too small",
-            nslices,itime,iheight,iwidth,nslices,otime,oheight,owidth);
-
-  if (gradOutput != NULL) {
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimN, nslices);
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimt, otime);
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, oheight);
-    THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, owidth);
-  }
-}
-
-static void THNN_(VolumetricAveragePooling_updateOutput_frame)(
-          scalar_t *input_p,
-          scalar_t *output_p,
-          int64_t nslices,
-          int64_t itime,
-          int64_t iwidth,
-          int64_t iheight,
-          int64_t otime,
-          int64_t owidth,
-          int64_t oheight,
-          int kT,
-          int kW,
-          int kH,
-          int dT,
-          int dW,
-          int dH,
-          int padT,
-          int padW,
-          int padH,
-          bool count_include_pad)
-{
-  at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
-    for (auto k = start; k < end; k++)
-    {
-      int64_t i, j, ti;
-
-      /* local pointers. */
-      scalar_t *ip = input_p + k * itime * iwidth * iheight;
-      scalar_t *op = output_p + k * otime * owidth * oheight;
-      for (i = 0; i < otime * oheight * owidth; ++i)
-        *(op + i) = 0;
-
-      /* loop over output */
-      for (ti = 0; ti < otime; ti++)
-      {
-        for (i = 0; i < oheight; i++)
-        {
-          for (j = 0; j < owidth; j++)
-          {
-            /* compute pool range. */
-            int64_t tstart = ti * dT - padT;
-            int64_t hstart = i  * dH - padH;
-            int64_t wstart = j  * dW - padW;
-            int64_t tend = std::min(tstart + kT, itime + padT);
-            int64_t hend = std::min(hstart + kH, iheight + padH);
-            int64_t wend = std::min(wstart + kW, iwidth + padW);
-            int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
-            tstart = std::max(tstart, (int64_t) 0);
-            hstart = std::max(hstart, (int64_t) 0);
-            wstart = std::max(wstart, (int64_t) 0);
-            tend = std::min(tend, itime);
-            hend = std::min(hend, iheight);
-            wend = std::min(wend, iwidth);
-
-            int divide_factor;
-            if (count_include_pad)
-              divide_factor = pool_size;
-            else
-              divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
-
-            /* compute local sum: */
-            scalar_t sum = 0.0;
-            int64_t x, y, z;
-
-            for (z = tstart; z < tend; z++)
-            {
-              for (y = hstart; y < hend; y++)
-              {
-                for (x = wstart; x < wend; x++)
-                {
-                  sum +=  *(ip + z * iwidth * iheight + y * iwidth + x);
-                }
-              }
-            }
-
-            /* set output to local max */
-            *op++ += sum / divide_factor;
-          }
-        }
-      }
-    }
-  });
-}
-
-void THNN_(VolumetricAveragePooling_updateOutput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *output,
-          int kT,
-          int kW,
-          int kH,
-          int dT,
-          int dW,
-          int dH,
-          int padT,
-          int padW,
-          int padH,
-          bool ceil_mode,
-          bool count_include_pad)
-{
-  int64_t nslices;
-  int64_t itime;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t otime;
-  int64_t oheight;
-  int64_t owidth;
-  scalar_t *input_data;
-  scalar_t *output_data;
-
-  THNN_(VolumetricAveragePooling_shapeCheck)(
-        state, input, NULL, kT, kW, kH,
-        dT, dW, dH, padT, padW, padH, ceil_mode);
-
-  int dimN = 0;
-  int dimt = 1;
-  int dimh = 2;
-  int dimw = 3;
-
-  if (input->dim() == 5)
-  {
-    dimN++;
-    dimt++;
-    dimh++;
-    dimw++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimN);
-  itime   = input->size(dimt);
-  iheight = input->size(dimh);
-  iwidth  = input->size(dimw);
-  otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
-  oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
-  owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
-
-  /* get contiguous input */
-  input = THTensor_(newContiguous)(input);
-
-  if (input->dim() == 4) /* non-batch mode */
-  {
-    /* resize output */
-    THTensor_(resize4d)(output, nslices, otime, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-    THNN_(VolumetricAveragePooling_updateOutput_frame)(
-      input_data, output_data, nslices,
-      itime, iwidth, iheight,
-      otime, owidth, oheight,
-      kT, kW, kH,
-      dT, dW, dH,
-      padT, padW, padH,
-      count_include_pad
-    );
-  }
-  else  /* batch mode */
-  {
-    int64_t nBatch = input->size(0);
-
-    int64_t istride = nslices * itime * iwidth * iheight;
-    int64_t ostride = nslices * otime * owidth * oheight;
-
-    /* resize output */
-    THTensor_(resize5d)(output, nBatch, nslices, otime, oheight, owidth);
-
-    input_data = input->data<scalar_t>();
-    output_data = output->data<scalar_t>();
-
-    at::parallel_for(0, nBatch, 0, [&](int64_t start, int64_t end) {
-      for (auto p = start; p < end; p++)
-      {
-        THNN_(VolumetricAveragePooling_updateOutput_frame)(
-          input_data + p * istride, output_data + p * ostride, nslices,
-          itime, iwidth, iheight,
-          otime, owidth, oheight,
-          kT, kW, kH,
-          dT, dW, dH,
-          padT, padW, padH,
-          count_include_pad
-        );
-      }
-    });
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(VolumetricAveragePooling_updateGradInput_frame)(
-          scalar_t *gradInput_p,
-          scalar_t *gradOutput_p,
-          int64_t nslices,
-          int64_t itime,
-          int64_t iwidth,
-          int64_t iheight,
-          int64_t otime,
-          int64_t owidth,
-          int64_t oheight,
-          int kT,
-          int kW,
-          int kH,
-          int dT,
-          int dW,
-          int dH,
-          int padT,
-          int padW,
-          int padH,
-          bool count_include_pad)
-{
-  at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
-    for (auto k = start; k < end; k++)
-    {
-      int64_t i, j, ti;
-
-      /* local pointers */
-      scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
-      scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
-      for (i = 0; i < itime*iwidth*iheight; i++)
-        *(ip + i) = 0;
-
-      /* loop over output */
-      for (ti = 0; ti < otime; ti++)
-      {
-        for (i = 0; i < oheight; i++)
-        {
-          for (j = 0; j < owidth; j++)
-          {
-            int64_t tstart = ti * dT - padT;
-            int64_t hstart = i  * dH - padH;
-            int64_t wstart = j  * dW - padW;
-            int64_t tend = std::min(tstart + kT, itime + padT);
-            int64_t hend = std::min(hstart + kH, iheight + padH);
-            int64_t wend = std::min(wstart + kW, iwidth + padW);
-            int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
-            tstart = std::max(tstart, (int64_t) 0);
-            hstart = std::max(hstart, (int64_t) 0);
-            wstart = std::max(wstart, (int64_t) 0);
-            tend = std::min(tend, itime);
-            hend = std::min(hend, iheight);
-            wend = std::min(wend, iwidth);
-
-            int64_t divide_factor;
-            if (count_include_pad)
-              divide_factor = pool_size;
-            else
-              divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
-
-            /* scatter gradients out to footprint: */
-            scalar_t val  = *op++;
-
-            int64_t x,y,z;
-            for (z = tstart; z < tend; z++)
-            {
-              for (y = hstart; y < hend; y++)
-              {
-                for (x = wstart; x < wend; x++)
-                {
-                  *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
-                }
-              }
-            }
-          }
-        }
-      }
-    }
-  });
-}
-
-void THNN_(VolumetricAveragePooling_updateGradInput)(
-          THNNState *state,
-          THTensor *input,
-          THTensor *gradOutput,
-          THTensor *gradInput,
-          int kT,
-          int kW,
-          int kH,
-          int dT,
-          int dW,
-          int dH,
-          int padT,
-          int padW,
-          int padH,
-          bool ceil_mode,
-          bool count_include_pad)
-{
-  int64_t nslices;
-  int64_t itime;
-  int64_t iheight;
-  int64_t iwidth;
-  int64_t otime;
-  int64_t oheight;
-  int64_t owidth;
-  scalar_t *gradInput_data;
-  scalar_t *gradOutput_data;
-
-  int dimN = 0;
-  int dimt = 1;
-  int dimh = 2;
-  int dimw = 3;
-
-  THNN_(VolumetricAveragePooling_shapeCheck)(
-        state, input, gradOutput, kT, kW, kH,
-        dT, dW, dH, padT, padW, padH, ceil_mode);
-
-  /* get contiguous gradOutput */
-  gradOutput = THTensor_(newContiguous)(gradOutput);
-
-  /* resize */
-  THTensor_(resizeAs)(gradInput, input);
-  THTensor_(zero)(gradInput);
-
-  if (input->dim() == 5)
-  {
-    dimN++;
-    dimt++;
-    dimh++;
-    dimw++;
-  }
-
-  /* sizes */
-  nslices = input->size(dimN);
-  itime = input->size(dimt);
-  iheight = input->size(dimh);
-  iwidth = input->size(dimw);
-  otime = gradOutput->size(dimt);
-  oheight = gradOutput->size(dimh);
-  owidth = gradOutput->size(dimw);
-
-  /* get raw pointers */
-  gradInput_data = gradInput->data<scalar_t>();
-  gradOutput_data = gradOutput->data<scalar_t>();
-
-  /* backprop */
-  if (input->dim() == 4) /* non-batch mode*/
-  {
-    THNN_(VolumetricAveragePooling_updateGradInput_frame)(
-      gradInput_data, gradOutput_data, nslices,
-      itime, iwidth, iheight,
-      otime, owidth, oheight,
-      kT, kW, kH,
-      dT, dW, dH,
-      padT, padW, padH,
-      count_include_pad
-    );
-  }
-  else /* batch mode */
-  {
-    int64_t nBatch = input->size(0);
-
-    int64_t istride = nslices * itime * iwidth * iheight;
-    int64_t ostride = nslices * otime * owidth * oheight;
-
-    at::parallel_for(0, nBatch, 0, [&](int64_t start, int64_t end) {
-      for (auto p = start; p < end; p++)
-      {
-        THNN_(VolumetricAveragePooling_updateGradInput_frame)(
-          gradInput_data  + p * istride, gradOutput_data + p * ostride, nslices,
-          itime, iwidth, iheight,
-          otime, owidth, oheight,
-          kT, kW, kH,
-          dT, dW, dH,
-          padT, padW, padH,
-          count_include_pad
-        );
-      }
-    });
-  }
-
-  /* cleanup */
-  c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-#endif
diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp
index 679a4ec..cf86038 100644
--- a/aten/src/THNN/init.cpp
+++ b/aten/src/THNN/init.cpp
@@ -136,9 +136,6 @@
 #include <THNN/generic/SpatialDilatedConvolution.c>
 #include <TH/THGenerateFloatTypes.h>
 
-#include <THNN/generic/VolumetricAveragePooling.c>
-#include <TH/THGenerateFloatTypes.h>
-
 #include <THNN/generic/VolumetricConvolutionMM.c>
 #include <TH/THGenerateFloatTypes.h>
 
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index 1b53acf..d3dff4c 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -277,7 +277,6 @@
         'SpatialConvolutionMM',
         'TemporalConvolution',
         'SpatialMaxUnpooling',
-        'VolumetricAveragePooling',
         'VolumetricMaxUnpooling',
         'VolumetricConvolution',
         'VolumetricFullConvolution',