blob: 352923698413923711cc0f0e6ff9d238359df7f5 [file] [log] [blame]
#include "ATen/Scalar.h"
#include <TH/TH.h>
namespace at {
template<> Half convert(double f) {
float t = static_cast<float>(f);
Half h;
TH_float2halfbits(&t,&h.x);
return h;
}
template<> double convert(Half f) {
float t;
TH_halfbits2float(&f.x,&t);
return t;
}
template<> Half convert(int64_t f) {
return convert<Half,double>(static_cast<double>(f));
}
template<> int64_t convert(Half f) {
return static_cast<int64_t>(convert<double,Half>(f));
}
#ifdef AT_CUDA_ENABLED
template<> half convert(double d) {
#if CUDA_VERSION < 9000
return half {convert<Half,double>(d).x};
#else
__half_raw raw;
raw.x = convert<Half,double>(d).x;
return half {raw};
#endif
}
#endif
}