guest_abi: Fix extension for float arguments

Uses a custom GuestArgument for RISC-V to ensure proper NaN boxing and
unboxing of floating-point arguments.

Test: berberis_run_host_tests
Bug: 300490784
Change-Id: Ie7d5bbbe5b99da08dad0215ea5c3f6a68f4309f8
diff --git a/guest_abi/riscv64/guest_abi_test.cc b/guest_abi/riscv64/guest_abi_test.cc
index e134d72..22686ef 100644
--- a/guest_abi/riscv64/guest_abi_test.cc
+++ b/guest_abi/riscv64/guest_abi_test.cc
@@ -95,6 +95,50 @@
   EXPECT_EQ(value, 0xffff'ffff'f123'4567U);
 }
 
+TEST(GuestAbi_riscv64_lp64, GuestArgumentFloat32) {
+  uint64_t value = 0;
+  auto& param = *reinterpret_cast<GuestAbi::GuestArgument<float, GuestAbi::kLp64>*>(&value);
+
+  value = 0x0000'0000'3f00'0000;
+  EXPECT_FLOAT_EQ(param, 0.5f);
+
+  param = 7.125f;
+  EXPECT_EQ(value, 0x0000'0000'40e4'0000U);
+}
+
+TEST(GuestAbi_riscv64_lp64, GuestArgumentFloat64) {
+  uint64_t value = 0;
+  auto& param = *reinterpret_cast<GuestAbi::GuestArgument<double, GuestAbi::kLp64>*>(&value);
+
+  value = 0x3fd5'c28f'5c28'f5c3;
+  EXPECT_DOUBLE_EQ(param, 0.34);
+
+  param = 0.125f;
+  EXPECT_EQ(value, 0x3fc0'0000'0000'0000U);
+}
+
+TEST(GuestAbi_riscv64_lp64d, GuestArgumentFloat32) {
+  uint64_t value = 0;
+  auto& param = *reinterpret_cast<GuestAbi::GuestArgument<float, GuestAbi::kLp64d>*>(&value);
+
+  value = 0xffff'ffff'3f00'0000;
+  EXPECT_FLOAT_EQ(param, 0.5f);
+
+  param = 7.125f;
+  EXPECT_EQ(value, 0xffff'ffff'40e4'0000U);
+}
+
+TEST(GuestAbi_riscv64_lp64d, GuestArgumentFloat64) {
+  uint64_t value = 0;
+  auto& param = *reinterpret_cast<GuestAbi::GuestArgument<double, GuestAbi::kLp64d>*>(&value);
+
+  value = 0x3fd5'c28f'5c28'f5c3;
+  EXPECT_DOUBLE_EQ(param, 0.34);
+
+  param = 0.125f;
+  EXPECT_EQ(value, 0x3fc0'0000'0000'0000U);
+}
+
 }  // namespace
 
 }  // namespace berberis
diff --git a/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h b/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h
index 3438d6e..8f46ada 100644
--- a/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h
+++ b/guest_abi/riscv64/include/berberis/guest_abi/guest_abi_arch.h
@@ -20,6 +20,7 @@
 #include <cstdint>
 #include <type_traits>
 
+#include "berberis/base/bit_util.h"
 #include "berberis/calling_conventions/calling_conventions_riscv64.h"  // IWYU pragma: export.
 #include "berberis/guest_abi/guest_type.h"
 
@@ -124,6 +125,99 @@
     GuestArgument<UnderlyingType, kCallingConventionsVariant> value_;
   };
 
+  template <typename FloatingPointType>
+  class alignas(sizeof(uint64_t))
+      GuestArgument<FloatingPointType,
+                    kLp64,
+                    std::enable_if_t<std::is_floating_point_v<FloatingPointType>>> {
+   public:
+    using Type = FloatingPointType;
+    GuestArgument(const FloatingPointType& value) : value_(Box(value)) {}
+    GuestArgument(FloatingPointType&& value) : value_(Box(value)) {}
+    GuestArgument() = default;
+    GuestArgument(const GuestArgument&) = default;
+    GuestArgument(GuestArgument&&) = default;
+    GuestArgument& operator=(const GuestArgument&) = default;
+    GuestArgument& operator=(GuestArgument&&) = default;
+    ~GuestArgument() = default;
+    operator FloatingPointType() const { return Unbox(value_); }
+
+   private:
+    // Floating-point arguments in integer registers do not require NaN boxing.  They are stored in
+    // the lower bits of the 64-bit integer register with the high bits undefined.  Bit casting and
+    // unsigned narrowing/widening conversions are sufficient.
+
+    static constexpr uint64_t Box(FloatingPointType value) {
+      if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+        return bit_cast<uint64_t>(value);
+      } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+        return static_cast<uint64_t>(bit_cast<uint32_t>(value));
+      } else {
+        FATAL("Unsupported floating-point argument width");
+      }
+    }
+
+    static constexpr FloatingPointType Unbox(uint64_t value) {
+      if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+        return bit_cast<FloatingPointType>(value);
+      } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+        return bit_cast<FloatingPointType>(static_cast<uint32_t>(value));
+      } else {
+        FATAL("Unsupported floating-point argument width");
+      }
+    }
+
+    uint64_t value_ = 0;
+  };
+
+  template <typename FloatingPointType>
+  class alignas(sizeof(uint64_t))
+      GuestArgument<FloatingPointType,
+                    kLp64d,
+                    std::enable_if_t<std::is_floating_point_v<FloatingPointType>>> {
+   public:
+    using Type = FloatingPointType;
+    GuestArgument(const FloatingPointType& value) : value_(Box(value)) {}
+    GuestArgument(FloatingPointType&& value) : value_(Box(value)) {}
+    GuestArgument() = default;
+    GuestArgument(const GuestArgument&) = default;
+    GuestArgument(GuestArgument&&) = default;
+    GuestArgument& operator=(const GuestArgument&) = default;
+    GuestArgument& operator=(GuestArgument&&) = default;
+    ~GuestArgument() = default;
+    operator FloatingPointType() const { return Unbox(value_); }
+
+   private:
+    // Floating-point arguments passed in floating-point registers require NaN boxing when they are
+    // narrower than 64 bits.  The argument is stored in the lower bits of the 64-bit floating-point
+    // register with the high bits set to 1.
+
+    static constexpr uint64_t Box(FloatingPointType value) {
+      if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+        return bit_cast<uint64_t>(value);
+      } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+        return bit_cast<uint32_t>(value) | kNanBoxFloat32;
+      } else {
+        FATAL("Unsupported floating-point argument width");
+      }
+    }
+
+    static constexpr FloatingPointType Unbox(uint64_t value) {
+      if constexpr (sizeof(FloatingPointType) == sizeof(uint64_t)) {
+        return bit_cast<Type>(value);
+      } else if constexpr (sizeof(FloatingPointType) == sizeof(uint32_t)) {
+        // Integer narrowing removes the NaN box.
+        return bit_cast<Type>(static_cast<uint32_t>(value));
+      } else {
+        FATAL("Unsupported floating-point argument width");
+      }
+    }
+
+    static constexpr uint64_t kNanBoxFloat32 = 0xffff'ffff'0000'0000ULL;
+
+    uint64_t value_ = 0;
+  };
+
  protected:
   enum class ArgumentClass { kInteger, kFp, kLargeStruct };
 
@@ -180,13 +274,17 @@
     constexpr static ArgumentClass kArgumentClass = ArgumentClass::kInteger;
     constexpr static unsigned kSize = 4;
     constexpr static unsigned kAlignment = 4;
-    using GuestType = GuestType<float>;
-    using HostType = float;
+    using GuestType = GuestArgument<float, kLp64>;
+    using HostType = GuestType;
   };
 
   template <>
-  struct GuestArgumentInfo<float, kLp64d> : GuestArgumentInfo<float, kLp64> {
+  struct GuestArgumentInfo<float, kLp64d> {
     constexpr static ArgumentClass kArgumentClass = ArgumentClass::kFp;
+    constexpr static unsigned kSize = 4;
+    constexpr static unsigned kAlignment = 4;
+    using GuestType = GuestArgument<float, kLp64d>;
+    using HostType = GuestType;
   };
 
   template <>
@@ -194,13 +292,17 @@
     constexpr static ArgumentClass kArgumentClass = ArgumentClass::kInteger;
     constexpr static unsigned kSize = 8;
     constexpr static unsigned kAlignment = 8;
-    using GuestType = GuestType<double>;
-    using HostType = double;
+    using GuestType = GuestArgument<double, kLp64>;
+    using HostType = GuestType;
   };
 
   template <>
-  struct GuestArgumentInfo<double, kLp64d> : GuestArgumentInfo<double, kLp64> {
+  struct GuestArgumentInfo<double, kLp64d> {
     constexpr static ArgumentClass kArgumentClass = ArgumentClass::kFp;
+    constexpr static unsigned kSize = 8;
+    constexpr static unsigned kAlignment = 8;
+    using GuestType = GuestArgument<double, kLp64d>;
+    using HostType = GuestType;
   };
 
   // Structures larger than 16 bytes are passed by reference.