| #pragma once |
| |
| #include <ATen/ATen.h> |
| #include <ATen/Parallel.h> |
| #include <ATen/native/DispatchStub.h> |
| |
| namespace at { |
| namespace native { |
| |
| // TODO(Heitor) Template by dimension |
| struct PoolingParams1D { |
| int64_t NB; // Number of batches |
| int64_t NC; // Number of channels |
| int64_t IW; // Input width |
| int64_t OW; // Output width |
| int64_t KW; // Kernel width |
| int64_t SJ; // Column stride |
| int64_t PJ; // Column padding |
| int64_t DJ; // Column dilation |
| |
| // Return index of input element for the given kernel and output index |
| inline int64_t index(int64_t kj, int64_t oj) const { |
| return oj * SJ + kj * DJ - PJ; |
| } |
| |
| // Return index of first output within bounds for this kernel index |
| inline int64_t valid_output_start(int64_t kj) const { |
| int64_t ij = index(kj, 0);; |
| return ij < 0 ? at::divup(-ij, SJ) : 0; |
| } |
| |
| // Return index one past last output within bounds for this kernel index |
| inline int64_t valid_output_end(int64_t kj) const { |
| int64_t ij = index(kj, OW - 1); |
| return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW; |
| } |
| }; |
| |
| using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&); |
| |
| DECLARE_DISPATCH(pooling_fn, max_pool1d_stub); |
| |
| } // namespace native |
| } // namespace at |