Merge "Implement addCountAnd{Drop|Pass}IfR0Is{OneOf|NoneOf}() methods" into main
diff --git a/src/android/net/apf/ApfV4Generator.java b/src/android/net/apf/ApfV4Generator.java
index 0fc05ec..c8d899d 100644
--- a/src/android/net/apf/ApfV4Generator.java
+++ b/src/android/net/apf/ApfV4Generator.java
@@ -19,8 +19,12 @@
 import static android.net.apf.BaseApfGenerator.Register.R0;
 import static android.net.apf.BaseApfGenerator.Register.R1;
 
+import android.annotation.NonNull;
+
 import com.android.internal.annotations.VisibleForTesting;
 
+import java.util.Set;
+
 /**
  * APFv4 assembler/generator. A tool for generating an APFv4 program.
  *
@@ -176,6 +180,54 @@
                 mCountAndPassLabel);
     }
 
+    @Override
+    public ApfV4Generator addCountAndPassIfR0IsOneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkPassCounterRange(cnt);
+        maybeAddLoadCounterOffset(R1, cnt);
+        for (Long v : values) {
+            addJumpIfR0Equals(v, mCountAndPassLabel);
+        }
+        return this;
+    }
+
+    @Override
+    public ApfV4Generator addCountAndDropIfR0IsOneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkDropCounterRange(cnt);
+        maybeAddLoadCounterOffset(R1, cnt);
+        for (Long v : values) {
+            addJumpIfR0Equals(v, mCountAndDropLabel);
+        }
+        return this;
+    }
+
+    @Override
+    public ApfV4Generator addCountAndPassIfR0IsNoneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkPassCounterRange(cnt);
+        String tgt = getUniqueLabel();
+        for (Long v : values) {
+            addJumpIfR0Equals(v, tgt);
+        }
+        addCountAndPass(cnt);
+        defineLabel(tgt);
+        return this;
+    }
+
+    @Override
+    public ApfV4Generator addCountAndDropIfR0IsNoneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkDropCounterRange(cnt);
+        String tgt = getUniqueLabel();
+        for (Long v : values) {
+            addJumpIfR0Equals(v, tgt);
+        }
+        addCountAndDrop(cnt);
+        defineLabel(tgt);
+        return this;
+    }
+
     /**
      * Add an instruction to the end of the program to load 32 bits from the data memory into
      * {@code register}. The source address is computed by adding the signed immediate
diff --git a/src/android/net/apf/ApfV4GeneratorBase.java b/src/android/net/apf/ApfV4GeneratorBase.java
index 5efe061..3d5414a 100644
--- a/src/android/net/apf/ApfV4GeneratorBase.java
+++ b/src/android/net/apf/ApfV4GeneratorBase.java
@@ -26,6 +26,7 @@
 import com.android.internal.annotations.VisibleForTesting;
 
 import java.util.Objects;
+import java.util.Set;
 
 /**
  * APF assembler/generator.  A tool for generating an APF program.
@@ -459,6 +460,38 @@
             ApfCounterTracker.Counter cnt) throws IllegalInstructionException;
 
     /**
+     * Add instructions to the end of the program to increase counter and pass packet if the
+     * value in register0 is one of {@code values}.
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndPassIfR0IsOneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException;
+
+    /**
+     * Add instructions to the end of the program to increase counter and drop packet if the
+     * value in register0 is one of {@code values}.
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndDropIfR0IsOneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException;
+
+    /**
+     * Add instructions to the end of the program to increase counter and pass packet if the
+     * value in register0 is none of {@code values}.
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndPassIfR0IsNoneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException;
+
+    /**
+     * Add instructions to the end of the program to increase counter and drop packet if the
+     * value in register0 is none of {@code values}.
+     * WARNING: may modify R1
+     */
+    public abstract Type addCountAndDropIfR0IsNoneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException;
+
+    /**
      * Add an instruction to the end of the program to load memory slot {@code slot} into
      * {@code register}.
      */
diff --git a/src/android/net/apf/ApfV6GeneratorBase.java b/src/android/net/apf/ApfV6GeneratorBase.java
index 8791c77..a94e887 100644
--- a/src/android/net/apf/ApfV6GeneratorBase.java
+++ b/src/android/net/apf/ApfV6GeneratorBase.java
@@ -736,6 +736,38 @@
     }
 
     @Override
+    public Type addCountAndPassIfR0IsOneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkPassCounterRange(cnt);
+        final String tgt = getUniqueLabel();
+        return addJumpIfNoneOf(R0, values, tgt).addCountAndPass(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public Type addCountAndDropIfR0IsOneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkDropCounterRange(cnt);
+        final String tgt = getUniqueLabel();
+        return addJumpIfNoneOf(R0, values, tgt).addCountAndDrop(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public Type addCountAndPassIfR0IsNoneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkPassCounterRange(cnt);
+        final String tgt = getUniqueLabel();
+        return addJumpIfOneOf(R0, values, tgt).addCountAndPass(cnt).defineLabel(tgt);
+    }
+
+    @Override
+    public Type addCountAndDropIfR0IsNoneOf(@NonNull Set<Long> values,
+            ApfCounterTracker.Counter cnt) throws IllegalInstructionException {
+        checkDropCounterRange(cnt);
+        final String tgt = getUniqueLabel();
+        return addJumpIfOneOf(R0, values, tgt).addCountAndDrop(cnt).defineLabel(tgt);
+    }
+
+    @Override
     public final Type addLoadCounter(Register register, ApfCounterTracker.Counter counter)
             throws IllegalInstructionException {
         return append(new Instruction(Opcodes.LDDW, register).addUnsigned(counter.value()));
diff --git a/tests/unit/src/android/net/apf/ApfNewTest.kt b/tests/unit/src/android/net/apf/ApfNewTest.kt
index 3c8e168..620a72f 100644
--- a/tests/unit/src/android/net/apf/ApfNewTest.kt
+++ b/tests/unit/src/android/net/apf/ApfNewTest.kt
@@ -317,6 +317,18 @@
             gen.addCountAndPassIfR0AnyBitsSet(3, DROPPED_ETH_BROADCAST)
         }
         assertFailsWith<IllegalArgumentException> {
+            gen.addCountAndDropIfR0IsOneOf(setOf(3), PASSED_ARP)
+        }
+        assertFailsWith<IllegalArgumentException> {
+            gen.addCountAndPassIfR0IsOneOf(setOf(3), DROPPED_ETH_BROADCAST)
+        }
+        assertFailsWith<IllegalArgumentException> {
+            gen.addCountAndDropIfR0IsNoneOf(setOf(3), PASSED_ARP)
+        }
+        assertFailsWith<IllegalArgumentException> {
+            gen.addCountAndPassIfR0IsNoneOf(setOf(3), DROPPED_ETH_BROADCAST)
+        }
+        assertFailsWith<IllegalArgumentException> {
             gen.addWrite32(byteArrayOf())
         }
         assertFailsWith<IllegalArgumentException> {
@@ -371,6 +383,18 @@
         assertFailsWith<IllegalArgumentException> {
             v4gen.addCountAndPassIfR0AnyBitsSet(3, DROPPED_ETH_BROADCAST)
         }
+        assertFailsWith<IllegalArgumentException> {
+            v4gen.addCountAndDropIfR0IsOneOf(setOf(3), PASSED_ARP)
+        }
+        assertFailsWith<IllegalArgumentException> {
+            v4gen.addCountAndPassIfR0IsOneOf(setOf(3), DROPPED_ETH_BROADCAST)
+        }
+        assertFailsWith<IllegalArgumentException> {
+            v4gen.addCountAndDropIfR0IsNoneOf(setOf(3), PASSED_ARP)
+        }
+        assertFailsWith<IllegalArgumentException> {
+            v4gen.addCountAndPassIfR0IsNoneOf(setOf(3), DROPPED_ETH_BROADCAST)
+        }
     }
 
     @Test
@@ -1041,6 +1065,58 @@
         expectedMap = getInitialMap()
         expectedMap[PASSED_ARP] = 1
         assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndDropIfR0IsOneOf(setOf(123, 124), DROPPED_ETH_BROADCAST)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(APF_VERSION_6, DROP, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[DROPPED_ETH_BROADCAST] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndPassIfR0IsOneOf(setOf(123, 124), PASSED_ARP)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(APF_VERSION_6, PASS, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[PASSED_ARP] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndDropIfR0IsNoneOf(setOf(122, 124), DROPPED_ETH_BROADCAST)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(APF_VERSION_6, DROP, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[DROPPED_ETH_BROADCAST] = 1
+        assertEquals(expectedMap, counterMap)
+
+        program = getGenerator()
+                .addLoadImmediate(R0, 123)
+                .addCountAndPassIfR0IsNoneOf(setOf(122, 124), PASSED_ARP)
+                .addPass()
+                .addCountTrampoline()
+                .generate()
+        dataRegion = ByteArray(Counter.totalSize()) { 0 }
+        assertVerdict(APF_VERSION_6, PASS, program, testPacket, dataRegion)
+        counterMap = decodeCountersIntoMap(dataRegion)
+        expectedMap = getInitialMap()
+        expectedMap[PASSED_ARP] = 1
+        assertEquals(expectedMap, counterMap)
     }
 
     @Test