Implement addJumpIfBytesAtR0Equals{Any|None}Of() in ApfV6GeneratorBase
Test: TH
Change-Id: I8596e3fd9f231488eb7e03e1034436a21b6a2f35
diff --git a/src/android/net/apf/ApfV6GeneratorBase.java b/src/android/net/apf/ApfV6GeneratorBase.java
index c00bc3d..e2932d5 100644
--- a/src/android/net/apf/ApfV6GeneratorBase.java
+++ b/src/android/net/apf/ApfV6GeneratorBase.java
@@ -24,7 +24,11 @@
import com.android.net.module.util.CollectionUtils;
import com.android.net.module.util.HexDump;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
import java.util.Objects;
import java.util.Set;
@@ -427,6 +431,74 @@
bytes.length).setTargetLabel(tgt).setBytesImm(bytes));
}
+ private List<byte[]> validateDeduplicateBytesList(List<byte[]> bytesList) {
+ if (bytesList == null || bytesList.size() == 0) {
+ throw new IllegalArgumentException(
+ "bytesList size must > 0, current size: "
+ + (bytesList == null ? "null" : bytesList.size()));
+ }
+ for (byte[] bytes : bytesList) {
+ validateBytes(bytes);
+ }
+ final int elementSize = bytesList.get(0).length;
+ if (elementSize > 2097151) { // 2 ^ 21 - 1
+ throw new IllegalArgumentException("too many elements");
+ }
+ List<byte[]> deduplicatedList = new ArrayList<>();
+ deduplicatedList.add(bytesList.get(0));
+ for (int i = 1; i < bytesList.size(); ++i) {
+ if (elementSize != bytesList.get(i).length) {
+ throw new IllegalArgumentException("byte arrays in the set have different size");
+ }
+ int j = 0;
+ for (; j < deduplicatedList.size(); ++j) {
+ if (Arrays.equals(bytesList.get(i), deduplicatedList.get(j))) {
+ break;
+ }
+ }
+ if (j == deduplicatedList.size()) {
+ deduplicatedList.add(bytesList.get(i));
+ }
+ }
+ return deduplicatedList;
+ }
+
+ private Type addJumpIfBytesAtR0EqualsHelper(@NonNull List<byte[]> bytesList, String tgt,
+ boolean jumpOnMatch) {
+ final List<byte[]> deduplicatedList = validateDeduplicateBytesList(bytesList);
+ final int elementSize = deduplicatedList.get(0).length;
+ final int totalElements = deduplicatedList.size();
+ final int totalSize = elementSize * totalElements;
+ final ByteBuffer buffer = ByteBuffer.allocate(totalSize);
+ for (byte[] array : deduplicatedList) {
+ buffer.put(array);
+ }
+ final Rbit rbit = jumpOnMatch ? Rbit1 : Rbit0;
+ final byte[] combinedBytes = buffer.array();
+ return append(new Instruction(Opcodes.JBSMATCH, rbit)
+ .addUnsigned((totalElements - 1) << 11 | elementSize)
+ .setTargetLabel(tgt)
+ .setBytesImm(combinedBytes));
+ }
+
+ /**
+ * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the
+ * packet at an offset specified by register0 match any of the elements in {@code bytesSet}.
+ * R=1 means check for equal.
+ */
+ public final Type addJumpIfBytesAtR0EqualsAnyOf(@NonNull List<byte[]> bytesList, String tgt) {
+ return addJumpIfBytesAtR0EqualsHelper(bytesList, tgt, true /* jumpOnMatch */);
+ }
+
+ /**
+ * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the
+ * packet at an offset specified by register0 match none of the elements in {@code bytesSet}.
+ * R=0 means check for not equal.
+ */
+ public final Type addJumpIfBytesAtR0EqualNoneOf(@NonNull List<byte[]> bytesList, String tgt) {
+ return addJumpIfBytesAtR0EqualsHelper(bytesList, tgt, false /* jumpOnMatch */);
+ }
+
/**
* Check if the byte is valid dns character: A-Z,0-9,-,_
diff --git a/tests/unit/src/android/net/apf/ApfV5Test.kt b/tests/unit/src/android/net/apf/ApfV5Test.kt
index 717f7ab..4107198 100644
--- a/tests/unit/src/android/net/apf/ApfV5Test.kt
+++ b/tests/unit/src/android/net/apf/ApfV5Test.kt
@@ -315,6 +315,15 @@
assertFailsWith<IllegalArgumentException> {
gen.addJumpIfOneOf(R0, List(34) { (it + 1).toLong() }.toSet(), PASS_LABEL)
}
+ assertFailsWith<IllegalArgumentException> {
+ gen.addJumpIfBytesAtR0EqualsAnyOf(listOf(ByteArray(2048) { 1 }), PASS_LABEL )
+ }
+ assertFailsWith<IllegalArgumentException> {
+ gen.addJumpIfBytesAtR0EqualsAnyOf(
+ listOf(byteArrayOf(1), byteArrayOf(1, 2)),
+ PASS_LABEL
+ )
+ }
val v4gen = ApfV4Generator(APF_VERSION_4)
assertFailsWith<IllegalArgumentException> { v4gen.addCountAndDrop(PASSED_ARP) }
@@ -698,6 +707,25 @@
"0: joneof r0, DROP, { 0, 128, 256, 65536 }",
"20: jnoneof r1, DROP, { 0, 128, 256, 65536 }"
), ApfJniUtils.disassembleApf(program).map{ it.trim() })
+
+ gen = ApfV6Generator()
+ gen.addJumpIfBytesAtR0EqualsAnyOf(listOf(byteArrayOf(1, 2), byteArrayOf(3, 4)), DROP_LABEL)
+ gen.addJumpIfBytesAtR0EqualNoneOf(listOf(byteArrayOf(1, 2), byteArrayOf(3, 4)), DROP_LABEL)
+ gen.addJumpIfBytesAtR0EqualNoneOf(listOf(byteArrayOf(1, 1), byteArrayOf(1, 1)), DROP_LABEL)
+ program = gen.generate().skipEmptyData()
+ assertContentEquals(byteArrayOf(
+ encodeInstruction(opcode = 20, immLength = 2, register = 1),
+ 0, 15, 8, 2, 1, 2, 3, 4,
+ encodeInstruction(opcode = 20, immLength = 2, register = 0),
+ 0, 6, 8, 2, 1, 2, 3, 4,
+ encodeInstruction(opcode = 20, immLength = 1, register = 0),
+ 1, 2, 1, 1
+ ), program)
+ assertContentEquals(listOf(
+ "0: jbseq r0, 0x2, DROP, { 0102, 0304 }",
+ "9: jbsne r0, 0x2, DROP, { 0102, 0304 }",
+ "18: jbsne r0, 0x2, DROP, 0101"
+ ), ApfJniUtils.disassembleApf(program).map{ it.trim() })
}
@Test
@@ -1293,6 +1321,49 @@
}
@Test
+ fun testJumpMultipleByteSequencesMatch() {
+ var program = ApfV6Generator()
+ .addLoadImmediate(R0, 0)
+ .addJumpIfBytesAtR0EqualsAnyOf(
+ listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+ DROP_LABEL
+ )
+ .addPass()
+ .generate()
+ assertDrop(APF_VERSION_6, program, testPacket)
+
+ program = ApfV6Generator()
+ .addLoadImmediate(R0, 2)
+ .addJumpIfBytesAtR0EqualsAnyOf(
+ listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+ DROP_LABEL
+ )
+ .addPass()
+ .generate()
+ assertPass(APF_VERSION_6, program, testPacket)
+
+ program = ApfV6Generator()
+ .addLoadImmediate(R0, 1)
+ .addJumpIfBytesAtR0EqualNoneOf(
+ listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+ DROP_LABEL
+ )
+ .addPass()
+ .generate()
+ assertDrop(APF_VERSION_6, program, testPacket)
+
+ program = ApfV6Generator()
+ .addLoadImmediate(R0, 0)
+ .addJumpIfBytesAtR0EqualNoneOf(
+ listOf(byteArrayOf(1, 2, 3), byteArrayOf(6, 5, 4)),
+ DROP_LABEL
+ )
+ .addPass()
+ .generate()
+ assertPass(APF_VERSION_6, program, testPacket)
+ }
+
+ @Test
fun testJumpOneOf() {
var program = ApfV6Generator()
.addLoadImmediate(R0, 255)