guest_abi: Fix extension for float arguments
Uses a custom GuestArgument for RISC-V to ensure proper NaN boxing and
unboxing of floating-point arguments.
Test: berberis_run_host_tests
Bug: 300490784
Change-Id: Ie7d5bbbe5b99da08dad0215ea5c3f6a68f4309f8
diff --git a/guest_abi/riscv64/guest_abi_test.cc b/guest_abi/riscv64/guest_abi_test.cc
index e134d72..22686ef 100644
--- a/guest_abi/riscv64/guest_abi_test.cc
+++ b/guest_abi/riscv64/guest_abi_test.cc
@@ -95,6 +95,50 @@
EXPECT_EQ(value, 0xffff'ffff'f123'4567U);
}
+TEST(GuestAbi_riscv64_lp64, GuestArgumentFloat32) {
+ uint64_t value = 0;
+ auto& param = *reinterpret_cast<GuestAbi::GuestArgument<float, GuestAbi::kLp64>*>(&value);
+
+ value = 0x0000'0000'3f00'0000;
+ EXPECT_FLOAT_EQ(param, 0.5f);
+
+ param = 7.125f;
+ EXPECT_EQ(value, 0x0000'0000'40e4'0000U);
+}
+
+TEST(GuestAbi_riscv64_lp64, GuestArgumentFloat64) {
+ uint64_t value = 0;
+ auto& param = *reinterpret_cast<GuestAbi::GuestArgument<double, GuestAbi::kLp64>*>(&value);
+
+ value = 0x3fd5'c28f'5c28'f5c3;
+ EXPECT_DOUBLE_EQ(param, 0.34);
+
+ param = 0.125f;
+ EXPECT_EQ(value, 0x3fc0'0000'0000'0000U);
+}
+
+TEST(GuestAbi_riscv64_lp64d, GuestArgumentFloat32) {
+ uint64_t value = 0;
+ auto& param = *reinterpret_cast<GuestAbi::GuestArgument<float, GuestAbi::kLp64d>*>(&value);
+
+ value = 0xffff'ffff'3f00'0000;
+ EXPECT_FLOAT_EQ(param, 0.5f);
+
+ param = 7.125f;
+ EXPECT_EQ(value, 0xffff'ffff'40e4'0000U);
+}
+
+TEST(GuestAbi_riscv64_lp64d, GuestArgumentFloat64) {
+ uint64_t value = 0;
+ auto& param = *reinterpret_cast<GuestAbi::GuestArgument<double, GuestAbi::kLp64d>*>(&value);
+
+ value = 0x3fd5'c28f'5c28'f5c3;
+ EXPECT_DOUBLE_EQ(param, 0.34);
+
+ param = 0.125f;
+ EXPECT_EQ(value, 0x3fc0'0000'0000'0000U);
+}
+
} // namespace
} // namespace berberis
diff --git a/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h b/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h
index 3438d6e..8f46ada 100644
--- a/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h
+++ b/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h
@@ -20,6 +20,7 @@
#include <cstdint>
#include <type_traits>
+#include "berberis/base/bit_util.h"
#include "berberis/calling_conventions/calling_conventions_riscv64.h" // IWYU pragma: export.
#include "berberis/guest_abi/guest_type.h"
@@ -124,6 +125,99 @@
GuestArgument<UnderlyingType, kCallingConventionsVariant> value_;
};
+ template <typename FloatingPointType>
+ class alignas(sizeof(uint64_t))
+ GuestArgument<FloatingPointType,
+ kLp64,
+ std::enable_if_t<std::is_floating_point_v<FloatingPointType>>> {
+ public:
+ using Type = FloatingPointType;
+ GuestArgument(const FloatingPointType& value) : value_(Box(value)) {}
+ GuestArgument(FloatingPointType&& value) : value_(Box(value)) {}
+ GuestArgument() = default;
+ GuestArgument(const GuestArgument&) = default;
+ GuestArgument(GuestArgument&&) = default;
+ GuestArgument& operator=(const GuestArgument&) = default;
+ GuestArgument& operator=(GuestArgument&&) = default;
+ ~GuestArgument() = default;
+ operator FloatingPointType() const { return Unbox(value_); }
+
+ private:
+ // Floating-point arguments in integer registers do not require NaN boxing. They are stored in
+ // the lower bits of the 64-bit integer register with the high bits undefined. Bit casting and
+ // unsigned narrowing/widening conversions are sufficient.
+
+ static constexpr uint64_t Box(FloatingPointType value) {
+ if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+ return bit_cast<uint64_t>(value);
+ } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+ return static_cast<uint64_t>(bit_cast<uint32_t>(value));
+ } else {
+ FATAL("Unsupported floating-point argument width");
+ }
+ }
+
+ static constexpr FloatingPointType Unbox(uint64_t value) {
+ if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+ return bit_cast<FloatingPointType>(value);
+ } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+ return bit_cast<FloatingPointType>(static_cast<uint32_t>(value));
+ } else {
+ FATAL("Unsupported floating-point argument width");
+ }
+ }
+
+ uint64_t value_ = 0;
+ };
+
+ template <typename FloatingPointType>
+ class alignas(sizeof(uint64_t))
+ GuestArgument<FloatingPointType,
+ kLp64d,
+ std::enable_if_t<std::is_floating_point_v<FloatingPointType>>> {
+ public:
+ using Type = FloatingPointType;
+ GuestArgument(const FloatingPointType& value) : value_(Box(value)) {}
+ GuestArgument(FloatingPointType&& value) : value_(Box(value)) {}
+ GuestArgument() = default;
+ GuestArgument(const GuestArgument&) = default;
+ GuestArgument(GuestArgument&&) = default;
+ GuestArgument& operator=(const GuestArgument&) = default;
+ GuestArgument& operator=(GuestArgument&&) = default;
+ ~GuestArgument() = default;
+ operator FloatingPointType() const { return Unbox(value_); }
+
+ private:
+ // Floating-point arguments passed in floating-point registers require NaN boxing when they are
+ // narrower than 64 bits. The argument is stored in the lower bits of the 64-bit floating-point
+ // register with the high bits set to 1.
+
+ static constexpr uint64_t Box(FloatingPointType value) {
+ if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+ return bit_cast<uint64_t>(value);
+ } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+ return bit_cast<uint32_t>(value) | kNanBoxFloat32;
+ } else {
+ FATAL("Unsupported floating-point argument width");
+ }
+ }
+
+ static constexpr FloatingPointType Unbox(uint64_t value) {
+ if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+ return bit_cast<Type>(value);
+ } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+ // Integer narrowing removes the NaN box.
+ return bit_cast<Type>(static_cast<uint32_t>(value));
+ } else {
+ FATAL("Unsupported floating-point argument width");
+ }
+ }
+
+ static constexpr uint64_t kNanBoxFloat32 = 0xffff'ffff'0000'0000ULL;
+
+ uint64_t value_ = 0;
+ };
+
protected:
enum class ArgumentClass { kInteger, kFp, kLargeStruct };
@@ -180,13 +274,17 @@
constexpr static ArgumentClass kArgumentClass = ArgumentClass::kInteger;
constexpr static unsigned kSize = 4;
constexpr static unsigned kAlignment = 4;
- using GuestType = GuestType<float>;
- using HostType = float;
+ using GuestType = GuestArgument<float, kLp64>;
+ using HostType = GuestType;
};
template <>
- struct GuestArgumentInfo<float, kLp64d> : GuestArgumentInfo<float, kLp64> {
+ struct GuestArgumentInfo<float, kLp64d> {
constexpr static ArgumentClass kArgumentClass = ArgumentClass::kFp;
+ constexpr static unsigned kSize = 4;
+ constexpr static unsigned kAlignment = 4;
+ using GuestType = GuestArgument<float, kLp64d>;
+ using HostType = GuestType;
};
template <>
@@ -194,13 +292,17 @@
constexpr static ArgumentClass kArgumentClass = ArgumentClass::kInteger;
constexpr static unsigned kSize = 8;
constexpr static unsigned kAlignment = 8;
- using GuestType = GuestType<double>;
- using HostType = double;
+ using GuestType = GuestArgument<double, kLp64>;
+ using HostType = GuestType;
};
template <>
- struct GuestArgumentInfo<double, kLp64d> : GuestArgumentInfo<double, kLp64> {
+ struct GuestArgumentInfo<double, kLp64d> {
constexpr static ArgumentClass kArgumentClass = ArgumentClass::kFp;
+ constexpr static unsigned kSize = 8;
+ constexpr static unsigned kAlignment = 8;
+ using GuestType = GuestArgument<double, kLp64d>;
+ using HostType = GuestType;
};
// Structures larger than 16 bytes are passed by reference.