blob: bfb1cc657edc41aca1bd9181be5f2081a5235214 [file] [log] [blame]
#pragma once
#include <stdint.h>
#ifdef AT_CUDA_ENABLED
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#endif
namespace at {
template<typename To, typename From> To convert(From f) {
return static_cast<To>(f);
}
#if defined(__GNUC__)
#define AT_ALIGN(n) __attribute__((aligned(n)))
#elif defined(_WIN32)
#define AT_ALIGN(n) __declspec(align(n))
#else
#define AT_ALIGN(n)
#endif
typedef struct AT_ALIGN(2) {
unsigned short x;
#ifdef AT_CUDA_ENABLED
#if CUDA_VERSION < 9000
operator half() { return half{ x }; }
#else
operator half() {
__half_raw x_raw;
x_raw.x = x;
return half(x_raw);
}
#endif
#endif
operator double();
} Half;
template<> Half convert(double f);
template<> double convert(Half f);
template<> Half convert(int64_t f);
template<> int64_t convert(Half f);
inline Half::operator double() {
return convert<double,Half>(*this);
}
#ifdef AT_CUDA_ENABLED
template<> half convert(double d);
#endif
template<typename To, typename From>
static inline To HalfFix(From h) {
return To { h.x };
}
#ifdef AT_CUDA_ENABLED
#if CUDA_VERSION >= 9000
template<>
inline __half HalfFix<__half, Half>(Half h) {
__half_raw raw;
raw.x = h.x;
return __half { raw };
}
template<>
inline Half HalfFix<Half, __half>(__half h) {
__half_raw raw(h);
return Half { raw.x };
}
#endif
#endif
}