Implement addJumpIfBytesAtR0Equals{Any|None}Of() in ApfV6GeneratorBase

Test: TH
Change-Id: I8596e3fd9f231488eb7e03e1034436a21b6a2f35
diff --git a/src/android/net/apf/ApfV6GeneratorBase.java b/src/android/net/apf/ApfV6GeneratorBase.java
index c00bc3d..e2932d5 100644
--- a/src/android/net/apf/ApfV6GeneratorBase.java
+++ b/src/android/net/apf/ApfV6GeneratorBase.java
@@ -24,7 +24,11 @@
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.HexDump;
 
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 import java.util.Objects;
 import java.util.Set;
 
@@ -427,6 +431,74 @@
                 bytes.length).setTargetLabel(tgt).setBytesImm(bytes));
     }
 
+    private List<byte[]> validateDeduplicateBytesList(List<byte[]> bytesList) {
+        if (bytesList == null || bytesList.size() == 0) {
+            throw new IllegalArgumentException(
+                    "bytesList size must > 0, current size: "
+                            + (bytesList == null ? "null" : bytesList.size()));
+        }
+        for (byte[] bytes : bytesList) {
+            validateBytes(bytes);
+        }
+        final int elementSize = bytesList.get(0).length;
+        if (elementSize > 2097151) { // 2 ^ 21 - 1
+            throw new IllegalArgumentException("too many elements");
+        }
+        List<byte[]> deduplicatedList = new ArrayList<>();
+        deduplicatedList.add(bytesList.get(0));
+        for (int i = 1; i < bytesList.size(); ++i) {
+            if (elementSize != bytesList.get(i).length) {
+                throw new IllegalArgumentException("byte arrays in the set have different size");
+            }
+            int j = 0;
+            for (; j < deduplicatedList.size(); ++j) {
+                if (Arrays.equals(bytesList.get(i), deduplicatedList.get(j))) {
+                    break;
+                }
+            }
+            if (j == deduplicatedList.size()) {
+                deduplicatedList.add(bytesList.get(i));
+            }
+        }
+        return deduplicatedList;
+    }
+
+    private Type addJumpIfBytesAtR0EqualsHelper(@NonNull List<byte[]> bytesList, String tgt,
+            boolean jumpOnMatch) {
+        final List<byte[]> deduplicatedList = validateDeduplicateBytesList(bytesList);
+        final int elementSize = deduplicatedList.get(0).length;
+        final int totalElements = deduplicatedList.size();
+        final int totalSize = elementSize * totalElements;
+        final ByteBuffer buffer = ByteBuffer.allocate(totalSize);
+        for (byte[] array : deduplicatedList) {
+            buffer.put(array);
+        }
+        final Rbit rbit = jumpOnMatch ? Rbit1 : Rbit0;
+        final byte[] combinedBytes = buffer.array();
+        return append(new Instruction(Opcodes.JBSMATCH, rbit)
+                .addUnsigned((totalElements - 1) << 11 | elementSize)
+                .setTargetLabel(tgt)
+                .setBytesImm(combinedBytes));
+    }
+
+    /**
+     * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the
+     * packet at an offset specified by register0 match any of the elements in {@code bytesSet}.
+     * R=1 means check for equal.
+     */
+    public final Type addJumpIfBytesAtR0EqualsAnyOf(@NonNull List<byte[]> bytesList, String tgt) {
+        return addJumpIfBytesAtR0EqualsHelper(bytesList, tgt, true /* jumpOnMatch */);
+    }
+
+    /**
+     * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the
+     * packet at an offset specified by register0 match none of the elements in {@code bytesSet}.
+     * R=0 means check for not equal.
+     */
+    public final Type addJumpIfBytesAtR0EqualNoneOf(@NonNull List<byte[]> bytesList, String tgt) {
+        return addJumpIfBytesAtR0EqualsHelper(bytesList, tgt, false /* jumpOnMatch */);
+    }
+
 
     /**
      * Check if the byte is valid dns character: A-Z,0-9,-,_
diff --git a/tests/unit/src/android/net/apf/ApfV5Test.kt b/tests/unit/src/android/net/apf/ApfV5Test.kt
index 717f7ab..4107198 100644
--- a/tests/unit/src/android/net/apf/ApfV5Test.kt
+++ b/tests/unit/src/android/net/apf/ApfV5Test.kt
@@ -315,6 +315,15 @@
         assertFailsWith<IllegalArgumentException> {
             gen.addJumpIfOneOf(R0, List(34) { (it + 1).toLong() }.toSet(), PASS_LABEL)
         }
+        assertFailsWith<IllegalArgumentException> {
+            gen.addJumpIfBytesAtR0EqualsAnyOf(listOf(ByteArray(2048) { 1 }), PASS_LABEL )
+        }
+        assertFailsWith<IllegalArgumentException> {
+            gen.addJumpIfBytesAtR0EqualsAnyOf(
+                    listOf(byteArrayOf(1), byteArrayOf(1, 2)),
+                    PASS_LABEL
+            )
+        }
 
         val v4gen = ApfV4Generator(APF_VERSION_4)
         assertFailsWith<IllegalArgumentException> { v4gen.addCountAndDrop(PASSED_ARP) }
@@ -698,6 +707,25 @@
                 "0: joneof      r0, DROP, { 0, 128, 256, 65536 }",
                 "20: jnoneof     r1, DROP, { 0, 128, 256, 65536 }"
         ), ApfJniUtils.disassembleApf(program).map{ it.trim() })
+
+        gen = ApfV6Generator()
+        gen.addJumpIfBytesAtR0EqualsAnyOf(listOf(byteArrayOf(1, 2), byteArrayOf(3, 4)), DROP_LABEL)
+        gen.addJumpIfBytesAtR0EqualNoneOf(listOf(byteArrayOf(1, 2), byteArrayOf(3, 4)), DROP_LABEL)
+        gen.addJumpIfBytesAtR0EqualNoneOf(listOf(byteArrayOf(1, 1), byteArrayOf(1, 1)), DROP_LABEL)
+        program = gen.generate().skipEmptyData()
+        assertContentEquals(byteArrayOf(
+                encodeInstruction(opcode = 20, immLength = 2, register = 1),
+                0, 15, 8, 2, 1, 2, 3, 4,
+                encodeInstruction(opcode = 20, immLength = 2, register = 0),
+                0, 6, 8, 2, 1, 2, 3, 4,
+                encodeInstruction(opcode = 20, immLength = 1, register = 0),
+                1, 2, 1, 1
+        ), program)
+        assertContentEquals(listOf(
+                "0: jbseq       r0, 0x2, DROP, { 0102, 0304 }",
+                "9: jbsne       r0, 0x2, DROP, { 0102, 0304 }",
+                "18: jbsne       r0, 0x2, DROP, 0101"
+        ), ApfJniUtils.disassembleApf(program).map{ it.trim() })
     }
 
     @Test
@@ -1293,6 +1321,49 @@
     }
 
     @Test
+    fun testJumpMultipleByteSequencesMatch() {
+        var program = ApfV6Generator()
+                .addLoadImmediate(R0, 0)
+                .addJumpIfBytesAtR0EqualsAnyOf(
+                        listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+                        DROP_LABEL
+                )
+                .addPass()
+                .generate()
+        assertDrop(APF_VERSION_6, program, testPacket)
+
+        program = ApfV6Generator()
+                .addLoadImmediate(R0, 2)
+                .addJumpIfBytesAtR0EqualsAnyOf(
+                        listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+                        DROP_LABEL
+                )
+                .addPass()
+                .generate()
+        assertPass(APF_VERSION_6, program, testPacket)
+
+        program = ApfV6Generator()
+                .addLoadImmediate(R0, 1)
+                .addJumpIfBytesAtR0EqualNoneOf(
+                        listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+                        DROP_LABEL
+                )
+                .addPass()
+                .generate()
+        assertDrop(APF_VERSION_6, program, testPacket)
+
+        program = ApfV6Generator()
+                .addLoadImmediate(R0, 0)
+                .addJumpIfBytesAtR0EqualNoneOf(
+                        listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+                        DROP_LABEL
+                )
+                .addPass()
+                .generate()
+        assertPass(APF_VERSION_6, program, testPacket)
+    }
+
+    @Test
     fun testJumpOneOf() {
         var program = ApfV6Generator()
                 .addLoadImmediate(R0, 255)