blob: 06b48271a8297db22bf03871897eab441bfcc736 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/Pool.h>
namespace at {
namespace native {
namespace {
void max_pool2d_with_indices_out_cpu_template(
Tensor& output,
Tensor& indices,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
// NB: stride default is not expressible as an integer constant, so we accept
// empty stride for this case
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
"max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"max_pool2d: padding must be either be a single int, or a tuple of two ints");
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
"max_pool2d: dilation must be either a single int, or a tuple of two ints");
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
TORCH_CHECK(input.dtype() == output.dtype(),
"expected dtype ", input.dtype(), " for `output` but got dtype ", output.dtype());
/* sizes */
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
pool2d_shape_check(
input,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight, outputWidth, input.suggest_memory_format());
/* resize output and indices */
if (input.ndimension() == 3) {
output.resize_({nInputPlane, outputHeight, outputWidth});
/* indices will contain the locations for each output point */
indices.resize_({nInputPlane, outputHeight, outputWidth});
} else {
output.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, input.suggest_memory_format());
/* indices will contain the locations for each output point */
indices.resize_({nbatch, nInputPlane, outputHeight, outputWidth}, input.suggest_memory_format());
}
max_pool2d_kernel(
kCPU, output, indices, input,
kW, kH,
dW, dH,
padW, padH,
dilationW, dilationH);
}
Tensor& max_pool2d_with_indices_backward_out_cpu_template(
Tensor& gradInput,
const Tensor& gradOutput,
const Tensor& input,
const Tensor& indices,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
// #20866, #22032: Guarantee this for the official C++ API?
TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2,
"max_pool2d: kernel_size must either be a single int, or a tuple of two ints")
const int kH = safe_downcast<int, int64_t>(kernel_size[0]);
const int kW = kernel_size.size() == 1 ? kH : safe_downcast<int, int64_t>(kernel_size[1]);
// NB: stride default is not expressible as an integer constant, so we accept
// empty stride for this case
TORCH_CHECK(stride.size() == 0 || stride.size() == 1 || stride.size() == 2,
"max_pool2d: stride must either be omitted, a single int, or a tuple of two ints")
const int dH = stride.empty() ? kH : safe_downcast<int, int64_t>(stride[0]);
const int dW = stride.empty() ? kW :
stride.size() == 1 ? dH : safe_downcast<int, int64_t>(stride[1]);
TORCH_CHECK(padding.size() == 1 || padding.size() == 2,
"max_pool2d: padding must be either be a single int, or a tuple of two ints");
const int padH = safe_downcast<int, int64_t>(padding[0]);
const int padW = padding.size() == 1 ? padH : safe_downcast<int, int64_t>(padding[1]);
TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2,
"max_pool2d: dilation must be either a single int, or a tuple of two ints");
const int dilationH = safe_downcast<int, int64_t>(dilation[0]);
const int dilationW = dilation.size() == 1 ? dilationH : safe_downcast<int, int64_t>(dilation[1]);
TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input");
TORCH_CHECK(input.dtype() == gradOutput.dtype(),
"expected dtype ", input.dtype(), " for `gradOutput` but got dtype ", gradOutput.dtype());
TORCH_CHECK(input.dtype() == gradInput.dtype(),
"expected dtype ", input.dtype(), " for `gradInput` but got dtype ", gradInput.dtype());
/* resize */
gradInput.resize_(input.sizes(), input.suggest_memory_format());
gradInput.zero_();
/* sizes */
const int64_t nbatch = input.ndimension() == 4 ? input.size(-4) : 1;
const int64_t nInputPlane = input.size(-3);
const int64_t inputHeight = input.size(-2);
const int64_t inputWidth = input.size(-1);
const int64_t outputHeight = gradOutput.size(-2);
const int64_t outputWidth = gradOutput.size(-1);
/* XXX preserve the existing shape check behavior */
const int64_t outputHeight_for_shape_check = pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, dilationH, ceil_mode);
const int64_t outputWidth_for_shape_check = pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, dilationW, ceil_mode);
max_pool2d_backward_shape_check(
input,
gradOutput,
indices,
nbatch,
kH, kW, dH, dW, padH, padW, dilationH, dilationW,
nInputPlane,
inputHeight, inputWidth,
outputHeight_for_shape_check, outputWidth_for_shape_check,
input.suggest_memory_format());
max_pool2d_backward_kernel(kCPU, gradInput, gradOutput, indices);
return gradInput;
}
} // namespace
std::tuple<Tensor&, Tensor&> max_pool2d_with_indices_out_cpu(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
Tensor& output,
Tensor& indices)
{
max_pool2d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor, Tensor> max_pool2d_with_indices_cpu(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode)
{
NoNamesGuard guard;
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
max_pool2d_with_indices_out_cpu_template(
output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
guard.reset();
namedinference::propagate_names(output, input);
namedinference::propagate_names(indices, input);
return std::tuple<Tensor, Tensor>(output, indices);
}
Tensor& max_pool2d_with_indices_backward_out_cpu(
const Tensor& gradOutput,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices,
Tensor& gradInput)
{
max_pool2d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
Tensor max_pool2d_with_indices_backward_cpu(
const Tensor& gradOutput,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode,
const Tensor& indices)
{
auto gradInput = at::empty({0}, input.options());
max_pool2d_with_indices_backward_out_cpu_template(
gradInput,
gradOutput,
input,
indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode);
return gradInput;
}
DEFINE_DISPATCH(max_pool2d_kernel);
DEFINE_DISPATCH(max_pool2d_backward_kernel);
} // at::native
} // at