Port avg_pool3d() to ATen (#21732)
Summary:
This will need a conflict resolution once avg_pool2d() has been merged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21732
Differential Revision: D15824923
Pulled By: ezyang
fbshipit-source-id: 83341e0209b660aecf788272079d8135d78b6ff1
diff --git a/aten/src/ATen/native/AveragePool3d.cpp b/aten/src/ATen/native/AveragePool3d.cpp
new file mode 100644
index 0000000..ff9fdf6
--- /dev/null
+++ b/aten/src/ATen/native/AveragePool3d.cpp
@@ -0,0 +1,488 @@
+#include <ATen/ATen.h>
+#include <ATen/Parallel.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/native/Pool.h>
+#include <tuple>
+
+
+namespace at {
+namespace native {
+
+namespace {
+
+template <typename scalar_t>
+static void avg_pool3d_out_frame(
+ scalar_t *input_p,
+ scalar_t *output_p,
+ int64_t nslices,
+ int64_t itime,
+ int64_t iwidth,
+ int64_t iheight,
+ int64_t otime,
+ int64_t owidth,
+ int64_t oheight,
+ int kT,
+ int kW,
+ int kH,
+ int dT,
+ int dW,
+ int dH,
+ int padT,
+ int padW,
+ int padH,
+ bool count_include_pad)
+{
+ at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
+ for (auto k = start; k < end; k++)
+ {
+ int64_t i, j, ti;
+
+ /* local pointers. */
+ scalar_t *ip = input_p + k * itime * iwidth * iheight;
+ scalar_t *op = output_p + k * otime * owidth * oheight;
+ for (i = 0; i < otime * oheight * owidth; ++i)
+ *(op + i) = 0;
+
+ /* loop over output */
+ for (ti = 0; ti < otime; ti++)
+ {
+ for (i = 0; i < oheight; i++)
+ {
+ for (j = 0; j < owidth; j++)
+ {
+ /* compute pool range. */
+ int64_t tstart = ti * dT - padT;
+ int64_t hstart = i * dH - padH;
+ int64_t wstart = j * dW - padW;
+ int64_t tend = std::min(tstart + kT, itime + padT);
+ int64_t hend = std::min(hstart + kH, iheight + padH);
+ int64_t wend = std::min(wstart + kW, iwidth + padW);
+ int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+ tstart = std::max(tstart, (int64_t) 0);
+ hstart = std::max(hstart, (int64_t) 0);
+ wstart = std::max(wstart, (int64_t) 0);
+ tend = std::min(tend, itime);
+ hend = std::min(hend, iheight);
+ wend = std::min(wend, iwidth);
+
+ int divide_factor;
+ if (count_include_pad)
+ divide_factor = pool_size;
+ else
+ divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
+
+ /* compute local sum: */
+ scalar_t sum = 0.0;
+ int64_t x, y, z;
+
+ for (z = tstart; z < tend; z++)
+ {
+ for (y = hstart; y < hend; y++)
+ {
+ for (x = wstart; x < wend; x++)
+ {
+ sum += *(ip + z * iwidth * iheight + y * iwidth + x);
+ }
+ }
+ }
+
+ /* set output to local max */
+ *op++ += sum / divide_factor;
+ }
+ }
+ }
+ }
+ });
+}
+
+void avg_pool3d_out_cpu_template(
+ Tensor& output,
+ const Tensor& input_,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ // #20866 [JIT] stride.empty() is passed through
+ // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+ TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+ (stride.empty() || stride.size() == 3) &&
+ (padding.size() == 1 || padding.size() == 3),
+ "avg_pool3d: all IntArrayRef sizes must be 3");
+
+ TORCH_CHECK((input_.ndimension() == 4 || input_.ndimension() == 5),
+ "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+ const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+ const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+ const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+ const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+ const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+ const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+ const int padT = safe_downcast<int, int64_t>(padding[0]);
+ const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+ const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+ const int64_t nslices = input_.size(-4);
+ const int64_t itime = input_.size(-3);
+ const int64_t iheight = input_.size(-2);
+ const int64_t iwidth = input_.size(-1);
+
+ const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+ const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+ const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+ pool3d_shape_check(
+ input_,
+ nslices,
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ 1, 1, 1,
+ itime, iheight, iwidth,
+ otime, oheight, owidth,
+ /*check_input_size=*/ true);
+
+ /* get contiguous input */
+ Tensor input = input_.contiguous();
+
+ if (input.ndimension() == 4) /* non-batch mode */
+ {
+ /* resize output */
+ output.resize_({nslices, otime, oheight, owidth});
+
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+ "avg_pool3d_out_frame",
+ [&] {
+ scalar_t *input_data = input.data<scalar_t>();
+ scalar_t *output_data = output.data<scalar_t>();
+
+ avg_pool3d_out_frame(
+ input_data, output_data, nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ kT, kW, kH,
+ dT, dW, dH,
+ padT, padW, padH,
+ count_include_pad);
+ });
+ }
+ else /* batch mode */
+ {
+ const int64_t nbatch = input.size(0);
+ const int64_t istride = nslices * itime * iwidth * iheight;
+ const int64_t ostride = nslices * otime * owidth * oheight;
+
+ /* resize output */
+ output.resize_({nbatch, nslices, otime, oheight, owidth});
+
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+ "avg_pool3d_out_frame",
+ [&] {
+ scalar_t *input_data = input.data<scalar_t>();
+ scalar_t *output_data = output.data<scalar_t>();
+
+ at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
+ for (auto p = start; p < end; p++) {
+ avg_pool3d_out_frame(
+ input_data + p * istride, output_data + p * ostride, nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ kT, kW, kH,
+ dT, dW, dH,
+ padT, padW, padH,
+ count_include_pad
+ );
+ }
+ });
+ });
+ }
+}
+
+template <typename scalar_t>
+static void avg_pool3d_backward_out_frame(
+ scalar_t *gradInput_p,
+ scalar_t *gradOutput_p,
+ int64_t nslices,
+ int64_t itime,
+ int64_t iwidth,
+ int64_t iheight,
+ int64_t otime,
+ int64_t owidth,
+ int64_t oheight,
+ int kT,
+ int kW,
+ int kH,
+ int dT,
+ int dW,
+ int dH,
+ int padT,
+ int padW,
+ int padH,
+ bool count_include_pad)
+{
+ at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
+ for (auto k = start; k < end; k++)
+ {
+ int64_t i, j, ti;
+
+ /* local pointers */
+ scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
+ scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
+ for (i = 0; i < itime*iwidth*iheight; i++)
+ *(ip + i) = 0;
+
+ /* loop over output */
+ for (ti = 0; ti < otime; ti++)
+ {
+ for (i = 0; i < oheight; i++)
+ {
+ for (j = 0; j < owidth; j++)
+ {
+ int64_t tstart = ti * dT - padT;
+ int64_t hstart = i * dH - padH;
+ int64_t wstart = j * dW - padW;
+ int64_t tend = std::min(tstart + kT, itime + padT);
+ int64_t hend = std::min(hstart + kH, iheight + padH);
+ int64_t wend = std::min(wstart + kW, iwidth + padW);
+ int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
+ tstart = std::max(tstart, (int64_t) 0);
+ hstart = std::max(hstart, (int64_t) 0);
+ wstart = std::max(wstart, (int64_t) 0);
+ tend = std::min(tend, itime);
+ hend = std::min(hend, iheight);
+ wend = std::min(wend, iwidth);
+
+ int64_t divide_factor;
+ if (count_include_pad)
+ divide_factor = pool_size;
+ else
+ divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
+
+ /* scatter gradients out to footprint: */
+ scalar_t val = *op++;
+
+ int64_t x,y,z;
+ for (z = tstart; z < tend; z++)
+ {
+ for (y = hstart; y < hend; y++)
+ {
+ for (x = wstart; x < wend; x++)
+ {
+ *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ });
+}
+
+Tensor& avg_pool3d_backward_out_cpu_template(
+ Tensor& gradInput,
+ const Tensor& gradOutput_,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ // #20866 [JIT] stride.empty() is passed through
+ // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+ TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+ (stride.empty() || stride.size() == 3) &&
+ (padding.size() == 1 || padding.size() == 3),
+ "avg_pool3d: all IntArrayRef sizes must be 3");
+
+ TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
+ "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+ const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+ const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+ const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+ const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+ const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+ const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+ const int padT = safe_downcast<int, int64_t>(padding[0]);
+ const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+ const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+ const int64_t nslices = input.size(-4);
+ const int64_t itime = input.size(-3);
+ const int64_t iheight = input.size(-2);
+ const int64_t iwidth = input.size(-1);
+
+ /* get contiguous gradOutput */
+ Tensor gradOutput = gradOutput_.contiguous();
+
+ const int64_t otime = gradOutput.size(-3);
+ const int64_t oheight = gradOutput.size(-2);
+ const int64_t owidth = gradOutput.size(-1);
+
+ /* XXX shape check behavior from TH */
+ const int64_t otime_for_shape_check = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+ const int64_t oheight_for_shape_check = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+ const int64_t owidth_for_shape_check = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+ avg_pool3d_backward_shape_check(
+ input,
+ gradOutput_,
+ nslices,
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ itime, iheight, iwidth,
+ otime_for_shape_check, oheight_for_shape_check, owidth_for_shape_check);
+
+ /* resize */
+ gradInput.resize_as_(input);
+ gradInput.zero_();
+
+ /* backprop */
+ if (input.ndimension() == 4) /* non-batch mode*/
+ {
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+ "avg_pool3d_backward_out_frame",
+ [&] {
+ scalar_t *gradInput_data = gradInput.data<scalar_t>();
+ scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+
+ avg_pool3d_backward_out_frame(
+ gradInput_data, gradOutput_data,
+ nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ kT, kW, kH,
+ dT, dW, dH,
+ padT, padW, padH,
+ count_include_pad);
+ });
+ }
+ else /* batch mode */
+ {
+ const int64_t nbatch = input.size(0);
+ const int64_t istride = nslices * itime * iwidth * iheight;
+ const int64_t ostride = nslices * otime * owidth * oheight;
+
+ AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
+ "avg_pool3d_backward_out_frame",
+ [&] {
+ scalar_t *gradInput_data = gradInput.data<scalar_t>();
+ scalar_t *gradOutput_data = gradOutput.data<scalar_t>();
+
+ at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
+ for (auto p = start; p < end; p++)
+ {
+ avg_pool3d_backward_out_frame(
+ gradInput_data + p * istride, gradOutput_data + p * ostride, nslices,
+ itime, iwidth, iheight,
+ otime, owidth, oheight,
+ kT, kW, kH,
+ dT, dW, dH,
+ padT, padW, padH,
+ count_include_pad
+ );
+ }
+ });
+ });
+ }
+
+ return gradInput;
+}
+
+} // namespace
+
+Tensor& avg_pool3d_out_cpu(
+ Tensor& output,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ avg_pool3d_out_cpu_template(
+ output,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return output;
+}
+
+Tensor avg_pool3d_cpu(
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ Tensor output = at::empty({0}, input.options());
+ avg_pool3d_out_cpu_template(
+ output,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return output;
+}
+
+Tensor& avg_pool3d_backward_out_cpu(
+ Tensor& gradInput,
+ const Tensor& gradOutput_,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ avg_pool3d_backward_out_cpu_template(
+ gradInput,
+ gradOutput_,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return gradInput;
+}
+
+Tensor avg_pool3d_backward_cpu(
+ const Tensor& gradOutput_,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ auto gradInput = at::zeros_like(input);
+ avg_pool3d_backward_out_cpu_template(
+ gradInput,
+ gradOutput_,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return gradInput;
+}
+
+} // at::native
+} // at
diff --git a/aten/src/ATen/native/DilatedMaxPool3d.cpp b/aten/src/ATen/native/DilatedMaxPool3d.cpp
index 5d4ffb3..c3b3290 100644
--- a/aten/src/ATen/native/DilatedMaxPool3d.cpp
+++ b/aten/src/ATen/native/DilatedMaxPool3d.cpp
@@ -184,7 +184,7 @@
const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
- max_pool3d_with_indices_shape_check(
+ pool3d_shape_check(
input_,
nslices,
kT, kH, kW,
@@ -396,7 +396,7 @@
const int64_t oheight = gradOutput.size(-2);
const int64_t owidth = gradOutput.size(-1);
- max_pool3d_with_indices_shape_check(
+ max_pool3d_backward_shape_check(
input,
gradOutput,
indices,
diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h
index c204722..d9169ef 100644
--- a/aten/src/ATen/native/Pool.h
+++ b/aten/src/ATen/native/Pool.h
@@ -132,8 +132,9 @@
check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
}
+// AveragePool3d/DilatedMaxPool3d (forward)
static inline void
-max_pool3d_with_indices_shape_check(
+pool3d_shape_check(
const Tensor& input,
int64_t nslices,
int kT, int kH, int kW,
@@ -141,7 +142,8 @@
int pT, int pH, int pW,
int dilationT, int dilationH, int dilationW,
int64_t itime, int64_t iheight, int64_t iwidth,
- int64_t otime, int64_t oheight, int64_t owidth)
+ int64_t otime, int64_t oheight, int64_t owidth,
+ bool check_input_size=false)
{
const int64_t ndim = input.ndimension();
@@ -158,6 +160,12 @@
TORCH_CHECK(input.numel() > 0 && (ndim == 4 || ndim == 5),
"non-empty 4D or 5D (batch mode) tensor expected for input, but got ndim: ", ndim);
+ if (check_input_size) { // AveragePool3d
+ TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
+ "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
+ "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
+ }
+
TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
"pad should be smaller than half of kernel size, but got "
"kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
@@ -171,7 +179,7 @@
}
static inline void
-max_pool3d_with_indices_shape_check(
+max_pool3d_backward_shape_check(
const Tensor& input,
const Tensor& gradOutput,
const Tensor& indices,
@@ -185,7 +193,7 @@
{
const int64_t ndim = input.ndimension();
- max_pool3d_with_indices_shape_check(
+ pool3d_shape_check(
input,
nslices,
kT, kH, kW,
@@ -206,6 +214,36 @@
check_dim_size(indices, ndim, ndim-1, owidth);
}
+static inline void
+avg_pool3d_backward_shape_check(
+ const Tensor& input,
+ const Tensor& gradOutput,
+ int64_t nslices,
+ int kT, int kH, int kW,
+ int dT, int dH, int dW,
+ int pT, int pH, int pW,
+ int64_t itime, int64_t iheight, int64_t iwidth,
+ int64_t otime, int64_t oheight, int64_t owidth)
+{
+ const int64_t ndim = input.ndimension();
+
+ pool3d_shape_check(
+ input,
+ nslices,
+ kT, kH, kW,
+ dT, dH, dW,
+ pT, pH, pW,
+ 1, 1, 1,
+ itime, iheight, iwidth,
+ otime, oheight, owidth,
+ true);
+
+ check_dim_size(gradOutput, ndim, ndim-4, nslices);
+ check_dim_size(gradOutput, ndim, ndim-3, otime);
+ check_dim_size(gradOutput, ndim, ndim-2, oheight);
+ check_dim_size(gradOutput, ndim, ndim-1, owidth);
+}
+
} // namespace
} // at::native
diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu
new file mode 100644
index 0000000..439d0ac
--- /dev/null
+++ b/aten/src/ATen/native/cuda/AveragePool3d.cu
@@ -0,0 +1,675 @@
+#include <ATen/AccumulateType.h>
+#include <ATen/native/Pool.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <ATen/cuda/detail/TensorInfo.cuh>
+#include <ATen/cuda/detail/IndexUtils.cuh>
+#include <ATen/cuda/detail/KernelUtils.h>
+#include <THC/THCNumerics.cuh>
+#include <c10/macros/Macros.h>
+
+
+namespace at {
+namespace native {
+namespace {
+
+__device__ inline int min(int a, int b) {
+ return a <= b ? a : b;
+}
+
+__device__ inline int max(int a, int b) {
+ return a >= b ? a : b;
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_output(
+ PackedTensorAccessor<scalar_t, 4> input,
+ PackedTensorAccessor<scalar_t, 4> output,
+ int kT, int kH, int kW,
+ int dT, int dH, int dW,
+ int padT, int padH, int padW,
+ bool count_include_pad,
+ int offsetZ)
+{
+ int oCol = blockIdx.x * blockDim.x + threadIdx.x;
+ int oRow = blockIdx.y * blockDim.y + threadIdx.y;
+ int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature
+
+ if (oRow < output.size(2) && oCol < output.size(3))
+ {
+ accscalar_t sum = 0.0;
+
+ int tstart = oFrame * dT - padT;
+ int hstart = oRow * dH - padH;
+ int wstart = oCol * dW - padW;
+ int tend = min(tstart + kT, input.size(1) + padT);
+ int hend = min(hstart + kH, input.size(2) + padH);
+ int wend = min(wstart + kW, input.size(3) + padW);
+ int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+ tstart = max(tstart, 0);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ tend = min(tend, input.size(1));
+ hend = min(hend, input.size(2));
+ wend = min(wend, input.size(3));
+
+ accscalar_t divide_factor;
+ if (count_include_pad)
+ divide_factor = static_cast<accscalar_t>(pool_size);
+ else
+ divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+ int ti, hi, wi;
+ for (ti = tstart; ti < tend; ++ti)
+ {
+ for (hi = hstart; hi < hend; ++hi)
+ {
+ for (wi = wstart; wi < wend; ++wi)
+ {
+ scalar_t val = input[slice][ti][hi][wi];
+ sum += val;
+ }
+ }
+ }
+
+ output[slice][oFrame][oRow][oCol] = ScalarConvert<accscalar_t, scalar_t>::to(sum / divide_factor);
+ }
+}
+
+// Inner-most loop size (kW) passed as template parameter for
+// performance reasons.
+//
+template<int KERNEL_WIDTH, typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_output(
+ PackedTensorAccessor<scalar_t, 4> input,
+ PackedTensorAccessor<scalar_t, 4> output,
+ int kT, int kH,
+ int dT, int dH, int dW,
+ int padT, int padH, int padW,
+ bool count_include_pad,
+ int offsetZ)
+{
+ int oCol = blockIdx.x * blockDim.x + threadIdx.x;
+ int oRow = blockIdx.y * blockDim.y + threadIdx.y;
+ int oFrame = (blockIdx.z + offsetZ) % output.size(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / output.size(1); // output slice/feature
+
+ if (oRow < output.size(2) && oCol < output.size(3))
+ {
+ accscalar_t sum = 0.0;
+
+ int tstart = oFrame * dT - padT;
+ int hstart = oRow * dH - padH;
+ int wstart = oCol * dW - padW;
+ int tend = min(tstart + kT, input.size(1) + padT);
+ int hend = min(hstart + kH, input.size(2) + padH);
+ int wend = min(wstart + KERNEL_WIDTH, input.size(3) + padW);
+ int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+ tstart = max(tstart, 0);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ tend = min(tend, input.size(1));
+ hend = min(hend, input.size(2));
+ wend = min(wend, input.size(3));
+
+ accscalar_t divide_factor;
+ if (count_include_pad)
+ divide_factor = static_cast<accscalar_t>(pool_size);
+ else
+ divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+ int ti, hi, wi;
+ for (ti = tstart; ti < tend; ++ti)
+ {
+ for (hi = hstart; hi < hend; ++hi)
+ {
+ for (wi = wstart; wi < wend; ++wi)
+ {
+ scalar_t val = input[slice][ti][hi][wi];
+ sum += val;
+ }
+ }
+ }
+
+ output[slice][oFrame][oRow][oCol] = ScalarConvert<accscalar_t, scalar_t>::to(sum / divide_factor);
+ }
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_single_backward_out_frame_stride1(
+ PackedTensorAccessor<scalar_t, 4> gradOutput,
+ PackedTensorAccessor<scalar_t, 4> gradInput,
+ int kT, int kH, int kW,
+ accscalar_t normFactor,
+ int offsetZ)
+{
+ int iCol = blockIdx.x * blockDim.x + threadIdx.x;
+ int iRow = blockIdx.y * blockDim.y + threadIdx.y;
+ int iFrame = (blockIdx.z + offsetZ) % gradInput.size(1); // input frame/time
+ int slice = (blockIdx.z + offsetZ) / gradInput.size(1); // input slice/feature
+
+ // guard against over-tiled threads
+ if (iRow < gradInput.size(2) && iCol < gradInput.size(3))
+ {
+ accscalar_t sum = 0.0;
+ scalar_t *gOut = &gradOutput[slice][max(0, iFrame - kT + 1)]
+ [max(0, iRow - kH + 1)][max(0, iCol - kW + 1)];
+ int frameOffset = 0;
+ for (int oFrame = max(0, iFrame - kT + 1);
+ oFrame < min(iFrame + 1, gradOutput.size(1));
+ ++oFrame)
+ {
+ int rowOffset = frameOffset;
+ for (int oRow = max(0, iRow - kH + 1);
+ oRow < min(iRow + 1, gradOutput.size(2));
+ ++oRow)
+ {
+ int colOffset = rowOffset;
+ for (int oCol = max(0, iCol - kW + 1);
+ oCol < min(iCol + 1, gradOutput.size(3));
+ ++oCol)
+ {
+ sum += gOut[colOffset];
+ ++colOffset;
+ }
+ rowOffset += gradOutput.size(3);
+ }
+ frameOffset += gradOutput.size(2) * gradOutput.size(3);
+ }
+ gradInput[slice][iFrame][iRow][iCol] = ScalarConvert<accscalar_t, scalar_t>::to(sum * normFactor);
+ }
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_grad_input_atomic(
+ PackedTensorAccessor<scalar_t, 4> gradOutput,
+ PackedTensorAccessor<scalar_t, 4> gradInput,
+ int kT, int kH, int kW,
+ int dT, int dH, int dW,
+ int padT, int padH, int padW,
+ bool count_include_pad,
+ int offsetZ)
+{
+ int oCol = blockIdx.x * blockDim.x + threadIdx.x;
+ int oRow = blockIdx.y * blockDim.y + threadIdx.y;
+ int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // gradOutput frame/time
+ int slice = (blockIdx.z + offsetZ) / gradOutput.size(1); // gradOutput slice/feature
+
+ // guard against over-tiled threads
+ if (oRow < gradOutput.size(2) && oCol < gradOutput.size(3))
+ {
+ int tstart = oFrame * dT - padT;
+ int hstart = oRow * dH - padH;
+ int wstart = oCol * dW - padW;
+ int tend = min(tstart + kT, gradInput.size(1) + padT);
+ int hend = min(hstart + kH, gradInput.size(2) + padH);
+ int wend = min(wstart + kW, gradInput.size(3) + padW);
+ int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+ tstart = max(tstart, 0);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ tend = min(tend, gradInput.size(1));
+ hend = min(hend, gradInput.size(2));
+ wend = min(wend, gradInput.size(3));
+
+ accscalar_t divide_factor;
+ if (count_include_pad)
+ divide_factor = static_cast<accscalar_t>(pool_size);
+ else
+ divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+ scalar_t val = ScalarConvert<accscalar_t, scalar_t>::to(
+ ScalarConvert<scalar_t, accscalar_t>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
+ for (int iFrame = tstart; iFrame < tend; ++iFrame)
+ {
+ for (int iRow = hstart; iRow < hend; ++iRow)
+ {
+ for (int iCol = wstart; iCol < wend; ++iCol)
+ {
+ atomicAdd(&gradInput[slice][iFrame][iRow][iCol], val);
+ }
+ }
+ }
+ }
+}
+
+template <typename scalar_t, typename accscalar_t>
+__global__ void avg_pool3d_cuda_update_grad_input(
+ PackedTensorAccessor<scalar_t, 4> gradOutput,
+ PackedTensorAccessor<scalar_t, 4> gradInput,
+ int kT, int kH, int kW,
+ int dT, int dH, int dW,
+ int padT, int padH, int padW,
+ bool count_include_pad, int offsetZ)
+{
+ int oCol = blockIdx.x * blockDim.x + threadIdx.x;
+ int oRow = blockIdx.y * blockDim.y + threadIdx.y;
+ int oFrame = (blockIdx.z + offsetZ) % gradOutput.size(1); // gradOutput frame/time
+ int slice = (blockIdx.z + offsetZ) / gradOutput.size(1); // gradOutput slice/feature
+
+ // guard against over-tiled threads
+ if (oRow < gradOutput.size(2) && oCol < gradOutput.size(3))
+ {
+ int tstart = oFrame * dT - padT;
+ int hstart = oRow * dH - padH;
+ int wstart = oCol * dW - padW;
+ int tend = min(tstart + kT, gradInput.size(1) + padT);
+ int hend = min(hstart + kH, gradInput.size(2) + padH);
+ int wend = min(wstart + kW, gradInput.size(3) + padW);
+ int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
+ tstart = max(tstart, 0);
+ hstart = max(hstart, 0);
+ wstart = max(wstart, 0);
+ tend = min(tend, gradInput.size(1));
+ hend = min(hend, gradInput.size(2));
+ wend = min(wend, gradInput.size(3));
+
+ accscalar_t divide_factor;
+ if (count_include_pad)
+ divide_factor = static_cast<accscalar_t>(pool_size);
+ else
+ divide_factor = static_cast<accscalar_t>((tend - tstart) * (hend - hstart) * (wend - wstart));
+
+ scalar_t val = ScalarConvert<accscalar_t, scalar_t>::to(
+ ScalarConvert<scalar_t, accscalar_t>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
+ for (int iFrame = tstart; iFrame < tend; ++iFrame)
+ {
+ for (int iRow = hstart; iRow < hend; ++iRow)
+ {
+ for (int iCol = wstart; iCol < wend; ++iCol)
+ {
+ gradInput[slice][iFrame][iRow][iCol] = val;
+ }
+ }
+ }
+ }
+}
+
+#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
+ avg_pool3d_cuda_update_output<KW, scalar_t, accscalar_t> \
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( \
+ work_input.packed_accessor<scalar_t, 4>(), \
+ work_output.packed_accessor<scalar_t, 4>(), \
+ kT, kH, \
+ dT, dH, dW, \
+ padT, padH, padW, \
+ count_include_pad, \
+ offsetZ); \
+ break
+
+void avg_pool3d_out_cuda_template(
+ Tensor& output,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ TensorArg output_arg{ output, "output", 1 };
+ TensorArg input_arg{ input, "input", 2 };
+
+ checkAllSameGPU("avg_pool3d_out_cuda", {output_arg, input_arg});
+
+ // #20866 [JIT] stride.empty() is passed through
+ // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+ TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+ (stride.empty() || stride.size() == 3) &&
+ (padding.size() == 1 || padding.size() == 3),
+ "avg_pool3d: all IntArrayRef sizes must be 3");
+
+ TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
+ "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+ const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+ const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+ const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+ const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+ const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+ const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+ const int padT = safe_downcast<int, int64_t>(padding[0]);
+ const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+ const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+ const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
+ const int64_t nslices = input.size(-4);
+ const int64_t itime = input.size(-3);
+ const int64_t iheight = input.size(-2);
+ const int64_t iwidth = input.size(-1);
+
+ const int64_t otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+ const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+ const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+ pool3d_shape_check(
+ input,
+ nslices,
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ 1, 1, 1,
+ itime, iheight, iwidth,
+ otime, oheight, owidth,
+ /*check_input_size=*/ true);
+
+ if (input.ndimension() == 4) {
+ output.resize_({ nslices, otime, oheight, owidth});
+ }
+ else {
+ output.resize_({nbatch, nslices, otime, oheight, owidth});
+ }
+
+ Tensor work_input = input.contiguous();
+ Tensor work_output = output;
+ if (input.ndimension() == 5) {
+ // Collapse batch and feature dimensions.
+ work_input = work_input.reshape({nbatch * nslices, itime, iheight, iwidth});
+ work_output = work_output.reshape({nbatch * nslices, otime, oheight, owidth});
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ input.scalar_type(),
+ "avg_pool3d_out_cuda",
+ [&] {
+ using accscalar_t = acc_type<scalar_t, true>;
+ int64_t totalZ = otime * nslices * nbatch;
+ int64_t offsetZ = 0;
+ dim3 block(32, 8);
+
+ while (totalZ > 0) {
+ dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
+ cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+
+ switch (kW) {
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
+ default:
+ avg_pool3d_cuda_update_output<scalar_t, accscalar_t>
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ work_input.packed_accessor<scalar_t, 4>(),
+ work_output.packed_accessor<scalar_t, 4>(),
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ count_include_pad,
+ offsetZ);
+ break;
+ }
+
+ TORCH_CHECK(cudaGetLastError() == cudaSuccess,
+ "avg_pool3d_out_cuda failed with error code ",
+ cudaGetLastError());
+
+ totalZ -= 65535;
+ offsetZ += 65535;
+ }
+ }
+ );
+}
+
+#undef LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH
+
+void avg_pool3d_backward_out_cuda_template(
+ Tensor& gradInput,
+ const Tensor& gradOutput,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ TensorArg gradInput_arg{ gradInput, "gradInput", 1 };
+ TensorArg gradOutput_arg{ gradOutput, "gradOutput", 2 };
+ TensorArg input_arg{ input, "input", 3 };
+
+ checkAllSameGPU("avg_pool3d_backward_out_cuda",
+ {gradInput_arg, gradOutput_arg, input_arg});
+
+ // #20866 [JIT] stride.empty() is passed through
+ // #20866 [LIBTORCH] IntegrationTest.MNIST: padding.size() == 1
+ TORCH_INTERNAL_ASSERT(kernel_size.size() == 3 &&
+ (stride.empty() || stride.size() == 3) &&
+ (padding.size() == 1 || padding.size() == 3),
+ "avg_pool3d: all IntArrayRef sizes must be 3");
+
+ TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5),
+ "non-empty 4D or 5D (batch mode) tensor expected for input");
+
+ TORCH_CHECK((gradOutput.ndimension() == 4 || gradOutput.ndimension() == 5),
+ "non-empty 4D or 5D (batch mode) tensor expected for gradOutput");
+
+ // Resize and initialize result tensor.
+ gradInput.resize_as_(input);
+ gradInput.zero_();
+
+ const int kT = safe_downcast<int, int64_t>(kernel_size[0]);
+ const int kH = safe_downcast<int, int64_t>(kernel_size[1]);
+ const int kW = safe_downcast<int, int64_t>(kernel_size[2]);
+
+ const int dT = stride.empty() ? kT : safe_downcast<int, int64_t>(stride[0]);
+ const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[1]);
+ const int dW = stride.empty() ? kW : safe_downcast<int, int64_t>(stride[2]);
+
+ const int padT = safe_downcast<int, int64_t>(padding[0]);
+ const int padH = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[1]);
+ const int padW = padding.size() == 1 ? padT : safe_downcast<int, int64_t>(padding[2]);
+
+ const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
+ const int64_t nslices = input.size(-4);
+ const int64_t itime = input.size(-3);
+ const int64_t iheight = input.size(-2);
+ const int64_t iwidth = input.size(-1);
+
+ const int64_t otime = gradOutput.size(-3);
+ const int64_t oheight = gradOutput.size(-2);
+ const int64_t owidth = gradOutput.size(-1);
+
+ /* XXX shape check behavior from TH */
+ const int64_t otime_for_shape_check = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
+ const int64_t oheight_for_shape_check = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
+ const int64_t owidth_for_chape_check = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
+
+ const bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW);
+
+ avg_pool3d_backward_shape_check(
+ input,
+ gradOutput,
+ nslices,
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ itime, iheight, iwidth,
+ otime, oheight, owidth);
+
+ Tensor work_grad_input = gradInput;
+ Tensor work_grad_output = gradOutput.contiguous();
+
+ if (input.ndimension() == 5) {
+ // Collapse batch and feature dimensions.
+ work_grad_input = work_grad_input.reshape({nbatch * nslices, itime, iheight, iwidth});
+ work_grad_output = work_grad_output.reshape({nbatch * nslices, otime, oheight, owidth});
+ }
+
+
+ // Optimizing for stride 1 is probably only of limited value, but this
+ // specialization yields 3x speedup over the atomicAdd implementation.
+ // Padding must be 0, otherwise, pool size may change.
+ if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+ "avg_pool3d_backward_out_frame_stride1",
+ [&] {
+ using accscalar_t = acc_type<scalar_t, true>;
+ int64_t totalZ = itime * nslices * nbatch;
+ int64_t offsetZ = 0;
+ dim3 block(32, 8);
+
+ while (totalZ > 0) {
+ dim3 grid(cuda::ATenCeilDiv(iwidth, static_cast<int64_t>(block.x)),
+ cuda::ATenCeilDiv(iheight, static_cast<int64_t>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+
+ avg_pool3d_single_backward_out_frame_stride1<scalar_t, accscalar_t>
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ work_grad_output.packed_accessor<scalar_t, 4>(),
+ work_grad_input.packed_accessor<scalar_t, 4>(),
+ kT, kH, kW,
+ 1.0f/(kT * kH * kW),
+ offsetZ);
+
+ TORCH_CHECK(cudaGetLastError() == cudaSuccess,
+ "avg_pool3d_backward_out_frame failed with error code ",
+ cudaGetLastError());
+
+ totalZ -= 65535;
+ offsetZ += 65535;
+ }
+ }
+ );
+ }
+ else {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(),
+ "avg_pool3d_backward_out_frame",
+ [&] {
+ using accscalar_t = acc_type<scalar_t, true>;
+ int64_t totalZ = otime * nslices * nbatch;
+ int64_t offsetZ = 0;
+ dim3 block(32, 8);
+
+ while (totalZ > 0) {
+ dim3 grid(cuda::ATenCeilDiv(owidth, static_cast<int64_t>(block.x)),
+ cuda::ATenCeilDiv(oheight, static_cast<int64_t>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+
+ if (kernelsOverlap) {
+ avg_pool3d_cuda_update_grad_input_atomic<scalar_t, accscalar_t>
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ work_grad_output.packed_accessor<scalar_t, 4>(),
+ work_grad_input.packed_accessor<scalar_t, 4>(),
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ count_include_pad,
+ offsetZ);
+ }
+ else {
+ avg_pool3d_cuda_update_grad_input<scalar_t, accscalar_t>
+ <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ work_grad_output.packed_accessor<scalar_t, 4>(),
+ work_grad_input.packed_accessor<scalar_t, 4>(),
+ kT, kH, kW,
+ dT, dH, dW,
+ padT, padH, padW,
+ count_include_pad,
+ offsetZ);
+ }
+
+ TORCH_CHECK(cudaGetLastError() == cudaSuccess,
+ "avg_pool3d_backward_out_frame failed with error code ",
+ cudaGetLastError());
+
+ totalZ -= 65535;
+ offsetZ += 65535;
+ }
+ }
+ );
+ }
+}
+
+} // namespace
+
+Tensor& avg_pool3d_out_cuda(
+ Tensor& output,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ avg_pool3d_out_cuda_template(
+ output,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return output;
+}
+
+Tensor avg_pool3d_cuda(
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ Tensor output = at::empty({0}, input.options());
+ avg_pool3d_out_cuda_template(
+ output,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return output;
+}
+
+Tensor& avg_pool3d_backward_out_cuda(
+ Tensor& gradInput,
+ const Tensor& gradOutput_,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ avg_pool3d_backward_out_cuda_template(
+ gradInput,
+ gradOutput_,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return gradInput;
+}
+
+Tensor avg_pool3d_backward_cuda(
+ const Tensor& gradOutput_,
+ const Tensor& input,
+ IntArrayRef kernel_size,
+ IntArrayRef stride,
+ IntArrayRef padding,
+ bool ceil_mode,
+ bool count_include_pad)
+{
+ auto gradInput = at::zeros_like(input);
+ avg_pool3d_backward_out_cuda_template(
+ gradInput,
+ gradOutput_,
+ input,
+ kernel_size,
+ stride,
+ padding,
+ ceil_mode,
+ count_include_pad);
+ return gradInput;
+}
+
+} // at::native
+} // at
diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
index 9305966..e6b39b0 100644
--- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
+++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu
@@ -327,7 +327,7 @@
const int64_t oheight = pooling_output_shape<int64_t>(iheight, kH, pH, dH, dilationH, ceil_mode);
const int64_t owidth = pooling_output_shape<int64_t>(iwidth, kW, pW, dW, dilationW, ceil_mode);
- max_pool3d_with_indices_shape_check(
+ pool3d_shape_check(
input,
nslices,
kT, kH, kW,
@@ -440,7 +440,7 @@
const int64_t iheight = gradInput.size(-2);
const int64_t iwidth = gradInput.size(-1);
- max_pool3d_with_indices_shape_check(
+ max_pool3d_backward_shape_check(
input,
gradOutput,
indices,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 1e44157..6cbc84c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4737,26 +4737,26 @@
- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
dispatch:
- CPU: legacy::cpu::_thnn_avg_pool3d_forward_out
- CUDA: legacy::cuda::_thnn_avg_pool3d_forward_out
+ CPU: avg_pool3d_out_cpu
+ CUDA: avg_pool3d_out_cuda
- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor
python_module: nn
dispatch:
- CPU: legacy::cpu::_thnn_avg_pool3d_forward
- CUDA: legacy::cuda::_thnn_avg_pool3d_forward
+ CPU: avg_pool3d_cpu
+ CUDA: avg_pool3d_cuda
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, *, Tensor(a!) grad_input) -> Tensor(a!)
python_module: nn
dispatch:
- CPU: legacy::cpu::_thnn_avg_pool3d_backward_out
- CUDA: legacy::cuda::_thnn_avg_pool3d_backward_out
+ CPU: avg_pool3d_backward_out_cpu
+ CUDA: avg_pool3d_backward_out_cuda
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad) -> Tensor
python_module: nn
dispatch:
- CPU: legacy::cpu::_thnn_avg_pool3d_backward
- CUDA: legacy::cuda::_thnn_avg_pool3d_backward
+ CPU: avg_pool3d_backward_cpu
+ CUDA: avg_pool3d_backward_cuda
# Return: (Tensor output, Tensor indices)
- func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index f99ca2b..e925b04 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -116,16 +116,6 @@
output: 'false'
grad_input: 'false'
-# Pooling
-
-- name: _thnn_avg_pool3d(Tensor self, IntArrayRef[3] kernel_size, IntArrayRef[3] stride={}, IntArrayRef[3] padding=0, bool ceil_mode=false, bool count_include_pad=true)
- cname: VolumetricAveragePooling
- default_init:
- stride: kernel_size
- scalar_check:
- output: 'false'
- grad_input: 'false'
-
# Private functions. These also exist in TH, but we want the backwards functions
# to implement derivatives.
diff --git a/aten/src/THCUNN/CMakeLists.txt b/aten/src/THCUNN/CMakeLists.txt
index ecb5ee4..2915ef4 100644
--- a/aten/src/THCUNN/CMakeLists.txt
+++ b/aten/src/THCUNN/CMakeLists.txt
@@ -39,7 +39,6 @@
${CMAKE_CURRENT_SOURCE_DIR}/TemporalConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/TemporalMaxPooling.cu
${CMAKE_CURRENT_SOURCE_DIR}/TemporalRowConvolution.cu
-${CMAKE_CURRENT_SOURCE_DIR}/VolumetricAveragePooling.cu
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricDilatedConvolution.cu
${CMAKE_CURRENT_SOURCE_DIR}/VolumetricFullConvolution.cu
diff --git a/aten/src/THCUNN/VolumetricAveragePooling.cu b/aten/src/THCUNN/VolumetricAveragePooling.cu
deleted file mode 100644
index 56e1d69..0000000
--- a/aten/src/THCUNN/VolumetricAveragePooling.cu
+++ /dev/null
@@ -1,279 +0,0 @@
-#include <THCUNN/THCUNN.h>
-#include <THC/THCTensor.hpp>
-#include <THCUNN/common.h>
-#include <THC/THCDeviceTensor.cuh>
-#include <THC/THCDeviceTensorUtils.cuh>
-#include <THC/THCDeviceUtils.cuh>
-#include <TH/THHalf.h>
-#include <THCUNN/THCHalfAutoNumerics.cuh>
-#include <THC/THCAtomics.cuh>
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateOutput(
- THCDeviceTensor<Dtype, 4> input,
- THCDeviceTensor<Dtype, 4> output,
- int kT, int kH, int kW,
- int dT, int dH, int dW,
- int padT, int padH, int padW,
- bool count_include_pad, int offsetZ)
-{
- int oCol = blockIdx.x * blockDim.x + threadIdx.x;
- int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
- int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
-
- if (oRow < output.getSize(2) && oCol < output.getSize(3))
- {
- Acctype sum = 0.0;
-
- int tstart = oFrame * dT - padT;
- int hstart = oRow * dH - padH;
- int wstart = oCol * dW - padW;
- int tend = min(tstart + kT, input.getSize(1) + padT);
- int hend = min(hstart + kH, input.getSize(2) + padH);
- int wend = min(wstart + kW, input.getSize(3) + padW);
- int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
- tstart = max(tstart, 0);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- tend = min(tend, input.getSize(1));
- hend = min(hend, input.getSize(2));
- wend = min(wend, input.getSize(3));
-
- Acctype divide_factor;
- if (count_include_pad)
- divide_factor = static_cast<Acctype>(pool_size);
- else
- divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
- int ti, hi, wi;
- for (ti = tstart; ti < tend; ++ti)
- {
- for (hi = hstart; hi < hend; ++hi)
- {
- for (wi = wstart; wi < wend; ++wi)
- {
- Dtype val = input[slice][ti][hi][wi];
- sum += val;
- }
- }
- }
-
- output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor);
- }
-}
-
-// Inner-most loop size (kW) passed as template parameter for
-// performance reasons.
-//
-template<int KERNEL_WIDTH, typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateOutput_fixedKW(
- THCDeviceTensor<Dtype, 4> input,
- THCDeviceTensor<Dtype, 4> output,
- int kT, int kH,
- int dT, int dH, int dW,
- int padT, int padH, int padW,
- bool count_include_pad, int offsetZ)
-{
- int oCol = blockIdx.x * blockDim.x + threadIdx.x;
- int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
- int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
-
- if (oRow < output.getSize(2) && oCol < output.getSize(3))
- {
- Acctype sum = 0.0;
-
- int tstart = oFrame * dT - padT;
- int hstart = oRow * dH - padH;
- int wstart = oCol * dW - padW;
- int tend = min(tstart + kT, input.getSize(1) + padT);
- int hend = min(hstart + kH, input.getSize(2) + padH);
- int wend = min(wstart + KERNEL_WIDTH, input.getSize(3) + padW);
- int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
- tstart = max(tstart, 0);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- tend = min(tend, input.getSize(1));
- hend = min(hend, input.getSize(2));
- wend = min(wend, input.getSize(3));
-
- Acctype divide_factor;
- if (count_include_pad)
- divide_factor = static_cast<Acctype>(pool_size);
- else
- divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
- int ti, hi, wi;
- for (ti = tstart; ti < tend; ++ti)
- {
- for (hi = hstart; hi < hend; ++hi)
- {
- for (wi = wstart; wi < wend; ++wi)
- {
- Dtype val = input[slice][ti][hi][wi];
- sum += val;
- }
- }
- }
-
- output[slice][oFrame][oRow][oCol] = ScalarConvert<Acctype, Dtype>::to(sum / divide_factor);
- }
-}
-
-#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
- cuda_VolumetricAveragePooling_updateOutput_fixedKW<KW, scalar_t, accreal> \
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
- cudaInput, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, count_include_pad, offsetZ); \
- break
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateGradInput_Stride1(
- THCDeviceTensor<Dtype, 4> gradOutput,
- THCDeviceTensor<Dtype, 4> gradInput,
- int kT, int kH, int kW,
- Acctype normFactor, int offsetZ)
-{
- int iCol = blockIdx.x * blockDim.x + threadIdx.x;
- int iRow = blockIdx.y * blockDim.y + threadIdx.y;
- int iFrame = (blockIdx.z + offsetZ) % gradInput.getSize(1); // input frame/time
- int slice = (blockIdx.z + offsetZ) / gradInput.getSize(1); // input slice/feature
-
- // guard against over-tiled threads
- if (iRow < gradInput.getSize(2) && iCol < gradInput.getSize(3))
- {
- Acctype sum = 0.0;
- Dtype *gOut = &gradOutput[slice][max(0, iFrame - kT + 1)]
- [max(0, iRow - kH + 1)][max(0, iCol - kW + 1)];
- int frameOffset = 0;
- for (int oFrame = max(0, iFrame - kT + 1);
- oFrame < min(iFrame + 1, gradOutput.getSize(1));
- ++oFrame)
- {
- int rowOffset = frameOffset;
- for (int oRow = max(0, iRow - kH + 1);
- oRow < min(iRow + 1, gradOutput.getSize(2));
- ++oRow)
- {
- int colOffset = rowOffset;
- for (int oCol = max(0, iCol - kW + 1);
- oCol < min(iCol + 1, gradOutput.getSize(3));
- ++oCol)
- {
- sum += gOut[colOffset];
- ++colOffset;
- }
- rowOffset += gradOutput.getSize(3);
- }
- frameOffset += gradOutput.getSize(2) * gradOutput.getSize(3);
- }
- gradInput[slice][iFrame][iRow][iCol] = ScalarConvert<Acctype, Dtype>::to(sum * normFactor);
- }
-}
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateGradInput_atomicAdd(
- THCDeviceTensor<Dtype, 4> gradOutput,
- THCDeviceTensor<Dtype, 4> gradInput,
- int kT, int kH, int kW,
- int dT, int dH, int dW,
- int padT, int padH, int padW,
- bool count_include_pad, int offsetZ)
-{
- int oCol = blockIdx.x * blockDim.x + threadIdx.x;
- int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
- int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
-
- // guard against over-tiled threads
- if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
- {
- int tstart = oFrame * dT - padT;
- int hstart = oRow * dH - padH;
- int wstart = oCol * dW - padW;
- int tend = min(tstart + kT, gradInput.getSize(1) + padT);
- int hend = min(hstart + kH, gradInput.getSize(2) + padH);
- int wend = min(wstart + kW, gradInput.getSize(3) + padW);
- int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
- tstart = max(tstart, 0);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- tend = min(tend, gradInput.getSize(1));
- hend = min(hend, gradInput.getSize(2));
- wend = min(wend, gradInput.getSize(3));
-
- Acctype divide_factor;
- if (count_include_pad)
- divide_factor = static_cast<Acctype>(pool_size);
- else
- divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
- Dtype val = ScalarConvert<Acctype, Dtype>::to(
- ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
- for (int iFrame = tstart; iFrame < tend; ++iFrame)
- {
- for (int iRow = hstart; iRow < hend; ++iRow)
- {
- for (int iCol = wstart; iCol < wend; ++iCol)
- {
- atomicAdd(&gradInput[slice][iFrame][iRow][iCol], val);
- }
- }
- }
- }
-}
-
-template <typename Dtype, typename Acctype>
-__global__ void cuda_VolumetricAveragePooling_updateGradInput(
- THCDeviceTensor<Dtype, 4> gradOutput,
- THCDeviceTensor<Dtype, 4> gradInput,
- int kT, int kH, int kW,
- int dT, int dH, int dW,
- int padT, int padH, int padW,
- bool count_include_pad, int offsetZ)
-{
- int oCol = blockIdx.x * blockDim.x + threadIdx.x;
- int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
- int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
-
- // guard against over-tiled threads
- if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
- {
- int tstart = oFrame * dT - padT;
- int hstart = oRow * dH - padH;
- int wstart = oCol * dW - padW;
- int tend = min(tstart + kT, gradInput.getSize(1) + padT);
- int hend = min(hstart + kH, gradInput.getSize(2) + padH);
- int wend = min(wstart + kW, gradInput.getSize(3) + padW);
- int pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
- tstart = max(tstart, 0);
- hstart = max(hstart, 0);
- wstart = max(wstart, 0);
- tend = min(tend, gradInput.getSize(1));
- hend = min(hend, gradInput.getSize(2));
- wend = min(wend, gradInput.getSize(3));
-
- Acctype divide_factor;
- if (count_include_pad)
- divide_factor = static_cast<Acctype>(pool_size);
- else
- divide_factor = static_cast<Acctype>((tend - tstart) * (hend - hstart) * (wend - wstart));
-
- Dtype val = ScalarConvert<Acctype, Dtype>::to(
- ScalarConvert<Dtype, Acctype>::to(gradOutput[slice][oFrame][oRow][oCol]) / divide_factor);
- for (int iFrame = tstart; iFrame < tend; ++iFrame)
- {
- for (int iRow = hstart; iRow < hend; ++iRow)
- {
- for (int iCol = wstart; iCol < wend; ++iCol)
- {
- gradInput[slice][iFrame][iRow][iCol] = val;
- }
- }
- }
- }
-}
-
-#include <THCUNN/generic/VolumetricAveragePooling.cu>
-#include <THC/THCGenerateFloatTypes.h>
diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h
index 7eed4f9..a685994 100644
--- a/aten/src/THCUNN/generic/THCUNN.h
+++ b/aten/src/THCUNN/generic/THCUNN.h
@@ -867,27 +867,6 @@
bool featFirst,
accreal scale);
-THC_API void THNN_(VolumetricAveragePooling_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode,
- bool count_include_pad);
-
-THC_API void THNN_(VolumetricAveragePooling_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode,
- bool count_include_pad);
-
// VolumetricConvolution is legacy and purposefully not bound by ATen
THC_API void THNN_(VolumetricConvolution_updateOutput)(
THCState *state,
diff --git a/aten/src/THCUNN/generic/VolumetricAveragePooling.cu b/aten/src/THCUNN/generic/VolumetricAveragePooling.cu
deleted file mode 100644
index 47939ef..0000000
--- a/aten/src/THCUNN/generic/VolumetricAveragePooling.cu
+++ /dev/null
@@ -1,337 +0,0 @@
-#ifndef THC_GENERIC_FILE
-#define THC_GENERIC_FILE "THCUNN/generic/VolumetricAveragePooling.cu"
-#else
-
-#include <THCUNN/generic/pooling_shape.h>
-
-static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode)
-{
- int inputSlices;
- int inputTime;
- int inputHeight;
- int inputWidth;
-
- int ndim = input->dim();
- int dimN = 0;
- int dimt = 1;
- int dimh = 2;
- int dimw = 3;
-
- if (input->dim() == 5)
- {
- dimN++;
- dimt++;
- dimh++;
- dimw++;
- }
-
- if (!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 4)
- {
- THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
- && input->size(dimt) >= kT, 2,
- "input image (T: %d H: %d W: %d) smaller than "
- "kernel size (kT: %d kH: %d kW: %d)",
- input->size(dimt), input->size(dimh), input->size(dimw),
- kT, kH, kW);
-
- /* sizes */
- inputSlices = THCTensor_(size)(state, input, 0);
- inputTime = THCTensor_(size)(state, input, 1);
- inputHeight = THCTensor_(size)(state, input, 2);
- inputWidth = THCTensor_(size)(state, input, 3);
- }
- else if (!input->is_empty() && THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5)
- {
- THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
- && input->size(dimt) >= kT, 2,
- "input image (T: %d H: %d W: %d) smaller than "
- "kernel size (kT: %d kH: %d kW: %d)",
- input->size(dimt), input->size(dimh), input->size(dimw),
- kT, kH, kW);
-
- /* sizes */
- inputSlices = THCTensor_(size)(state, input, 1);
- inputTime = THCTensor_(size)(state, input, 2);
- inputHeight = THCTensor_(size)(state, input, 3);
- inputWidth = THCTensor_(size)(state, input, 4);
- }
- else
- {
- AT_ERROR("non-empty 4D or 5D tensor expected, but got size: ", input->sizes());
- }
-
- // The second argument is the index of padH.
- THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11,
- "pad should not be greater than half of kernel size, but got "
- "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d",
- padT, padW, padH, kT, kW, kH);
-
- int outputTime = pooling_output_shape<int>(inputTime, kT, padT, dT, 1, ceil_mode);
- int outputHeight = pooling_output_shape<int>(inputHeight, kH, padH, dH, 1, ceil_mode);
- int outputWidth = pooling_output_shape<int>(inputWidth, kW, padW, dW, 1, ceil_mode);
-
- if (gradOutput != NULL)
- {
- THCUNN_check_dim_size(state, gradOutput, ndim, dimN, inputSlices);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimt, outputTime);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimh, outputHeight);
- THCUNN_check_dim_size(state, gradOutput, ndim, dimw, outputWidth);
- }
-}
-
-void THNN_(VolumetricAveragePooling_updateOutput)(
- THCState *state,
- THCTensor *input,
- THCTensor *output,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode,
- bool count_include_pad)
-{
- int batchSize;
- int inputSlices;
- int inputTime;
- int inputHeight;
- int inputWidth;
-
- int dimt = 1;
- int dimh = 2;
- int dimw = 3;
-
- int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
- if (fiveDimensionalInput)
- {
- dimt++;
- dimh++;
- dimw++;
- }
-
- THNN_(VolumetricAveragePooling_shapeCheck)
- (state, input, NULL, kT, kW, kH, dT, dW, dH,
- padT, padW, padH, ceil_mode);
-
- if (!fiveDimensionalInput) /* 4D */
- {
- /* sizes */
- batchSize = 1;
- inputSlices = THCTensor_(size)(state, input, 0);
- inputTime = THCTensor_(size)(state, input, 1);
- inputHeight = THCTensor_(size)(state, input, 2);
- inputWidth = THCTensor_(size)(state, input, 3);
- }
- else /* 5D */
- {
- /* sizes */
- batchSize = THCTensor_(size)(state, input, 0);
- inputSlices = THCTensor_(size)(state, input, 1);
- inputTime = THCTensor_(size)(state, input, 2);
- inputHeight = THCTensor_(size)(state, input, 3);
- inputWidth = THCTensor_(size)(state, input, 4);
- }
-
- int outputTime = pooling_output_shape<int>(inputTime, kT, padT, dT, 1, ceil_mode);
- int outputHeight = pooling_output_shape<int>(inputHeight, kH, padH, dH, 1, ceil_mode);
- int outputWidth = pooling_output_shape<int>(inputWidth, kW, padW, dW, 1, ceil_mode);
-
- if (!fiveDimensionalInput) /* 4D */
- {
- /* resize output */
- THCTensor_(resize4d)(state, output, inputSlices,
- outputTime, outputHeight, outputWidth);
- }
- else /* 5D */
- {
- THCTensor_(resize5d)(state, output, batchSize, inputSlices,
- outputTime, outputHeight, outputWidth);
- }
-
- input = THCTensor_(newContiguous)(state, input);
- if (fiveDimensionalInput) {
- // Collapse batch and feature dimensions
- output = THCTensor_(newFoldBatchDim)(state, output);
-
- THCTensor *old_input = input;
- input = THCTensor_(newFoldBatchDim)(state, input);
- THCTensor_(free)(state, old_input);
- } else {
- THCTensor_(retain)(state, output);
- }
-
- THCDeviceTensor<scalar_t, 4> cudaInput;
- THCDeviceTensor<scalar_t, 4> cudaOutput;
- cudaInput = toDeviceTensor<scalar_t, 4>(state, input);
- cudaOutput = toDeviceTensor<scalar_t, 4>(state, output);
-
- int totalZ = outputTime * inputSlices * batchSize;
- int offsetZ = 0;
- dim3 block(32, 8);
- while (totalZ > 0) {
- dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
- THCCeilDiv(outputHeight, static_cast<int>(block.y)),
- totalZ > 65535 ? 65535 : totalZ);
-
- switch (kW)
- {
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
- default:
- cuda_VolumetricAveragePooling_updateOutput<scalar_t, accreal>
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- cudaInput,
- cudaOutput,
- kT, kH, kW,
- dT, dH, dW,
- padT, padH, padW,
- count_include_pad,
- offsetZ);
- break;
- }
- totalZ -= 65535;
- offsetZ += 65535;
- THCudaCheck(cudaGetLastError());
- }
-
- THCTensor_(free)(state, input);
- THCTensor_(free)(state, output);
-}
-
-void THNN_(VolumetricAveragePooling_updateGradInput)(
- THCState *state,
- THCTensor *input,
- THCTensor *gradOutput,
- THCTensor *gradInput,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode,
- bool count_include_pad)
-{
- THNN_(VolumetricAveragePooling_shapeCheck)
- (state, input, gradOutput, kT, kW, kH, dT, dW, dH,
- padT, padW, padH, ceil_mode);
- bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW);
-
- // Resize and initialize result tensor.
- THCTensor_(resizeAs)(state, gradInput, input);
- THCTensor_(zero)(state, gradInput);
-
- int batchSize;
- int inputSlices;
- int inputTime;
- int inputHeight;
- int inputWidth;
-
- int outputTime;
- int outputHeight;
- int outputWidth;
-
- int fiveDimensionalInput = THCTensor_(nDimensionLegacyNoScalars)(state, input) == 5;
- if (!fiveDimensionalInput) /* 4D */
- {
- batchSize = 1;
- inputSlices = THCTensor_(size)(state, input, 0);
- inputTime = THCTensor_(size)(state, input, 1);
- inputHeight = THCTensor_(size)(state, input, 2);
- inputWidth = THCTensor_(size)(state, input, 3);
-
- outputTime = THCTensor_(size)(state, gradOutput, 1);
- outputHeight = THCTensor_(size)(state, gradOutput, 2);
- outputWidth = THCTensor_(size)(state, gradOutput, 3);
- }
- else
- {
- batchSize = THCTensor_(size)(state, input, 0);
- inputSlices = THCTensor_(size)(state, input, 1);
- inputTime = THCTensor_(size)(state, input, 2);
- inputHeight = THCTensor_(size)(state, input, 3);
- inputWidth = THCTensor_(size)(state, input, 4);
-
- outputTime = THCTensor_(size)(state, gradOutput, 2);
- outputHeight = THCTensor_(size)(state, gradOutput, 3);
- outputWidth = THCTensor_(size)(state, gradOutput, 4);
- }
-
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- if (fiveDimensionalInput) {
- // Collapse batch and feature dimensions
- gradInput = THCTensor_(newFoldBatchDim)(state, gradInput);
-
- THCTensor *old_gradOutput = gradOutput;
- gradOutput = THCTensor_(newFoldBatchDim)(state, gradOutput);
- THCTensor_(free)(state, old_gradOutput);
- } else {
- THCTensor_(retain)(state, gradInput);
- }
-
- THCDeviceTensor<scalar_t, 4> cudaGradInput;
- THCDeviceTensor<scalar_t, 4> cudaGradOutput;
- cudaGradInput = toDeviceTensor<scalar_t, 4>(state, gradInput);
- cudaGradOutput = toDeviceTensor<scalar_t, 4>(state, gradOutput);
-
- dim3 block(32, 8);
-
- // Optimizing for stride 1 is probably only of limited value, but this
- // specialization yields 3x speedup over the atomicAdd implementation.
- // Padding must be 0, otherwise, pool size may change.
- if (dT == 1 && dH == 1 && dW == 1 && padT == 0 && padH == 0 && padW == 0)
- {
- int totalZ = inputTime * inputSlices * batchSize;
- int offsetZ = 0;
- while (totalZ > 0) {
- dim3 grid(THCCeilDiv(inputWidth, static_cast<int>(block.x)),
- THCCeilDiv(inputHeight, static_cast<int>(block.y)),
- totalZ > 65535 ? 65535 : totalZ);
- cuda_VolumetricAveragePooling_updateGradInput_Stride1<scalar_t, accreal>
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- cudaGradOutput, cudaGradInput, kT, kH, kW, 1.0f/(kT * kH * kW), offsetZ);
- THCudaCheck(cudaGetLastError());
- totalZ -= 65535;
- offsetZ += 65535;
- }
- }
- else
- {
- int totalZ = outputTime * inputSlices * batchSize;
- int offsetZ = 0;
- while (totalZ > 0) {
- dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
- THCCeilDiv(outputHeight, static_cast<int>(block.y)),
- totalZ > 65535 ? 65535 : totalZ);
- if (kernelsOverlap)
- {
- cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<scalar_t, accreal>
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW,
- padT, padH, padW, count_include_pad, offsetZ);
- }
- else
- {
- cuda_VolumetricAveragePooling_updateGradInput<scalar_t, accreal>
- <<<grid, block, 0, THCState_getCurrentStream(state)>>>(
- cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW,
- padT, padH, padW, count_include_pad, offsetZ);
- }
- THCudaCheck(cudaGetLastError());
- totalZ -= 65535;
- offsetZ += 65535;
- }
- }
-
- THCTensor_(free)(state, gradInput);
- THCTensor_(free)(state, gradOutput);
-}
-
-#endif
diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h
index 6cde3bc..ff76b71 100644
--- a/aten/src/THNN/generic/THNN.h
+++ b/aten/src/THNN/generic/THNN.h
@@ -489,24 +489,6 @@
int inputWidth, int inputHeight,
int outputWidth, int outputHeight);
-TH_API void THNN_(VolumetricAveragePooling_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode, bool count_include_pad);
-TH_API void THNN_(VolumetricAveragePooling_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- int kT, int kW, int kH,
- int dT, int dW, int dH,
- int padT, int padW, int padH,
- bool ceil_mode, bool count_include_pad);
-
TH_API void THNN_(VolumetricDilatedConvolution_updateOutput)(
THNNState *state,
THTensor *input,
diff --git a/aten/src/THNN/generic/VolumetricAveragePooling.c b/aten/src/THNN/generic/VolumetricAveragePooling.c
deleted file mode 100644
index 3671413..0000000
--- a/aten/src/THNN/generic/VolumetricAveragePooling.c
+++ /dev/null
@@ -1,465 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "THNN/generic/VolumetricAveragePooling.c"
-#else
-
-#include <THNN/generic/pooling_shape.h>
-#include <algorithm>
-
-#include <ATen/Parallel.h>
-
-static inline void THNN_(VolumetricAveragePooling_shapeCheck)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- int kT,
- int kW,
- int kH,
- int dT,
- int dW,
- int dH,
- int padT,
- int padW,
- int padH,
- bool ceil_mode)
-{
- int64_t nslices;
- int64_t itime;
- int64_t iheight;
- int64_t iwidth;
- int64_t otime;
- int64_t oheight;
- int64_t owidth;
- int ndim = input->dim();
- int dimN = 0;
- int dimt = 1;
- int dimh = 2;
- int dimw = 3;
-
- if (input->dim() == 5)
- {
- dimN++;
- dimt++;
- dimh++;
- dimw++;
- }
-
- THArgCheck(kT > 0 && kW > 0 && kH > 0, 5,
- "kernel size should be greater than zero, but got kT: %d kH: %d kW: %d",
- kT, kH, kW);
- THArgCheck(dT > 0 && dW > 0 && dH > 0, 8,
- "stride should be greater than zero, but got dT: %d dH: %d dW: %d",
- dT, dH, dW);
- THNN_ARGCHECK(!input->is_empty() && (input->dim() == 4 || input->dim() == 5), 2, input,
- "non-empty 4D or 5D (batch mode) tensor expected for input, but got: %s");
-
- THArgCheck(input->size(dimw) >= kW && input->size(dimh) >= kH
- && input->size(dimt) >= kT, 2,
- "input image (T: %d H: %d W: %d) smaller than "
- "kernel size (kT: %d kH: %d kW: %d)",
- input->size(dimt), input->size(dimh), input->size(dimw),
- kT, kH, kW);
-
- // The second argument is argNumber... here is the index of padH.
- THArgCheck(kT/2 >= padT && kW/2 >= padW && kH/2 >= padH, 11,
- "pad should not be greater than half of kernel size, but got "
- "padT = %d, padW = %d, padH = %d, kT = %d, kW = %d, kH = %d",
- padT, padW, padH, kT, kW, kH);
-
- /* sizes */
- nslices = input->size(dimN);
- itime = input->size(dimt);
- iheight = input->size(dimh);
- iwidth = input->size(dimw);
-
- otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
- oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
- owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
-
- if (otime < 1 || owidth < 1 || oheight < 1)
- THError("Given input size: (%dx%dx%dx%d). "
- "Calculated output size: (%dx%dx%dx%d). Output size is too small",
- nslices,itime,iheight,iwidth,nslices,otime,oheight,owidth);
-
- if (gradOutput != NULL) {
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimN, nslices);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimt, otime);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimh, oheight);
- THNN_CHECK_DIM_SIZE(gradOutput, ndim, dimw, owidth);
- }
-}
-
-static void THNN_(VolumetricAveragePooling_updateOutput_frame)(
- scalar_t *input_p,
- scalar_t *output_p,
- int64_t nslices,
- int64_t itime,
- int64_t iwidth,
- int64_t iheight,
- int64_t otime,
- int64_t owidth,
- int64_t oheight,
- int kT,
- int kW,
- int kH,
- int dT,
- int dW,
- int dH,
- int padT,
- int padW,
- int padH,
- bool count_include_pad)
-{
- at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
- for (auto k = start; k < end; k++)
- {
- int64_t i, j, ti;
-
- /* local pointers. */
- scalar_t *ip = input_p + k * itime * iwidth * iheight;
- scalar_t *op = output_p + k * otime * owidth * oheight;
- for (i = 0; i < otime * oheight * owidth; ++i)
- *(op + i) = 0;
-
- /* loop over output */
- for (ti = 0; ti < otime; ti++)
- {
- for (i = 0; i < oheight; i++)
- {
- for (j = 0; j < owidth; j++)
- {
- /* compute pool range. */
- int64_t tstart = ti * dT - padT;
- int64_t hstart = i * dH - padH;
- int64_t wstart = j * dW - padW;
- int64_t tend = std::min(tstart + kT, itime + padT);
- int64_t hend = std::min(hstart + kH, iheight + padH);
- int64_t wend = std::min(wstart + kW, iwidth + padW);
- int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
- tstart = std::max(tstart, (int64_t) 0);
- hstart = std::max(hstart, (int64_t) 0);
- wstart = std::max(wstart, (int64_t) 0);
- tend = std::min(tend, itime);
- hend = std::min(hend, iheight);
- wend = std::min(wend, iwidth);
-
- int divide_factor;
- if (count_include_pad)
- divide_factor = pool_size;
- else
- divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
-
- /* compute local sum: */
- scalar_t sum = 0.0;
- int64_t x, y, z;
-
- for (z = tstart; z < tend; z++)
- {
- for (y = hstart; y < hend; y++)
- {
- for (x = wstart; x < wend; x++)
- {
- sum += *(ip + z * iwidth * iheight + y * iwidth + x);
- }
- }
- }
-
- /* set output to local max */
- *op++ += sum / divide_factor;
- }
- }
- }
- }
- });
-}
-
-void THNN_(VolumetricAveragePooling_updateOutput)(
- THNNState *state,
- THTensor *input,
- THTensor *output,
- int kT,
- int kW,
- int kH,
- int dT,
- int dW,
- int dH,
- int padT,
- int padW,
- int padH,
- bool ceil_mode,
- bool count_include_pad)
-{
- int64_t nslices;
- int64_t itime;
- int64_t iheight;
- int64_t iwidth;
- int64_t otime;
- int64_t oheight;
- int64_t owidth;
- scalar_t *input_data;
- scalar_t *output_data;
-
- THNN_(VolumetricAveragePooling_shapeCheck)(
- state, input, NULL, kT, kW, kH,
- dT, dW, dH, padT, padW, padH, ceil_mode);
-
- int dimN = 0;
- int dimt = 1;
- int dimh = 2;
- int dimw = 3;
-
- if (input->dim() == 5)
- {
- dimN++;
- dimt++;
- dimh++;
- dimw++;
- }
-
- /* sizes */
- nslices = input->size(dimN);
- itime = input->size(dimt);
- iheight = input->size(dimh);
- iwidth = input->size(dimw);
- otime = pooling_output_shape<int64_t>(itime, kT, padT, dT, 1, ceil_mode);
- oheight = pooling_output_shape<int64_t>(iheight, kH, padH, dH, 1, ceil_mode);
- owidth = pooling_output_shape<int64_t>(iwidth, kW, padW, dW, 1, ceil_mode);
-
- /* get contiguous input */
- input = THTensor_(newContiguous)(input);
-
- if (input->dim() == 4) /* non-batch mode */
- {
- /* resize output */
- THTensor_(resize4d)(output, nslices, otime, oheight, owidth);
-
- input_data = input->data<scalar_t>();
- output_data = output->data<scalar_t>();
-
- THNN_(VolumetricAveragePooling_updateOutput_frame)(
- input_data, output_data, nslices,
- itime, iwidth, iheight,
- otime, owidth, oheight,
- kT, kW, kH,
- dT, dW, dH,
- padT, padW, padH,
- count_include_pad
- );
- }
- else /* batch mode */
- {
- int64_t nBatch = input->size(0);
-
- int64_t istride = nslices * itime * iwidth * iheight;
- int64_t ostride = nslices * otime * owidth * oheight;
-
- /* resize output */
- THTensor_(resize5d)(output, nBatch, nslices, otime, oheight, owidth);
-
- input_data = input->data<scalar_t>();
- output_data = output->data<scalar_t>();
-
- at::parallel_for(0, nBatch, 0, [&](int64_t start, int64_t end) {
- for (auto p = start; p < end; p++)
- {
- THNN_(VolumetricAveragePooling_updateOutput_frame)(
- input_data + p * istride, output_data + p * ostride, nslices,
- itime, iwidth, iheight,
- otime, owidth, oheight,
- kT, kW, kH,
- dT, dW, dH,
- padT, padW, padH,
- count_include_pad
- );
- }
- });
- }
-
- /* cleanup */
- c10::raw::intrusive_ptr::decref(input);
-}
-
-static void THNN_(VolumetricAveragePooling_updateGradInput_frame)(
- scalar_t *gradInput_p,
- scalar_t *gradOutput_p,
- int64_t nslices,
- int64_t itime,
- int64_t iwidth,
- int64_t iheight,
- int64_t otime,
- int64_t owidth,
- int64_t oheight,
- int kT,
- int kW,
- int kH,
- int dT,
- int dW,
- int dH,
- int padT,
- int padW,
- int padH,
- bool count_include_pad)
-{
- at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
- for (auto k = start; k < end; k++)
- {
- int64_t i, j, ti;
-
- /* local pointers */
- scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
- scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
- for (i = 0; i < itime*iwidth*iheight; i++)
- *(ip + i) = 0;
-
- /* loop over output */
- for (ti = 0; ti < otime; ti++)
- {
- for (i = 0; i < oheight; i++)
- {
- for (j = 0; j < owidth; j++)
- {
- int64_t tstart = ti * dT - padT;
- int64_t hstart = i * dH - padH;
- int64_t wstart = j * dW - padW;
- int64_t tend = std::min(tstart + kT, itime + padT);
- int64_t hend = std::min(hstart + kH, iheight + padH);
- int64_t wend = std::min(wstart + kW, iwidth + padW);
- int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
- tstart = std::max(tstart, (int64_t) 0);
- hstart = std::max(hstart, (int64_t) 0);
- wstart = std::max(wstart, (int64_t) 0);
- tend = std::min(tend, itime);
- hend = std::min(hend, iheight);
- wend = std::min(wend, iwidth);
-
- int64_t divide_factor;
- if (count_include_pad)
- divide_factor = pool_size;
- else
- divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
-
- /* scatter gradients out to footprint: */
- scalar_t val = *op++;
-
- int64_t x,y,z;
- for (z = tstart; z < tend; z++)
- {
- for (y = hstart; y < hend; y++)
- {
- for (x = wstart; x < wend; x++)
- {
- *(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
- }
- }
- }
- }
- }
- }
- }
- });
-}
-
-void THNN_(VolumetricAveragePooling_updateGradInput)(
- THNNState *state,
- THTensor *input,
- THTensor *gradOutput,
- THTensor *gradInput,
- int kT,
- int kW,
- int kH,
- int dT,
- int dW,
- int dH,
- int padT,
- int padW,
- int padH,
- bool ceil_mode,
- bool count_include_pad)
-{
- int64_t nslices;
- int64_t itime;
- int64_t iheight;
- int64_t iwidth;
- int64_t otime;
- int64_t oheight;
- int64_t owidth;
- scalar_t *gradInput_data;
- scalar_t *gradOutput_data;
-
- int dimN = 0;
- int dimt = 1;
- int dimh = 2;
- int dimw = 3;
-
- THNN_(VolumetricAveragePooling_shapeCheck)(
- state, input, gradOutput, kT, kW, kH,
- dT, dW, dH, padT, padW, padH, ceil_mode);
-
- /* get contiguous gradOutput */
- gradOutput = THTensor_(newContiguous)(gradOutput);
-
- /* resize */
- THTensor_(resizeAs)(gradInput, input);
- THTensor_(zero)(gradInput);
-
- if (input->dim() == 5)
- {
- dimN++;
- dimt++;
- dimh++;
- dimw++;
- }
-
- /* sizes */
- nslices = input->size(dimN);
- itime = input->size(dimt);
- iheight = input->size(dimh);
- iwidth = input->size(dimw);
- otime = gradOutput->size(dimt);
- oheight = gradOutput->size(dimh);
- owidth = gradOutput->size(dimw);
-
- /* get raw pointers */
- gradInput_data = gradInput->data<scalar_t>();
- gradOutput_data = gradOutput->data<scalar_t>();
-
- /* backprop */
- if (input->dim() == 4) /* non-batch mode*/
- {
- THNN_(VolumetricAveragePooling_updateGradInput_frame)(
- gradInput_data, gradOutput_data, nslices,
- itime, iwidth, iheight,
- otime, owidth, oheight,
- kT, kW, kH,
- dT, dW, dH,
- padT, padW, padH,
- count_include_pad
- );
- }
- else /* batch mode */
- {
- int64_t nBatch = input->size(0);
-
- int64_t istride = nslices * itime * iwidth * iheight;
- int64_t ostride = nslices * otime * owidth * oheight;
-
- at::parallel_for(0, nBatch, 0, [&](int64_t start, int64_t end) {
- for (auto p = start; p < end; p++)
- {
- THNN_(VolumetricAveragePooling_updateGradInput_frame)(
- gradInput_data + p * istride, gradOutput_data + p * ostride, nslices,
- itime, iwidth, iheight,
- otime, owidth, oheight,
- kT, kW, kH,
- dT, dW, dH,
- padT, padW, padH,
- count_include_pad
- );
- }
- });
- }
-
- /* cleanup */
- c10::raw::intrusive_ptr::decref(gradOutput);
-}
-
-#endif
diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp
index 679a4ec..cf86038 100644
--- a/aten/src/THNN/init.cpp
+++ b/aten/src/THNN/init.cpp
@@ -136,9 +136,6 @@
#include <THNN/generic/SpatialDilatedConvolution.c>
#include <TH/THGenerateFloatTypes.h>
-#include <THNN/generic/VolumetricAveragePooling.c>
-#include <TH/THGenerateFloatTypes.h>
-
#include <THNN/generic/VolumetricConvolutionMM.c>
#include <TH/THGenerateFloatTypes.h>
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index 1b53acf..d3dff4c 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -277,7 +277,6 @@
'SpatialConvolutionMM',
'TemporalConvolution',
'SpatialMaxUnpooling',
- 'VolumetricAveragePooling',
'VolumetricMaxUnpooling',
'VolumetricConvolution',
'VolumetricFullConvolution',