|  | #pragma once | 
|  |  | 
|  | #include <c10/util/irange.h> | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/logging.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/core/types.h" | 
|  | #include "caffe2/utils/cast.h" | 
|  | #include "caffe2/utils/conversions.h" | 
|  | #include "caffe2/utils/math.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <class Context> | 
|  | class CastOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | explicit CastOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws) { | 
|  | const ArgumentHelper helper(operator_def); | 
|  | TensorProto_DataType to = cast::GetCastDataType(helper, "to"); | 
|  |  | 
|  | SetBody(to); | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | return (this->*body_)(); | 
|  | } | 
|  |  | 
|  | // Allow for Context-specific implementations | 
|  | void SetBody(TensorProto_DataType to); | 
|  |  | 
|  | template <typename DstType> | 
|  | bool DoRunWithDstType(); | 
|  |  | 
|  | template <typename DstType, typename SrcType> | 
|  | bool DoRunWithType() { | 
|  | auto& input = Input(0); | 
|  | auto* output = Output(0); | 
|  | output->ResizeLike(input); | 
|  | const auto* data = input.template data<SrcType>(); | 
|  | auto* out = output->template mutable_data<DstType>(); | 
|  | auto N = input.size(); | 
|  | for (const auto i : c10::irange(N)) { | 
|  | out[i] = static_cast<DstType>(data[i]); | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | bool (CastOp::*body_)(); | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 |