blob: f66fb5971e15de1c179d60e72dc1a04d97369ec6 [file] [log] [blame]
#pragma once
// Defines the bloat16 type (brain floating-point). This representation uses
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
#include <c10/macros/Macros.h>
#include <cmath>
#include <cstring>
namespace c10 {
namespace detail {
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
float res = 0;
uint32_t tmp = src;
tmp <<= 16;
#ifdef __HIP_PLATFORM_HCC__
float* tempRes;
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
tempRes = reinterpret_cast<float*>(&tmp);
res = *tempRes;
#else
std::memcpy(&res, &tmp, sizeof(tmp));
#endif
return res;
}
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
uint32_t res = 0;
#ifdef __HIP_PLATFORM_HCC__
// We should be using memcpy in order to respect the strict aliasing rule
// but it fails in the HIP environment.
uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
res = *tempRes;
#else
std::memcpy(&res, &src, sizeof(res));
#endif
return res >> 16;
}
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
if (std::isnan(src)) {
return 0x7FC0;
} else {
union {
uint32_t U32;
float F32;
};
F32 = src;
uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFF;
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
}
} // namespace detail
struct alignas(2) BFloat16 {
uint16_t x;
// HIP wants __host__ __device__ tag, CUDA does not
#ifdef __HIP_PLATFORM_HCC__
C10_HOST_DEVICE BFloat16() = default;
#else
BFloat16() = default;
#endif
struct from_bits_t {};
static constexpr from_bits_t from_bits() {
return from_bits_t();
}
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){};
inline C10_HOST_DEVICE BFloat16(float value);
inline C10_HOST_DEVICE operator float() const;
};
} // namespace c10
#include <c10/util/BFloat16-inl.h>