blob: 2659e1ec7d170103e1c3caff7fa5dd10d28c5feb [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <c10/util/Exception.h>
#include <tuple>
namespace at { namespace native {
static void check1d(
const char* function_name,
const char* argument_name,
IntArrayRef x) {
TORCH_CHECK(
x.size() == 1,
function_name, "() argument '", argument_name,
"' should contain one int (got ", x.size(), ")");
}
Tensor adaptive_avg_pool1d(const Tensor & self, IntArrayRef output_size) {
checkDim("adaptive_avg_pool1d", TensorArg(self, "self", 1), 3);
check1d("adaptive_avg_pool1d", "output_size", output_size);
auto output = at::adaptive_avg_pool2d(
self.unsqueeze(2),
{1, output_size[0]});
return output.squeeze(2);
}
std::tuple<Tensor,Tensor> adaptive_max_pool1d(const Tensor & self, IntArrayRef output_size) {
checkDim("adaptive_max_pool1d", TensorArg(self, "self", 1), 3);
check1d("adaptive_max_pool1d", "output_size", output_size);
Tensor output, indices;
std::tie(output, indices) = at::adaptive_max_pool2d(
self.unsqueeze(2),
{1, output_size[0]});
return std::make_tuple(output.squeeze(2), indices.squeeze(2));
}
std::tuple<Tensor, Tensor> max_pool1d_with_indices(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (stride.empty()) {
stride = kernel_size;
}
checkDim("max_pool1d", TensorArg(self, "self", 1), 3);
check1d("max_pool1d", "kernel_size", kernel_size);
check1d("max_pool1d", "stride", stride);
check1d("max_pool1d", "padding", padding);
check1d("max_pool1d", "dilation", dilation);
Tensor output, indices;
std::tie(output, indices) = at::max_pool2d_with_indices(
self.unsqueeze(2),
{1, kernel_size[0]},
{1, stride[0]},
{0, padding[0]},
{1, dilation[0]},
ceil_mode);
return std::make_tuple(output.squeeze(2), indices.squeeze(2));
}
Tensor avg_pool1d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad) {
if (stride.empty()) {
stride = kernel_size;
}
checkDim("avg_pool1d", TensorArg(self, "self", 1), 3);
check1d("avg_pool1d", "kernel_size", kernel_size);
check1d("avg_pool1d", "stride", stride);
check1d("avg_pool1d", "padding", padding);
auto output = at::avg_pool2d(
self.unsqueeze(2),
{1, kernel_size[0]},
{1, stride[0]},
{0, padding[0]},
ceil_mode,
count_include_pad);
return output.squeeze(2);
}
Tensor max_pool1d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
auto output_and_indices = at::max_pool1d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
return std::get<0>(output_and_indices);
}
Tensor max_pool2d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (self.is_quantized()) {
return at::quantized_max_pool2d(self, kernel_size, stride, padding,
dilation, ceil_mode);
}
if (self.is_mkldnn()) {
return at::mkldnn_max_pool2d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
auto output_and_indices = at::max_pool2d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
return std::get<0>(output_and_indices);
}
Tensor max_pool3d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
auto output_and_indices = at::max_pool3d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
return std::get<0>(output_and_indices);
}
} // namespace native
} // namespace at