blob: 4601f258e5c6fe6559e79b8f69f7d8d3ff2293f8 [file] [log] [blame]
#include "caffe2/operators/half_float_ops.h"
namespace caffe2 {
namespace {
OPERATOR_SCHEMA(FloatToHalf).NumInputs(1).NumOutputs(1);
OPERATOR_SCHEMA(HalfToFloat).NumInputs(1).NumOutputs(1);
class GetFloatToHalfGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"HalfToFloat", "", vector<string>{GO(0)}, vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(FloatToHalf, GetFloatToHalfGradient);
class GetHalfToFloatGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"FloatToHalf", "", vector<string>{GO(0)}, vector<string>{GI(0)});
}
};
REGISTER_GRADIENT(HalfToFloat, GetHalfToFloatGradient);
} // namespace
} // namespace caffe2