Add JDNSAMATCH opcode support to ApfGenerator

* design doc: go/apf-v6-proposal

Bug: 293811969
Test: TH
Change-Id: I9171386442c0ed9a7be76afcac3de7e63a005247
diff --git a/src/android/net/apf/ApfGenerator.java b/src/android/net/apf/ApfGenerator.java
index 05b6484..6346a02 100644
--- a/src/android/net/apf/ApfGenerator.java
+++ b/src/android/net/apf/ApfGenerator.java
@@ -136,15 +136,25 @@
         // e.g. "pktcopy r0, 5", "pktcopy r0, r1", "datacopy r0, 5", "datacopy r0, r1"
         EPKTCOPY(41),
         EDATACOPY(42),
-        // Jumps if the UDP payload content (starting at R0) does not contain the specified QNAME,
-        // applying MDNS case insensitivity.
+        // Jumps if the UDP payload content (starting at R0) does not contain ont
+        // of the specified QNAME, applying case insensitivity.
         // R0: Offset to UDP payload content
+        // R=0/1 meanining 'does not match' vs 'matches'
         // imm1: Opcode
         // imm2: Label offset
         // imm3(u8): Question type (PTR/SRV/TXT/A/AAAA)
         // imm4(bytes): TLV-encoded QNAME list (null-terminated)
         // e.g.: "jdnsqmatch R0,label,0x0c,\002aa\005local\0\0"
-        JDNSQMATCH(43);
+        JDNSQMATCH(43), // Jumps if the UDP payload content (starting at R0) does not contain one
+        // of the specified NAME in answers/authority/additional records, applying
+        // case insensitivity.
+        // R=0/1 meanining 'does not match' vs 'matches'
+        // R0: Offset to UDP payload content
+        // imm1: Opcode
+        // imm2: Label offset
+        // imm3(bytes): TLV-encoded QNAME list (null-terminated)
+        // e.g.: "jdnsamatch R0,label,0x0c,\002aa\005local\0\0"
+        JDNSAMATCH(44);
 
         final int value;
 
@@ -1218,16 +1228,16 @@
         return (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_';
     }
 
-    private static void validateQnames(@NonNull byte[] qnames) {
-        final int len = qnames.length;
+    private static void validateNames(@NonNull byte[] names) {
+        final int len = names.length;
         if (len < 4) {
             throw new IllegalArgumentException("qnames must have at least length 4");
         }
-        final String errorMessage = "qname: " + HexDump.toHexString(qnames)
+        final String errorMessage = "qname: " + HexDump.toHexString(names)
                 + "is not null-terminated list of TLV-encoded names";
         int i = 0;
         while (i < len - 1) {
-            int label_len = qnames[i++];
+            int label_len = names[i++];
             if (label_len < 1 || label_len > 63) {
                 throw new IllegalArgumentException(
                         "label len: " + label_len + " must be between 1 and 63");
@@ -1236,16 +1246,16 @@
                 throw new IllegalArgumentException(errorMessage);
             }
             while (label_len-- > 0) {
-                if (!isValidDnsCharacter(qnames[i++])) {
-                    throw new IllegalArgumentException("qname: " + HexDump.toHexString(qnames)
+                if (!isValidDnsCharacter(names[i++])) {
+                    throw new IllegalArgumentException("qname: " + HexDump.toHexString(names)
                             + " contains invalid character");
                 }
             }
-            if (qnames[i] == 0) {
+            if (names[i] == 0) {
                 i++; // skip null terminator.
             }
         }
-        if (qnames[len - 1] != 0) {
+        if (names[len - 1] != 0) {
             throw new IllegalArgumentException(errorMessage);
         }
     }
@@ -1259,7 +1269,7 @@
     public ApfGenerator addJumpIfPktAtR0DoesNotContainDnsQ(@NonNull byte[] qnames, int qtype,
             @NonNull String tgt) throws IllegalInstructionException {
         requireApfVersion(MIN_APF_VERSION_IN_DEV);
-        validateQnames(qnames);
+        validateNames(qnames);
         return append(new Instruction(ExtendedOpcodes.JDNSQMATCH).setTargetLabel(tgt).addU8(
                 qtype).setBytesImm(qnames));
     }
@@ -1273,11 +1283,39 @@
     public ApfGenerator addJumpIfPktAtR0ContainDnsQ(@NonNull byte[] qnames, int qtype,
             @NonNull String tgt) throws IllegalInstructionException {
         requireApfVersion(MIN_APF_VERSION_IN_DEV);
-        validateQnames(qnames);
+        validateNames(qnames);
         return append(new Instruction(ExtendedOpcodes.JDNSQMATCH, R1).setTargetLabel(tgt).addU8(
                 qtype).setBytesImm(qnames));
     }
 
+    /**
+     * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP
+     * payload's DNS answers/authority/additional records do NOT contain the NAME
+     * specified in {@code Names}. Examines the payload starting at the offset in R0.
+     * R = 0 means check for "does not contain".
+     */
+    public ApfGenerator addJumpIfPktAtR0DoesNotContainDnsA(@NonNull byte[] names,
+            @NonNull String tgt) throws IllegalInstructionException {
+        requireApfVersion(MIN_APF_VERSION_IN_DEV);
+        validateNames(names);
+        return append(new Instruction(ExtendedOpcodes.JDNSAMATCH).setTargetLabel(tgt).setBytesImm(
+                names));
+    }
+
+    /**
+     * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP
+     * payload's DNS answers/authority/additional records contain the NAME
+     * specified in {@code Names}. Examines the payload starting at the offset in R0.
+     * R = 1 means check for "contain".
+     */
+    public ApfGenerator addJumpIfPktAtR0ContainDnsA(@NonNull byte[] names,
+            @NonNull String tgt) throws IllegalInstructionException {
+        requireApfVersion(MIN_APF_VERSION_IN_DEV);
+        validateNames(names);
+        return append(new Instruction(ExtendedOpcodes.JDNSAMATCH, R1).setTargetLabel(
+                tgt).setBytesImm(names));
+    }
+
     private static void checkRange(@NonNull String variableName, long value, long lowerBound,
             long upperBound) {
         if (value >= lowerBound && value <= upperBound) {
diff --git a/tests/unit/src/android/net/apf/ApfV5Test.kt b/tests/unit/src/android/net/apf/ApfV5Test.kt
index c02ede9..162feef 100644
--- a/tests/unit/src/android/net/apf/ApfV5Test.kt
+++ b/tests/unit/src/android/net/apf/ApfV5Test.kt
@@ -65,6 +65,10 @@
                 byteArrayOf(1, 'A'.code.toByte()), 0x0c, ApfGenerator.DROP_LABEL) }
         assertFailsWith<IllegalInstructionException> { gen.addJumpIfPktAtR0ContainDnsQ(
                 byteArrayOf(1, 'A'.code.toByte()), 0x0c, ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalInstructionException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte()), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalInstructionException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte()), ApfGenerator.DROP_LABEL) }
     }
 
     @Test
@@ -127,6 +131,40 @@
         assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ(
                 byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()),
                 0xc0, ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(1, 'a'.code.toByte(), 0, 0), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(1, '.'.code.toByte(), 0, 0), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(0, 0), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte()), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(64) + ByteArray(64) { 'A'.code.toByte() } + byteArrayOf(0, 0),
+                 ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0),
+                ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()),
+                ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(1, 'a'.code.toByte(), 0, 0), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(1, '.'.code.toByte(), 0, 0), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(0, 0), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte()), ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(64) + ByteArray(64) { 'A'.code.toByte() } + byteArrayOf(0, 0),
+                ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0),
+                ApfGenerator.DROP_LABEL) }
+        assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA(
+                byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()),
+                ApfGenerator.DROP_LABEL) }
     }
 
     @Test
@@ -307,6 +345,16 @@
         ) + qnames + byteArrayOf(
                 encodeInstruction(21, 1, 1), 43, 1, 0x0c.toByte(),
         ) + qnames, program)
+
+        gen = ApfGenerator(ApfGenerator.MIN_APF_VERSION_IN_DEV)
+        gen.addJumpIfPktAtR0DoesNotContainDnsA(qnames, ApfGenerator.DROP_LABEL)
+        gen.addJumpIfPktAtR0ContainDnsA(qnames, ApfGenerator.DROP_LABEL)
+        program = gen.generate()
+        assertContentEquals(byteArrayOf(
+                encodeInstruction(21, 1, 0), 44, 10,
+        ) + qnames + byteArrayOf(
+                encodeInstruction(21, 1, 1), 44, 1,
+        ) + qnames, program)
     }
 
     private fun encodeInstruction(opcode: Int, immLength: Int, register: Int): Byte {