| #ifndef CAFFE2_OPERATORS_POW_OP_H_ |
| #define CAFFE2_OPERATORS_POW_OP_H_ |
| |
| #include "caffe2/core/common_omp.h" |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/operators/elementwise_ops.h" |
| #include "caffe2/operators/elementwise_ops_utils.h" |
| #include "caffe2/utils/math.h" |
| |
| namespace caffe2 { |
| |
| template < |
| typename InputTypes, |
| class Context, |
| class Functor, |
| class TypeMap = SameTypeAsInput> |
| class PowOp : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit PowOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0), |
| OP_SINGLE_ARG(int, "axis", axis_, -1), |
| OP_SINGLE_ARG(string, "axis_str", axis_str_, ""), |
| OP_SINGLE_ARG(string, "order", order_, "NCHW"), |
| functor_() { |
| if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp |
| exponent_ = this->template GetSingleArgument<float>( |
| "exponent", 0); // based on pow_ops.h |
| } else if (InputSize() == 2) { // BinaryElementwiseOp |
| // Figure out the correct axis to use. |
| if (enable_broadcast_) { |
| if (axis_ != -1) { |
| // Get axis from an explicit axis argument. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), |
| 0U, |
| "Args axis and axis_str cannot be used simultaneously."); |
| } else if (axis_str_.size()) { |
| // Get the axis index semantically. |
| CAFFE_ENFORCE_EQ( |
| axis_str_.size(), 1U, "Unsupported axis string", axis_str_); |
| size_t semantic_axis_ = order_.find(axis_str_); |
| CAFFE_ENFORCE_NE( |
| semantic_axis_, |
| string::npos, |
| "Unrecognizable axis string ", |
| axis_str_, |
| " from order string ", |
| order_); |
| axis_ = semantic_axis_; |
| } |
| } else { |
| CAFFE_ENFORCE( |
| axis_ == -1 && axis_str_.empty(), |
| "Do not specify axis or axis_str if broadcast is not enabled."); |
| } |
| } else { |
| CAFFE_THROW( |
| "Only a tensor with an argument or two input tensors are supported as input to pow operator."); |
| } |
| } |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<InputTypes>::call(this, Input(0)); |
| } |
| |
| template <typename T> |
| bool DoRunWithType() { |
| if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp |
| const auto& A = Input(0); |
| |
| auto* C = |
| Output(0, A.sizes(), at::dtype<typename TypeMap::template type<T>>()); |
| const T* Adata = A.template data<T>(); |
| auto* Cdata = |
| C->template mutable_data<typename TypeMap::template type<T>>(); |
| functor_.template Run<true, T, float, T>( |
| A.numel(), Adata, NULL, exponent_, Cdata, &context_); |
| } else if (InputSize() == 2) { // BinaryElementwiseOp |
| const auto& A = Input(0); |
| const auto& B = Input(1); |
| CAFFE_ENFORCE( |
| !IsInputOutputAlias(1, 0) || !enable_broadcast_, |
| "In-place is allowed only with the first tensor when broadcasting"); |
| auto* C = |
| Output(0, A.sizes(), at::dtype<typename TypeMap::template type<T>>()); |
| const T* Adata = A.template data<T>(); |
| const T* Bdata = B.template data<T>(); |
| auto* Cdata = |
| C->template mutable_data<typename TypeMap::template type<T>>(); |
| if (!enable_broadcast_) { |
| CAFFE_ENFORCE_EQ( |
| A.sizes(), |
| B.sizes(), |
| "Dimension mismatch - did you forget to set broadcast=1?"); |
| functor_.template Run<false, T, T, T>( |
| A.numel(), Adata, Bdata, 0, Cdata, &context_); |
| } else if (B.numel() == 1) { |
| functor_.template Run<true, T, T, T>( |
| A.numel(), Adata, Bdata, 0, Cdata, &context_); |
| } else { |
| size_t pre, n, post; |
| std::tie(pre, n, post) = |
| elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_); |
| if (post == 1) { |
| functor_.template RunWithBroadcast<T, T, T>( |
| Adata, Bdata, Cdata, pre, n, &context_); |
| } else { |
| functor_.template RunWithBroadcast2<T, T, T>( |
| Adata, Bdata, Cdata, pre, n, post, &context_); |
| } |
| } |
| } else { |
| CAFFE_THROW( |
| "Only a tensor with an argument or two input tensors are supported as input to pow operator."); |
| } |
| return true; |
| } |
| |
| private: |
| bool enable_broadcast_; |
| int axis_; |
| string axis_str_; |
| string order_; |
| float exponent_; |
| Functor functor_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_POW_OP_H_ |