blob: 64b1171856fc3b5dba7ec6a24c17a5761b4be8da [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_CAST_OP_H_
#define CAFFE2_OPERATORS_CAST_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace cast {
inline TensorProto_DataType GetCastDataType(const ArgumentHelper& helper) {
TensorProto_DataType to;
if (helper.HasSingleArgumentOfType<string>("to")) {
#ifndef CAFFE2_USE_LITE_PROTO
string s = helper.GetSingleArgument<string>("to", "");
std::transform(s.begin(), s.end(), s.begin(), ::toupper);
CAFFE_ENFORCE(
TensorProto_DataType_Parse(s, &to), "Unknown 'to' argument: ", s);
#else
CAFFE_THROW("String cast op not supported");
#endif
} else {
to = static_cast<TensorProto_DataType>(
helper.GetSingleArgument<int>("to", TensorProto_DataType_UNDEFINED));
}
return to;
}
} // namespace cast
template <class Context>
class CastOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
CastOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
TensorProto_DataType to = cast::GetCastDataType(this->arg_helper());
switch (to) {
case TensorProto_DataType_FLOAT:
body_ = &CastOp::DoRunWithDstType<float>;
break;
case TensorProto_DataType_INT32:
body_ = &CastOp::DoRunWithDstType<int>;
break;
case TensorProto_DataType_BYTE:
LOG(FATAL) << "BYTE is deprecated";
break;
case TensorProto_DataType_STRING:
CAFFE_THROW("Casting to and from strings is not supported yet");
// break;
case TensorProto_DataType_BOOL:
body_ = &CastOp::DoRunWithDstType<bool>;
break;
case TensorProto_DataType_UINT8:
body_ = &CastOp::DoRunWithDstType<uint8_t>;
break;
case TensorProto_DataType_INT8:
body_ = &CastOp::DoRunWithDstType<int8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &CastOp::DoRunWithDstType<uint16_t>;
break;
case TensorProto_DataType_INT16:
body_ = &CastOp::DoRunWithDstType<int16_t>;
break;
case TensorProto_DataType_INT64:
body_ = &CastOp::DoRunWithDstType<int64_t>;
break;
case TensorProto_DataType_FLOAT16:
CAFFE_THROW("Casting to and from float16 is not supported yet");
// break;
case TensorProto_DataType_DOUBLE:
body_ = &CastOp::DoRunWithDstType<double>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("Cast op must have 'to' argument of type DataType");
// break;
default:
CAFFE_THROW("Unexpected 'to' argument value: ", to);
}
}
bool RunOnDevice() override {
return (this->*body_)();
}
template <typename DstType>
bool DoRunWithDstType() {
return DispatchHelper<
TensorTypes<
float,
int32_t,
bool,
uint8_t,
int8_t,
uint16_t,
int16_t,
int64_t,
double>,
DstType>::call(this, Input(0));
}
template <typename DstType, typename SrcType>
bool DoRunWithType() {
auto& input = Input(0);
auto* output = Output(0);
output->ResizeLike(input);
const auto* data = input.template data<SrcType>();
auto* out = output->template mutable_data<DstType>();
auto N = input.size();
for (TIndex i = 0; i < N; ++i) {
out[i] = static_cast<DstType>(data[i]);
}
return true;
}
private:
bool (CastOp::*body_)();
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_CAST_OP_H_