[backend][heavy_optimizer] Support mov for simd regs
Supports moves into and from simd registers in the automatic intrinsics.
Bug: 291126189
Test: mm and berberis_host_tests
Change-Id: I5052a41a58f02bb986757d0fa15598659a3288d6
Merged-In: I5052a41a58f02bb986757d0fa15598659a3288d6
diff --git a/backend/include/berberis/backend/common/machine_ir.h b/backend/include/berberis/backend/common/machine_ir.h
index ec19952..2683936 100644
--- a/backend/include/berberis/backend/common/machine_ir.h
+++ b/backend/include/berberis/backend/common/machine_ir.h
@@ -60,6 +60,8 @@
return reg_ > kInvalidMachineVRegNumber && reg_ < kFirstVRegNumber;
}
+ [[nodiscard]] constexpr bool IsInvalidReg() const { return reg_ == kInvalidMachineVRegNumber; }
+
[[nodiscard]] constexpr bool IsVReg() const { return reg_ >= kFirstVRegNumber; }
[[nodiscard]] constexpr uint32_t GetVRegIndex() const {
@@ -120,7 +122,7 @@
int num_regs;
const MachineReg regs[sizeof(reg_mask) * CHAR_BIT];
- [[nodiscard]] int RegSize() const { return reg_size; }
+ [[nodiscard]] constexpr int RegSize() const { return reg_size; }
[[nodiscard]] bool HasReg(MachineReg r) const { return reg_mask & (uint64_t{1} << r.reg()); }
diff --git a/backend/x86_64/lir_instructions.json b/backend/x86_64/lir_instructions.json
index 9a83174..2c0b6e7 100644
--- a/backend/x86_64/lir_instructions.json
+++ b/backend/x86_64/lir_instructions.json
@@ -111,6 +111,8 @@
"MovqRegXReg",
"MovqXRegMemInsns",
"MovqXRegReg",
+ "MovsdXRegXReg",
+ "MovssXRegXReg",
"MovsxbqRegMemInsns",
"MovsxbqRegReg",
"MovsxlqRegMemInsns",
@@ -188,6 +190,8 @@
"Vfnmadd231ssXRegXRegXReg",
"Vfnmsub231sdXRegXRegXReg",
"Vfnmsub231ssXRegXRegXReg",
+ "VmovsdXRegXRegXReg",
+ "VmovssXRegXRegXReg",
"XorlRegImm",
"XorlRegReg",
"XorpdXRegXReg",
diff --git a/heavy_optimizer/riscv64/frontend.h b/heavy_optimizer/riscv64/frontend.h
index 98e9726..691461e 100644
--- a/heavy_optimizer/riscv64/frontend.h
+++ b/heavy_optimizer/riscv64/frontend.h
@@ -332,7 +332,7 @@
}
if (TryInlineIntrinsicForHeavyOptimizer<kFunction>(
- &builder_, UnwrapSimdReg(result), GetFlagsRegister(), UnwrapSimdReg(args)...)) {
+ &builder_, result, GetFlagsRegister(), args...)) {
return result;
}
@@ -362,20 +362,6 @@
const MachineBasicBlock* old_bb,
MachineBasicBlock* new_bb);
- template <typename T>
- static constexpr auto UnwrapSimdReg(T r) {
- if constexpr (std::is_same_v<T, SimdReg>) {
- return r.machine_reg();
- } else {
- return r;
- }
- }
-
- template <typename T, typename U>
- static constexpr auto UnwrapSimdReg(std::tuple<T, U> regs) {
- return std::make_tuple(UnwrapSimdReg(std::get<0>(regs)), UnwrapSimdReg(std::get<1>(regs)));
- }
-
void StartRegion() {
auto* region_entry_bb = builder_.ir()->NewBasicBlock();
auto* cont_bb = builder_.ir()->NewBasicBlock();
diff --git a/heavy_optimizer/riscv64/inline_intrinsic.h b/heavy_optimizer/riscv64/inline_intrinsic.h
index 57759d7..c65d6d5 100644
--- a/heavy_optimizer/riscv64/inline_intrinsic.h
+++ b/heavy_optimizer/riscv64/inline_intrinsic.h
@@ -21,18 +21,24 @@
#include <cstdint>
#include <tuple>
#include <type_traits>
+#include <utility>
#include <variant>
#include "berberis/assembler/x86_64.h"
+#include "berberis/backend/common/machine_ir.h"
#include "berberis/backend/x86_64/machine_insn_intrinsics.h"
#include "berberis/backend/x86_64/machine_ir.h"
#include "berberis/backend/x86_64/machine_ir_builder.h"
#include "berberis/base/dependent_false.h"
+#include "berberis/intrinsics/common_to_x86/intrinsics_bindings.h"
#include "berberis/intrinsics/intrinsics.h"
+#include "berberis/intrinsics/intrinsics_args.h"
#include "berberis/intrinsics/intrinsics_process_bindings.h"
#include "berberis/intrinsics/macro_assembler.h"
#include "berberis/runtime_primitives/platform.h"
+#include "simd_register.h"
+
namespace berberis {
template <auto kFunc>
@@ -70,13 +76,77 @@
}
};
-template <std::size_t size, typename DestType, typename SrcType>
-auto GenPseudoCopy(x86_64::MachineIRBuilder* builder, DestType dest, SrcType src)
- -> decltype(std::declval<x86_64::MachineIRBuilder*>()->Gen<PseudoCopy>(
- std::declval<DestType>(),
- std::declval<SrcType>(),
- std::declval<std::size_t>())) {
- return builder->Gen<PseudoCopy>(dest, src, size);
+template <typename DestRegClass, typename SrcRegClass>
+void Mov(x86_64::MachineIRBuilder* builder, MachineReg dest, MachineReg src) {
+ using DestType = typename DestRegClass::Type;
+ using SrcType = typename SrcRegClass::Type;
+ constexpr const auto src_reg_class = SrcRegClass::template kRegClass<x86_64::MachineInsnX86_64>;
+ if constexpr (std::is_integral_v<DestType>) {
+ if constexpr (std::is_integral_v<SrcType>) {
+ builder->Gen<PseudoCopy>(dest, src, src_reg_class.RegSize());
+ } else if constexpr (SrcRegClass::kAsRegister == 'x') {
+ if constexpr (src_reg_class.RegSize() == 4) {
+ if (host_platform::kHasAVX) {
+ builder->Gen<x86_64::VmovdRegXReg>(dest, src);
+ } else {
+ builder->Gen<x86_64::MovdRegXReg>(dest, src);
+ }
+ } else {
+ static_assert(src_reg_class.RegSize() >= 8);
+ if (host_platform::kHasAVX) {
+ builder->Gen<x86_64::VmovqRegXReg>(dest, src);
+ } else {
+ builder->Gen<x86_64::MovqRegXReg>(dest, src);
+ }
+ }
+ } else {
+ static_assert(kDependentTypeFalse<std::tuple<DestRegClass, SrcRegClass>>);
+ }
+ } else if (DestRegClass::kAsRegister == 'x') {
+ if constexpr (src_reg_class.RegSize() == 4) {
+ if constexpr (std::is_integral_v<SrcType>) {
+ if (host_platform::kHasAVX) {
+ builder->Gen<x86_64::VmovdXRegReg>(dest, src);
+ } else {
+ builder->Gen<x86_64::MovdXRegReg>(dest, src);
+ }
+ } else if constexpr (SrcRegClass::kAsRegister == 'x') {
+ builder->Gen<PseudoCopy>(dest, src, 16);
+ } else {
+ static_assert(kDependentTypeFalse<std::tuple<DestRegClass, SrcRegClass>>);
+ }
+ } else {
+ static_assert(src_reg_class.RegSize() >= 8);
+ if constexpr (std::is_integral_v<SrcType>) {
+ if (host_platform::kHasAVX) {
+ builder->Gen<x86_64::VmovqXRegReg>(dest, src);
+ } else {
+ builder->Gen<x86_64::MovqXRegReg>(dest, src);
+ }
+ } else if constexpr (SrcRegClass::kAsRegister == 'x') {
+ builder->Gen<PseudoCopy>(dest, src, 16);
+ } else {
+ static_assert(kDependentTypeFalse<std::tuple<DestRegClass, SrcRegClass>>);
+ }
+ }
+ }
+}
+
+template <typename DestRegClass, typename SrcReg>
+void MovFromInput(x86_64::MachineIRBuilder* builder, MachineReg dest, SrcReg src) {
+ if constexpr (std::is_same_v<SrcReg, SimdReg>) {
+ Mov<DestRegClass, intrinsics::bindings::XmmReg>(builder, dest, src.machine_reg());
+ } else {
+ Mov<DestRegClass, intrinsics::bindings::GeneralReg64>(builder, dest, src);
+ }
+}
+template <typename SrcRegClass, typename DestReg>
+void MovToResult(x86_64::MachineIRBuilder* builder, DestReg dest, MachineReg src) {
+ if constexpr (std::is_same_v<DestReg, SimdReg>) {
+ Mov<intrinsics::bindings::XmmReg, SrcRegClass>(builder, dest.machine_reg(), src);
+ } else {
+ Mov<intrinsics::bindings::GeneralReg64, SrcRegClass>(builder, dest, src);
+ }
}
template <auto kFunction, typename ResType, typename FlagRegister, typename... ArgType>
@@ -130,6 +200,7 @@
ArgType... args)
: builder_(builder),
result_{result},
+ xmm_result_reg_{},
flag_register_{flag_register},
input_args_(std::tuple{args...}),
success_(
@@ -191,10 +262,117 @@
using MachineInsn =
typename AsmCallInfo::template MachineInsn<berberis::x86_64::MachineInsn, MachineOpcode>;
std::apply(MachineInsn::kGenFunc,
- std::tuple_cat(
- std::tuple<x86_64::MachineIRBuilder&>{*builder_},
- AsmCallInfo::template MakeTuplefromBindings<
- TryBindingBasedInlineIntrinsicForHeavyOptimizer&>(*this, asm_call_info)));
+ std::tuple_cat(std::tuple<x86_64::MachineIRBuilder&>{*builder_},
+ UnwrapSimdReg(AsmCallInfo::template MakeTuplefromBindings<
+ TryBindingBasedInlineIntrinsicForHeavyOptimizer&>(
+ *this, asm_call_info))));
+ ProcessBindingsResults<AsmCallInfo>(type_wrapper<typename AsmCallInfo::Bindings>());
+ return true;
+ }
+
+ template <typename ArgBinding, typename AsmCallInfo>
+ auto /*MakeTuplefromBindingsClient*/ operator()(ArgTraits<ArgBinding>, AsmCallInfo) {
+ static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+ if constexpr (arg_info.arg_type == ArgInfo::IMM_ARG) {
+ auto imm = std::get<arg_info.from>(input_args_);
+ return std::tuple{imm};
+ } else {
+ return ProcessArgInput<ArgBinding, AsmCallInfo>();
+ }
+ }
+
+ template <typename ArgBinding, typename AsmCallInfo>
+ auto ProcessArgInput() {
+ static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+ using RegisterClass = typename ArgTraits<ArgBinding>::RegisterClass;
+ using Usage = typename ArgTraits<ArgBinding>::Usage;
+ static constexpr const auto kNumOut = std::tuple_size_v<typename AsmCallInfo::OutputArguments>;
+
+ if constexpr (arg_info.arg_type == ArgInfo::IN_ARG) {
+ static_assert(std::is_same_v<Usage, intrinsics::bindings::Use>);
+ static_assert(!RegisterClass::kIsImplicitReg);
+ if constexpr (RegisterClass::kAsRegister == 'x' &&
+ std::is_same_v<std::tuple_element_t<arg_info.from, std::tuple<ArgType...>>,
+ MachineReg>) {
+ auto xmm_reg = AllocVReg();
+ MovFromInput<RegisterClass>(builder_, xmm_reg, std::get<arg_info.from>(input_args_));
+ return std::tuple{xmm_reg};
+ } else {
+ return std::tuple{std::get<arg_info.from>(input_args_)};
+ }
+ } else if constexpr (arg_info.arg_type == ArgInfo::IN_OUT_ARG) {
+ static_assert(!std::is_same_v<ResType, std::monostate>);
+ static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
+ static_assert(!RegisterClass::kIsImplicitReg);
+ if constexpr (RegisterClass::kAsRegister == 'x') {
+ if constexpr (kNumOut > 1) {
+ static_assert(kDependentTypeFalse<ArgTraits<ArgBinding>>);
+ } else {
+ CHECK(xmm_result_reg_.IsInvalidReg());
+ xmm_result_reg_ = AllocVReg();
+ MovFromInput<RegisterClass>(
+ builder_, xmm_result_reg_, std::get<arg_info.from>(input_args_));
+ return std::tuple{xmm_result_reg_};
+ }
+ } else if constexpr (kNumOut > 1) {
+ auto res = std::get<arg_info.to>(result_);
+ MovFromInput<RegisterClass>(builder_, res, std::get<arg_info.from>(input_args_));
+ return std::tuple{res};
+ } else {
+ MovFromInput<RegisterClass>(builder_, result_, std::get<arg_info.from>(input_args_));
+ return std::tuple{result_};
+ }
+ } else if constexpr (arg_info.arg_type == ArgInfo::IN_TMP_ARG) {
+ if constexpr (RegisterClass::kIsImplicitReg) {
+ auto implicit_reg = AllocVReg();
+ MovFromInput<RegisterClass>(builder_, implicit_reg, std::get<arg_info.from>(input_args_));
+ return std::tuple{implicit_reg};
+ } else {
+ static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
+ return std::tuple{std::get<arg_info.from>(input_args_)};
+ }
+ } else if constexpr (arg_info.arg_type == ArgInfo::OUT_ARG) {
+ static_assert(!std::is_same_v<ResType, std::monostate>);
+ static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
+ std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
+ static_assert(!RegisterClass::kIsImplicitReg);
+ if constexpr (RegisterClass::kAsRegister == 'x') {
+ CHECK(xmm_result_reg_.IsInvalidReg());
+ xmm_result_reg_ = AllocVReg();
+ return std::tuple{xmm_result_reg_};
+ } else if constexpr (kNumOut > 1) {
+ return std::tuple{std::get<arg_info.to>(result_)};
+ } else {
+ return std::tuple{result_};
+ }
+ } else if constexpr (arg_info.arg_type == ArgInfo::TMP_ARG) {
+ static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
+ std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
+ if constexpr (RegisterClass::kAsRegister == 'm') {
+ static_assert(kDependentTypeFalse<RegisterClass>);
+ } else if constexpr (RegisterClass::kIsImplicitReg) {
+ if constexpr (RegisterClass::kAsRegister == 0) {
+ return std::tuple{flag_register_};
+ } else {
+ return std::tuple{};
+ }
+ } else {
+ auto reg = AllocVReg();
+ return std::tuple{reg};
+ }
+ } else {
+ static_assert(berberis::kDependentValueFalse<arg_info.arg_type>);
+ }
+ }
+
+ template <typename T>
+ struct type_wrapper {
+ using type = T;
+ };
+
+ template <typename AsmCallInfo, typename... ArgBinding>
+ void ProcessBindingsResults(type_wrapper<std::tuple<ArgBinding...>>) {
+ (ProcessBindingResult<ArgBinding, AsmCallInfo>(), ...);
if constexpr (std::tuple_size_v<typename AsmCallInfo::OutputArguments> == 0) {
// No return value. Do nothing.
} else if constexpr (std::tuple_size_v<typename AsmCallInfo::OutputArguments> == 1) {
@@ -220,98 +398,46 @@
} else {
static_assert(kDependentTypeFalse<typename AsmCallInfo::OutputArguments>);
}
- return true;
}
template <typename ArgBinding, typename AsmCallInfo>
- auto /*MakeTuplefromBindingsClient*/ operator()(ArgTraits<ArgBinding>, AsmCallInfo) {
- static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
- if constexpr (arg_info.arg_type == ArgInfo::IMM_ARG) {
- auto imm = std::get<arg_info.from>(input_args_);
- return std::tuple{imm};
- } else {
- return ProcessArgInput<ArgBinding, AsmCallInfo>();
- }
- }
-
- template <typename ArgBinding, typename AsmCallInfo>
- auto ProcessArgInput() {
- static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+ void ProcessBindingResult() {
using RegisterClass = typename ArgTraits<ArgBinding>::RegisterClass;
- using Usage = typename ArgTraits<ArgBinding>::Usage;
- static constexpr const auto kNumOut = std::tuple_size_v<typename AsmCallInfo::OutputArguments>;
-
- if constexpr (arg_info.arg_type == ArgInfo::IN_ARG) {
- static_assert(std::is_same_v<Usage, intrinsics::bindings::Use>);
- static_assert(!RegisterClass::kIsImplicitReg);
- return std::tuple{std::get<arg_info.from>(input_args_)};
- } else if constexpr (arg_info.arg_type == ArgInfo::IN_OUT_ARG) {
- static_assert(!std::is_same_v<ResType, std::monostate>);
- static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
- static_assert(!RegisterClass::kIsImplicitReg);
- if constexpr (RegisterClass::kAsRegister == 'x') {
- if constexpr (kNumOut > 1) {
- auto res = std::get<arg_info.to>(result_);
- GenPseudoCopy<16>(builder_, res, std::get<arg_info.from>(input_args_));
- return std::tuple{res};
- } else {
- GenPseudoCopy<16>(builder_, result_, std::get<arg_info.from>(input_args_));
- return std::tuple{result_};
- }
- } else if constexpr (kNumOut > 1) {
- auto res = std::get<arg_info.to>(result_);
- GenPseudoCopy<sizeof(typename RegisterClass::Type)>(
- builder_, res, std::get<arg_info.from>(input_args_));
- return std::tuple{res};
- } else {
- GenPseudoCopy<sizeof(typename RegisterClass::Type)>(
- builder_, result_, std::get<arg_info.from>(input_args_));
- return std::tuple{result_};
- }
- } else if constexpr (arg_info.arg_type == ArgInfo::IN_TMP_ARG) {
- if constexpr (RegisterClass::kIsImplicitReg) {
- auto implicit_reg = builder_->ir()->AllocVReg();
- GenPseudoCopy<sizeof(typename RegisterClass::Type)>(
- builder_, implicit_reg, std::get<arg_info.from>(input_args_));
- return std::tuple{implicit_reg};
- } else {
- static_assert(std::is_same_v<Usage, intrinsics::bindings::UseDef>);
- static_assert(!RegisterClass::kIsImplicitReg);
- return std::tuple{std::get<arg_info.from>(input_args_)};
- }
- } else if constexpr (arg_info.arg_type == ArgInfo::OUT_ARG) {
- static_assert(!std::is_same_v<ResType, std::monostate>);
- static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
- std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
- static_assert(!RegisterClass::kIsImplicitReg);
- if constexpr (kNumOut > 1) {
- return std::tuple{std::get<arg_info.to>(result_)};
- } else {
- return std::tuple{result_};
- }
- } else if constexpr (arg_info.arg_type == ArgInfo::TMP_ARG) {
- static_assert(std::is_same_v<Usage, intrinsics::bindings::Def> ||
- std::is_same_v<Usage, intrinsics::bindings::DefEarlyClobber>);
- if constexpr (RegisterClass::kAsRegister == 'm') {
- static_assert(kDependentTypeFalse<RegisterClass>);
- } else if constexpr (RegisterClass::kIsImplicitReg) {
- if constexpr (RegisterClass::kAsRegister == 0) {
- return std::tuple{flag_register_};
- } else {
- return std::tuple{};
- }
- } else {
- auto reg = builder_->ir()->AllocVReg();
- return std::tuple{reg};
- }
- } else {
- static_assert(berberis::kDependentValueFalse<arg_info.arg_type>);
+ static constexpr const auto& arg_info = ArgTraits<ArgBinding>::arg_info;
+ if constexpr ((arg_info.arg_type == ArgInfo::IN_OUT_ARG ||
+ arg_info.arg_type == ArgInfo::OUT_ARG) &&
+ RegisterClass::kAsRegister == 'x') {
+ CHECK(!xmm_result_reg_.IsInvalidReg());
+ MovToResult<RegisterClass>(builder_, result_, xmm_result_reg_);
}
}
+ MachineReg AllocVReg() { return builder_->ir()->AllocVReg(); }
+
+ template <typename T>
+ static constexpr auto UnwrapSimdReg(T r) {
+ if constexpr (std::is_same_v<T, SimdReg>) {
+ return r.machine_reg();
+ } else {
+ return r;
+ }
+ }
+
+ template <typename... T>
+ static constexpr auto UnwrapSimdReg(std::tuple<T...> regs) {
+ constexpr const auto num_args = std::tuple_size<std::tuple<T...>>::value;
+ return UnwrapSimdReg(std::make_index_sequence<num_args>(), regs);
+ }
+
+ template <typename... T, auto... I>
+ static constexpr auto UnwrapSimdReg(std::index_sequence<I...>, std::tuple<T...> regs) {
+ return std::make_tuple(UnwrapSimdReg(std::get<I>(regs))...);
+ }
+
private:
x86_64::MachineIRBuilder* builder_;
ResType result_;
+ MachineReg xmm_result_reg_;
FlagRegister flag_register_;
std::tuple<ArgType...> input_args_;
bool success_;