blob: bb3d2e21ef29e121ef7510838fe4d935a6715d30 [file] [log] [blame]
#pragma once
#include <caffe2/core/types.h>
#ifdef __CUDA_ARCH__
#include <cuda_fp16.h>
#endif
#ifdef __CUDA_ARCH__
#define CONVERSIONS_DECL __host__ __device__ inline
#else
#define CONVERSIONS_DECL inline
#endif
namespace caffe2 {
namespace convert {
namespace {
inline float16 cpu_float2half_rn(float f) {
float16 ret;
static_assert(
sizeof(unsigned int) == sizeof(float),
"Programming error sizeof(unsigned int) != sizeof(float)");
unsigned* xp = reinterpret_cast<unsigned int*>(&f);
unsigned x = *xp;
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
ret.x = 0x7fffU;
return ret;
}
sign = ((x >> 16) & 0x8000);
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
ret.x = sign | 0x7c00U;
return ret;
}
if (u < 0x33000001) {
ret.x = (sign | 0x0000);
return ret;
}
exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);
if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
} else {
shift = 0x7e - exponent;
exponent = 0;
mantissa |= 0x800000;
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);
// Round to nearest even.
remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
++exponent;
mantissa = 0;
}
}
ret.x = (sign | (exponent << 10) | mantissa);
return ret;
}
inline float cpu_half2float(float16 h) {
unsigned sign = ((h.x >> 15) & 1);
unsigned exponent = ((h.x >> 10) & 0x1f);
unsigned mantissa = ((h.x & 0x3ff) << 13);
if (exponent == 0x1f) { /* NaN or Inf */
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
exponent = 0xff;
} else if (!exponent) { /* Denorm or Zero */
if (mantissa) {
unsigned int msb;
exponent = 0x71;
do {
msb = (mantissa & 0x400000);
mantissa <<= 1; /* normalize */
--exponent;
} while (!msb);
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
}
} else {
exponent += 0x70;
}
unsigned i = ((sign << 31) | (exponent << 23) | mantissa);
float ret;
memcpy(&ret, &i, sizeof(i));
return ret;
}
}; // anonymous
// general version: defer to static_cast
template <typename IN, typename OUT>
CONVERSIONS_DECL OUT To(const IN in) {
return static_cast<OUT>(in);
}
#if __CUDA_ARCH__
__device__ __inline__ __half inf_clip(__half h) {
int isi = __hisinf(h);
if (isi > 0) {
// Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff)
h.x = 0x7bffU;
} else if (isi < 0) {
// As above, negated
h.x = 0x7bffU ^ 0x8000;
}
return h;
}
#endif
// explicit for fp16
template <>
CONVERSIONS_DECL float16 To(const float in) {
#if __CUDA_ARCH__
// hacky interface between C2 fp16 and CUDA
float16 ret;
#if 0
// alternative truncation scheme
__half r;
r.x = __float2half_rn(in);
ret.x = inf_clip(r).x;
#else
ret.x = __float2half(in).x;
#endif
return ret;
#else
return cpu_float2half_rn(in);
#endif
}
template <>
CONVERSIONS_DECL float To(const float16 in) {
#if __CUDA_ARCH__
__half tmp;
tmp.x = in.x;
return __half2float(tmp);
#else
return cpu_half2float(in);
#endif
};
template <>
CONVERSIONS_DECL float To(const float in) {
return in;
}
template <typename OUT, typename IN>
CONVERSIONS_DECL OUT Get(IN x) {
return static_cast<OUT>(x);
}
template <>
CONVERSIONS_DECL float Get(float16 x) {
return To<float16, float>(x);
}
template <>
CONVERSIONS_DECL float16 Get(float x) {
return To<float, float16>(x);
}
}; // namespace convert
}; // namespace caffe2
#undef CONVERSIONS_DECL