[backend][heavy_optimizer] Support mov for simd regs

Supports moves into and from simd registers in the automatic intrinsics.

Bug: 291126189
Test: mm and berberis_host_tests
Change-Id: I5052a41a58f02bb986757d0fa15598659a3288d6
Merged-In: I5052a41a58f02bb986757d0fa15598659a3288d6
diff --git a/backend/include/berberis/backend/common/machine_ir.h b/backend/include/berberis/backend/common/machine_ir.h
index ec19952..2683936 100644
--- a/backend/include/berberis/backend/common/machine_ir.h
+++ b/backend/include/berberis/backend/common/machine_ir.h
@@ -60,6 +60,8 @@
     return reg_ > kInvalidMachineVRegNumber && reg_ < kFirstVRegNumber;
   }
 
+  [[nodiscard]] constexpr bool IsInvalidReg() const { return reg_ == kInvalidMachineVRegNumber; }
+
   [[nodiscard]] constexpr bool IsVReg() const { return reg_ >= kFirstVRegNumber; }
 
   [[nodiscard]] constexpr uint32_t GetVRegIndex() const {
@@ -120,7 +122,7 @@
   int num_regs;
   const MachineReg regs[sizeof(reg_mask) * CHAR_BIT];
 
-  [[nodiscard]] int RegSize() const { return reg_size; }
+  [[nodiscard]] constexpr int RegSize() const { return reg_size; }
 
   [[nodiscard]] bool HasReg(MachineReg r) const { return reg_mask & (uint64_t{1} << r.reg()); }
 
diff --git a/backend/x86_64/lir_instructions.json b/backend/x86_64/lir_instructions.json
index 9a83174..2c0b6e7 100644
--- a/backend/x86_64/lir_instructions.json
+++ b/backend/x86_64/lir_instructions.json
@@ -111,6 +111,8 @@
         "MovqRegXReg",
         "MovqXRegMemInsns",
         "MovqXRegReg",
+        "MovsdXRegXReg",
+        "MovssXRegXReg",
         "MovsxbqRegMemInsns",
         "MovsxbqRegReg",
         "MovsxlqRegMemInsns",
@@ -188,6 +190,8 @@
         "Vfnmadd231ssXRegXRegXReg",
         "Vfnmsub231sdXRegXRegXReg",
         "Vfnmsub231ssXRegXRegXReg",
+        "VmovsdXRegXRegXReg",
+        "VmovssXRegXRegXReg",
         "XorlRegImm",
         "XorlRegReg",
         "XorpdXRegXReg",
diff --git a/heavy_optimizer/riscv64/frontend.h b/heavy_optimizer/riscv64/frontend.h
index 98e9726..691461e 100644
--- a/heavy_optimizer/riscv64/frontend.h
+++ b/heavy_optimizer/riscv64/frontend.h
@@ -332,7 +332,7 @@
     }
 
     if (TryInlineIntrinsicForHeavyOptimizer<kFunction>(
-            &builder_, UnwrapSimdReg(result), GetFlagsRegister(), UnwrapSimdReg(args)...)) {
+            &builder_, result, GetFlagsRegister(), args...)) {
       return result;
     }
 
@@ -362,20 +362,6 @@
                                      const MachineBasicBlock* old_bb,
                                      MachineBasicBlock* new_bb);
 
-  template <typename T>
-  static constexpr auto UnwrapSimdReg(T r) {
-    if constexpr (std::is_same_v<T, SimdReg>) {
-      return r.machine_reg();
-    } else {
-      return r;
-    }
-  }
-
-  template <typename T, typename U>
-  static constexpr auto UnwrapSimdReg(std::tuple<T, U> regs) {
-    return std::make_tuple(UnwrapSimdReg(std::get<0>(regs)), UnwrapSimdReg(std::get<1>(regs)));
-  }
-
   void StartRegion() {
     auto* region_entry_bb = builder_.ir()->NewBasicBlock();
     auto* cont_bb = builder_.ir()->NewBasicBlock();
diff --git a/heavy_optimizer/riscv64/inline_intrinsic.h b/heavy_optimizer/riscv64/inline_intrinsic.h
index 57759d7..c65d6d5 100644
--- a/heavy_optimizer/riscv64/inline_intrinsic.h
+++ b/heavy_optimizer/riscv64/inline_intrinsic.h
@@ -21,18 +21,24 @@
 #include <cstdint>
 #include <tuple>
 #include <type_traits>
+#include <utility>
 #include <variant>
 
 #include "berberis/assembler/x86_64.h"
+#include "berberis/backend/common/machine_ir.h"
 #include "berberis/backend/x86_64/machine_insn_intrinsics.h"
 #include "berberis/backend/x86_64/machine_ir.h"
 #include "berberis/backend/x86_64/machine_ir_builder.h"
 #include "berberis/base/dependent_false.h"
+#include "berberis/intrinsics/common_to_x86/intrinsics_bindings.h"
 #include "berberis/intrinsics/intrinsics.h"
+#include "berberis/intrinsics/intrinsics_args.h"
 #include "berberis/intrinsics/intrinsics_process_bindings.h"
 #include "berberis/intrinsics/macro_assembler.h"
 #include "berberis/runtime_primitives/platform.h"
 
+#include "simd_register.h"
+
 namespace berberis {
 
 template <auto kFunc>
@@ -70,13 +76,77 @@
   }
 };
 
-template <std::size_t size, typename DestType, typename SrcType>
-auto GenPseudoCopy(x86_64::MachineIRBuilder* builder, DestType dest, SrcType src)
-    -> decltype(std::declval<x86_64::MachineIRBuilder*>()->Gen<PseudoCopy>(
-        std::declval<DestType>(),
-        std::declval<SrcType>(),
-        std::declval<std::size_t>())) {
-  return builder->Gen<PseudoCopy>(dest, src, size);
+template <typename DestRegClass, typename SrcRegClass>
+void Mov(x86_64::MachineIRBuilder* builder, MachineReg dest, MachineReg src) {
+  using DestType = typename DestRegClass::Type;
+  using SrcType = typename SrcRegClass::Type;
+  constexpr const auto src_reg_class = SrcRegClass::template kRegClass<x86_64::MachineInsnX86_64>;
+  if constexpr (std::is_integral_v<DestType>) {
+    if constexpr (std::is_integral_v<SrcType>) {
+      builder->Gen<PseudoCopy>(dest, src, src_reg_class.RegSize());
+    } else if constexpr (SrcRegClass::kAsRegister == 'x') {
+      if constexpr (src_reg_class.RegSize() == 4) {
+        if (host_platform::kHasAVX) {
+          builder->Gen<x86_64::VmovdRegXReg>(dest, src);
+        } else {
+          builder->Gen<x86_64::MovdRegXReg>(dest, src);
+        }
+      } else {
+        static_assert(src_reg_class.RegSize() >= 8);
+        if (host_platform::kHasAVX) {
+          builder->Gen<x86_64::VmovqRegXReg>(dest, src);
+        } else {
+          builder->Gen<x86_64::MovqRegXReg>(dest, src);
+        }
+      }
+    } else {
+      static_assert(kDependentTypeFalse<std::tuple<DestRegClass, SrcRegClass>>);
+    }
+  } else if (DestRegClass::kAsRegister == 'x') {
+    if constexpr (src_reg_class.RegSize() == 4) {
+      if constexpr (std::is_integral_v<SrcType>) {
+        if (host_platform::kHasAVX) {
+          builder->Gen<x86_64::VmovdXRegReg>(dest, src);
+        } else {
+          builder->Gen<x86_64::MovdXRegReg>(dest, src);
+        }
+      } else if constexpr (SrcRegClass::kAsRegister == 'x') {
+        builder->Gen<PseudoCopy>(dest, src, 16);
+      } else {
+        static_assert(kDependentTypeFalse<std::tuple<DestRegClass, SrcRegClass>>);
+      }
+    } else {
+      static_assert(src_reg_class.RegSize() >= 8);
+      if constexpr (std::is_integral_v<SrcType>) {
+        if (host_platform::kHasAVX) {
+          builder->Gen<x86_64::VmovqXRegReg>(dest, src);
+        } else {
+          builder->Gen<x86_64::MovqXRegReg>(dest, src);
+        }
+      } else if constexpr (SrcRegClass::kAsRegister == 'x') {
+        builder->Gen<PseudoCopy>(dest, src, 16);
+      } else {
+        static_assert(kDependentTypeFalse<std::tuple<DestRegClass, SrcRegClass>>);
+      }
+    }
+  }
+}
+
+template <typename DestRegClass, typename SrcReg>
+void MovFromInput(x86_64::MachineIRBuilder* builder, MachineReg dest, SrcReg src) {
+  if constexpr (std::is_same_v<SrcReg, SimdReg>) {
+    Mov<DestRegClass, intrinsics::bindings::XmmReg>(builder, dest, src.machine_reg());
+  } else {
+    Mov<DestRegClass, intrinsics::bindings::GeneralReg64>(builder, dest, src);
+  }
+}
+template <typename SrcRegClass, typename DestReg>
+void MovToResult(x86_64::MachineIRBuilder* builder, DestReg dest, MachineReg src) {
+  if constexpr (std::is_same_v<DestReg, SimdReg>) {
+    Mov<intrinsics::bindings::XmmReg, SrcRegClass>(builder, dest.machine_reg(), src);
+  } else {
+    Mov<intrinsics::bindings::GeneralReg64, SrcRegClass>(builder, dest, src);
+  }
 }
 
 template <auto kFunction, typename ResType, typename FlagRegister, typename... ArgType>
@@ -130,6 +200,7 @@
                                                   ArgType... args)
       : builder_(builder),
         result_{result},
+        xmm_result_reg_{},
         flag_register_{flag_register},
         input_args_(std::tuple{args...}),
         success_(
@@ -191,10 +262,117 @@
     using MachineInsn =
         typename AsmCallInfo::template MachineInsn<berberis::x86_64::MachineInsn, MachineOpcode>;
     std::apply(MachineInsn::kGenFunc,
-               std::tuple_cat(
-                   std::tuple<x86_64::MachineIRBuilder&>{*builder_},
-                   AsmCallInfo::template MakeTuplefromBindings<
-                       TryBindingBasedInlineIntrinsicForHeavyOptimizer&>(*this, asm_call_info)));
+               std::tuple_cat(std::tuple<x86_64::MachineIRBuilder&>{*builder_},
+                              UnwrapSimdReg(AsmCallInfo::template MakeTuplefromBindings<
+                                            TryBindingBasedInlineIntrinsicForHeavyOptimizer&>(
+                                  *this, asm_call_info))));
+    ProcessBindingsResults<AsmCallInfo>(type_wrapper<typename AsmCallInfo::Bindings>());
+    return true;
+  }
+
+  template <typename ArgBinding, typename AsmCallInfo>
+  auto /*MakeTuplefromBindingsClient*/ operator()(ArgTraits<ArgBinding>, AsmCallInfo) {
+    static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+    if constexpr (arg_info.arg_type == ArgInfo::IMM_ARG) {
+      auto imm = std::get<arg_info.from>(input_args_);
+      return std::tuple{imm};
+    } else {
+      return ProcessArgInput<ArgBinding, AsmCallInfo>();
+    }
+  }
+
+  template <typename ArgBinding, typename AsmCallInfo>
+  auto ProcessArgInput() {
+    static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+    using RegisterClass = typename ArgTraits<ArgBinding>::RegisterClass;
+    using Usage = typename ArgTraits<ArgBinding>::Usage;
+    static constexpr const auto kNumOut = std::tuple_size_v<typename AsmCallInfo::OutputArguments>;
+
+    if constexpr (arg_info.arg_type == ArgInfo::IN_ARG) {
+      static_assert(std::is_same_v<Usage, intrinsics::bindings::Use>);
+      static_assert(!RegisterClass::kIsImplicitReg);
+      if constexpr (RegisterClass::kAsRegister == 'x' &&
+                    std::is_same_v<std::tuple_element_t<arg_info.from, std::tuple<ArgType...>>,
+                                   MachineReg>) {
+        auto xmm_reg = AllocVReg();
+        MovFromInput<RegisterClass>(builder_, xmm_reg, std::get<arg_info.from>(input_args_));
+        return std::tuple{xmm_reg};
+      } else {
+        return std::tuple{std::get<arg_info.from>(input_args_)};
+      }
+    } else if constexpr (arg_info.arg_type == ArgInfo::IN_OUT_ARG) {
+      static_assert(!std::is_same_v<ResType, std::monostate>);
+      static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
+      static_assert(!RegisterClass::kIsImplicitReg);
+      if constexpr (RegisterClass::kAsRegister == 'x') {
+        if constexpr (kNumOut > 1) {
+          static_assert(kDependentTypeFalse<ArgTraits<ArgBinding>>);
+        } else {
+          CHECK(xmm_result_reg_.IsInvalidReg());
+          xmm_result_reg_ = AllocVReg();
+          MovFromInput<RegisterClass>(
+              builder_, xmm_result_reg_, std::get<arg_info.from>(input_args_));
+          return std::tuple{xmm_result_reg_};
+        }
+      } else if constexpr (kNumOut > 1) {
+        auto res = std::get<arg_info.to>(result_);
+        MovFromInput<RegisterClass>(builder_, res, std::get<arg_info.from>(input_args_));
+        return std::tuple{res};
+      } else {
+        MovFromInput<RegisterClass>(builder_, result_, std::get<arg_info.from>(input_args_));
+        return std::tuple{result_};
+      }
+    } else if constexpr (arg_info.arg_type == ArgInfo::IN_TMP_ARG) {
+      if constexpr (RegisterClass::kIsImplicitReg) {
+        auto implicit_reg = AllocVReg();
+        MovFromInput<RegisterClass>(builder_, implicit_reg, std::get<arg_info.from>(input_args_));
+        return std::tuple{implicit_reg};
+      } else {
+        static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
+        return std::tuple{std::get<arg_info.from>(input_args_)};
+      }
+    } else if constexpr (arg_info.arg_type == ArgInfo::OUT_ARG) {
+      static_assert(!std::is_same_v<ResType, std::monostate>);
+      static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
+                    std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
+      static_assert(!RegisterClass::kIsImplicitReg);
+      if constexpr (RegisterClass::kAsRegister == 'x') {
+        CHECK(xmm_result_reg_.IsInvalidReg());
+        xmm_result_reg_ = AllocVReg();
+        return std::tuple{xmm_result_reg_};
+      } else if constexpr (kNumOut > 1) {
+        return std::tuple{std::get<arg_info.to>(result_)};
+      } else {
+        return std::tuple{result_};
+      }
+    } else if constexpr (arg_info.arg_type == ArgInfo::TMP_ARG) {
+      static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
+                    std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
+      if constexpr (RegisterClass::kAsRegister == 'm') {
+        static_assert(kDependentTypeFalse<RegisterClass>);
+      } else if constexpr (RegisterClass::kIsImplicitReg) {
+        if constexpr (RegisterClass::kAsRegister == 0) {
+          return std::tuple{flag_register_};
+        } else {
+          return std::tuple{};
+        }
+      } else {
+        auto reg = AllocVReg();
+        return std::tuple{reg};
+      }
+    } else {
+      static_assert(berberis::kDependentValueFalse<arg_info.arg_type>);
+    }
+  }
+
+  template <typename T>
+  struct type_wrapper {
+    using type = T;
+  };
+
+  template <typename AsmCallInfo, typename... ArgBinding>
+  void ProcessBindingsResults(type_wrapper<std::tuple<ArgBinding...>>) {
+    (ProcessBindingResult<ArgBinding, AsmCallInfo>(), ...);
     if constexpr (std::tuple_size_v<typename AsmCallInfo::OutputArguments> == 0) {
       // No return value. Do nothing.
     } else if constexpr (std::tuple_size_v<typename AsmCallInfo::OutputArguments> == 1) {
@@ -220,98 +398,46 @@
     } else {
       static_assert(kDependentTypeFalse<typename AsmCallInfo::OutputArguments>);
     }
-    return true;
   }
 
   template <typename ArgBinding, typename AsmCallInfo>
-  auto /*MakeTuplefromBindingsClient*/ operator()(ArgTraits<ArgBinding>, AsmCallInfo) {
-    static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
-    if constexpr (arg_info.arg_type == ArgInfo::IMM_ARG) {
-      auto imm = std::get<arg_info.from>(input_args_);
-      return std::tuple{imm};
-    } else {
-      return ProcessArgInput<ArgBinding, AsmCallInfo>();
-    }
-  }
-
-  template <typename ArgBinding, typename AsmCallInfo>
-  auto ProcessArgInput() {
-    static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+  void ProcessBindingResult() {
     using RegisterClass = typename ArgTraits<ArgBinding>::RegisterClass;
-    using Usage = typename ArgTraits<ArgBinding>::Usage;
-    static constexpr const auto kNumOut = std::tuple_size_v<typename AsmCallInfo::OutputArguments>;
-
-    if constexpr (arg_info.arg_type == ArgInfo::IN_ARG) {
-      static_assert(std::is_same_v<Usage, intrinsics::bindings::Use>);
-      static_assert(!RegisterClass::kIsImplicitReg);
-      return std::tuple{std::get<arg_info.from>(input_args_)};
-    } else if constexpr (arg_info.arg_type == ArgInfo::IN_OUT_ARG) {
-      static_assert(!std::is_same_v<ResType, std::monostate>);
-      static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
-      static_assert(!RegisterClass::kIsImplicitReg);
-      if constexpr (RegisterClass::kAsRegister == 'x') {
-        if constexpr (kNumOut > 1) {
-          auto res = std::get<arg_info.to>(result_);
-          GenPseudoCopy<16>(builder_, res, std::get<arg_info.from>(input_args_));
-          return std::tuple{res};
-        } else {
-          GenPseudoCopy<16>(builder_, result_, std::get<arg_info.from>(input_args_));
-          return std::tuple{result_};
-        }
-      } else if constexpr (kNumOut > 1) {
-        auto res = std::get<arg_info.to>(result_);
-        GenPseudoCopy<sizeof(typename RegisterClass::Type)>(
-            builder_, res, std::get<arg_info.from>(input_args_));
-        return std::tuple{res};
-      } else {
-        GenPseudoCopy<sizeof(typename RegisterClass::Type)>(
-            builder_, result_, std::get<arg_info.from>(input_args_));
-        return std::tuple{result_};
-      }
-    } else if constexpr (arg_info.arg_type == ArgInfo::IN_TMP_ARG) {
-      if constexpr (RegisterClass::kIsImplicitReg) {
-        auto implicit_reg = builder_->ir()->AllocVReg();
-        GenPseudoCopy<sizeof(typename RegisterClass::Type)>(
-            builder_, implicit_reg, std::get<arg_info.from>(input_args_));
-        return std::tuple{implicit_reg};
-      } else {
-        static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
-        static_assert(!RegisterClass::kIsImplicitReg);
-        return std::tuple{std::get<arg_info.from>(input_args_)};
-      }
-    } else if constexpr (arg_info.arg_type == ArgInfo::OUT_ARG) {
-      static_assert(!std::is_same_v<ResType, std::monostate>);
-      static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
-                    std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
-      static_assert(!RegisterClass::kIsImplicitReg);
-      if constexpr (kNumOut > 1) {
-        return std::tuple{std::get<arg_info.to>(result_)};
-      } else {
-        return std::tuple{result_};
-      }
-    } else if constexpr (arg_info.arg_type == ArgInfo::TMP_ARG) {
-      static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
-                    std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
-      if constexpr (RegisterClass::kAsRegister == 'm') {
-        static_assert(kDependentTypeFalse<RegisterClass>);
-      } else if constexpr (RegisterClass::kIsImplicitReg) {
-        if constexpr (RegisterClass::kAsRegister == 0) {
-          return std::tuple{flag_register_};
-        } else {
-          return std::tuple{};
-        }
-      } else {
-        auto reg = builder_->ir()->AllocVReg();
-        return std::tuple{reg};
-      }
-    } else {
-      static_assert(berberis::kDependentValueFalse<arg_info.arg_type>);
+    static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+    if constexpr ((arg_info.arg_type == ArgInfo::IN_OUT_ARG ||
+                   arg_info.arg_type == ArgInfo::OUT_ARG) &&
+                  RegisterClass::kAsRegister == 'x') {
+      CHECK(!xmm_result_reg_.IsInvalidReg());
+      MovToResult<RegisterClass>(builder_, result_, xmm_result_reg_);
     }
   }
 
+  MachineReg AllocVReg() { return builder_->ir()->AllocVReg(); }
+
+  template <typename T>
+  static constexpr auto UnwrapSimdReg(T r) {
+    if constexpr (std::is_same_v<T, SimdReg>) {
+      return r.machine_reg();
+    } else {
+      return r;
+    }
+  }
+
+  template <typename... T>
+  static constexpr auto UnwrapSimdReg(std::tuple<T...> regs) {
+    constexpr const auto num_args = std::tuple_size<std::tuple<T...>>::value;
+    return UnwrapSimdReg(std::make_index_sequence<num_args>(), regs);
+  }
+
+  template <typename... T, auto... I>
+  static constexpr auto UnwrapSimdReg(std::index_sequence<I...>, std::tuple<T...> regs) {
+    return std::make_tuple(UnwrapSimdReg(std::get<I>(regs))...);
+  }
+
  private:
   x86_64::MachineIRBuilder* builder_;
   ResType result_;
+  MachineReg xmm_result_reg_;
   FlagRegister flag_register_;
   std::tuple<ArgType...> input_args_;
   bool success_;