| #include <ATen/native/mkldnn/Utils.h> |
| #include <ATen/native/Pool.h> |
| |
| namespace at { namespace native { |
| |
| std::vector<int64_t> pool_output_sizes( |
| IntArrayRef input_size, |
| IntArrayRef kernel_size, |
| IntArrayRef stride, |
| IntArrayRef padding_l, |
| IntArrayRef padding_r, |
| IntArrayRef dilation, |
| bool ceil_mode) { |
| std::vector<int64_t> output_size(input_size.size()); |
| // copy N and C |
| output_size[0] = input_size[0]; |
| output_size[1] = input_size[1]; |
| |
| for (size_t i = 2; i < input_size.size(); ++i) { |
| output_size[i] = pooling_output_shape_pad_lr<int64_t>( |
| input_size[i], |
| kernel_size[i - 2], |
| padding_l[i - 2], |
| padding_r[i - 2], |
| stride[i - 2], |
| dilation[i - 2], |
| ceil_mode |
| ); |
| } |
| |
| return output_size; |
| } |
| |
| }} |