Implement dup (vector) for aarch64 assembler

PiperOrigin-RevId: 423438655
diff --git a/src/jit/aarch64-assembler.cc b/src/jit/aarch64-assembler.cc
index 5105c05..550f72c 100644
--- a/src/jit/aarch64-assembler.cc
+++ b/src/jit/aarch64-assembler.cc
@@ -37,17 +37,11 @@
 constexpr uint32_t kTbxzImmMask = 0x3FFF;
 constexpr uint32_t kUnconditionalImmMask = 0x03FFFFFF;
 
-inline uint32_t rd(VRegister vn) { return vn.code; }
-inline uint32_t rd(XRegister xn) { return xn.code; }
-inline uint32_t rt(QRegister qn) { return qn.code; }
-inline uint32_t rt(SRegister sn) { return sn.code; }
-inline uint32_t rt(VRegister vn) { return vn.code; }
+template <typename Reg> inline uint32_t rd(Reg vn) { return vn.code; }
+template <typename Reg> inline uint32_t rt(Reg qn) { return qn.code; }
 inline uint32_t rt2(QRegister qn) { return qn.code << 10; }
-inline uint32_t rm(XRegister xn) { return xn.code << 16; }
-inline uint32_t rm(VRegister vn) { return vn.code << 16; }
-inline uint32_t rm(VRegisterLane vn) { return vn.code << 16; }
-inline uint32_t rn(XRegister xn) { return xn.code << 5; }
-inline uint32_t rn(VRegister vn) { return vn.code << 5; }
+template <typename Reg> inline uint32_t rm(Reg xn) { return xn.code << 16; }
+template <typename Reg> inline uint32_t rn(Reg rn) { return rn.code << 5; }
 inline uint32_t q(VRegister vt) { return vt.q << 30; }
 inline uint32_t size(VRegister vt) { return vt.size << 10; }
 inline uint32_t fp_sz(VRegister vn) { return vn.is_s() ? 0 : 1 << 22; }
@@ -287,6 +281,15 @@
 
 // SIMD instructions.
 
+Assembler& Assembler::dup(DRegister dd, VRegisterLane vn) {
+  if (vn.size != 3 || vn.lane > 1) {
+    error_ = Error::kInvalidOperand;
+    return *this;
+  }
+  const uint8_t imm5 = 0b1000 | (vn.lane & 1) << 4;
+  return emit32(0x5E000400 | imm5 << 16 | rn(vn) | rd(dd));
+}
+
 Assembler& Assembler::fadd(VRegister vd, VRegister vn, VRegister vm) {
   if (!is_same_shape(vd, vn, vm)) {
     error_ = Error::kInvalidOperand;
diff --git a/src/xnnpack/aarch64-assembler.h b/src/xnnpack/aarch64-assembler.h
index 921ed65..9d4a807 100644
--- a/src/xnnpack/aarch64-assembler.h
+++ b/src/xnnpack/aarch64-assembler.h
@@ -79,6 +79,7 @@
   VRegister v2d() const { return {code, 3, 1}; }
 
   ScalarVRegister s() const { return {code, 2}; }
+  ScalarVRegister d() const { return {code, 3}; }
 
   const bool is_s() { return size == 2; };
 };
@@ -170,6 +171,43 @@
 constexpr SRegister s30{30};
 constexpr SRegister s31{31};
 
+struct DRegister {
+  uint8_t code;
+};
+
+constexpr DRegister d0{0};
+constexpr DRegister d1{1};
+constexpr DRegister d2{2};
+constexpr DRegister d3{3};
+constexpr DRegister d4{4};
+constexpr DRegister d5{5};
+constexpr DRegister d6{6};
+constexpr DRegister d7{7};
+constexpr DRegister d8{8};
+constexpr DRegister d9{9};
+constexpr DRegister d10{10};
+constexpr DRegister d11{11};
+constexpr DRegister d12{12};
+constexpr DRegister d13{13};
+constexpr DRegister d14{14};
+constexpr DRegister d15{15};
+constexpr DRegister d16{16};
+constexpr DRegister d17{17};
+constexpr DRegister d18{18};
+constexpr DRegister d19{19};
+constexpr DRegister d20{20};
+constexpr DRegister d21{21};
+constexpr DRegister d22{22};
+constexpr DRegister d23{23};
+constexpr DRegister d24{24};
+constexpr DRegister d25{25};
+constexpr DRegister d26{26};
+constexpr DRegister d27{27};
+constexpr DRegister d28{28};
+constexpr DRegister d29{29};
+constexpr DRegister d30{30};
+constexpr DRegister d31{31};
+
 struct QRegister {
   uint8_t code;
 };
@@ -295,6 +333,7 @@
   Assembler& tbz(XRegister xd, uint8_t bit, Label& l);
 
   // SIMD instructions
+  Assembler& dup(DRegister dd, VRegisterLane vn);
   Assembler& fadd(VRegister vd, VRegister vn, VRegister vm);
   Assembler& fmax(VRegister vd, VRegister vn, VRegister vm);
   Assembler& fmin(VRegister vd, VRegister vn, VRegister vm);
diff --git a/test/aarch64-assembler.cc b/test/aarch64-assembler.cc
index 3710188..e749305 100644
--- a/test/aarch64-assembler.cc
+++ b/test/aarch64-assembler.cc
@@ -67,6 +67,10 @@
   xnn_allocate_code_memory(&b, XNN_DEFAULT_CODE_BUFFER_SIZE);
   Assembler a(&b);
 
+  CHECK_ENCODING(0x5E180610, a.dup(d16, v16.d()[1]));
+  EXPECT_ERROR(Error::kInvalidOperand, a.dup(d16, v16.d()[2]));
+  EXPECT_ERROR(Error::kInvalidOperand, a.dup(d16, v16.s()[1]));
+
   CHECK_ENCODING(0x4E25D690, a.fadd(v16.v4s(), v20.v4s(), v5.v4s()));
   EXPECT_ERROR(Error::kInvalidOperand, a.fadd(v16.v4s(), v20.v4s(), v5.v2s()));