blob: 8cbd7d6abb231829d77da43a6e9112a4e6e3f585 [file] [log] [blame]
#pragma once
#include<stdint.h>
#ifdef AT_CUDA_ENABLED
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#endif
namespace at {
template<typename To, typename From> To convert(From f) {
return static_cast<To>(f);
}
#if defined(__GNUC__)
#define AT_ALIGN(n) __attribute__((aligned(n)))
#elif defined(_WIN32)
#define AT_ALIGN(n) __declspec(align(n))
#else
#define AT_ALIGN(n)
#endif
typedef struct AT_ALIGN(2) {
unsigned short x;
#ifdef AT_CUDA_ENABLED
operator half() { return half { x }; }
#endif
operator double();
} Half;
template<> Half convert(double f);
template<> double convert(Half f);
template<> Half convert(int64_t f);
template<> int64_t convert(Half f);
inline Half::operator double() {
return convert<double,Half>(*this);
}
#ifdef AT_CUDA_ENABLED
template<> half convert(double d);
#endif
template<typename To, typename From>
static inline To HalfFix(From h) {
return To { h.x };
}
}