Snap for 9886856 from 6da82b7951fccd978b158875d4e828d8bed23a60 to mainline-sdkext-release Change-Id: I45a11f64482d3291bd3a987ab6b5cd8a4949c1e0
diff --git a/src/android/net/apf/ApfGenerator.java b/src/android/net/apf/ApfGenerator.java index db51186..ee713c5 100644 --- a/src/android/net/apf/ApfGenerator.java +++ b/src/android/net/apf/ApfGenerator.java
@@ -282,10 +282,6 @@ } // Calculate distance from end of this instruction to instruction.offset. final int targetLabelOffset = targetLabelInstruction.offset - (offset + size()); - if (targetLabelOffset < 0) { - throw new IllegalInstructionException("backward branches disallowed; label: " + - mTargetLabel); - } return targetLabelOffset; }
diff --git a/src/android/net/apf/DnsUtils.java b/src/android/net/apf/DnsUtils.java new file mode 100644 index 0000000..e30ebc7 --- /dev/null +++ b/src/android/net/apf/DnsUtils.java
@@ -0,0 +1,407 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.net.apf; + +import static android.net.apf.ApfGenerator.Register.R0; +import static android.net.apf.ApfGenerator.Register.R1; + +import static com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN; +import static com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN; +import static com.android.net.module.util.NetworkStackConstants.UDP_HEADER_LEN; + +import androidx.annotation.NonNull; + +/** + * Utility class that generates generating APF filters for DNS packets. + */ +public class DnsUtils { + + /** Length of the DNS header. */ + private static final int DNS_HEADER_LEN = 12; + /** Offset of the qdcount field within the DNS header. */ + private static final int DNS_QDCOUNT_OFFSET = 4; + + // Static labels + private static final String LABEL_START_MATCH = "start_match"; + private static final String LABEL_PARSE_DNS_LABEL = "parse_dns_label"; + private static final String LABEL_FIND_NEXT_DNS_QUESTION = "find_next_dns_question"; + + // Length of the pointers used by compressed names. + private static final int LABEL_SIZE = Byte.BYTES; + private static final int POINTER_SIZE = Short.BYTES; + private static final int QUESTION_HEADER_SIZE = Short.BYTES + Short.BYTES; + private static final int LABEL_AND_QUESTION_HEADER_SIZE = LABEL_SIZE + QUESTION_HEADER_SIZE; + private static final int POINTER_AND_QUESTION_HEADER_SIZE = POINTER_SIZE + QUESTION_HEADER_SIZE; + + /** Memory slot that stores the offset within the packet of the DNS header. */ + private static final int SLOT_DNS_HEADER_OFFSET = 1; + /** Memory slot that stores the current parsing offset. */ + private static final int SLOT_CURRENT_PARSE_OFFSET = 2; + /** + * Memory slot that stores the offset after the current question, if the code is currently + * parsing a pointer, or 0 if it is not. + */ + private static final int SLOT_AFTER_POINTER_OFFSET = 3; + /** + * Contains qdcount remaining, as a negative number. For example, will be -1 when starting to + * parse a DNS packet with one question in it. It's stored as a negative number because adding 1 + * is much easier than subtracting 1 (which can't be done just by adding -1, because that just + * adds 254). + */ + private static final int SLOT_NEGATIVE_QDCOUNT_REMAINING = 6; + /** Memory slot used by the jump table. */ + private static final int SLOT_RETURN_VALUE_INDEX = 10; + + /** + * APF function: parse_dns_label + * + * Parses a label potentially containing a pointer, and calculates the label length and the + * offset of the label data. + * + * Inputs: + * - m[SLOT_DNS_HEADER_OFFSET]: offset of DNS header + * - m[SLOT_CURRENT_PARSE_OFFSET]: current parsing offset + * - m[SLOT_AFTER_POINTER_OFFSET]: offset after the question (e.g., offset of the next question, + * or offset of the answer section) if a pointer is being chased, 0 otherwise + * - m[SLOT_RETURN_VALUE_INDEX]: index into return jump table + * + * Outputs: + * - R1: label length + * - m[SLOT_CURRENT_PARSE_OFFSET]: offset of label text + */ + private static void genParseDnsLabel(ApfGenerator gen, JumpTable jumpTable) throws Exception { + final String labelParseDnsLabelReal = "parse_dns_label_real"; + final String labelPointerOffsetStored = "pointer_offset_stored"; + + /** + * :parse_dns_label + * // Load parsing offset. + * LDM R1, 2 // R1 = parsing offset. (All indexed loads use R1.) + */ + gen.defineLabel(LABEL_PARSE_DNS_LABEL); + gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET); + + + /** + * // Check that we’re in the DNS packet, i.e., that R1 >= m[SLOT_DNS_HEADER_OFFSET]. + * LDM R0, 1 // R0 = DNS header offset + * JGT R0, R1, DROP // Bad pointer. Drop. + */ + gen.addLoadFromMemory(R0, SLOT_DNS_HEADER_OFFSET); + gen.addJumpIfR0GreaterThanR1(ApfGenerator.DROP_LABEL); + + /** + * // Now parse the label. + * LDBX R0, [R1+0] // R0 = label length, R1 = parsing offset + * AND R0, 0xc0 // Is this a pointer? + * + * JEQ R0, 0, :parse_dns_label_real + */ + gen.addLoad8Indexed(R0, 0); + gen.addAnd(0xc0); + gen.addJumpIfR0Equals(0, labelParseDnsLabelReal); + + + /** + * // If we’re not already chasing a pointer, store offset after pointer into + * // m[SLOT_AFTER_POINTER_OFFSET]. + * LDM R0, 3 // R0 = previous offset after pointer + * JNE 0, :pointer_offset_stored + * MOV R0, R1 // R0 = R1 + * ADD R0, 6 // R0 = offset after pointer and record + * STM R0, 3 // Store offset after pointer + */ + gen.addLoadFromMemory(R0, SLOT_AFTER_POINTER_OFFSET); + gen.addJumpIfR0NotEquals(0, labelPointerOffsetStored); + gen.addMove(R0); + gen.addAdd(POINTER_AND_QUESTION_HEADER_SIZE); + gen.addStoreToMemory(R0, SLOT_AFTER_POINTER_OFFSET); + + /** + * :pointer_offset_stored + * LDHX R0, [R1+0] // R0 = 2-byte pointer value + * AND R0, 0x3ff // R0 = pointer destination offset (from DNS header) + * LDM R1, 1 // R1 = offset in packet of DNS header + * ADD R0, R1 // R0 = pointer destination offset + * LDM R1, 2 // R1 = current parsing offset + * JEQ R0, R1, DROP // Drop if pointer points here... + * JGT R0, R1, DROP // ... or after here (must point backwards) + * STM R0, 2 // Set next parsing offset to pointer destination + */ + gen.defineLabel(labelPointerOffsetStored); + gen.addLoad16Indexed(R0, 0); + gen.addAnd(0x3ff); + gen.addLoadFromMemory(R1, SLOT_DNS_HEADER_OFFSET); + gen.addAddR1(); + gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET); + gen.addJumpIfR0EqualsR1(ApfGenerator.DROP_LABEL); + gen.addJumpIfR0GreaterThanR1(ApfGenerator.DROP_LABEL); + gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + + /** // Pointer chased. Parse starting from the pointer destination (which may also be a + * pointer). + * JMP :parse_dns_label + */ + gen.addJump(LABEL_PARSE_DNS_LABEL); + + /** + * :parse_real_label + * // This is where the real (non-pointer) label starts. + * // Load label length into R1, and return to caller. + * // m[SLOT_CURRENT_PARSE_OFFSET] already contains label offset. + * LDHX R1 [R1+0] // R1 = label length + */ + gen.defineLabel(labelParseDnsLabelReal); + gen.addLoad8Indexed(R1, 0); + + /** // Return + * LDM R0, 10 + * JMP :jump_table + */ + gen.addLoadFromMemory(R0, SLOT_RETURN_VALUE_INDEX); + gen.addJump(jumpTable.getStartLabel()); + } + + /** + * APF function: find_next_dns_question + * + * Finds the next question in the question section, or drops the packet if there is none. + * + * Inputs: + * - m[SLOT_CURRENT_PARSE_OFFSET]: current parsing offset + * - m[SLOT_AFTER_POINTER_OFFSET]: offset after first pointer in name, or 0 if not chasing a + * pointer + * - m[SLOT_NEGATIVE_QDCOUNT_REMAINING]: qdcount remaining, as a negative number. This is + * because adding 1 is much easier than subtracting 1 (which can't be done just by + * adding -1, because that just adds 254) + * - m[SLOT_RETURN_VALUE_INDEX]: index into return jump table + * + * Outputs: + * None + */ + private static void genFindNextDnsQuestion(ApfGenerator gen, JumpTable jumpTable) + throws Exception { + final String labelFindNextDnsQuestionFollow = "find_next_dns_question_follow"; + final String labelFindNextDnsQuestionLabel = "find_next_dns_question_label"; + final String labelFindNextDnsQuestionLoop = "find_next_dns_question_loop"; + final String labelFindNextDnsQuestionNoPointer = "find_next_dns_question_no_pointer"; + final String labelFindNextDnsQuestionReturn = "find_next_dns_question_return"; + + // Function entry point. + gen.defineLabel(LABEL_FIND_NEXT_DNS_QUESTION); + + // Are we chasing a pointer? + gen.addLoadFromMemory(R0, SLOT_AFTER_POINTER_OFFSET); + gen.addJumpIfR0Equals(0, labelFindNextDnsQuestionFollow); + + // If so, offset after the pointer and question is stored in m[SLOT_AFTER_POINTER_OFFSET]. + // Move parsing offset there, clear m[SLOT_AFTER_POINTER_OFFSET], and return. + gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + gen.addLoadImmediate(R0, 0); + gen.addStoreToMemory(R0, SLOT_AFTER_POINTER_OFFSET); + gen.addJump(labelFindNextDnsQuestionReturn); + + // We weren't chasing a pointer. Loop, following the label chain, until we reach a + // zero-length label or a pointer. At the beginning of the loop, the current parsing offset + // is m[SLOT_CURRENT_PARSE_OFFSET]. Move it to R1 and keep it in R1 throughout the loop. + gen.defineLabel(labelFindNextDnsQuestionFollow); + gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET); + + // Load label length. + gen.defineLabel(labelFindNextDnsQuestionLoop); + gen.addLoad8Indexed(R0, 0); + // Is it a pointer? + gen.addAnd(0xc0); + gen.addJumpIfR0Equals(0, labelFindNextDnsQuestionNoPointer); + // It's a pointer. Skip the pointer and question, and return. + gen.addLoadImmediate(R0, POINTER_AND_QUESTION_HEADER_SIZE); + gen.addAddR1(); + gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + gen.addJump(labelFindNextDnsQuestionReturn); + + // R1 still contains parsing offset. + gen.defineLabel(labelFindNextDnsQuestionNoPointer); + gen.addLoad8Indexed(R0, 0); + + // Zero-length label? We're done. + // Skip the label (1 byte) and query (2 bytes qtype, 2 bytes qclass) and return. + gen.addJumpIfR0NotEquals(0, labelFindNextDnsQuestionLabel); + gen.addLoadImmediate(R0, LABEL_AND_QUESTION_HEADER_SIZE); + gen.addAddR1(); + gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + gen.addJump(labelFindNextDnsQuestionReturn); + + // Non-zero length label. Consume it and continue. + gen.defineLabel(labelFindNextDnsQuestionLabel); + gen.addAdd(1); + gen.addAddR1(); + gen.addMove(R1); + gen.addJump(labelFindNextDnsQuestionLoop); + + gen.defineLabel(labelFindNextDnsQuestionReturn); + + // Is this the last question? If so, drop. + gen.addLoadFromMemory(R0, SLOT_NEGATIVE_QDCOUNT_REMAINING); + gen.addAdd(1); + gen.addStoreToMemory(R0, SLOT_NEGATIVE_QDCOUNT_REMAINING); + gen.addJumpIfR0Equals(0, ApfGenerator.DROP_LABEL); + + // If not, return. + gen.addJump(jumpTable.getStartLabel()); + } + + /** + * Returns the name of a jump label used while parsing the specified DNS label. + * TODO: use another scheme to name the labels. Using the label name does not work if the name + * to be matched contains duplicate labels. + */ + private static String getPostMatchJumpTargetForLabel(String label) { + return label + "_parsed"; + } + + private static void addMatchLabel(@NonNull ApfGenerator gen, @NonNull JumpTable jumpTable, + @NonNull String label, @NonNull String nextLabel) throws Exception { + final String parsedLabel = getPostMatchJumpTargetForLabel(label); + final String noMatchLabel = label + "_nomatch"; + gen.defineLabel(label); + + // Store return address. + gen.addLoadImmediate(R0, jumpTable.getIndex(parsedLabel)); + gen.addStoreToMemory(R0, SLOT_RETURN_VALUE_INDEX); + + // Call the parse_label function. + gen.addJump(LABEL_PARSE_DNS_LABEL); + + gen.defineLabel(parsedLabel); + + // If label length is 0, this is the end of the name and the match failed. + gen.addSwap(); // Move label length from R1 to R0 + gen.addJumpIfR0Equals(0, noMatchLabel); + + // Label parsed, check it matches what we're looking for. + gen.addJumpIfR0NotEquals(label.length(), noMatchLabel); + gen.addLoadFromMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + gen.addAdd(1); + gen.addJumpIfBytesNotEqual(R0, label.getBytes(), noMatchLabel); + + // Prep offset of next label. + gen.addAdd(label.length()); + gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + + // Match, go to next label. + gen.addJump(nextLabel); + + // Match failed. Go to next name, and restart from the first match. + gen.defineLabel(noMatchLabel); + gen.addLoadImmediate(R1, jumpTable.getIndex(LABEL_START_MATCH)); + gen.addStoreToMemory(R1, SLOT_RETURN_VALUE_INDEX); + gen.addJump(LABEL_FIND_NEXT_DNS_QUESTION); + } + + /** + * Generates a filter that accepts DNS packet that ask for the specified name. + * + * The filter supports compressed DNS names and scanning through multiple questions in the same + * packet, e.g., as used by MDNS. However, it currently only supports one DNS name. + * + * Limitations: + * - Filter size is just under 300 bytes for a typical question. + * - Because the bytecode extensively uses backwards jumps, it can hit the APF interpreter + * instruction limit. This limit causes the APF interpreter to accept the packet once it has + * executed a number of instructions equal to the program length in bytes. + * A program that consists *only* of this filter will be able to execute just under 300 + * instructions, and will be able to correctly drop packets with two questions but not three + * questions. In a real APF setup, there will be other code (e.g., RA filtering) which counts + * against the limit, so the filter should be able to parse packets with more questions. + * - Matches are case-sensitive. This is due to the use of JNEBS to match DNS labels and is + * likely impossible to overcome without interpreter changes. + * + * TODO: + * - Add unit tests for the parse_dns_label and find_next_dns_question functions. + * - Add an efficient way to parse the first question in the packet. This can be done much more + * efficiently because the first name cannot be compressed. + * - Support accepting more than one name. + * - For devices where power saving is a priority (e.g., flat panel TVs), add support for + * dropping packets with more than X queries, to ensure the filter will drop the packet rather + * than hit the instruction limit. + */ + public static void generateFilter(ApfGenerator gen, boolean ipv6, String[] labels) + throws Exception { + final int etherPlusUdpLen = ETHER_HEADER_LEN + UDP_HEADER_LEN; + + final String labelJumpTable = "jump_table"; + + // Initialize parsing + /** + * - m[SLOT_DNS_HEADER_OFFSET]: offset of DNS header + * - m[SLOT_CURRENT_PARSE_OFFSET]: current parsing offset (start of question section) + * - m[SLOT_AFTER_POINTER_OFFSET]: offset after first pointer in name, must be 0 when + * starting a new name + * - m[SLOT_NEGATIVE_QDCOUNT_REMAINING]: negative qdcount + */ + if (ipv6) { + gen.addLoadImmediate(R0, IPV6_HEADER_LEN); + } else { + gen.addLoadFromMemory(R0, ApfGenerator.IPV4_HEADER_SIZE_MEMORY_SLOT); + } + gen.addAdd(etherPlusUdpLen); + gen.addStoreToMemory(R0, SLOT_DNS_HEADER_OFFSET); + + gen.addAdd(DNS_QDCOUNT_OFFSET); + gen.addMove(R1); + gen.addLoad16Indexed(R1, 0); + gen.addNeg(R1); + gen.addStoreToMemory(R1, SLOT_NEGATIVE_QDCOUNT_REMAINING); + + gen.addAdd(DNS_HEADER_LEN - DNS_QDCOUNT_OFFSET); + gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); + + gen.addLoadImmediate(R0, 0); + gen.addStoreToMemory(R0, SLOT_AFTER_POINTER_OFFSET); + + gen.addJump(LABEL_START_MATCH); + + // Create JumpTable but + final JumpTable table = new JumpTable(labelJumpTable, SLOT_RETURN_VALUE_INDEX); + + // Generate bytecode for parse_label function. + genParseDnsLabel(gen, table); + genFindNextDnsQuestion(gen, table); + + // Populate jump table. Should be before the code that calls to it (i.e., the addMatchLabel + // calls below) because otherwise all the jumps are backwards, and backwards jumps are more + // expensive (5 bytes of bytecode) + for (int i = 0; i < labels.length; i++) { + table.addLabel(getPostMatchJumpTargetForLabel(labels[i])); + } + table.addLabel(LABEL_START_MATCH); + table.generate(gen); + + // Add match statements for name. + gen.defineLabel(LABEL_START_MATCH); + for (int i = 0; i < labels.length; i++) { + final String nextLabel = (i == labels.length - 1) + ? ApfGenerator.PASS_LABEL + : labels[i + 1]; + addMatchLabel(gen, table, labels[i], nextLabel); + } + gen.addJump(ApfGenerator.DROP_LABEL); + } + + private DnsUtils() { + } +}
diff --git a/src/android/net/apf/JumpTable.java b/src/android/net/apf/JumpTable.java new file mode 100644 index 0000000..b449697 --- /dev/null +++ b/src/android/net/apf/JumpTable.java
@@ -0,0 +1,136 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.net.apf; + +import static android.net.apf.ApfGenerator.Register.R0; + +import androidx.annotation.NonNull; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; + +/** + * A table that stores program labels to jump to. + * + * This is needed to implement subroutines because APF jump targets must be known at compile + * time and cannot be computed dynamically. + * + * At compile time, any code that calls a subroutine must: + * + * <ul> + * <li>Define a label (via {@link ApfGenerator#defineLabel}) immediately after the code that invokes + * the subroutine. + * <li>Add the label to the jump table using {@link #addLabel}. + * <li>Generate the jump table in the program. + * </ul> + * + * <p>At runtime, before invoking the subroutine, the APF code must store the index of the return + * label (obtained via {@link #getIndex}) into the jump table's return address memory slot, and then + * jump to the subroutine. To return to the caller, the subroutine must jump to the label returned + * by {@link #getStartLabel}, and the jump table will then jump to the return label. + * + * <p>Implementation details: + * <ul> + * <li>The jumps are added to the program in the same order as the labels were added. + * <li>Using the jump table will overwrite the value of register R0. + * <li>If, before calling a subroutine, the APF code stores a nonexistent return label index, then + * the jump table will pass the packet. This cannot happen if the code correctly obtains the + * label using {@link #getIndex}, as that would throw an exception when generating the program. + * </ul> + * + * For example: + * <pre> + * JumpTable t = new JumpTable("my_jump_table", 7); + * t.addLabel("jump_1"); + * ... + * t.addLabel("after_parsing"); + * ... + * t.addLabel("after_subroutine"); + * t.generate(gen); + *</pre> + * generates the following APF code: + * <pre> + * :my_jump_table + * ldm r0, 7 + * jeq r0, 0, jump_1 + * jeq r0, 1, after_parsing + * jeq r0, 2, after_subroutine + * jmp DROP + * </pre> + */ +public class JumpTable { + /** Maps jump indices to jump labels. LinkedHashMap guarantees iteration in insertion order. */ + private final Map<String, Integer> mJumpLabels = new LinkedHashMap<>(); + /** Label to jump to to execute this jump table. */ + private final String mStartLabel; + /** Memory slot that contains the return value index. */ + private final int mReturnAddressMemorySlot; + + private int mIndex = 0; + + public JumpTable(@NonNull String startLabel, int returnAddressMemorySlot) { + Objects.requireNonNull(startLabel); + mStartLabel = startLabel; + if (returnAddressMemorySlot < 0 + || returnAddressMemorySlot >= ApfGenerator.FIRST_PREFILLED_MEMORY_SLOT) { + throw new IllegalArgumentException("Invalid memory slot " + returnAddressMemorySlot); + } + mReturnAddressMemorySlot = returnAddressMemorySlot; + } + + /** Returns the label to jump to to start executing the table. */ + @NonNull + public String getStartLabel() { + return mStartLabel; + } + + /** + * Adds a jump label to this table. Passing a label that was already added is not an error. + * + * @param label the label to add + */ + public void addLabel(@NonNull String label) { + Objects.requireNonNull(label); + if (mJumpLabels.putIfAbsent(label, mIndex) == null) mIndex++; + } + + /** + * Gets the index of a previously-added label. + * @return the label's index. + * @throws NoSuchElementException if the label was never added. + */ + public int getIndex(@NonNull String label) { + final Integer index = mJumpLabels.get(label); + if (index == null) throw new NoSuchElementException("Unknown label " + label); + return index; + } + + /** Generates APF code for this jump table */ + public void generate(@NonNull ApfGenerator gen) + throws ApfGenerator.IllegalInstructionException { + gen.defineLabel(mStartLabel); + gen.addLoadFromMemory(R0, mReturnAddressMemorySlot); + for (Map.Entry<String, Integer> e : mJumpLabels.entrySet()) { + gen.addJumpIfR0Equals(e.getValue(), e.getKey()); + } + // Cannot happen unless the program is malformed (i.e., the APF code loads an invalid return + // label index before jumping to the subroutine. + gen.addJump(ApfGenerator.PASS_LABEL); + } +}
diff --git a/tests/unit/src/android/net/apf/ApfTest.java b/tests/unit/src/android/net/apf/ApfTest.java index f0747b1..46c36a0 100644 --- a/tests/unit/src/android/net/apf/ApfTest.java +++ b/tests/unit/src/android/net/apf/ApfTest.java
@@ -100,6 +100,7 @@ import java.net.Inet6Address; import java.net.InetAddress; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; import java.util.Random; @@ -189,7 +190,7 @@ final String msg = "Unexpected APF verdict. To debug:\n" + " apf_run --program " + HexDump.toHexString(program) + " --packet " + HexDump.toHexString(packet) + " --trace | less\n "; - assertReturnCodesEqual(expected, apfSimulate(program, packet, null, filterAge)); + assertReturnCodesEqual(msg, expected, apfSimulate(program, packet, null, filterAge)); } private void assertVerdict(int expected, byte[] program, byte[] packet) { @@ -1164,8 +1165,6 @@ private static final byte[] IPV4_MDNS_MULTICAST_ADDR = {(byte) 224, 0, 0, (byte) 251}; private static final byte[] IPV6_MDNS_MULTICAST_ADDR = {(byte) 0xff, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, (byte) 0xfb}; - private static final int DNS_HEADER_LEN = 12; - private static final int DNS_QDCOUNT_OFFSET = 4; private static final int IPV6_UDP_DEST_PORT_OFFSET = IPV6_PAYLOAD_OFFSET + 2; private static final int MDNS_UDP_PORT = 5353; @@ -1304,8 +1303,17 @@ apfFilter.shutdown(); } - private static byte[] makeMdnsV4Packet(String qname) throws IOException { - final ByteBuffer buf = ByteBuffer.wrap(new byte[100]); + private static void fillQuestionSection(ByteBuffer buf, String... qnames) throws IOException { + buf.put(new DnsPacket.DnsHeader(0 /* id */, 0 /* flags */, qnames.length, 0 /* ancount */) + .getBytes()); + for (String qname : qnames) { + buf.put(DnsPacket.DnsRecord.makeQuestion(qname, 0 /* nsType */, 0 /* nsClass */) + .getBytes()); + } + } + + private static byte[] makeMdnsV4Packet(String... qnames) throws IOException { + final ByteBuffer buf = ByteBuffer.wrap(new byte[256]); final PacketBuilder builder = new PacketBuilder(buf); builder.writeL2Header(MacAddress.fromString("11:22:33:44:55:66"), MacAddress.fromBytes(ETH_MULTICAST_MDNS_v4_MAC_ADDRESS), @@ -1315,13 +1323,12 @@ (Inet4Address) Inet4Address.getByAddress(IPV4_SOURCE_ADDR), (Inet4Address) Inet4Address.getByAddress(IPV4_MDNS_MULTICAST_ADDR)); builder.writeUdpHeader((short) MDNS_UDP_PORT, (short) MDNS_UDP_PORT); - buf.put(new DnsPacket.DnsHeader(0, 0, 1, 0).getBytes()); - buf.put(DnsPacket.DnsRecord.makeQuestion(qname, 0, 0).getBytes()); + fillQuestionSection(buf, qnames); return builder.finalizePacket().array(); } - private static byte[] makeMdnsV6Packet(String qname) throws IOException { - ByteBuffer buf = ByteBuffer.wrap(new byte[100]); + private static byte[] makeMdnsV6Packet(String... qnames) throws IOException { + ByteBuffer buf = ByteBuffer.wrap(new byte[256]); final PacketBuilder builder = new PacketBuilder(buf); builder.writeL2Header(MacAddress.fromString("11:22:33:44:55:66"), MacAddress.fromBytes(ETH_MULTICAST_MDNS_V6_MAC_ADDRESS), @@ -1330,8 +1337,61 @@ (Inet6Address) InetAddress.getByAddress(IPV6_ANOTHER_ADDR), (Inet6Address) Inet6Address.getByAddress(IPV6_MDNS_MULTICAST_ADDR)); builder.writeUdpHeader((short) MDNS_UDP_PORT, (short) MDNS_UDP_PORT); - buf.put(new DnsPacket.DnsHeader(0, 0, 1, 0).getBytes()); - buf.put(DnsPacket.DnsRecord.makeQuestion(qname, 0, 0).getBytes()); + fillQuestionSection(buf, qnames); + return builder.finalizePacket().array(); + } + + private static void putLabel(ByteBuffer buf, String label) { + final byte[] bytes = label.getBytes(StandardCharsets.UTF_8); + buf.put((byte) bytes.length); + buf.put(bytes); + } + + private static void putPointer(ByteBuffer buf, int offset) { + short pointer = (short) (offset | 0xc000); + buf.putShort(pointer); + } + + private static byte[] makeMdnsCompressedV6Packet() throws IOException { + ByteBuffer buf = ByteBuffer.wrap(new byte[256]); + final PacketBuilder builder = new PacketBuilder(buf); + builder.writeL2Header(MacAddress.fromString("11:22:33:44:55:66"), + MacAddress.fromBytes(ETH_MULTICAST_MDNS_V6_MAC_ADDRESS), + (short) ETH_P_IPV6); + builder.writeIpv6Header(0x680515ca /* vtf */, (byte) IPPROTO_UDP, (short) 0 /* hopLimit */, + (Inet6Address) InetAddress.getByAddress(IPV6_ANOTHER_ADDR), + (Inet6Address) Inet6Address.getByAddress(IPV6_MDNS_MULTICAST_ADDR)); + builder.writeUdpHeader((short) MDNS_UDP_PORT, (short) MDNS_UDP_PORT); + + ByteBuffer questions = ByteBuffer.allocate(128); + questions.put(new DnsPacket.DnsHeader(123, 0, 4, 0).getBytes()); + + // myservice.tcp.local + putLabel(questions, "myservice"); + final int offsetTcpLocal = questions.position(); + putLabel(questions, "tcp"); + final int offsetLocal = questions.position(); + putLabel(questions, "local"); + putLabel(questions, ""); + questions.put(new byte[4]); + + // googlecast.tcp.local + putLabel(questions, "googlecast"); + putPointer(questions, offsetTcpLocal); + questions.put(new byte[4]); + + // matter.tcp.local + putLabel(questions, "matter"); + putPointer(questions, offsetTcpLocal); + questions.put(new byte[4]); + + // myhostname.local + putLabel(questions, "myhostname"); + putPointer(questions, offsetLocal); + questions.put(new byte[4]); + + buf.put(questions.array()); + return builder.finalizePacket().array(); } @@ -1397,6 +1457,66 @@ apfFilter.shutdown(); } + private void doTestDnsParsing(boolean expectPass, boolean ipv6, String filterName, + byte[] pkt) throws Exception { + ApfGenerator gen = new ApfGenerator(MIN_APF_VERSION); + final String[] labels = filterName.split(/*regex=*/ "[.]"); + DnsUtils.generateFilter(gen, ipv6, labels); + + // Hack to prevent the APF instruction limit triggering. + for (int i = 0; i < 500; i++) { + gen.addOr(0); + } + + byte[] program = gen.generate(); + Log.d(TAG, "prog_len=" + program.length); + if (expectPass) { + assertPass(program, pkt, 0); + } else { + assertDrop(program, pkt, 0); + } + } + + private void doTestDnsParsing(boolean expectPass, boolean ipv6, String filterName, + String... packetNames) throws Exception { + final byte[] pkt = ipv6 ? makeMdnsV6Packet(packetNames) : makeMdnsV4Packet(packetNames); + doTestDnsParsing(expectPass, ipv6, filterName, pkt); + } + + @Test + public void testDnsParsing() throws Exception { + final boolean ipv4 = false, ipv6 = true; + + // Packets with one question. + doTestDnsParsing(true, ipv6, "googlecast.tcp.local", "googlecast.tcp.local"); + doTestDnsParsing(true, ipv4, "googlecast.tcp.local", "googlecast.tcp.local"); + doTestDnsParsing(false, ipv6, "googlecast.tcp.lozal", "googlecast.tcp.local"); + doTestDnsParsing(false, ipv4, "googlecast.tcp.lozal", "googlecast.tcp.local"); + doTestDnsParsing(false, ipv6, "googlecast.udp.local", "googlecast.tcp.local"); + doTestDnsParsing(false, ipv4, "googlecast.udp.local", "googlecast.tcp.local"); + + // Packets with multiple questions that can't be compressed. Not realistic for MDNS since + // everything ends in .local, but useful to ensure only the non-compression code is tested. + doTestDnsParsing(true, ipv6, "googlecast.tcp.local", + "googlecast.tcp.local", "developer.android.com"); + doTestDnsParsing(true, ipv4, "googlecast.tcp.local", + "developer.android.com", "googlecast.tcp.local"); + doTestDnsParsing(false, ipv4, "googlecast.tcp.local", + "developer.android.com", "googlecast.tcp.invalid"); + doTestDnsParsing(true, ipv6, "googlecast.tcp.local", + "developer.android.com", "www.google.co.jp", "googlecast.tcp.local"); + doTestDnsParsing(false, ipv4, "veryverylongservicename.tcp.local", + "www.google.co.jp", "veryverylongservicename.tcp.invalid"); + doTestDnsParsing(true, ipv6, "googlecast.tcp.local", + "www.google.co.jp", "googlecast.tcp.local", "developer.android.com"); + + final byte[] pkt = makeMdnsCompressedV6Packet(); + doTestDnsParsing(true, ipv6, "googlecast.tcp.local", pkt); + doTestDnsParsing(true, ipv6, "matter.tcp.local", pkt); + doTestDnsParsing(true, ipv6, "myservice.tcp.local", pkt); + doTestDnsParsing(false, ipv6, "otherservice.tcp.local", pkt); + } + @Test public void testApfFilterMulticast() throws Exception { final byte[] unicastIpv4Addr = {(byte)192,0,2,63};
diff --git a/tests/unit/src/android/net/apf/JumpTableTest.kt b/tests/unit/src/android/net/apf/JumpTableTest.kt new file mode 100644 index 0000000..6fdf38f --- /dev/null +++ b/tests/unit/src/android/net/apf/JumpTableTest.kt
@@ -0,0 +1,104 @@ +/* + * Copyright (C) 2023 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.net.apf + +import androidx.test.filters.SmallTest +import androidx.test.runner.AndroidJUnit4 +import com.android.testutils.assertThrows +import java.util.NoSuchElementException +import java.util.concurrent.atomic.AtomicReference +import kotlin.test.assertEquals +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Mock +import org.mockito.Mockito.inOrder +import org.mockito.MockitoAnnotations + +@RunWith(AndroidJUnit4::class) +@SmallTest +class JumpTableTest { + + @Mock + lateinit var gen: ApfGenerator + + @Before + fun setUp() { + MockitoAnnotations.initMocks(this) + } + + @Test(expected = NullPointerException::class) + fun testNullStartLabel() { + // Can't use "null" because the method is @NonNull. + JumpTable(AtomicReference<String>(null).get(), 10) + } + + @Test(expected = IllegalArgumentException::class) + fun testNegativeSlot() { + JumpTable("my_jump_table", -1) + } + + @Test(expected = IllegalArgumentException::class) + fun testSlotTooLarge() { + JumpTable("my_jump_table", 13) + } + + @Test + fun testValidSlotNumbers() { + JumpTable("my_jump_table", 1) + JumpTable("my_jump_table", 10) + JumpTable("my_jump_table", 12) + } + + @Test + fun testGetStartLabel() { + assertEquals("xyz", JumpTable("xyz", 3).startLabel) + assertEquals("abc", JumpTable("abc", 9).startLabel) + } + + @Test + fun testCodeGeneration() { + val name = "my_jump_table" + val slot = 7 + + val j = JumpTable(name, slot) + j.addLabel("foo") + j.addLabel("bar") + j.addLabel("bar") + j.addLabel("baz") + + assertEquals(0, j.getIndex("foo")) + assertEquals(1, j.getIndex("bar")) + assertEquals(2, j.getIndex("baz")) + + assertThrows(NoSuchElementException::class.java) { + j.getIndex("nonexistent") + } + + val inOrder = inOrder(gen) + + j.generate(gen) + + inOrder.verify(gen).defineLabel(name) + inOrder.verify(gen).addLoadFromMemory(ApfGenerator.Register.R0, slot) + inOrder.verify(gen).addJumpIfR0Equals(0, "foo") + inOrder.verify(gen).addJumpIfR0Equals(1, "bar") + inOrder.verify(gen).addJumpIfR0Equals(2, "baz") + inOrder.verify(gen).addJump(ApfGenerator.PASS_LABEL) + inOrder.verifyNoMoreInteractions() + } +}