blob: 78e41bfb6909522cbf5ed8f23333d4ced6004381 [file] [log] [blame]
/*
* Copyright © 2020 Collabora Ltd.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice (including the next
* paragraph) shall be included in all copies or substantial portions of the
* Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
#ifndef NIR_CONVERSION_BUILDER_H
#define NIR_CONVERSION_BUILDER_H
#include "util/u_math.h"
#include "nir_builder.h"
#include "nir_builtin_builder.h"
#ifdef __cplusplus
extern "C" {
#endif
static inline nir_ssa_def *
nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,
nir_rounding_mode round)
{
switch (round) {
case nir_rounding_mode_ru:
return nir_fceil(b, src);
case nir_rounding_mode_rd:
return nir_ffloor(b, src);
case nir_rounding_mode_rtne:
return nir_fround_even(b, src);
case nir_rounding_mode_undef:
case nir_rounding_mode_rtz:
break;
}
unreachable("unexpected rounding mode");
}
static inline nir_ssa_def *
nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,
unsigned dest_bit_size,
nir_rounding_mode round)
{
unsigned src_bit_size = src->bit_size;
if (dest_bit_size > src_bit_size)
return src; /* No rounding is needed for an up-convert */
nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
nir_type_float | dest_bit_size,
nir_rounding_mode_undef);
nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
nir_type_float | src_bit_size,
nir_rounding_mode_undef);
switch (round) {
case nir_rounding_mode_ru: {
/* If lower-precision conversion results in a lower value, push it
* up one ULP. */
nir_ssa_def *lower_prec =
nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
nir_ssa_def *roundtrip =
nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
nir_ssa_def *cmp = nir_flt(b, roundtrip, src);
nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
}
case nir_rounding_mode_rd: {
/* If lower-precision conversion results in a higher value, push it
* down one ULP. */
nir_ssa_def *lower_prec =
nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
nir_ssa_def *roundtrip =
nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
nir_ssa_def *cmp = nir_flt(b, src, roundtrip);
nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
}
case nir_rounding_mode_rtz:
return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),
nir_round_float_to_float(b, src, dest_bit_size,
nir_rounding_mode_ru),
nir_round_float_to_float(b, src, dest_bit_size,
nir_rounding_mode_rd));
case nir_rounding_mode_rtne:
case nir_rounding_mode_undef:
break;
}
unreachable("unexpected rounding mode");
}
static inline nir_ssa_def *
nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,
nir_alu_type src_type,
unsigned dest_bit_size,
nir_rounding_mode round)
{
/* We only care whether or not its signed */
src_type = nir_alu_type_get_base_type(src_type);
unsigned mantissa_bits;
switch (dest_bit_size) {
case 16:
mantissa_bits = 10;
break;
case 32:
mantissa_bits = 23;
break;
case 64:
mantissa_bits = 52;
break;
default: unreachable("Unsupported bit size");
}
if (src->bit_size < mantissa_bits)
return src;
if (src_type == nir_type_int) {
nir_ssa_def *sign =
nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
nir_ssa_def *abs = nir_iabs(b, src);
nir_ssa_def *positive_rounded =
nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
nir_ssa_def *max_positive =
nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
switch (round) {
case nir_rounding_mode_rtz:
return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
positive_rounded);
break;
case nir_rounding_mode_ru:
return nir_bcsel(b, sign,
nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
nir_umin(b, positive_rounded, max_positive));
break;
case nir_rounding_mode_rd:
return nir_bcsel(b, sign,
nir_ineg(b,
nir_umin(b, max_positive,
nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
positive_rounded);
case nir_rounding_mode_rtne:
case nir_rounding_mode_undef:
break;
}
unreachable("unexpected rounding mode");
} else {
nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);
nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);
nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));
nir_ssa_def *truncated = nir_iand(b, src, mask);
switch (round) {
case nir_rounding_mode_rtz:
case nir_rounding_mode_rd:
return truncated;
break;
case nir_rounding_mode_ru:
return nir_bcsel(b, nir_ieq(b, src, truncated),
src, nir_uadd_sat(b, truncated, adjust));
case nir_rounding_mode_rtne:
case nir_rounding_mode_undef:
break;
}
unreachable("unexpected rounding mode");
}
}
/** Returns true if the representable range of a contains the representable
* range of b.
*/
static inline bool
nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
{
/* Split types from bit sizes */
nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
unsigned a_bit_size = nir_alu_type_get_type_size(a);
unsigned b_bit_size = nir_alu_type_get_type_size(b);
/* This requires sized types */
assert(a_bit_size > 0 && b_bit_size > 0);
if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
return true;
if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
a_bit_size > b_bit_size)
return true;
/* 16-bit floats fit in 32-bit integers */
if (a_base_type == nir_type_int && a_bit_size >= 32 &&
b == nir_type_float16)
return true;
/* All signed or unsigned ints can fit in float or above. A uint8 can fit
* in a float16.
*/
if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
(a_bit_size >= 32 || b_bit_size == 8))
return true;
return false;
}
/**
* Clamp the source value into the widest representatble range of the
* destination type with cmp + bcsel.
*/
static inline nir_ssa_def *
nir_clamp_to_type_range(nir_builder *b,
nir_ssa_def *src, nir_alu_type src_type,
nir_alu_type dest_type)
{
assert(nir_alu_type_get_type_size(src_type) == 0 ||
nir_alu_type_get_type_size(src_type) == src->bit_size);
src_type |= src->bit_size;
if (nir_alu_type_range_contains_type_range(dest_type, src_type))
return src;
/* Split types from bit sizes */
nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
assert(dest_bit_size != 0);
/* limits of the destination type, expressed in the source type */
nir_ssa_def *low = NULL, *high = NULL;
switch (dest_base_type) {
case nir_type_int: {
int64_t ilow, ihigh;
if (dest_bit_size == 64) {
ilow = INT64_MIN;
ihigh = INT64_MAX;
} else {
ilow = -(1ll << (dest_bit_size - 1));
ihigh = (1ll << (dest_bit_size - 1)) - 1;
}
if (src_base_type == nir_type_int) {
low = nir_imm_intN_t(b, ilow, src->bit_size);
high = nir_imm_intN_t(b, ihigh, src->bit_size);
} else if (src_base_type == nir_type_uint) {
assert(src->bit_size >= dest_bit_size);
high = nir_imm_intN_t(b, ihigh, src->bit_size);
} else {
low = nir_imm_floatN_t(b, ilow, src->bit_size);
high = nir_imm_floatN_t(b, ihigh, src->bit_size);
}
break;
}
case nir_type_uint: {
uint64_t uhigh = dest_bit_size == 64 ?
~0ull : (1ull << dest_bit_size) - 1;
if (src_base_type != nir_type_float) {
low = nir_imm_intN_t(b, 0, src->bit_size);
if (src_base_type == nir_type_uint || src->bit_size > dest_bit_size)
high = nir_imm_intN_t(b, uhigh, src->bit_size);
} else {
low = nir_imm_floatN_t(b, 0.0f, src->bit_size);
high = nir_imm_floatN_t(b, uhigh, src->bit_size);
}
break;
}
case nir_type_float: {
double flow, fhigh;
switch (dest_bit_size) {
case 16:
flow = -65504.0f;
fhigh = 65504.0f;
break;
case 32:
flow = -FLT_MAX;
fhigh = FLT_MAX;
break;
case 64:
flow = -DBL_MAX;
fhigh = DBL_MAX;
break;
default:
unreachable("Unhandled bit size");
}
switch (src_base_type) {
case nir_type_int: {
int64_t src_ilow, src_ihigh;
if (src->bit_size == 64) {
src_ilow = INT64_MIN;
src_ihigh = INT64_MAX;
} else {
src_ilow = -(1ll << (src->bit_size - 1));
src_ihigh = (1ll << (src->bit_size - 1)) - 1;
}
if (src_ilow < flow)
low = nir_imm_intN_t(b, flow, src->bit_size);
if (src_ihigh > fhigh)
high = nir_imm_intN_t(b, fhigh, src->bit_size);
break;
}
case nir_type_uint: {
uint64_t src_uhigh = src->bit_size == 64 ?
~0ull : (1ull << src->bit_size) - 1;
if (src_uhigh > fhigh)
high = nir_imm_intN_t(b, fhigh, src->bit_size);
break;
}
case nir_type_float:
low = nir_imm_floatN_t(b, flow, src->bit_size);
high = nir_imm_floatN_t(b, fhigh, src->bit_size);
break;
default:
unreachable("Clamping from unknown type");
}
break;
}
default:
unreachable("clamping to unknown type");
break;
}
nir_ssa_def *low_cond = NULL, *high_cond = NULL;
switch (src_base_type) {
case nir_type_int:
low_cond = low ? nir_ilt(b, src, low) : NULL;
high_cond = high ? nir_ilt(b, high, src) : NULL;
break;
case nir_type_uint:
low_cond = low ? nir_ult(b, src, low) : NULL;
high_cond = high ? nir_ult(b, high, src) : NULL;
break;
case nir_type_float:
low_cond = low ? nir_flt(b, src, low) : NULL;
high_cond = high ? nir_flt(b, high, src) : NULL;
break;
default:
unreachable("clamping from unknown type");
}
nir_ssa_def *res = src;
if (low_cond)
res = nir_bcsel(b, low_cond, low, res);
if (high_cond)
res = nir_bcsel(b, high_cond, high, res);
return res;
}
static inline nir_rounding_mode
nir_simplify_conversion_rounding(nir_alu_type src_type,
nir_alu_type dest_type,
nir_rounding_mode rounding)
{
nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
assert(src_bit_size > 0 && dest_bit_size > 0);
if (rounding == nir_rounding_mode_undef)
return rounding;
/* Pure integer conversion doesn't have any rounding */
if (src_base_type != nir_type_float &&
dest_base_type != nir_type_float)
return nir_rounding_mode_undef;
/* Float down-casts don't round */
if (src_base_type == nir_type_float &&
dest_base_type == nir_type_float &&
dest_bit_size >= src_bit_size)
return nir_rounding_mode_undef;
/* Regular float to int conversions are RTZ */
if (src_base_type == nir_type_float &&
dest_base_type != nir_type_float &&
rounding == nir_rounding_mode_rtz)
return nir_rounding_mode_undef;
/* The CL spec requires regular conversions to float to be RTNE */
if (dest_base_type == nir_type_float &&
rounding == nir_rounding_mode_rtne)
return nir_rounding_mode_undef;
/* Couldn't simplify */
return rounding;
}
static inline nir_ssa_def *
nir_convert_with_rounding(nir_builder *b,
nir_ssa_def *src, nir_alu_type src_type,
nir_alu_type dest_type,
nir_rounding_mode round,
bool clamp)
{
/* Some stuff wants sized types */
assert(nir_alu_type_get_type_size(src_type) == 0 ||
nir_alu_type_get_type_size(src_type) == src->bit_size);
src_type |= src->bit_size;
/* Split types from bit sizes */
nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
/* Try to simplify the conversion if we can */
clamp = clamp &&
!nir_alu_type_range_contains_type_range(dest_type, src_type);
round = nir_simplify_conversion_rounding(src_type, dest_type, round);
/*
* If we don't care about rounding and clamping, we can just use NIR's
* built-in ops. There is also a special case for SPIR-V in shaders, where
* f32/f64 -> f16 conversions can have one of two rounding modes applied,
* which NIR has built-in opcodes for.
*
* For the rest, we have our own implementation of rounding and clamping.
*/
bool trivial_convert;
if (!clamp && round == nir_rounding_mode_undef) {
trivial_convert = true;
} else if (!clamp && src_type == nir_type_float32 &&
dest_type == nir_type_float16 &&
(round == nir_rounding_mode_rtne ||
round == nir_rounding_mode_rtz)) {
trivial_convert = true;
} else {
trivial_convert = false;
}
if (trivial_convert) {
nir_op op = nir_type_conversion_op(src_type, dest_type, round);
return nir_build_alu(b, op, src, NULL, NULL, NULL);
}
nir_ssa_def *dest = src;
/* clamp the result into range */
if (clamp)
dest = nir_clamp_to_type_range(b, dest, src_type, dest_type);
/* round with selected rounding mode */
if (!trivial_convert && round != nir_rounding_mode_undef) {
if (src_base_type == nir_type_float) {
if (dest_base_type == nir_type_float) {
dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
} else {
dest = nir_round_float_to_int(b, dest, round);
}
} else {
dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
}
round = nir_rounding_mode_undef;
}
/* now we can convert the value */
nir_op op = nir_type_conversion_op(src_type, dest_type, round);
return nir_build_alu(b, op, dest, NULL, NULL, NULL);
}
#ifdef __cplusplus
}
#endif
#endif /* NIR_CONVERSION_BUILDER_H */