blob: 873b03855f14de78fe110aac515d990cfb36d4b3 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/operators/cast_op.h"
namespace caffe2 {
template <>
template <typename DstType, typename SrcType>
bool CastOp<CPUContext>::DoRunWithType() {
auto& input = Input(0);
auto* output = Output(0);
output->ResizeLike(input);
const auto* data = input.template data<SrcType>();
auto* out = output->template mutable_data<DstType>();
auto N = input.size();
for (TIndex i = 0; i < N; ++i) {
out[i] = static_cast<DstType>(data[i]);
}
return true;
}
template <>
void CastOp<CPUContext>::SetBody(TensorProto_DataType to) {
switch (to) {
case TensorProto_DataType_FLOAT:
// body_ = &CastOp::DoRunIncFp16WithDstType<float>;
body_ = &CastOp<CPUContext>::DoRunWithDstType<float>;
break;
case TensorProto_DataType_INT32:
body_ = &CastOp<CPUContext>::DoRunWithDstType<int>;
break;
case TensorProto_DataType_BYTE:
LOG(FATAL) << "BYTE is deprecated";
break;
case TensorProto_DataType_STRING:
CAFFE_THROW("Casting to and from strings is not supported yet");
// break;
case TensorProto_DataType_BOOL:
body_ = &CastOp<CPUContext>::DoRunWithDstType<bool>;
break;
case TensorProto_DataType_UINT8:
body_ = &CastOp<CPUContext>::DoRunWithDstType<uint8_t>;
break;
case TensorProto_DataType_INT8:
body_ = &CastOp<CPUContext>::DoRunWithDstType<int8_t>;
break;
case TensorProto_DataType_UINT16:
body_ = &CastOp<CPUContext>::DoRunWithDstType<uint16_t>;
break;
case TensorProto_DataType_INT16:
body_ = &CastOp<CPUContext>::DoRunWithDstType<int16_t>;
break;
case TensorProto_DataType_INT64:
body_ = &CastOp<CPUContext>::DoRunWithDstType<int64_t>;
break;
case TensorProto_DataType_FLOAT16:
CAFFE_THROW("Casting to and from float16 on CPU is not supported yet");
// break;
case TensorProto_DataType_DOUBLE:
//body_ = &CastOp::DoRunIncFp16WithDstType<double>;
body_ = &CastOp<CPUContext>::DoRunWithDstType<double>;
break;
case TensorProto_DataType_UNDEFINED:
CAFFE_THROW("Cast op must have 'to' argument of type DataType");
// break;
default:
CAFFE_THROW("Unexpected 'to' argument value: ", to);
}
}
template <>
template <typename DstType>
bool CastOp<CPUContext>::DoRunWithDstType() {
return DispatchHelper<
TensorTypes<
float,
int32_t,
bool,
uint8_t,
int8_t,
uint16_t,
int16_t,
int64_t,
double>,
DstType>::call(this, Input(0));
}
REGISTER_CPU_OPERATOR(Cast, CastOp<CPUContext>);
OPERATOR_SCHEMA(Cast)
.NumInputs(1)
.NumOutputs(1)
.TensorInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& in) {
ArgumentHelper helper(def);
vector<TensorShape> out;
out.push_back(in[0]);
out[0].set_data_type(cast::GetCastDataType(helper, "to"));
return out;
})
.SetDoc(R"DOC(
The operator casts the elements of a given input tensor to a data type
specified by the 'to' argument and returns an output tensor of the same size in
the converted type. The 'to' argument must be one of the data types specified
in the 'DataType' enum field in the TensorProto message. If the 'to' argument
is not provided or is not one of the enumerated types in DataType, Caffe2
throws an Enforce error.
NOTE: Casting to and from strings is not supported yet.
)DOC")
.Arg(
"to",
"The data type to which the elements of the input tensor are cast."
"Strictly must be one of the types from DataType enum in TensorProto")
.Input(0, "input", "Input tensor to be cast.")
.Output(
0,
"output",
"Output tensor with the same shape as input with type "
"specified by the 'to' argument");
// Some Casts are compatible with gradients, but for now we don't support it
// GRADIENT_NOT_IMPLEMENTED_YET(Cast);
class GetCastGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
vector<OperatorDef> defs = SingleGradientDef("Cast", "", vector<string>{GO(0)}, vector<string>{GI(0)});
// now modify the arguments in defs[0]
ArgumentHelper argsHelper(def_);
auto to_name = cast::GetCastDataType(argsHelper, "to");
CAFFE_ENFORCE(
argsHelper.HasSingleArgumentOfType<string>("from_type") ||
argsHelper.HasSingleArgumentOfType<int>("from_type"),
"Argument 'from_type' of type int or string"
" is required to get the gradient of CastOp");
auto from_name = cast::GetCastDataType(argsHelper, "from_type");
Argument *to = defs[0].add_arg();
to->set_name("to");
to->set_i(from_name);
Argument *from = defs[0].add_arg();
from->set_name("from_type");
from->set_i(to_name);
return defs;
}
bool CopyArguments() const override {
return false;
}
};
REGISTER_GRADIENT(Cast, GetCastGradient);
} // namespace caffe2