| #ifndef CAFFE2_OPERATORS_POOL_OP_H_ |
| #define CAFFE2_OPERATORS_POOL_OP_H_ |
| |
| #include <vector> |
| |
| #include "caffe2/core/common_omp.h" |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/operators/conv_pool_op_base.h" |
| |
| namespace caffe2 { |
| |
| template <typename T, class Context, class Functor> |
| class PoolOp final : public ConvPoolOpBase<Context> { |
| public: |
| USE_CONV_POOL_BASE_FUNCTIONS(Context); |
| |
| template <class... Args> |
| explicit PoolOp(Args&&... args) |
| : ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) { |
| const int kernel_size = kernel_.size(); |
| for (int i = 0; i < kernel_size; ++i) { |
| CAFFE_ENFORCE_EQ( |
| dilation_[i], 1, "Pooling op does not support dilation right now."); |
| } |
| if (!global_pooling_) { |
| for (int i = 0; i < kernel_size; ++i) { |
| CAFFE_ENFORCE( |
| pads_[i] < kernel_[i] && pads_[i + kernel_size] < kernel_[i], |
| "Pad should be smaller than kernel."); |
| } |
| } |
| } |
| |
| ~PoolOp() = default; |
| |
| bool RunOnDeviceWithOrderNCHW() override { |
| const auto& X = Input(0); |
| auto* Y = Output(0); |
| const int N = X.dim32(0); |
| const int C = X.dim32(1); |
| ConvPoolOpBase<Context>::SetOutputSize(X, Y, C); |
| const T* X_data = X.template data<T>(); |
| T* Y_data = Y->template mutable_data<T>(); |
| if (N == 0) { |
| return true; |
| } |
| if (global_pooling_) { |
| const int HxW = X.numel() / (N * C); |
| return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>( |
| N, C, HxW, X_data, Y_data, &context_); |
| } |
| const std::vector<int> X_HW_dims = GetDims(X); |
| const std::vector<int> Y_HW_dims = GetDims(*Y); |
| return functor_.template Forward<T, StorageOrder::NCHW>( |
| N, |
| C, |
| X_HW_dims, |
| Y_HW_dims, |
| kernel_, |
| dilation_, |
| stride_, |
| pads_, |
| X.template data<T>(), |
| Y->template mutable_data<T>(), |
| &context_); |
| } |
| |
| bool RunOnDeviceWithOrderNHWC() override { |
| const auto& X = Input(0); |
| auto* Y = Output(0); |
| const int ndim = X.dim(); |
| const int N = X.dim32(0); |
| const int C = X.dim32(ndim - 1); |
| ConvPoolOpBase<Context>::SetOutputSize(X, Y, C); |
| const T* X_data = X.template data<T>(); |
| T* Y_data = Y->template mutable_data<T>(); |
| if (N == 0) { |
| return true; |
| } |
| if (global_pooling_) { |
| const int HxW = X.numel() / (N * C); |
| return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>( |
| N, C, HxW, X_data, Y_data, &context_); |
| } |
| const std::vector<int> X_HW_dims = GetDims(X); |
| const std::vector<int> Y_HW_dims = GetDims(*Y); |
| return functor_.template Forward<T, StorageOrder::NHWC>( |
| N, |
| C, |
| X_HW_dims, |
| Y_HW_dims, |
| kernel_, |
| dilation_, |
| stride_, |
| pads_, |
| X.template data<T>(), |
| Y->template mutable_data<T>(), |
| &context_); |
| } |
| |
| private: |
| const Functor functor_; |
| }; |
| |
| template <typename T, class Context, class Functor> |
| class PoolGradientOp final : public ConvPoolOpBase<Context> { |
| public: |
| USE_CONV_POOL_BASE_FUNCTIONS(Context); |
| template <class... Args> |
| explicit PoolGradientOp(Args&&... args) |
| : ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {} |
| |
| ~PoolGradientOp() = default; |
| |
| bool RunOnDeviceWithOrderNCHW() override { |
| const auto& X = Input(0); |
| const auto& Y = Input(1); |
| const auto& dY = Input(2); |
| auto* dX = Output(0, X.sizes(), at::dtype<T>()); |
| const int N = X.dim32(0); |
| const int C = X.dim32(1); |
| const std::vector<int> X_HW_dims = GetDims(X); |
| const std::vector<int> Y_HW_dims = GetDims(Y); |
| ConvPoolOpBase<Context>::ComputePads(X_HW_dims); |
| const T* dY_data = dY.template data<T>(); |
| const T* X_data = X.template data<T>(); |
| const T* Y_data = Y.template data<T>(); |
| T* dX_data = dX->template mutable_data<T>(); |
| if (N == 0) { |
| return true; |
| } |
| if (global_pooling_) { |
| const int HxW = X.numel() / (N * C); |
| return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>( |
| N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_); |
| } |
| return functor_.template Backward<T, StorageOrder::NCHW>( |
| N, |
| C, |
| X_HW_dims, |
| Y_HW_dims, |
| kernel_, |
| dilation_, |
| stride_, |
| pads_, |
| dY_data, |
| X_data, |
| Y_data, |
| dX_data, |
| &context_); |
| } |
| |
| bool RunOnDeviceWithOrderNHWC() override { |
| const auto& X = Input(0); |
| const auto& Y = Input(1); |
| const auto& dY = Input(2); |
| auto* dX = Output(0, X.sizes(), at::dtype<T>()); |
| const int ndim = X.dim(); |
| const int N = X.dim32(0); |
| const int C = X.dim32(ndim - 1); |
| const std::vector<int> X_HW_dims = GetDims(X); |
| const std::vector<int> Y_HW_dims = GetDims(Y); |
| ConvPoolOpBase<Context>::ComputePads(X_HW_dims); |
| const T* dY_data = dY.template data<T>(); |
| const T* X_data = X.template data<T>(); |
| const T* Y_data = Y.template data<T>(); |
| T* dX_data = dX->template mutable_data<T>(); |
| if (N == 0) { |
| return true; |
| } |
| if (global_pooling_) { |
| const int HxW = X.numel() / (N * C); |
| return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>( |
| N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_); |
| } |
| return functor_.template Backward<T, StorageOrder::NHWC>( |
| N, |
| C, |
| X_HW_dims, |
| Y_HW_dims, |
| kernel_, |
| dilation_, |
| stride_, |
| pads_, |
| dY_data, |
| X_data, |
| Y_data, |
| dX_data, |
| &context_); |
| } |
| |
| private: |
| const Functor functor_; |
| }; |
| |
| template <class Context> |
| struct AveragePoolFunctor { |
| explicit AveragePoolFunctor(const OperatorBase& op) |
| : count_include_pad( |
| op.template GetSingleArgument<bool>("count_include_pad", false)) {} |
| |
| template <typename T, StorageOrder kOrder> |
| bool GlobalPoolingForward( |
| int N, |
| int C, |
| int HxW, |
| const T* X, |
| T* Y, |
| Context* context) const; |
| |
| template <typename T, StorageOrder kOrder> |
| bool Forward( |
| int N, |
| int C, |
| const std::vector<int>& X_dims, |
| const std::vector<int>& Y_dims, |
| const std::vector<int>& kernel, |
| const std::vector<int>& dilation, |
| const std::vector<int>& stride, |
| const std::vector<int>& pads, |
| const T* X, |
| T* Y, |
| Context* context) const; |
| |
| template <typename T, StorageOrder kOrder> |
| bool GlobalPoolingBackward( |
| int N, |
| int C, |
| int HxW, |
| const T* dY, |
| const T* X, |
| const T* Y, |
| T* dX, |
| Context* context) const; |
| |
| template <typename T, StorageOrder kOrder> |
| bool Backward( |
| int N, |
| int C, |
| const std::vector<int>& X_dims, |
| const std::vector<int>& Y_dims, |
| const std::vector<int>& kernel, |
| const std::vector<int>& dilation, |
| const std::vector<int>& stride, |
| const std::vector<int>& pads, |
| const T* dY, |
| const T* X, |
| const T* Y, |
| T* dX, |
| Context* context) const; |
| |
| const bool count_include_pad; |
| Tensor ones{Context::GetDeviceType()}; |
| }; |
| |
| template <class Context> |
| struct MaxPoolFunctor { |
| explicit MaxPoolFunctor(const OperatorBase& /* op */) {} |
| |
| template <typename T, StorageOrder kOrder> |
| bool GlobalPoolingForward( |
| int N, |
| int C, |
| int HxW, |
| const T* X, |
| T* Y, |
| Context* context) const; |
| |
| template <typename T, StorageOrder kOrder> |
| bool Forward( |
| int N, |
| int C, |
| const std::vector<int>& X_dims, |
| const std::vector<int>& Y_dims, |
| const std::vector<int>& kernel, |
| const std::vector<int>& dilation, |
| const std::vector<int>& stride, |
| const std::vector<int>& pads, |
| const T* X, |
| T* Y, |
| Context* context) const; |
| |
| template <typename T, StorageOrder kOrder> |
| bool GlobalPoolingBackward( |
| int N, |
| int C, |
| int HxW, |
| const T* dY, |
| const T* X, |
| const T* Y, |
| T* dX, |
| Context* context) const; |
| |
| template <typename T, StorageOrder kOrder> |
| bool Backward( |
| int N, |
| int C, |
| const std::vector<int>& X_dims, |
| const std::vector<int>& Y_dims, |
| const std::vector<int>& kernel, |
| const std::vector<int>& dilation, |
| const std::vector<int>& stride, |
| const std::vector<int>& pads, |
| const T* dY, |
| const T* X, |
| const T* Y, |
| T* dX, |
| Context* context) const; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_POOL_OP_H_ |