blob: f090f3507bfa06fbbe6c64ce3db93a11d99edb1d [file] [log] [blame]
#pragma once
#include <limits>
#include <stdint.h>
#ifdef AT_CUDA_ENABLED
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#endif
namespace at {
template<typename To, typename From> To convert(From f) {
return static_cast<To>(f);
}
template<typename To, typename From> bool overflows(From f) {
using limit = std::numeric_limits<To>;
return f < limit::lowest() || f > limit::max();
}
template<typename To, typename From> To checked_convert(From f, const char* name) {
if (overflows<To, From>(f)) {
std::string msg = "value cannot be converted to type ";
msg += name;
msg += " without overflow: ";
msg += std::to_string(f);
throw std::domain_error(std::move(msg));
}
return convert<To, From>(f);
}
#if defined(__GNUC__)
#define AT_ALIGN(n) __attribute__((aligned(n)))
#elif defined(_WIN32)
#define AT_ALIGN(n) __declspec(align(n))
#else
#define AT_ALIGN(n)
#endif
typedef struct AT_ALIGN(2) {
unsigned short x;
#ifdef AT_CUDA_ENABLED
#if CUDA_VERSION < 9000
operator half() { return half{ x }; }
#else
operator half() {
__half_raw x_raw;
x_raw.x = x;
return half(x_raw);
}
#endif
#endif
operator double();
} Half;
template<> Half convert(double f);
template<> double convert(Half f);
template<> Half convert(int64_t f);
template<> int64_t convert(Half f);
template<> bool overflows<Half, double>(double f);
template<> bool overflows<Half, int64_t>(int64_t f);
inline Half::operator double() {
return convert<double,Half>(*this);
}
#ifdef AT_CUDA_ENABLED
template<> half convert(double d);
#endif
template<typename To, typename From>
static inline To HalfFix(From h) {
return To { h.x };
}
#ifdef AT_CUDA_ENABLED
#if CUDA_VERSION >= 9000
template<>
inline __half HalfFix<__half, Half>(Half h) {
__half_raw raw;
raw.x = h.x;
return __half { raw };
}
template<>
inline Half HalfFix<Half, __half>(__half h) {
__half_raw raw(h);
return Half { raw.x };
}
#endif
#endif
}