[backend][heavy_optimizer] Support most CSRs

Only setting FCsr is not yet supported because we don't yet have a way
to pass the mem arg to the macro assembler.

Bug: 291126436
Test: mm and berberis_host_tests
(cherry picked from https://googleplex-android-review.googlesource.com/q/commit:8d62bb20809b0a27e44bd2d9fcae066767b411f5)
Merged-In: I12761c9e5cc6ed5f4666395fa37a44a7f1b1c754
Change-Id: I12761c9e5cc6ed5f4666395fa37a44a7f1b1c754
diff --git a/backend/x86_64/lir_instructions.json b/backend/x86_64/lir_instructions.json
index 2e044a5..a79f26e 100644
--- a/backend/x86_64/lir_instructions.json
+++ b/backend/x86_64/lir_instructions.json
@@ -23,6 +23,8 @@
         "AddqRegImm",
         "AddqRegReg",
         "AddqRegMemInsns",
+        "AndbRegImm",
+        "AndbMemImmInsns",
         "AndlRegImm",
         "AndlRegReg",
         "AndnqRegRegReg",
@@ -66,10 +68,13 @@
         "SarlRegReg",
         "SarqRegImm",
         "SarqRegReg",
+        "ShlbRegImm",
+        "ShldlRegRegImm",
         "ShllRegImm",
         "ShllRegReg",
         "ShlqRegImm",
         "ShlqRegReg",
+        "ShrbRegImm",
         "ShrlRegImm",
         "ShrlRegReg",
         "ShrqRegImm",
@@ -86,6 +91,7 @@
         "LockCmpXchgqRegMemRegInsns",
         "LockCmpXchg16bRegRegRegRegMemInsns",
         "Mfence",
+        "MovbMemImmInsns",
         "MovbMemRegInsns",
         "MovdMemXRegInsns",
         "MovdRegXReg",
@@ -137,6 +143,9 @@
         "MulsdXRegXReg",
         "MulssXRegXReg",
         "NotqReg",
+        "OrbMemImmInsns",
+        "OrbMemRegInsns",
+        "OrbRegReg",
         "OrlRegImm",
         "OrlRegReg",
         "OrqRegImm",
diff --git a/heavy_optimizer/riscv64/frontend.cc b/heavy_optimizer/riscv64/frontend.cc
index 2ab51a8..6098327 100644
--- a/heavy_optimizer/riscv64/frontend.cc
+++ b/heavy_optimizer/riscv64/frontend.cc
@@ -767,6 +767,52 @@
   return res;
 }
 
+Register HeavyOptimizerFrontend::UpdateCsr(Decoder::CsrOpcode opcode, Register arg, Register csr) {
+  Register res = AllocTempReg();
+  switch (opcode) {
+    case Decoder::CsrOpcode::kCsrrs:
+      Gen<PseudoCopy>(res, arg, 8);
+      Gen<x86_64::OrqRegReg>(res, csr, GetFlagsRegister());
+      break;
+    case Decoder::CsrOpcode::kCsrrc:
+      if (host_platform::kHasBMI) {
+        Gen<x86_64::AndnqRegRegReg>(res, arg, csr, GetFlagsRegister());
+      } else {
+        Gen<PseudoCopy>(res, arg, 8);
+        Gen<x86_64::NotqReg>(res);
+        Gen<x86_64::AndqRegReg>(res, csr, GetFlagsRegister());
+      }
+      break;
+    default:
+      Unimplemented();
+      return {};
+  }
+  return arg;
+}
+
+Register HeavyOptimizerFrontend::UpdateCsr(Decoder::CsrImmOpcode opcode,
+                                           uint8_t imm,
+                                           Register csr) {
+  Register res = AllocTempReg();
+  switch (opcode) {
+    case Decoder::CsrImmOpcode::kCsrrwi:
+      Gen<x86_64::MovlRegImm>(res, imm);
+      break;
+    case Decoder::CsrImmOpcode::kCsrrsi:
+      Gen<x86_64::MovlRegImm>(res, imm);
+      Gen<x86_64::OrqRegReg>(res, csr, GetFlagsRegister());
+      break;
+    case Decoder::CsrImmOpcode::kCsrrci:
+      Gen<x86_64::MovqRegImm>(res, static_cast<int8_t>(~imm));
+      Gen<x86_64::AndqRegReg>(res, csr, GetFlagsRegister());
+      break;
+    default:
+      Unimplemented();
+      return {};
+  }
+  return res;
+}
+
 void HeavyOptimizerFrontend::StoreWithoutRecovery(Decoder::StoreOperandType operand_type,
                                                   Register base,
                                                   int32_t disp,
diff --git a/heavy_optimizer/riscv64/frontend.h b/heavy_optimizer/riscv64/frontend.h
index 35db8b4..557ee0a 100644
--- a/heavy_optimizer/riscv64/frontend.h
+++ b/heavy_optimizer/riscv64/frontend.h
@@ -20,6 +20,7 @@
 #include "berberis/backend/x86_64/machine_ir.h"
 #include "berberis/backend/x86_64/machine_ir_builder.h"
 #include "berberis/base/arena_map.h"
+#include "berberis/base/checks.h"
 #include "berberis/base/dependent_false.h"
 #include "berberis/decoder/riscv64/decoder.h"
 #include "berberis/decoder/riscv64/semantics_player.h"
@@ -289,15 +290,8 @@
   // Csr
   //
 
-  Register UpdateCsr(Decoder::CsrOpcode /* opcode */, Register /* arg */, Register /* csr */) {
-    Unimplemented();
-    return {};
-  }
-
-  Register UpdateCsr(Decoder::CsrImmOpcode /* opcode */, uint8_t /* imm */, Register /* csr */) {
-    Unimplemented();
-    return {};
-  }
+  Register UpdateCsr(Decoder::CsrOpcode opcode, Register arg, Register csr);
+  Register UpdateCsr(Decoder::CsrImmOpcode opcode, uint8_t imm, Register csr);
 
   [[nodiscard]] bool success() const { return success_; }
 
@@ -325,18 +319,48 @@
 
   template <CsrName kName>
   [[nodiscard]] Register GetCsr() {
-    Unimplemented();
-    return {};
+    auto csr_reg = AllocTempReg();
+    if constexpr (std::is_same_v<CsrFieldType<kName>, uint8_t>) {
+      Gen<x86_64::MovzxblRegMemBaseDisp>(csr_reg, x86_64::kMachineRegRBP, kCsrFieldOffset<kName>);
+    } else if constexpr (std::is_same_v<CsrFieldType<kName>, uint64_t>) {
+      Gen<x86_64::MovqRegMemBaseDisp>(csr_reg, x86_64::kMachineRegRBP, kCsrFieldOffset<kName>);
+    } else {
+      static_assert(kDependentTypeFalse<CsrFieldType<kName>>);
+    }
+    return csr_reg;
   }
 
   template <CsrName kName>
-  void SetCsr(uint8_t /* imm */) {
-    Unimplemented();
+  void SetCsr(uint8_t imm) {
+    // Note: csr immediate only have 5 bits in RISC-V encoding which guarantess us that
+    // “imm & kCsrMask<kName>”can be used as 8-bit immediate.
+    if constexpr (std::is_same_v<CsrFieldType<kName>, uint8_t>) {
+      Gen<x86_64::MovbMemBaseDispImm>(x86_64::kMachineRegRBP,
+                                      kCsrFieldOffset<kName>,
+                                      static_cast<int8_t>(imm & kCsrMask<kName>));
+    } else if constexpr (std::is_same_v<CsrFieldType<kName>, uint64_t>) {
+      Gen<x86_64::MovbMemBaseDispImm>(x86_64::kMachineRegRBP,
+                                      kCsrFieldOffset<kName>,
+                                      static_cast<int8_t>(imm & kCsrMask<kName>));
+    } else {
+      static_assert(kDependentTypeFalse<CsrFieldType<kName>>);
+    }
   }
 
   template <CsrName kName>
-  void SetCsr(Register /* arg */) {
-    Unimplemented();
+  void SetCsr(Register arg) {
+    auto tmp = AllocTempReg();
+    Gen<PseudoCopy>(tmp, arg, sizeof(CsrFieldType<kName>));
+    if constexpr (sizeof(CsrFieldType<kName>) == 1) {
+      Gen<x86_64::AndbRegImm>(tmp, kCsrMask<kName>, GetFlagsRegister());
+      Gen<x86_64::MovbMemBaseDispReg>(x86_64::kMachineRegRBP, kCsrFieldOffset<kName>, tmp);
+    } else if constexpr (sizeof(CsrFieldType<kName>) == 8) {
+      Gen<x86_64::AndqRegImm>(
+          tmp, constants_pool::kConst<uint64_t{kCsrMask<kName>}>, GetFlagsRegister());
+      Gen<x86_64::MovqMemBaseDispReg>(x86_64::kMachineRegRBP, kCsrFieldOffset<kName>, tmp);
+    } else {
+      static_assert(kDependentTypeFalse<CsrFieldType<kName>>);
+    }
   }
 
  private:
@@ -442,6 +466,178 @@
   ArenaMap<GuestAddr, MachineInsnPosition> branch_targets_;
 };
 
+template <>
+[[nodiscard]] inline HeavyOptimizerFrontend::Register
+HeavyOptimizerFrontend::GetCsr<CsrName::kFCsr>() {
+  auto csr_reg = AllocTempReg();
+  auto tmp = AllocTempReg();
+  bool inline_successful = TryInlineIntrinsicForHeavyOptimizer<&intrinsics::FeGetExceptions>(
+      &builder_, GetFlagsRegister(), tmp);
+  CHECK(inline_successful);
+  Gen<x86_64::MovzxbqRegMemBaseDisp>(
+      csr_reg, x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kFrm>);
+  Gen<x86_64::ShlbRegImm>(csr_reg, 5, GetFlagsRegister());
+  Gen<x86_64::OrbRegReg>(csr_reg, tmp, GetFlagsRegister());
+  return csr_reg;
+}
+
+template <>
+[[nodiscard]] inline HeavyOptimizerFrontend::Register
+HeavyOptimizerFrontend::GetCsr<CsrName::kFFlags>() {
+  return FeGetExceptions();
+}
+
+template <>
+[[nodiscard]] inline HeavyOptimizerFrontend::Register
+HeavyOptimizerFrontend::GetCsr<CsrName::kVlenb>() {
+  return GetImm(16);
+}
+
+template <>
+[[nodiscard]] inline HeavyOptimizerFrontend::Register
+HeavyOptimizerFrontend::GetCsr<CsrName::kVxrm>() {
+  auto reg = AllocTempReg();
+  Gen<x86_64::MovzxbqRegMemBaseDisp>(reg, x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>);
+  Gen<x86_64::AndbRegImm>(reg, 0b11, GetFlagsRegister());
+  return reg;
+}
+
+template <>
+[[nodiscard]] inline HeavyOptimizerFrontend::Register
+HeavyOptimizerFrontend::GetCsr<CsrName::kVxsat>() {
+  auto reg = AllocTempReg();
+  Gen<x86_64::MovzxbqRegMemBaseDisp>(reg, x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>);
+  Gen<x86_64::ShrbRegImm>(reg, 2, GetFlagsRegister());
+  return reg;
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kFCsr>(uint8_t /* imm */) {
+  Unimplemented();
+  // TODO(b/291126436) Figure out how to pass Mem arg to FeSetExceptionsAndRoundImmTranslate.
+  // // Note: instructions Csrrci or Csrrsi couldn't affect Frm because immediate only has five
+  // bits.
+  // // But these instruction don't pass their immediate-specified argument into `SetCsr`, they
+  // combine
+  // // it with register first. Fixing that can only be done by changing code in the semantics
+  // player.
+  // //
+  // // But Csrrwi may clear it.  And we actually may only arrive here from Csrrwi.
+  // // Thus, technically, we know that imm >> 5 is always zero, but it doesn't look like a good
+  // idea
+  // // to rely on that: it's very subtle and it only affects code generation speed.
+  // Gen<x86_64::MovbMemBaseDispImm>(x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kFrm>,
+  // static_cast<int8_t>(imm >> 5)); bool successful =
+  // TryInlineIntrinsicForHeavyOptimizer<&intrinsics::FeSetExceptionsAndRoundImmTranslate>(
+  //     &builder_,
+  //     GetFlagsRegister(),
+  //     x86_64::kMachineRegRBP,
+  //     static_cast<int>(offsetof(ThreadState, intrinsics_scratch_area)),
+  //     imm);
+  // CHECK(successful);
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kFCsr>(Register /* arg */) {
+  Unimplemented();
+  // TODO(b/291126436) Figure out how to pass Mem arg to FeSetExceptionsAndRoundTranslate.
+  // auto tmp1 = AllocTempReg();
+  // auto tmp2 = AllocTempReg();
+  // Gen<PseudoCopy>(tmp1, arg, 1);
+  // Gen<x86_64::AndlRegImm>(tmp1, 0b1'1111, GetFlagsRegister());
+  // Gen<x86_64::ShldlRegRegImm>(tmp2, arg, int8_t{32 - 5}, GetFlagsRegister());
+  // Gen<x86_64::AndbRegImm>(tmp2, kCsrMask<CsrName::kFrm>, GetFlagsRegister());
+  // Gen<x86_64::MovbMemBaseDispReg>(x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kFrm>,
+  //                  tmp2);
+  // bool successful =
+  // TryInlineIntrinsicForHeavyOptimizer<&intrinsics::FeSetExceptionsAndRoundTranslate>(
+  //     &builder_,
+  //     GetFlagsRegister(),
+  //     tmp1,
+  //     x86_64::kMachineRegRBP,
+  //     static_cast<int>(offsetof(ThreadState, intrinsics_scratch_area)),
+  //     tmp1);
+  // CHECK(successful);
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kFFlags>(uint8_t imm) {
+  FeSetExceptionsImm(static_cast<int8_t>(imm & 0b1'1111));
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kFFlags>(Register arg) {
+  auto tmp = AllocTempReg();
+  Gen<PseudoCopy>(tmp, arg, 1);
+  Gen<x86_64::AndlRegImm>(tmp, 0b1'1111, GetFlagsRegister());
+  FeSetExceptions(tmp);
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kFrm>(uint8_t imm) {
+  Gen<x86_64::MovbMemBaseDispImm>(x86_64::kMachineRegRBP,
+                                  kCsrFieldOffset<CsrName::kFrm>,
+                                  static_cast<int8_t>(imm & kCsrMask<CsrName::kFrm>));
+  FeSetRoundImm(static_cast<int8_t>(imm & kCsrMask<CsrName::kFrm>));
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kFrm>(Register arg) {
+  // Use RCX as temporary register. We know it would be used by FeSetRound, too.
+  auto tmp = AllocTempReg();
+  Gen<PseudoCopy>(tmp, arg, 1);
+  Gen<x86_64::AndbRegImm>(tmp, kCsrMask<CsrName::kFrm>, GetFlagsRegister());
+  Gen<x86_64::MovbMemBaseDispReg>(x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kFrm>, tmp);
+  FeSetRound(tmp);
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kVxrm>(uint8_t imm) {
+  imm &= 0b11;
+  if (imm != 0b11) {
+    Gen<x86_64::AndbMemBaseDispImm>(
+        x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, 0b100, GetFlagsRegister());
+  }
+  if (imm != 0b00) {
+    Gen<x86_64::OrbMemBaseDispImm>(
+        x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, imm, GetFlagsRegister());
+  }
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kVxrm>(Register arg) {
+  Gen<x86_64::AndbMemBaseDispImm>(
+      x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, 0b100, GetFlagsRegister());
+  Gen<x86_64::AndbRegImm>(arg, 0b11, GetFlagsRegister());
+  Gen<x86_64::OrbMemBaseDispReg>(
+      x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, arg, GetFlagsRegister());
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kVxsat>(uint8_t imm) {
+  if (imm & 0b1) {
+    Gen<x86_64::OrbMemBaseDispImm>(
+        x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, 0b100, GetFlagsRegister());
+  } else {
+    Gen<x86_64::AndbMemBaseDispImm>(
+        x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, 0b11, GetFlagsRegister());
+  }
+}
+
+template <>
+inline void HeavyOptimizerFrontend::SetCsr<CsrName::kVxsat>(Register arg) {
+  using Condition = x86_64::Assembler::Condition;
+  Gen<x86_64::AndbMemBaseDispImm>(
+      x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, 0b11, GetFlagsRegister());
+  Gen<x86_64::TestbRegImm>(arg, 1, GetFlagsRegister());
+  auto tmp = AllocTempReg();
+  Gen<x86_64::SetccReg>(Condition::kNotZero, tmp, GetFlagsRegister());
+  Gen<x86_64::MovzxbqRegReg>(tmp, tmp);
+  Gen<x86_64::ShlbRegImm>(tmp, int8_t{2}, GetFlagsRegister());
+  Gen<x86_64::OrbMemBaseDispReg>(
+      x86_64::kMachineRegRBP, kCsrFieldOffset<CsrName::kVcsr>, tmp, GetFlagsRegister());
+}
+
 }  // namespace berberis
 
 #endif /* BERBERIS_HEAVY_OPTIMIZER_RISCV64_FRONTEND_H_ */
diff --git a/test_utils/include/berberis/test_utils/insn_tests_riscv64-inl.h b/test_utils/include/berberis/test_utils/insn_tests_riscv64-inl.h
index f05b785..9f8dada 100644
--- a/test_utils/include/berberis/test_utils/insn_tests_riscv64-inl.h
+++ b/test_utils/include/berberis/test_utils/insn_tests_riscv64-inl.h
@@ -222,6 +222,10 @@
     EXPECT_EQ(GetXReg<2>(state_.cpu), expected_fflags);
   }
 
+#endif  // defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR)
+#if defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR) || \
+    defined(TESTING_HEAVY_OPTIMIZER)
+
   void TestFrm(uint32_t insn_bytes, uint8_t frm_to_set, uint8_t expected_rm) {
     auto code_start = ToGuestAddr(&insn_bytes);
     state_.cpu.insn_addr = code_start;
@@ -232,10 +236,6 @@
     EXPECT_EQ(state_.cpu.frm, expected_rm);
   }
 
-#endif  // defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR)
-#if defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR) || \
-    defined(TESTING_HEAVY_OPTIMIZER)
-
   void TestOp(uint32_t insn_bytes,
               std::initializer_list<std::tuple<uint64_t, uint64_t, uint64_t>> args) {
     for (auto [arg1, arg2, expected_result] : args) {
@@ -1143,6 +1143,10 @@
 
 // Tests for Non-Compressed Instructions.
 
+#endif  // defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR)
+#if defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR) || \
+    defined(TESTING_HEAVY_OPTIMIZER)
+
 TEST_F(TESTSUITE, CsrInstructions) {
   ScopedRoundingMode scoped_rounding_mode;
   // Csrrw x2, frm, 2
@@ -1153,6 +1157,10 @@
   TestFrm(0x0020f173, 0, 0);
 }
 
+#endif  // defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR) ||
+        // defined(TESTING_HEAVY_OPTIMIZER)
+#if defined(TESTING_INTERPRETER) || defined(TESTING_LITE_TRANSLATOR)
+
 TEST_F(TESTSUITE, FCsrRegister) {
   fenv_t saved_environment;
   EXPECT_EQ(fegetenv(&saved_environment), 0);