| #ifndef ROI_POOL_OP_H_ |
| #define ROI_POOL_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| |
| namespace caffe2 { |
| |
| template <typename T, class Context> |
| class RoIPoolOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit RoIPoolOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| is_test_( |
| this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)), |
| order_(StringToStorageOrder( |
| this->template GetSingleArgument<string>("order", "NCHW"))), |
| pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)), |
| pooled_width_(this->template GetSingleArgument<int>("pooled_w", 1)), |
| spatial_scale_( |
| this->template GetSingleArgument<float>("spatial_scale", 1.)) { |
| CAFFE_ENFORCE( |
| (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 2), |
| "Output size mismatch."); |
| CAFFE_ENFORCE_GT(spatial_scale_, 0); |
| CAFFE_ENFORCE_GT(pooled_height_, 0); |
| CAFFE_ENFORCE_GT(pooled_width_, 0); |
| CAFFE_ENFORCE_EQ( |
| order_, StorageOrder::NCHW, "Only NCHW order is supported right now."); |
| } |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| bool is_test_; |
| StorageOrder order_; |
| int pooled_height_; |
| int pooled_width_; |
| float spatial_scale_; |
| }; |
| |
| template <typename T, class Context> |
| class RoIPoolGradientOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit RoIPoolGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| spatial_scale_( |
| this->template GetSingleArgument<float>("spatial_scale", 1.)), |
| pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)), |
| pooled_width_(this->template GetSingleArgument<int>("pooled_w", 1)), |
| order_(StringToStorageOrder( |
| this->template GetSingleArgument<string>("order", "NCHW"))) { |
| CAFFE_ENFORCE_GT(spatial_scale_, 0); |
| CAFFE_ENFORCE_GT(pooled_height_, 0); |
| CAFFE_ENFORCE_GT(pooled_width_, 0); |
| CAFFE_ENFORCE_EQ( |
| order_, StorageOrder::NCHW, "Only NCHW order is supported right now."); |
| } |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override { |
| CAFFE_NOT_IMPLEMENTED; |
| } |
| |
| protected: |
| float spatial_scale_; |
| int pooled_height_; |
| int pooled_width_; |
| StorageOrder order_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // ROI_POOL_OP_H_ |