blob: 0526f5b2b8e43e0e3808f6958b690ce582f51bd7 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/xnnpack/Engine.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <tuple>
namespace at { namespace native {
static void check1d(
const char* function_name,
const char* argument_name,
IntArrayRef x) {
TORCH_CHECK(
x.size() == 1,
function_name, "() argument '", argument_name,
"' should contain one int (got ", x.size(), ")");
}
Tensor adaptive_avg_pool1d(const Tensor & self, IntArrayRef output_size) {
checkDimRange("adaptive_avg_pool1d", TensorArg(self, "self", 1), 2, 4 /* exclusive */);
check1d("adaptive_avg_pool1d", "output_size", output_size);
auto output = at::adaptive_avg_pool2d(
self.unsqueeze(-2),
{1, output_size[0]});
return output.squeeze(-2);
}
std::tuple<Tensor,Tensor> adaptive_max_pool1d(const Tensor & self, IntArrayRef output_size) {
checkDimRange("adaptive_max_pool1d", TensorArg(self, "self", 1), 2, 4 /* exclusive */);
check1d("adaptive_max_pool1d", "output_size", output_size);
Tensor output, indices;
std::tie(output, indices) = at::adaptive_max_pool2d(
self.unsqueeze(-2),
{1, output_size[0]});
return std::make_tuple(output.squeeze(-2), indices.squeeze(-2));
}
std::tuple<Tensor, Tensor> max_pool1d_with_indices(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (stride.empty()) {
stride = kernel_size;
}
checkDimRange("max_pool1d", TensorArg(self, "self", 1), 2, 4 /* exclusive */);
check1d("max_pool1d", "kernel_size", kernel_size);
check1d("max_pool1d", "stride", stride);
check1d("max_pool1d", "padding", padding);
check1d("max_pool1d", "dilation", dilation);
NoNamesGuard guard;
Tensor output, indices;
std::tie(output, indices) = at::max_pool2d_with_indices(
self.unsqueeze(-2),
{1, kernel_size[0]},
{1, stride[0]},
{0, padding[0]},
{1, dilation[0]},
ceil_mode);
output = output.squeeze(-2);
indices = indices.squeeze(-2);
guard.reset();
namedinference::propagate_names(output, self);
namedinference::propagate_names(indices, self);
return std::make_tuple(output, indices);
}
Tensor avg_pool1d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad) {
if (stride.empty()) {
stride = kernel_size;
}
checkDimRange("avg_pool1d", TensorArg(self, "self", 1), 2, 4 /* exclusive */);
check1d("avg_pool1d", "kernel_size", kernel_size);
check1d("avg_pool1d", "stride", stride);
check1d("avg_pool1d", "padding", padding);
auto output = at::avg_pool2d(
self.unsqueeze(-2),
{1, kernel_size[0]},
{1, stride[0]},
{0, padding[0]},
ceil_mode,
count_include_pad);
return output.squeeze(-2);
}
Tensor max_pool2d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (self.is_quantized()) {
return at::quantized_max_pool2d(self, kernel_size, stride, padding,
dilation, ceil_mode);
}
if (self.is_mkldnn()) {
return at::mkldnn_max_pool2d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
#if defined(C10_MOBILE)
if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride,
dilation, ceil_mode)) {
return xnnpack::max_pool2d(
self, kernel_size, padding, stride, dilation, ceil_mode);
}
#endif
auto output_and_indices = at::max_pool2d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
return std::get<0>(output_and_indices);
}
Tensor max_pool3d(
const Tensor& self,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
IntArrayRef dilation,
bool ceil_mode) {
if (self.is_mkldnn()) {
return at::mkldnn_max_pool3d(
self, kernel_size, stride, padding, dilation, ceil_mode);
}
auto output_and_indices = at::max_pool3d_with_indices(
self, kernel_size, stride, padding, dilation, ceil_mode);
return std::get<0>(output_and_indices);
}
} // namespace native
} // namespace at