|  | #ifndef CAFFE2_OPERATOR_GLU_OP_H_ | 
|  | #define CAFFE2_OPERATOR_GLU_OP_H_ | 
|  |  | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/operator.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  | template <typename T, class Context> | 
|  | class GluOp final : public Operator<Context> { | 
|  | public: | 
|  | template <class... Args> | 
|  | explicit GluOp(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...), | 
|  | dim_(this->template GetSingleArgument<int>("dim", -1)) {} | 
|  |  | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | bool RunOnDevice() { | 
|  | auto& X = Input(0); | 
|  |  | 
|  | vector<int64_t> Yshape; | 
|  | Yshape.insert(Yshape.end(), X.sizes().begin(), X.sizes().end()); | 
|  | const int split_index = dim_ == -1 ? Yshape.size() - 1 : dim_; | 
|  | CAFFE_ENFORCE( | 
|  | Yshape[split_index] % 2 == 0, | 
|  | "Split dimension ", | 
|  | Yshape[split_index], | 
|  | " should be divided by two"); | 
|  | const int split_dim_size = Yshape[split_index] / 2; | 
|  | const int M = X.size_to_dim(split_index); | 
|  | const int N = X.size_from_dim(split_index + 1); | 
|  | Yshape[split_index] = split_dim_size; | 
|  | auto* Y = Output(0, Yshape, at::dtype<T>()); | 
|  | ComputeGlu( | 
|  | M, | 
|  | split_dim_size, | 
|  | N, | 
|  | X.template data<T>(), | 
|  | Y->template mutable_data<T>()); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | protected: | 
|  | void ComputeGlu( | 
|  | const int M, | 
|  | const int split_dim_size, | 
|  | const int N, | 
|  | const T* X, | 
|  | T* output); | 
|  |  | 
|  | private: | 
|  | const int dim_; | 
|  | }; | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATOR_GLU_OP_H_ |