Add APF opcodes to read/write data memory (take 2)

The new opcodes are LDDW (LoaD Data Word) and STDW (STore Data Word)

The only supported addressing mode is register-indirect with immediate
offset (register value + immediate value). Since there's a single
register bit encoded in the opcode, the other register is implicitly
used for the address operand. Hence, the following variations are
possible:

  lddw R0, [1234]R1  ; R0 = *(R1 + 1234)
  lddw R1, [1234]R0  ; R1 = *(R0 + 1234)
  stdw R0, [1234]R1  ; *(R1 + 1234) = R0
  stdw R1, [1234]R0  ; *(R0 + 1234) = R1

The immediate can also be specified to be length 0, making the memory
access equivalent to a plain register-indirect with no offset.

  lddw R0, R1  ; R0 = *R1
  lddw R1, R0  ; R1 = *R0
  stdw R0, R1  ; *R1 = R0
  stdw R1, R0  ; *R0 = R1

The encoding of the above instructions is a single byte, making a
typical counter increment more efficient, especially when the address
requires a multi-byte immediate:

  ldh   R1, 1234    ; address of our packet counter (3 bytes)
  lddw  R0, R1      ; load the counter from address 1234 (1 byte)
  add   R0, 1       ; increment the counter (2 bytes)
  stdw  R0, R1      ; write-back to data ram (1 byte)

Total: 7 bytes. Defining a separate INCW instruction would reduce the
above sequence down to 3 bytes (or 4 bytes if using an EXT opcode to
avoid wasting opcodes). This optimization can be added at a later point
if the data access patterns of production APF code justify it.

Bug: 73804303
Test: runtest -x tests/net/java/android/net/apf/ApfTest.java
Merged-In: I4bea29ea701cc11dc61cdcf60cc824bbe14b24f6
Merged-In: Ibbd427e12987a1eef63b41d816af05a1bd9f9170
Change-Id: Ibbd427e12987a1eef63b41d816af05a1bd9f9170
(cherry picked from commit 75410970184bf98626342588ba2eabf79cda6d38)
diff --git a/apf.h b/apf.h
index 2ae48ec..2d64930 100644
--- a/apf.h
+++ b/apf.h
@@ -19,7 +19,7 @@
 // APF machine is composed of:
 //  1. A read-only program consisting of bytecodes as described below.
 //  2. Two 32-bit registers, called R0 and R1.
-//  3. Sixteen 32-bit memory slots.
+//  3. Sixteen 32-bit temporary memory slots (cleared between packets).
 //  4. A read-only packet.
 // The program is executed by the interpreter below and parses the packet
 // to determine if the application processor (AP) should be woken up to
@@ -47,7 +47,7 @@
 //    They load either 1, 2 or 4 bytes, as determined by the "opcode" field.
 //    They load into the register specified by the "register" field.
 //    The immediate value that follows the first byte of the instruction is
-//    the byte offset from the begining of the packet to load from.
+//    the byte offset from the beginning of the packet to load from.
 //    There are "indexing" loads which add the value in R1 to the byte offset
 //    to load from. The "opcode" field determines which loads are "indexing".
 //  Arithmetic instructions
@@ -79,7 +79,7 @@
 //
 //  Miscellaneous details:
 //
-//  Pre-filled memory slot values
+//  Pre-filled temporary memory slot values
 //    When the APF program begins execution, three of the sixteen memory slots
 //    are pre-filled by the interpreter with values that may be useful for
 //    programs:
@@ -116,9 +116,9 @@
 //        position specified by the value of the register specified by the
 //        "register" field of the instruction.
 
-// Number of memory slots, see ldm/stm instructions.
+// Number of temporary memory slots, see ldm/stm instructions.
 #define MEMORY_ITEMS 16
-// Upon program execution starting some memory slots are prefilled:
+// Upon program execution, some temporary memory slots are prefilled:
 #define MEMORY_OFFSET_IPV4_HEADER_SIZE 13 // 4*([APF_FRAME_HEADER_SIZE]&15)
 #define MEMORY_OFFSET_PACKET_SIZE 14      // Size of packet in bytes.
 #define MEMORY_OFFSET_FILTER_AGE 15       // Age since filter installed in seconds.
@@ -127,9 +127,9 @@
 #define LDB_OPCODE 1    // Load 1 byte from immediate offset, e.g. "ldb R0, [5]"
 #define LDH_OPCODE 2    // Load 2 bytes from immediate offset, e.g. "ldh R0, [5]"
 #define LDW_OPCODE 3    // Load 4 bytes from immediate offset, e.g. "ldw R0, [5]"
-#define LDBX_OPCODE 4   // Load 1 byte from immediate offset plus register, e.g. "ldbx R0, [5]R0"
-#define LDHX_OPCODE 5   // Load 2 byte from immediate offset plus register, e.g. "ldhx R0, [5]R0"
-#define LDWX_OPCODE 6   // Load 4 byte from immediate offset plus register, e.g. "ldwx R0, [5]R0"
+#define LDBX_OPCODE 4   // Load 1 byte from immediate offset plus register, e.g. "ldbx R0, [5+R0]"
+#define LDHX_OPCODE 5   // Load 2 byte from immediate offset plus register, e.g. "ldhx R0, [5+R0]"
+#define LDWX_OPCODE 6   // Load 4 byte from immediate offset plus register, e.g. "ldwx R0, [5+R0]"
 #define ADD_OPCODE 7    // Add, e.g. "add R0,5"
 #define MUL_OPCODE 8    // Multiply, e.g. "mul R0,5"
 #define DIV_OPCODE 9    // Divide, e.g. "div R0,5"
@@ -145,12 +145,15 @@
 #define JSET_OPCODE 19  // Compare any bits set and branch, e.g. "jset R0,5,label"
 #define JNEBS_OPCODE 20 // Compare not equal byte sequence, e.g. "jnebs R0,5,label,0x1122334455"
 #define EXT_OPCODE 21   // Immediate value is one of *_EXT_OPCODE
+#define LDDW_OPCODE 22  // Load 4 bytes from data address (register + imm): "lddw R0, [5+R1]"
+#define STDW_OPCODE 23  // Store 4 bytes to data address (register + imm): "stdw R0, [5+R1]"
+
 // Extended opcodes. These all have an opcode of EXT_OPCODE
 // and specify the actual opcode in the immediate field.
-#define LDM_EXT_OPCODE 0   // Load from memory, e.g. "ldm R0,5"
-  // Values 0-15 represent loading the different memory slots.
-#define STM_EXT_OPCODE 16  // Store to memory, e.g. "stm R0,5"
-  // Values 16-31 represent storing to the different memory slots.
+#define LDM_EXT_OPCODE 0   // Load from temporary memory, e.g. "ldm R0,5"
+  // Values 0-15 represent loading the different temporary memory slots.
+#define STM_EXT_OPCODE 16  // Store to temporary memory, e.g. "stm R0,5"
+  // Values 16-31 represent storing to the different temporary memory slots.
 #define NOT_EXT_OPCODE 32  // Not, e.g. "not R0"
 #define NEG_EXT_OPCODE 33  // Negate, e.g. "neg R0"
 #define SWAP_EXT_OPCODE 34 // Swap, e.g. "swap R0,R1"
diff --git a/apf_disassembler.c b/apf_disassembler.c
index 03e6a06..b61202c 100644
--- a/apf_disassembler.c
+++ b/apf_disassembler.c
@@ -50,6 +50,8 @@
     [JLT_OPCODE] = "jlt",
     [JSET_OPCODE] = "jset",
     [JNEBS_OPCODE] = "jnebs",
+    [LDDW_OPCODE] = "lddw",
+    [STDW_OPCODE] = "stdw",
 };
 
 static void print_jump_target(uint32_t target, uint32_t program_len) {
@@ -206,6 +208,12 @@
                       break;
               }
               break;
+          case LDDW_OPCODE:
+          case STDW_OPCODE:
+              PRINT_OPCODE();
+              printf("r%u, [%u+r%u]", reg_num, imm, reg_num ^ 1);
+              break;
+
           // Unknown opcode
           default:
               printf("unknown %u", opcode);
diff --git a/apf_interpreter.c b/apf_interpreter.c
index 924b23e..f5e0072 100644
--- a/apf_interpreter.c
+++ b/apf_interpreter.c
@@ -31,30 +31,26 @@
 // superfluous ">= 0" with unsigned expressions generates compile warnings.
 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
 
-/**
- * Runs a packet filtering program over a packet.
- *
- * @param program the program bytecode.
- * @param program_len the length of {@code apf_program} in bytes.
- * @param packet the packet bytes, starting from the 802.3 header and not
- *               including any CRC bytes at the end.
- * @param packet_len the length of {@code packet} in bytes.
- * @param filter_age the number of seconds since the filter was programmed.
- *
- * @return non-zero if packet should be passed to AP, zero if
- *         packet should be dropped.
- */
 int accept_packet(const uint8_t* program, uint32_t program_len,
                   const uint8_t* packet, uint32_t packet_len,
+                  uint8_t* data, uint32_t data_len,
                   uint32_t filter_age) {
 // Is offset within program bounds?
 #define IN_PROGRAM_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < program_len)
 // Is offset within packet bounds?
 #define IN_PACKET_BOUNDS(p) (ENFORCE_UNSIGNED(p) && (p) < packet_len)
+// Is access to offset |p| length |size| within data bounds?
+#define IN_DATA_BOUNDS(p, size) (ENFORCE_UNSIGNED(p) && \
+                                 ENFORCE_UNSIGNED(size) && \
+                                 (p) + (size) < data_len && \
+                                 (p) + (size) >= (p))  // catch wraparounds
 // Accept packet if not within program bounds
 #define ASSERT_IN_PROGRAM_BOUNDS(p) ASSERT_RETURN(IN_PROGRAM_BOUNDS(p))
 // Accept packet if not within packet bounds
 #define ASSERT_IN_PACKET_BOUNDS(p) ASSERT_RETURN(IN_PACKET_BOUNDS(p))
+// Accept packet if not within data bounds
+#define ASSERT_IN_DATA_BOUNDS(p, size) ASSERT_RETURN(IN_DATA_BOUNDS(p, size))
+
   // Program counter.
   uint32_t pc = 0;
 // Accept packet if not within program or not ahead of program counter
@@ -268,6 +264,27 @@
                     return PASS_PACKET;
               }
               break;
+          case LDDW_OPCODE: {
+              uint32_t offs = imm + OTHER_REG;
+              uint32_t size = 4;
+              uint32_t val = 0;
+              ASSERT_IN_DATA_BOUNDS(offs, size);
+              while (size--)
+                  val = (val << 8) | data[offs++];
+              REG = val;
+              break;
+          }
+          case STDW_OPCODE: {
+              uint32_t offs = imm + OTHER_REG;
+              uint32_t size = 4;
+              uint32_t val = REG;
+              ASSERT_IN_DATA_BOUNDS(offs, size);
+              while (size--) {
+                  data[offs++] = (val >> 24);
+                  val <<= 8;
+              }
+              break;
+          }
           // Unknown opcode
           default:
               // Bail out
diff --git a/apf_interpreter.h b/apf_interpreter.h
index 2012c28..78a0dd3 100644
--- a/apf_interpreter.h
+++ b/apf_interpreter.h
@@ -27,7 +27,7 @@
  * Version of APF instruction set processed by accept_packet().
  * Should be returned by wifi_get_packet_filter_info.
  */
-#define APF_VERSION 2
+#define APF_VERSION 3
 
 /**
  * Runs a packet filtering program over a packet.
@@ -37,6 +37,8 @@
  * @param packet the packet bytes, starting from the 802.3 header and not
  *               including any CRC bytes at the end.
  * @param packet_len the length of {@code packet} in bytes.
+ * @param data writable data memory region (preserved between packets).
+ * @param data_len the length of {@code data} in bytes.
  * @param filter_age the number of seconds since the filter was programmed.
  *
  * @return non-zero if packet should be passed to AP, zero if
@@ -44,6 +46,7 @@
  */
 int accept_packet(const uint8_t* program, uint32_t program_len,
                   const uint8_t* packet, uint32_t packet_len,
+                  uint8_t* data, uint32_t data_len,
                   uint32_t filter_age);
 
 #ifdef __cplusplus
diff --git a/apf_run.c b/apf_run.c
index 32b4506..dab7f5e 100644
--- a/apf_run.c
+++ b/apf_run.c
@@ -50,13 +50,20 @@
     return length;
 }
 
+void print_hex(uint8_t* input, int len) {
+    for (int i = 0; i < len; ++i) {
+        printf("%02x", input[i]);
+    }
+}
+
 int main(int argc, char* argv[]) {
-    if (argc != 4) {
+    if (argc < 3 || argc > 5) {
         fprintf(stderr,
-                "Usage: %s <program> <packet> <program age>\n"
+                "Usage: %s <program> <packet> [<data>] [<age>]\n"
                 "  program:     APF program, in hex\n"
                 "  packet:      Packet to run through program, in hex\n"
-                "  program age: Age of program in seconds.\n",
+                "  data:        Data memory contents, in hex\n",
+                "  age:         Age of program in seconds (default: 0)\n",
                 basename(argv[0]));
         exit(1);
     }
@@ -64,11 +71,19 @@
     uint32_t program_len = parse_hex(argv[1], &program);
     uint8_t* packet;
     uint32_t packet_len = parse_hex(argv[2], &packet);
-    uint32_t filter_age = atoi(argv[3]);
+    uint8_t* data = NULL;
+    uint32_t data_len = argc > 3 ? parse_hex(argv[3], &data) : 0;
+    uint32_t filter_age = argc > 4 ? atoi(argv[4]) : 0;
     int ret = accept_packet(program, program_len, packet, packet_len,
-                            filter_age);
+                            data, data_len, filter_age);
     printf("Packet %sed\n", ret ? "pass" : "dropp");
+    if (data) {
+        printf("Data: ");
+        print_hex(data, data_len);
+        printf("\n");
+        free(data);
+    }
     free(program);
     free(packet);
     return ret;
-}
\ No newline at end of file
+}