|  | #ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_ | 
|  | #define CAFFE2_OPERATORS_FLATTEN_OP_H_ | 
|  |  | 
|  | #include "caffe2/core/operator.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <class Context> | 
|  | class FlattenOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | template <class... Args> | 
|  | explicit FlattenOp(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...), | 
|  | axis_(this->template GetSingleArgument<int>("axis", 1)) {} | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | auto& input = Input(0); | 
|  | auto* output = Output(0); | 
|  | CAFFE_ENFORCE_GE( | 
|  | input.dim(), axis_, "The rank of the tensor must be >= axis."); | 
|  | output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_)); | 
|  | context_.CopyItemsSameDevice( | 
|  | input.dtype(), | 
|  | input.numel(), | 
|  | input.raw_data(), | 
|  | output->raw_mutable_data(input.dtype())); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | int axis_; | 
|  | }; | 
|  |  | 
|  | inline std::vector<TensorShape> TensorInferenceForFlatten( | 
|  | const OperatorDef& def, | 
|  | const std::vector<TensorShape>& in) { | 
|  | ArgumentHelper helper(def); | 
|  | const int axis = helper.GetSingleArgument<int>("axis", 1); | 
|  | std::vector<TensorShape> out(1); | 
|  | int64_t outer = 1; | 
|  | int64_t inner = 1; | 
|  | std::size_t index = 0; | 
|  | for (auto d : in[0].dims()) { | 
|  | if (index < axis) { | 
|  | outer *= d; | 
|  | } else { | 
|  | inner *= d; | 
|  | } | 
|  | ++index; | 
|  | } | 
|  | out[0].set_data_type(in[0].data_type()); | 
|  | out[0].add_dims(outer); | 
|  | out[0].add_dims(inner); | 
|  | return out; | 
|  | } | 
|  |  | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_ |