blob: 4904f48cb48a24f57e487ef98f56840f9a74dfb4 [file] [log] [blame]
#include "ATen/Scalar.h"
#include <TH/TH.h>
namespace at {
template<> Half convert(double f) {
float t = static_cast<float>(f);
Half h;
TH_float2halfbits(&t,&h.x);
return h;
}
template<> double convert(Half f) {
float t;
TH_halfbits2float(&f.x,&t);
return t;
}
template<> Half convert(int64_t f) {
return convert<Half,double>(static_cast<double>(f));
}
template<> int64_t convert(Half f) {
return static_cast<int64_t>(convert<double,Half>(f));
}
#ifdef AT_CUDA_ENABLED
template<> half convert(double d) {
return half { convert<Half,double>(d).x };
}
#endif
}