blob: 4df02b7c1def52ff26e32721b874a46baa6dcd52 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/Pool.h>
#include <tuple>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
static void max_pool2d_with_indices_single_out_frame(
scalar_t *input_p,
scalar_t *output_p,
int64_t *ind_p,
int64_t nslices,
int64_t iwidth,
int64_t iheight,
int64_t owidth,
int64_t oheight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH
)
{
at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++)
{
/* loop over output */
int64_t i, j;
scalar_t *ip = input_p + k*iwidth*iheight;
for(i = 0; i < oheight; i++)
{
for(j = 0; j < owidth; j++)
{
int64_t hstart = i * dH - padH;
int64_t wstart = j * dW - padW;
int64_t hend = std::min(hstart + (kH - 1) * dilationH + 1, iheight);
int64_t wend = std::min(wstart + (kW - 1) * dilationW + 1, iwidth);
while(hstart < 0)
hstart += dilationH;
while(wstart < 0)
wstart += dilationW;
/* local pointers */
scalar_t *op = output_p + k*owidth*oheight + i*owidth + j;
int64_t *indp = ind_p + k*owidth*oheight + i*owidth + j;
/* compute local max: */
int64_t maxindex = hstart*iwidth + wstart;
scalar_t maxval = -std::numeric_limits<scalar_t>::infinity();
for(int64_t y = hstart; y < hend; y += dilationH)
{
for(int64_t x = wstart; x < wend; x += dilationW)
{
int64_t tcntr = y*iwidth + x;
scalar_t val = *(ip + tcntr);
if ((val > maxval) || std::isnan(val))
{
maxval = val;
maxindex = tcntr;
}
}
}
/* set output to local max */
*op = maxval;
/* store location of max */
*indp = maxindex;
}
}
}
});
}
template <typename scalar_t>
static void max_pool2d_with_indices_out_frame(
scalar_t *input_data,
scalar_t *output_data,
int64_t *indices_data,
int64_t nbatch,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int kW,
int kH,
int dW,
int dH,
int padW,
int padH,
int dilationW,
int dilationH)
{
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (auto p = start; p < end; p++) {
max_pool2d_with_indices_single_out_frame(
input_data+p*nInputPlane*inputWidth*inputHeight,
output_data+p*nInputPlane*outputWidth*outputHeight,
indices_data+p*nInputPlane*outputWidth*outputHeight,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH);
}
});
}
void max_pool2d_with_indices_out_cpu_template(
Tensor& output,
Tensor& indices,
const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
// NB: stride default is not expressible as an integer constant, so we accept
// empty stride for this case
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
"max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"max_pool2d: padding must be either be a single int, or a tuple of two ints");
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
"max_pool2d: dilation must be either a single int, or a tuple of two ints");
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
TORCH_CHECK((input_.ndimension() == 3 || input_.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
/* sizes */
const int64_t nbatch = input_.ndimension() == 4 ? input_.size(-4) : 1;
const int64_t nInputPlane = input_.size(-3);
const int64_t inputHeight = input_.size(-2);
const int64_t inputWidth = input_.size(-1);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
pool2d_shape_check(
input_,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth);
/* get contiguous input */
Tensor input = input_.contiguous();
/* resize output */
if (input.ndimension() == 3)
{
output.resize_({nInputPlane, outputHeight, outputWidth});
/* indices will contain the locations for each output point */
indices.resize_({nInputPlane, outputHeight, outputWidth});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_cpu",
[&] {
/* get raw pointers */
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
max_pool2d_with_indices_single_out_frame(
input_data, output_data,
indices_data,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH);
}
);
}
else
{
output.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
/* indices will contain the locations for each output point */
indices.resize_({nbatch, nInputPlane, outputHeight, outputWidth});
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_cpu",
[&] {
scalar_t *input_data = input.data_ptr<scalar_t>();
scalar_t *output_data = output.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
max_pool2d_with_indices_out_frame(
input_data,
output_data,
indices_data,
nbatch,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
kW, kH, dW, dH,
padW, padH,
dilationW, dilationH); }
);
}
}
template <typename scalar_t>
static void max_pool2d_with_indices_backward_single_out_frame(
scalar_t *gradInput_p,
scalar_t *gradOutput_p,
int64_t *ind_p,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int dW,
int dH)
{
at::parallel_for(0, nInputPlane, 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++)
{
scalar_t *gradInput_p_k = gradInput_p + k*inputWidth*inputHeight;
scalar_t *gradOutput_p_k = gradOutput_p + k*outputWidth*outputHeight;
int64_t *ind_p_k = ind_p + k*outputWidth*outputHeight;
/* calculate max points */
int64_t i, j;
for(i = 0; i < outputHeight; i++)
{
for(j = 0; j < outputWidth; j++)
{
/* retrieve position of max */
int64_t maxp = ind_p_k[i*outputWidth + j];
if (maxp != -1) {
/* update gradient */
gradInput_p_k[maxp] += gradOutput_p_k[i*outputWidth + j];
}
}
}
}
});
}
template <typename scalar_t>
static void max_pool2d_with_indices_backward_out_frame(
scalar_t *gradInput_data,
scalar_t *gradOutput_data,
int64_t *indices_data,
int64_t nbatch,
int64_t nInputPlane,
int64_t inputWidth,
int64_t inputHeight,
int64_t outputWidth,
int64_t outputHeight,
int dW,
int dH)
{
at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
for (auto p = start; p < end; p++) {
max_pool2d_with_indices_backward_single_out_frame<scalar_t>(
gradInput_data+p*nInputPlane*inputWidth*inputHeight,
gradOutput_data+p*nInputPlane*outputWidth*outputHeight,
indices_data+p*nInputPlane*outputWidth*outputHeight,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
});
}
Tensor& max_pool2d_with_indices_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input,
const Tensor& indices,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
// NB: stride default is not expressible as an integer constant, so we accept
// empty stride for this case
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
"max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"max_pool2d: padding must be either be a single int, or a tuple of two ints");
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
"max_pool2d: dilation must be either a single int, or a tuple of two ints");
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
/* get contiguous gradOutput */
const Tensor gradOutput = gradOutput_.contiguous();
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
/* sizes */
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = gradOutput.size(-2);
const int64_t outputWidth = gradOutput.size(-1);
/* XXX preserve the existing shape check behavior */
const int64_t outputHeight_for_shape_check = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth_for_shape_check = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
max_pool2d_backward_shape_check(
input,
gradOutput_,
indices,
nbatch,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight_for_shape_check, outputWidth_for_shape_check);
/* backprop */
if (input.ndimension() == 3)
{
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_backward",
[&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
max_pool2d_with_indices_backward_single_out_frame(
gradInput_data, gradOutput_data,
indices_data,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
);
}
else
{
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"max_pool2d_with_indices_backward",
[&] {
/* get raw pointers */
scalar_t *gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t *gradOutput_data = gradOutput.data_ptr<scalar_t>();
int64_t *indices_data = indices.data_ptr<int64_t>();
max_pool2d_with_indices_backward_out_frame<scalar_t>(
gradInput_data, gradOutput_data,
indices_data,
nbatch,
nInputPlane,
inputWidth, inputHeight,
outputWidth, outputHeight,
dW, dH);
}
);
}
return gradInput;
}
} // namespace
std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out_cpu(
Tensor& output,
Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
max_pool2d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor, Tensor> max_pool2d_with_indices_cpu(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
max_pool2d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor, Tensor>(output, indices);
}
Tensor& max_pool2d_with_indices_backward_out_cpu(
Tensor& gradInput,
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
max_pool2d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
Tensor max_pool2d_with_indices_backward_cpu(
const Tensor& gradOutput_,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
auto gradInput = at::zeros_like(input);
max_pool2d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput_,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
} // at::native
} // at