Snap for 11670064 from b51ad95232e18eb4813f83b5aeb63877499b8111 to mainline-cellbroadcast-release Change-Id: Ied22ed8d67f908d0573e08cd1cf67c2b0c5d2b51
diff --git a/Android.bp b/Android.bp index 9af9da4..8910371 100644 --- a/Android.bp +++ b/Android.bp
@@ -315,7 +315,10 @@ srcs: [ ":framework-networkstack-shared-srcs", ], - libs: ["unsupportedappusage"], + libs: [ + "error_prone_annotations", + "unsupportedappusage", + ], static_libs: [ "androidx.annotation_annotation", "modules-utils-build_system", @@ -332,6 +335,7 @@ "net-utils-device-common-ip", "net-utils-device-common-netlink", "net-utils-device-common-struct", + "net-utils-device-common-struct-base", ], } @@ -374,8 +378,6 @@ ], manifest: "AndroidManifestBase.xml", visibility: [ - "//frameworks/base/packages/Connectivity/tests/integration", - "//frameworks/base/tests/net/integration", "//packages/modules/Connectivity/Tethering/tests/integration", "//packages/modules/Connectivity/tests/integration", "//packages/modules/NetworkStack/tests/unit", @@ -572,7 +574,7 @@ tools: ["stats-log-api-gen"], cmd: "$(location stats-log-api-gen) --java $(out) --module network_stack" + " --javaPackage com.android.networkstack.metrics --javaClass NetworkStackStatsLog" + - " --minApiLevel 30 --compileApiLevel 30", + " --minApiLevel 30", out: ["com/android/networkstack/metrics/NetworkStackStatsLog.java"], }
diff --git a/common/networkstackclient/Android.bp b/common/networkstackclient/Android.bp index c566715..a5d7230 100644 --- a/common/networkstackclient/Android.bp +++ b/common/networkstackclient/Android.bp
@@ -260,7 +260,6 @@ "networkstack-aidl-latest", ], visibility: [ - "//frameworks/base/packages/Connectivity/service", "//packages/modules/Connectivity/Tethering", "//packages/modules/Connectivity/service", "//frameworks/base/services/net",
diff --git a/src/android/net/apf/AndroidPacketFilter.java b/src/android/net/apf/AndroidPacketFilter.java index 18c704e..f4856ec 100644 --- a/src/android/net/apf/AndroidPacketFilter.java +++ b/src/android/net/apf/AndroidPacketFilter.java
@@ -19,6 +19,8 @@ import android.net.NattKeepalivePacketDataParcelable; import android.net.TcpKeepalivePacketDataParcelable; +import androidx.annotation.Nullable; + import com.android.internal.util.IndentingPrintWriter; /** @@ -77,4 +79,23 @@ * Dump the status of APF. */ void dump(IndentingPrintWriter pw); + + /** + * Indicates whether the ApfFilter is currently running / paused for test and debugging + * purposes. + */ + default boolean isRunning() { + return true; + } + + /** Pause ApfFilter updates for testing purposes. */ + default void pause() {} + + /** Resume ApfFilter updates for testing purposes. */ + default void resume() {} + + /** Return hex string of current APF snapshot for testing purposes. */ + default @Nullable String getDataSnapshotHexString() { + return null; + } }
diff --git a/src/android/net/apf/ApfConstant.java b/src/android/net/apf/ApfConstant.java new file mode 100644 index 0000000..6e59a7f --- /dev/null +++ b/src/android/net/apf/ApfConstant.java
@@ -0,0 +1,102 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.net.apf; + +/** + * The class which declares constants used in ApfFilter and unit tests. + */ +public final class ApfConstant { + + private ApfConstant() {} + public static final int ETH_HEADER_LEN = 14; + public static final int ETH_DEST_ADDR_OFFSET = 0; + public static final int ETH_ETHERTYPE_OFFSET = 12; + public static final int ETH_TYPE_MIN = 0x0600; + public static final int ETH_TYPE_MAX = 0xFFFF; + // TODO: Make these offsets relative to end of link-layer header; don't include ETH_HEADER_LEN. + public static final int IPV4_TOTAL_LENGTH_OFFSET = ETH_HEADER_LEN + 2; + public static final int IPV4_FRAGMENT_OFFSET_OFFSET = ETH_HEADER_LEN + 6; + // Endianness is not an issue for this constant because the APF interpreter always operates in + // network byte order. + public static final int IPV4_FRAGMENT_OFFSET_MASK = 0x1fff; + public static final int IPV4_FRAGMENT_MORE_FRAGS_MASK = 0x2000; + public static final int IPV4_PROTOCOL_OFFSET = ETH_HEADER_LEN + 9; + public static final int IPV4_SRC_ADDR_OFFSET = ETH_HEADER_LEN + 12; + public static final int IPV4_DEST_ADDR_OFFSET = ETH_HEADER_LEN + 16; + public static final int IPV4_ANY_HOST_ADDRESS = 0; + public static final int IPV4_BROADCAST_ADDRESS = -1; // 255.255.255.255 + + // Traffic class and Flow label are not byte aligned. Luckily we + // don't care about either value so we'll consider bytes 1-3 of the + // IPv6 header as don't care. + public static final int IPV6_FLOW_LABEL_OFFSET = ETH_HEADER_LEN + 1; + public static final int IPV6_FLOW_LABEL_LEN = 3; + public static final int IPV6_NEXT_HEADER_OFFSET = ETH_HEADER_LEN + 6; + public static final int IPV6_SRC_ADDR_OFFSET = ETH_HEADER_LEN + 8; + public static final int IPV6_DEST_ADDR_OFFSET = ETH_HEADER_LEN + 24; + public static final int IPV6_HEADER_LEN = 40; + // The IPv6 all nodes address ff02::1 + public static final byte[] IPV6_ALL_NODES_ADDRESS = + { (byte) 0xff, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; + + public static final int ICMP6_TYPE_OFFSET = ETH_HEADER_LEN + IPV6_HEADER_LEN; + + public static final int IPPROTO_HOPOPTS = 0; + + // NOTE: this must be added to the IPv4 header length in IPV4_HEADER_SIZE_MEMORY_SLOT + public static final int TCP_UDP_DESTINATION_PORT_OFFSET = ETH_HEADER_LEN + 2; + public static final int UDP_HEADER_LEN = 8; + + public static final int TCP_HEADER_SIZE_OFFSET = 12; + + public static final int DHCP_SERVER_PORT = 67; + public static final int DHCP_CLIENT_PORT = 68; + // NOTE: this must be added to the IPv4 header length in IPV4_HEADER_SIZE_MEMORY_SLOT + + public static final int ARP_HEADER_OFFSET = ETH_HEADER_LEN; + public static final byte[] ARP_IPV4_HEADER = { + 0, 1, // Hardware type: Ethernet (1) + 8, 0, // Protocol type: IP (0x0800) + 6, // Hardware size: 6 + 4, // Protocol size: 4 + }; + public static final int ARP_OPCODE_OFFSET = ARP_HEADER_OFFSET + 6; + // Opcode: ARP request (0x0001), ARP reply (0x0002) + public static final short ARP_OPCODE_REQUEST = 1; + public static final short ARP_OPCODE_REPLY = 2; + public static final int ARP_SOURCE_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 14; + public static final int ARP_TARGET_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 24; + // Limit on the Black List size to cap on program usage for this + // TODO: Select a proper max length + public static final int APF_MAX_ETH_TYPE_BLACK_LIST_LEN = 20; + + public static final byte[] ETH_MULTICAST_MDNS_V4_MAC_ADDRESS = + {(byte) 0x01, (byte) 0x00, (byte) 0x5e, (byte) 0x00, (byte) 0x00, (byte) 0xfb}; + public static final byte[] ETH_MULTICAST_MDNS_V6_MAC_ADDRESS = + {(byte) 0x33, (byte) 0x33, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0xfb}; + public static final int MDNS_PORT = 5353; + + public static final int ECHO_PORT = 7; + public static final int DNS_HEADER_LEN = 12; + public static final int DNS_QDCOUNT_OFFSET = 4; + // NOTE: this must be added to the IPv4 header length in IPV4_HEADER_SIZE_MEMORY_SLOT, or the + // IPv6 header length. + public static final int DHCP_CLIENT_MAC_OFFSET = ETH_HEADER_LEN + UDP_HEADER_LEN + 28; + public static final int MDNS_QDCOUNT_OFFSET = + ETH_HEADER_LEN + UDP_HEADER_LEN + DNS_QDCOUNT_OFFSET; + public static final int MDNS_QNAME_OFFSET = + ETH_HEADER_LEN + UDP_HEADER_LEN + DNS_HEADER_LEN; +}
diff --git a/src/android/net/apf/ApfCounterTracker.java b/src/android/net/apf/ApfCounterTracker.java index b2b52e9..a5c31ab 100644 --- a/src/android/net/apf/ApfCounterTracker.java +++ b/src/android/net/apf/ApfCounterTracker.java
@@ -19,8 +19,6 @@ import android.util.ArrayMap; import android.util.Log; -import com.android.internal.annotations.VisibleForTesting; - import java.util.Arrays; import java.util.List; import java.util.Map; @@ -38,7 +36,6 @@ * buffer, using negative byte offsets, where -4 is equivalent to maximumApfProgramSize - 4, * the last writable 32bit word. */ - @VisibleForTesting public enum Counter { RESERVED_OOB, // Points to offset 0 from the end of the buffer (out-of-bounds) ENDIANNESS, // APFv6 interpreter stores 0x12345678 here @@ -48,7 +45,9 @@ CORRUPT_DNS_PACKET, // hardcoded in APFv6 interpreter FILTER_AGE_SECONDS, FILTER_AGE_16384THS, - PASSED_ARP, + APF_VERSION, + APF_PROGRAM_ID, + PASSED_ARP, // see also MIN_PASS_COUNTER below PASSED_DHCP, PASSED_IPV4, PASSED_IPV6_NON_ICMP, @@ -60,10 +59,12 @@ PASSED_ARP_UNICAST_REPLY, PASSED_NON_IP_UNICAST, PASSED_MDNS, - DROPPED_ETH_BROADCAST, + PASSED_MLD, // see also MAX_PASS_COUNTER below + DROPPED_ETH_BROADCAST, // see also MIN_DROP_COUNTER below DROPPED_RA, DROPPED_GARP_REPLY, DROPPED_ARP_OTHER_HOST, + DROPPED_ARP_REQUEST_NO_ADDRESS, DROPPED_IPV4_L2_BROADCAST, DROPPED_IPV4_BROADCAST_ADDR, DROPPED_IPV4_BROADCAST_NET, @@ -82,7 +83,7 @@ DROPPED_MDNS, DROPPED_IPV4_TCP_PORT7_UNICAST, DROPPED_ARP_NON_IPV4, - DROPPED_ARP_UNKNOWN; + DROPPED_ARP_UNKNOWN; // see also MAX_DROP_COUNTER below /** * Returns the negative byte offset from the end of the APF data segment for @@ -108,6 +109,11 @@ } } + public static final Counter MIN_DROP_COUNTER = Counter.DROPPED_ETH_BROADCAST; + public static final Counter MAX_DROP_COUNTER = Counter.DROPPED_ARP_UNKNOWN; + public static final Counter MIN_PASS_COUNTER = Counter.PASSED_ARP; + public static final Counter MAX_PASS_COUNTER = Counter.PASSED_MLD; + private static final String TAG = ApfCounterTracker.class.getSimpleName(); private final List<Counter> mCounterList;
diff --git a/src/android/net/apf/ApfFilter.java b/src/android/net/apf/ApfFilter.java index 0b2c101..2f50ebc 100644 --- a/src/android/net/apf/ApfFilter.java +++ b/src/android/net/apf/ApfFilter.java
@@ -1,5 +1,5 @@ /* - * Copyright (C) 2016 The Android Open Source Project + * Copyright (C) 2024 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,51 @@ package android.net.apf; +import static android.net.apf.ApfConstant.APF_MAX_ETH_TYPE_BLACK_LIST_LEN; +import static android.net.apf.ApfConstant.ARP_HEADER_OFFSET; +import static android.net.apf.ApfConstant.ARP_IPV4_HEADER; +import static android.net.apf.ApfConstant.ARP_OPCODE_OFFSET; +import static android.net.apf.ApfConstant.ARP_OPCODE_REPLY; +import static android.net.apf.ApfConstant.ARP_OPCODE_REQUEST; +import static android.net.apf.ApfConstant.ARP_SOURCE_IP_ADDRESS_OFFSET; +import static android.net.apf.ApfConstant.ARP_TARGET_IP_ADDRESS_OFFSET; +import static android.net.apf.ApfConstant.DHCP_CLIENT_MAC_OFFSET; +import static android.net.apf.ApfConstant.DHCP_CLIENT_PORT; +import static android.net.apf.ApfConstant.ECHO_PORT; +import static android.net.apf.ApfConstant.ETH_DEST_ADDR_OFFSET; +import static android.net.apf.ApfConstant.ETH_ETHERTYPE_OFFSET; +import static android.net.apf.ApfConstant.ETH_HEADER_LEN; +import static android.net.apf.ApfConstant.ETH_MULTICAST_MDNS_V4_MAC_ADDRESS; +import static android.net.apf.ApfConstant.ETH_MULTICAST_MDNS_V6_MAC_ADDRESS; +import static android.net.apf.ApfConstant.ETH_TYPE_MAX; +import static android.net.apf.ApfConstant.ETH_TYPE_MIN; +import static android.net.apf.ApfConstant.ICMP6_TYPE_OFFSET; +import static android.net.apf.ApfConstant.IPPROTO_HOPOPTS; +import static android.net.apf.ApfConstant.IPV4_ANY_HOST_ADDRESS; +import static android.net.apf.ApfConstant.IPV4_BROADCAST_ADDRESS; +import static android.net.apf.ApfConstant.IPV4_DEST_ADDR_OFFSET; +import static android.net.apf.ApfConstant.IPV4_FRAGMENT_MORE_FRAGS_MASK; +import static android.net.apf.ApfConstant.IPV4_FRAGMENT_OFFSET_MASK; +import static android.net.apf.ApfConstant.IPV4_FRAGMENT_OFFSET_OFFSET; +import static android.net.apf.ApfConstant.IPV4_PROTOCOL_OFFSET; +import static android.net.apf.ApfConstant.IPV4_TOTAL_LENGTH_OFFSET; +import static android.net.apf.ApfConstant.IPV6_ALL_NODES_ADDRESS; +import static android.net.apf.ApfConstant.IPV6_DEST_ADDR_OFFSET; +import static android.net.apf.ApfConstant.IPV6_FLOW_LABEL_LEN; +import static android.net.apf.ApfConstant.IPV6_FLOW_LABEL_OFFSET; +import static android.net.apf.ApfConstant.IPV6_HEADER_LEN; +import static android.net.apf.ApfConstant.IPV6_NEXT_HEADER_OFFSET; +import static android.net.apf.ApfConstant.IPV6_SRC_ADDR_OFFSET; +import static android.net.apf.ApfConstant.MDNS_PORT; +import static android.net.apf.ApfConstant.MDNS_QDCOUNT_OFFSET; +import static android.net.apf.ApfConstant.MDNS_QNAME_OFFSET; +import static android.net.apf.ApfConstant.TCP_HEADER_SIZE_OFFSET; +import static android.net.apf.ApfConstant.TCP_UDP_DESTINATION_PORT_OFFSET; +import static android.net.apf.BaseApfGenerator.DROP_LABEL; +import static android.net.apf.BaseApfGenerator.FILTER_AGE_MEMORY_SLOT; +import static android.net.apf.BaseApfGenerator.IPV4_HEADER_SIZE_MEMORY_SLOT; +import static android.net.apf.BaseApfGenerator.PACKET_SIZE_MEMORY_SLOT; +import static android.net.apf.BaseApfGenerator.PASS_LABEL; import static android.net.apf.BaseApfGenerator.Register.R0; import static android.net.apf.BaseApfGenerator.Register.R1; import static android.net.util.SocketUtils.makePacketSocketAddress; @@ -81,7 +126,6 @@ import java.io.FileDescriptor; import java.io.IOException; import java.net.Inet4Address; -import java.net.Inet6Address; import java.net.InetAddress; import java.net.SocketAddress; import java.net.SocketException; @@ -99,7 +143,7 @@ * For networks that support packet filtering via APF programs, {@code ApfFilter} * listens for IPv6 ICMPv6 router advertisements (RAs) and generates APF programs to * filter out redundant duplicate ones. - * + * <p> * Threading model: * A collection of RAs we've received is kept in mRas. Generating APF programs uses mRas to * know what RAs to filter for, thus generating APF programs is dependent on mRas. @@ -118,6 +162,7 @@ // Helper class for specifying functional filter parameters. public static class ApfConfiguration { public ApfCapabilities apfCapabilities; + public int installableProgramSizeClamp = Integer.MAX_VALUE; public boolean multicastFilter; public boolean ieee802_3Filter; public int[] ethTypeBlackList; @@ -125,6 +170,7 @@ public int acceptRaMinLft; public boolean shouldHandleLightDoze; public long minMetricsSessionDurationMs; + public boolean enableApfV6; } /** A wrapper class of {@link SystemClock} to be mocked in unit tests. */ @@ -140,7 +186,7 @@ /** * When APFv4 is supported, loads R1 with the offset of the specified counter. */ - private void maybeSetupCounter(ApfV4Generator gen, Counter c) { + private void maybeSetupCounter(ApfV4GeneratorBase<?> gen, Counter c) { if (mApfCapabilities.hasDataAccess()) { gen.addLoadImmediate(R1, c.offset()); } @@ -189,86 +235,8 @@ private static final boolean DBG = true; private static final boolean VDBG = false; - private static final int ETH_HEADER_LEN = 14; - private static final int ETH_DEST_ADDR_OFFSET = 0; - private static final int ETH_ETHERTYPE_OFFSET = 12; - private static final int ETH_TYPE_MIN = 0x0600; - private static final int ETH_TYPE_MAX = 0xFFFF; - // TODO: Make these offsets relative to end of link-layer header; don't include ETH_HEADER_LEN. - private static final int IPV4_TOTAL_LENGTH_OFFSET = ETH_HEADER_LEN + 2; - private static final int IPV4_FRAGMENT_OFFSET_OFFSET = ETH_HEADER_LEN + 6; - // Endianness is not an issue for this constant because the APF interpreter always operates in - // network byte order. - private static final int IPV4_FRAGMENT_OFFSET_MASK = 0x1fff; - private static final int IPV4_FRAGMENT_MORE_FRAGS_MASK = 0x2000; - private static final int IPV4_PROTOCOL_OFFSET = ETH_HEADER_LEN + 9; - private static final int IPV4_DEST_ADDR_OFFSET = ETH_HEADER_LEN + 16; - private static final int IPV4_ANY_HOST_ADDRESS = 0; - private static final int IPV4_BROADCAST_ADDRESS = -1; // 255.255.255.255 - private static final int IPV4_HEADER_LEN = 20; // Without options - - // Traffic class and Flow label are not byte aligned. Luckily we - // don't care about either value so we'll consider bytes 1-3 of the - // IPv6 header as don't care. - private static final int IPV6_FLOW_LABEL_OFFSET = ETH_HEADER_LEN + 1; - private static final int IPV6_FLOW_LABEL_LEN = 3; - private static final int IPV6_NEXT_HEADER_OFFSET = ETH_HEADER_LEN + 6; - private static final int IPV6_SRC_ADDR_OFFSET = ETH_HEADER_LEN + 8; - private static final int IPV6_DEST_ADDR_OFFSET = ETH_HEADER_LEN + 24; - private static final int IPV6_HEADER_LEN = 40; - // The IPv6 all nodes address ff02::1 - private static final byte[] IPV6_ALL_NODES_ADDRESS = - { (byte) 0xff, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }; - - private static final int ICMP6_TYPE_OFFSET = ETH_HEADER_LEN + IPV6_HEADER_LEN; - - private static final int IPPROTO_HOPOPTS = 0; - - // NOTE: this must be added to the IPv4 header length in IPV4_HEADER_SIZE_MEMORY_SLOT - private static final int TCP_UDP_DESTINATION_PORT_OFFSET = ETH_HEADER_LEN + 2; - private static final int UDP_HEADER_LEN = 8; - - private static final int TCP_HEADER_SIZE_OFFSET = 12; - - private static final int DHCP_CLIENT_PORT = 68; - // NOTE: this must be added to the IPv4 header length in IPV4_HEADER_SIZE_MEMORY_SLOT - private static final int DHCP_CLIENT_MAC_OFFSET = ETH_HEADER_LEN + UDP_HEADER_LEN + 28; - - private static final int ARP_HEADER_OFFSET = ETH_HEADER_LEN; - private static final byte[] ARP_IPV4_HEADER = { - 0, 1, // Hardware type: Ethernet (1) - 8, 0, // Protocol type: IP (0x0800) - 6, // Hardware size: 6 - 4, // Protocol size: 4 - }; - private static final int ARP_OPCODE_OFFSET = ARP_HEADER_OFFSET + 6; - // Opcode: ARP request (0x0001), ARP reply (0x0002) - private static final short ARP_OPCODE_REQUEST = 1; - private static final short ARP_OPCODE_REPLY = 2; - private static final int ARP_SOURCE_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 14; - private static final int ARP_TARGET_IP_ADDRESS_OFFSET = ARP_HEADER_OFFSET + 24; - // Limit on the Black List size to cap on program usage for this - // TODO: Select a proper max length - private static final int APF_MAX_ETH_TYPE_BLACK_LIST_LEN = 20; - - private static final byte[] ETH_MULTICAST_MDNS_V4_MAC_ADDRESS = - {(byte) 0x01, (byte) 0x00, (byte) 0x5e, (byte) 0x00, (byte) 0x00, (byte) 0xfb}; - private static final byte[] ETH_MULTICAST_MDNS_V6_MAC_ADDRESS = - {(byte) 0x33, (byte) 0x33, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0xfb}; - private static final int MDNS_PORT = 5353; - - private static final int ECHO_PORT = 7; - private static final int DNS_HEADER_LEN = 12; - private static final int DNS_QDCOUNT_OFFSET = 4; - // NOTE: this must be added to the IPv4 header length in IPV4_HEADER_SIZE_MEMORY_SLOT, or the - // IPv6 header length. - private static final int MDNS_QDCOUNT_OFFSET = - ETH_HEADER_LEN + UDP_HEADER_LEN + DNS_QDCOUNT_OFFSET; - private static final int MDNS_QNAME_OFFSET = - ETH_HEADER_LEN + UDP_HEADER_LEN + DNS_HEADER_LEN; - - private final ApfCapabilities mApfCapabilities; + private final int mInstallableProgramSizeClamp; private final IpClientCallbacksWrapper mIpClientCallback; private final InterfaceParams mInterfaceParams; private final TokenBucket mTokenBucket; @@ -289,7 +257,7 @@ private final Clock mClock; private final ApfCounterTracker mApfCounterTracker = new ApfCounterTracker(); @GuardedBy("this") - private long mSessionStartMs = 0; + private final long mSessionStartMs; @GuardedBy("this") private int mNumParseErrorRas = 0; @GuardedBy("this") @@ -314,6 +282,7 @@ private final int mAcceptRaMinLft; private final boolean mShouldHandleLightDoze; + private final boolean mEnableApfV6; private final NetworkQuirkMetrics mNetworkQuirkMetrics; private final IpClientRaInfoMetrics mIpClientRaInfoMetrics; private final ApfSessionInfoMetrics mApfSessionInfoMetrics; @@ -362,7 +331,6 @@ } } }; - private final Context mContext; // Our IPv4 address, if we have just one, otherwise null. @GuardedBy("this") @@ -371,6 +339,11 @@ @GuardedBy("this") private int mIPv4PrefixLength; + // mIsRunning is reflects the state of the ApfFilter during integration tests. ApfFilter can be + // paused using "adb shell cmd apf <iface> <cmd>" commands. A paused ApfFilter will not install + // any new programs, but otherwise operate normally. + private volatile boolean mIsRunning = true; + private final Dependencies mDependencies; public ApfFilter(Context context, ApfConfiguration config, InterfaceParams ifParams, @@ -392,13 +365,13 @@ IpClientCallbacksWrapper ipClientCallback, NetworkQuirkMetrics networkQuirkMetrics, Dependencies dependencies, Clock clock) { mApfCapabilities = config.apfCapabilities; + mInstallableProgramSizeClamp = config.installableProgramSizeClamp; mIpClientCallback = ipClientCallback; mInterfaceParams = ifParams; mMulticastFilter = config.multicastFilter; mDrop802_3Frames = config.ieee802_3Filter; mMinRdnssLifetimeSec = config.minRdnssLifetimeSec; mAcceptRaMinLft = config.acceptRaMinLft; - mContext = context; mShouldHandleLightDoze = config.shouldHandleLightDoze; mDependencies = dependencies; mNetworkQuirkMetrics = networkQuirkMetrics; @@ -407,6 +380,7 @@ mClock = clock; mSessionStartMs = mClock.elapsedRealtime(); mMinMetricsSessionDurationMs = config.minMetricsSessionDurationMs; + mEnableApfV6 = config.enableApfV6; if (mApfCapabilities.hasDataAccess()) { mCountAndPassLabel = "countAndPass"; @@ -414,8 +388,8 @@ } else { // APFv4 unsupported: turn jumps to the counter trampolines to immediately PASS or DROP, // preserving the original pre-APFv4 behavior. - mCountAndPassLabel = ApfV4Generator.PASS_LABEL; - mCountAndDropLabel = ApfV4Generator.DROP_LABEL; + mCountAndPassLabel = PASS_LABEL; + mCountAndDropLabel = DROP_LABEL; } // Now fill the black list from the passed array @@ -435,6 +409,10 @@ // Listen for doze-mode transition changes to enable/disable the IPv6 multicast filter. mDependencies.addDeviceIdleReceiver(mDeviceIdleReceiver, mShouldHandleLightDoze); + + mDependencies.onApfFilterCreated(this); + // mReceiveThread is created in maybeStartFilter() and halted in shutdown(). + mDependencies.onThreadCreated(mReceiveThread); } /** @@ -475,11 +453,31 @@ public IpClientRaInfoMetrics getIpClientRaInfoMetrics() { return new IpClientRaInfoMetrics(); } + + /** + * Callback to be called when an ApfFilter instance is created. + * + * This method is designed to be overridden in test classes to collect created ApfFilter + * instances. + */ + public void onApfFilterCreated(@NonNull AndroidPacketFilter apfFilter) { + } + + /** + * Callback to be called when a ReceiveThread instance is created. + * + * This method is designed for overriding in test classes to collect created threads and + * waits for the termination. + */ + public void onThreadCreated(@NonNull Thread thread) { + } } public synchronized void setDataSnapshot(byte[] data) { mDataSnapshot = data; - mApfCounterTracker.updateCountersFromData(data); + if (mIsRunning) { + mApfCounterTracker.updateCountersFromData(data); + } } private void log(String s) { @@ -492,7 +490,7 @@ } private static int[] filterEthTypeBlackList(int[] ethTypeBlackList) { - ArrayList<Integer> bl = new ArrayList<Integer>(); + ArrayList<Integer> bl = new ArrayList<>(); for (int p : ethTypeBlackList) { // Check if the protocol is a valid ether type @@ -532,7 +530,7 @@ // Clear the APF memory to reset all counters upon connecting to the first AP // in an SSID. This is limited to APFv4 devices because this large write triggers // a crash on some older devices (b/78905546). - if (mApfCapabilities.hasDataAccess()) { + if (mIsRunning && mApfCapabilities.hasDataAccess()) { byte[] zeroes = new byte[mApfCapabilities.maximumApfProgramSize]; if (!mIpClientCallback.installPacketFilter(zeroes)) { sendNetworkQuirkMetrics(NetworkQuirkEvent.QE_APF_INSTALL_FAILURE); @@ -631,10 +629,8 @@ private static final int ICMP6_RA_ROUTER_LIFETIME_LEN = 2; // Prefix information option. private static final int ICMP6_PREFIX_OPTION_TYPE = 3; - private static final int ICMP6_PREFIX_OPTION_LEN = 32; private static final int ICMP6_PREFIX_OPTION_VALID_LIFETIME_OFFSET = 4; private static final int ICMP6_PREFIX_OPTION_VALID_LIFETIME_LEN = 4; - private static final int ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_OFFSET = 8; private static final int ICMP6_PREFIX_OPTION_PREFERRED_LIFETIME_LEN = 4; // From RFC4861: source link-layer address @@ -670,8 +666,8 @@ private long mMinRioRouteLifetime = Long.MAX_VALUE; // Minimum lifetime of RDNSSs in packet, Long.MAX_VALUE means not seen. private long mMinRdnssLifetime = Long.MAX_VALUE; - // Minimum lifetime in packet - private final int mMinLifetime; + // The time in seconds in which some of the information contained in this RA expires. + private final int mExpirationTime; // When the packet was last captured, in seconds since Unix Epoch private final int mLastSeen; @@ -701,7 +697,7 @@ return "???"; } byte[] addressBytes = Arrays.copyOfRange(array, pos, pos + 16); - InetAddress address = (Inet6Address) InetAddress.getByAddress(addressBytes); + InetAddress address = InetAddress.getByAddress(addressBytes); return address.getHostAddress(); } catch (UnsupportedOperationException e) { // array() failed. Cannot happen, mPacket is array-backed and read-write. @@ -747,7 +743,7 @@ System.arraycopy(mPacket.array(), offset + 8, prefix, 0, optLen - 8); sb.append("RIO ").append(lifetime).append("s "); try { - InetAddress address = (Inet6Address) InetAddress.getByAddress(prefix); + InetAddress address = InetAddress.getByAddress(prefix); sb.append(address.getHostAddress()); } catch (UnknownHostException impossible) { sb.append("???"); @@ -988,7 +984,7 @@ break; } } - mMinLifetime = minLifetime(); + mExpirationTime = getExpirationTime(); } public enum MatchType { @@ -1101,14 +1097,13 @@ return MatchType.MATCH_DROP; } - // What is the minimum of all lifetimes within {@code packet} in seconds? - // Precondition: matches(packet, length) already returned true. - private int minLifetime() { + // Get the number of seconds in which some of the information contained in this RA expires. + private int getExpirationTime() { // While technically most lifetimes in the RA are u32s, as far as the RA filter is // concerned, INT_MAX is still a *much* longer lifetime than any filter would ever // reasonably be active for. - // Clamp minLifetime at INT_MAX. - int minLifetime = Integer.MAX_VALUE; + // Clamp expirationTime at INT_MAX. + int expirationTime = Integer.MAX_VALUE; for (PacketSection section : mPacketSections) { if (section.type != PacketSection.Type.LIFETIME) { continue; @@ -1118,14 +1113,14 @@ continue; } - minLifetime = (int) Math.min(minLifetime, section.lifetime); + expirationTime = (int) Math.min(expirationTime, section.lifetime); } - return minLifetime; + return expirationTime; } - // Filter for a fraction of the lifetime and adjust for the age of the RA. + // Filter for a fraction of the expiration time and adjust for the age of the RA. int getRemainingFilterLft(int currentTimeSeconds) { - int filterLifetime = (int) ((mMinLifetime / FRACTION_OF_LIFETIME_TO_FILTER) + int filterLifetime = ((mExpirationTime / FRACTION_OF_LIFETIME_TO_FILTER) - (currentTimeSeconds - mLastSeen)); filterLifetime = Math.max(0, filterLifetime); // Clamp filterLifetime to <= 65535, so it fits in 2 bytes. @@ -1135,14 +1130,14 @@ // Append a filter for this RA to {@code gen}. Jump to DROP_LABEL if it should be dropped. // Jump to the next filter if packet doesn't match this RA. @GuardedBy("ApfFilter.this") - void generateFilterLocked(ApfV4Generator gen, int timeSeconds) + void generateFilterLocked(ApfV4GeneratorBase<?> gen, int timeSeconds) throws IllegalInstructionException { String nextFilterLabel = "Ra" + getUniqueNumberLocked(); // Skip if packet is not the right size - gen.addLoadFromMemory(R0, gen.PACKET_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R0, PACKET_SIZE_MEMORY_SLOT); gen.addJumpIfR0NotEquals(mPacket.capacity(), nextFilterLabel); // Skip filter if expired - gen.addLoadFromMemory(R0, gen.FILTER_AGE_MEMORY_SLOT); + gen.addLoadFromMemory(R0, FILTER_AGE_MEMORY_SLOT); gen.addJumpIfR0GreaterThan(getRemainingFilterLft(timeSeconds), nextFilterLabel); for (PacketSection section : mPacketSections) { // Generate code to match the packet bytes. @@ -1220,8 +1215,7 @@ } } } - maybeSetupCounter(gen, Counter.DROPPED_RA); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_RA); gen.defineLabel(nextFilterLabel); } } @@ -1236,12 +1230,12 @@ // Append a filter for this keepalive ack to {@code gen}. // Jump to drop if it matches the keepalive ack. // Jump to the next filter if packet doesn't match the keepalive ack. - abstract void generateFilterLocked(ApfV4Generator gen) throws IllegalInstructionException; + abstract void generateFilterLocked(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException; } // A class to hold NAT-T keepalive ack information. private class NattKeepaliveResponse extends KeepalivePacket { - static final int UDP_LENGTH_OFFSET = 4; static final int UDP_HEADER_LEN = 8; protected class NattKeepaliveResponseData { @@ -1280,7 +1274,7 @@ @Override @GuardedBy("ApfFilter.this") - void generateFilterLocked(ApfV4Generator gen) throws IllegalInstructionException { + void generateFilterLocked(ApfV4GeneratorBase<?> gen) throws IllegalInstructionException { final String nextFilterLabel = "natt_keepalive_filter" + getUniqueNumberLocked(); gen.addLoadImmediate(R0, ETH_HEADER_LEN + IPV4_SRC_ADDR_OFFSET); @@ -1288,16 +1282,16 @@ // A NAT-T keepalive packet contains 1 byte payload with the value 0xff // Check payload length is 1 - gen.addLoadFromMemory(R0, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R0, IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addAdd(UDP_HEADER_LEN); gen.addSwap(); gen.addLoad16(R0, IPV4_TOTAL_LENGTH_OFFSET); gen.addNeg(R1); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0NotEquals(1, nextFilterLabel); // Check that the ports match - gen.addLoadFromMemory(R0, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R0, IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addAdd(ETH_HEADER_LEN); gen.addJumpIfBytesAtR0NotEqual(mPortFingerprint, nextFilterLabel); @@ -1305,8 +1299,7 @@ gen.addAdd(UDP_HEADER_LEN); gen.addJumpIfBytesAtR0NotEqual(mPayload, nextFilterLabel); - maybeSetupCounter(gen, Counter.DROPPED_IPV4_NATT_KEEPALIVE); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_IPV4_NATT_KEEPALIVE); gen.defineLabel(nextFilterLabel); } @@ -1382,7 +1375,8 @@ // Append a filter for this keepalive ack to {@code gen}. // Jump to drop if it matches the keepalive ack. // Jump to the next filter if packet doesn't match the keepalive ack. - abstract void generateFilterLocked(ApfV4Generator gen) throws IllegalInstructionException; + abstract void generateFilterLocked(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException; } private class TcpKeepaliveAckV4 extends TcpKeepaliveAck { @@ -1396,7 +1390,7 @@ @Override @GuardedBy("ApfFilter.this") - void generateFilterLocked(ApfV4Generator gen) throws IllegalInstructionException { + void generateFilterLocked(ApfV4GeneratorBase<?> gen) throws IllegalInstructionException { final String nextFilterLabel = "keepalive_ack" + getUniqueNumberLocked(); gen.addLoadImmediate(R0, ETH_HEADER_LEN + IPV4_SRC_ADDR_OFFSET); @@ -1405,32 +1399,31 @@ // Skip to the next filter if it's not zero-sized : // TCP_HEADER_SIZE + IPV4_HEADER_SIZE - ipv4_total_length == 0 // Load the IP header size into R1 - gen.addLoadFromMemory(R1, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R1, IPV4_HEADER_SIZE_MEMORY_SLOT); // Load the TCP header size into R0 (it's indexed by R1) gen.addLoad8Indexed(R0, ETH_HEADER_LEN + TCP_HEADER_SIZE_OFFSET); // Size offset is in the top nibble, but it must be multiplied by 4, and the two // top bits of the low nibble are guaranteed to be zeroes. Right-shift R0 by 2. gen.addRightShift(2); // R0 += R1 -> R0 contains TCP + IP headers length - gen.addAddR1(); + gen.addAddR1ToR0(); // Load IPv4 total length gen.addLoad16(R1, IPV4_TOTAL_LENGTH_OFFSET); gen.addNeg(R0); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0NotEquals(0, nextFilterLabel); // Add IPv4 header length - gen.addLoadFromMemory(R1, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R1, IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addLoadImmediate(R0, ETH_HEADER_LEN); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(mPortSeqAckFingerprint, nextFilterLabel); - maybeSetupCounter(gen, Counter.DROPPED_IPV4_KEEPALIVE_ACK); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_IPV4_KEEPALIVE_ACK); gen.defineLabel(nextFilterLabel); } } - private class TcpKeepaliveAckV6 extends TcpKeepaliveAck { + private static class TcpKeepaliveAckV6 extends TcpKeepaliveAck { TcpKeepaliveAckV6(final TcpKeepalivePacketDataParcelable sentKeepalivePacket) { this(new TcpKeepaliveAckData(sentKeepalivePacket)); } @@ -1439,7 +1432,7 @@ } @Override - void generateFilterLocked(ApfV4Generator gen) throws IllegalInstructionException { + void generateFilterLocked(ApfV4GeneratorBase<?> gen) { throw new UnsupportedOperationException("IPv6 TCP Keepalive is not supported yet"); } } @@ -1448,9 +1441,9 @@ private static final int MAX_RAS = 10; @GuardedBy("this") - private ArrayList<Ra> mRas = new ArrayList<>(); + private final ArrayList<Ra> mRas = new ArrayList<>(); @GuardedBy("this") - private SparseArray<KeepalivePacket> mKeepalivePackets = new SparseArray<>(); + private final SparseArray<KeepalivePacket> mKeepalivePackets = new SparseArray<>(); @GuardedBy("this") private final List<String[]> mMdnsAllowList = new ArrayList<>(); @@ -1472,7 +1465,7 @@ /** * For debugging only. Contains the latest APF buffer snapshot captured from the firmware. - * + * <p> * A typical size for this buffer is 4KB. It is present only if the WiFi HAL supports * IWifiStaIface#readApfPacketFilterData(), and the APF interpreter advertised support for * the opcodes to access the data buffer (LDDW and STDW). @@ -1489,9 +1482,6 @@ // The maximum number of distinct RAs @GuardedBy("this") private int mMaxDistinctRas = 0; - // How many times the program was updated since we started for allowing multicast traffic. - @GuardedBy("this") - private int mNumProgramUpdatesAllowingMulticast = 0; /** * Generate filter code to process ARP packets. Execution of this code ends in either the @@ -1500,7 +1490,8 @@ * - Packet being filtered is ARP */ @GuardedBy("this") - private void generateArpFilterLocked(ApfV4Generator gen) throws IllegalInstructionException { + private void generateArpFilterLocked(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException { // Here's a basic summary of what the ARP filter program does: // // if not ARP IPv4 @@ -1526,16 +1517,20 @@ maybeSetupCounter(gen, Counter.DROPPED_ARP_NON_IPV4); gen.addJumpIfBytesAtR0NotEqual(ARP_IPV4_HEADER, mCountAndDropLabel); - // Drop if unknown ARP opcode. gen.addLoad16(R0, ARP_OPCODE_OFFSET); - gen.addJumpIfR0Equals(ARP_OPCODE_REQUEST, checkTargetIPv4); // Skip to unicast check - maybeSetupCounter(gen, Counter.DROPPED_ARP_UNKNOWN); - gen.addJumpIfR0NotEquals(ARP_OPCODE_REPLY, mCountAndDropLabel); + if (mIPv4Address == null) { + // Drop if ARP REQUEST and we do not have an IPv4 address + gen.addCountAndDropIfR0Equals(ARP_OPCODE_REQUEST, + Counter.DROPPED_ARP_REQUEST_NO_ADDRESS); + } else { + gen.addJumpIfR0Equals(ARP_OPCODE_REQUEST, checkTargetIPv4); // Skip to unicast check + } + // Drop if unknown ARP opcode. + gen.addCountAndDropIfR0NotEquals(ARP_OPCODE_REPLY, Counter.DROPPED_ARP_UNKNOWN); // Drop if ARP reply source IP is 0.0.0.0 gen.addLoad32(R0, ARP_SOURCE_IP_ADDRESS_OFFSET); - maybeSetupCounter(gen, Counter.DROPPED_ARP_REPLY_SPA_NO_HOST); - gen.addJumpIfR0Equals(IPV4_ANY_HOST_ADDRESS, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(IPV4_ANY_HOST_ADDRESS, Counter.DROPPED_ARP_REPLY_SPA_NO_HOST); // Pass if non-broadcast reply. gen.addLoadImmediate(R0, ETH_DEST_ADDR_OFFSET); @@ -1547,8 +1542,7 @@ if (mIPv4Address == null) { // When there is no IPv4 address, drop GARP replies (b/29404209). gen.addLoad32(R0, ARP_TARGET_IP_ADDRESS_OFFSET); - maybeSetupCounter(gen, Counter.DROPPED_GARP_REPLY); - gen.addJumpIfR0Equals(IPV4_ANY_HOST_ADDRESS, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(IPV4_ANY_HOST_ADDRESS, Counter.DROPPED_GARP_REPLY); } else { // When there is an IPv4 address, drop unicast/broadcast requests // and broadcast replies with a different target IPv4 address. @@ -1557,8 +1551,7 @@ gen.addJumpIfBytesAtR0NotEqual(mIPv4Address, mCountAndDropLabel); } - maybeSetupCounter(gen, Counter.PASSED_ARP); - gen.addJump(mCountAndPassLabel); + gen.addCountAndPass(Counter.PASSED_ARP); } /** @@ -1568,7 +1561,8 @@ * - Packet being filtered is IPv4 */ @GuardedBy("this") - private void generateIPv4FilterLocked(ApfV4Generator gen) throws IllegalInstructionException { + private void generateIPv4FilterLocked(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException { // Here's a basic summary of what the IPv4 filter program does: // // if filtering multicast (i.e. multicast lock not held): @@ -1595,16 +1589,15 @@ gen.addLoad16(R0, IPV4_FRAGMENT_OFFSET_OFFSET); gen.addJumpIfR0AnyBitsSet(IPV4_FRAGMENT_OFFSET_MASK, skipDhcpv4Filter); // Check it's addressed to DHCP client port. - gen.addLoadFromMemory(R1, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R1, IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addLoad16Indexed(R0, TCP_UDP_DESTINATION_PORT_OFFSET); gen.addJumpIfR0NotEquals(DHCP_CLIENT_PORT, skipDhcpv4Filter); // Check it's DHCP to our MAC address. gen.addLoadImmediate(R0, DHCP_CLIENT_MAC_OFFSET); // NOTE: Relies on R1 containing IPv4 header offset. - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(mHardwareAddress, skipDhcpv4Filter); - maybeSetupCounter(gen, Counter.PASSED_DHCP); - gen.addJump(mCountAndPassLabel); + gen.addCountAndPass(Counter.PASSED_DHCP); // Drop all multicasts/broadcasts. gen.defineLabel(skipDhcpv4Filter); @@ -1612,17 +1605,15 @@ // If IPv4 destination address is in multicast range, drop. gen.addLoad8(R0, IPV4_DEST_ADDR_OFFSET); gen.addAnd(0xf0); - maybeSetupCounter(gen, Counter.DROPPED_IPV4_MULTICAST); - gen.addJumpIfR0Equals(0xe0, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(0xe0, Counter.DROPPED_IPV4_MULTICAST); // If IPv4 broadcast packet, drop regardless of L2 (b/30231088). - maybeSetupCounter(gen, Counter.DROPPED_IPV4_BROADCAST_ADDR); gen.addLoad32(R0, IPV4_DEST_ADDR_OFFSET); - gen.addJumpIfR0Equals(IPV4_BROADCAST_ADDRESS, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(IPV4_BROADCAST_ADDRESS, + Counter.DROPPED_IPV4_BROADCAST_ADDR); if (mIPv4Address != null && mIPv4PrefixLength < 31) { - maybeSetupCounter(gen, Counter.DROPPED_IPV4_BROADCAST_NET); int broadcastAddr = ipv4BroadcastAddress(mIPv4Address, mIPv4PrefixLength); - gen.addJumpIfR0Equals(broadcastAddr, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(broadcastAddr, Counter.DROPPED_IPV4_BROADCAST_NET); } } @@ -1642,20 +1633,18 @@ maybeSetupCounter(gen, Counter.PASSED_IPV4_UNICAST); gen.addLoadImmediate(R0, ETH_DEST_ADDR_OFFSET); gen.addJumpIfBytesAtR0NotEqual(ETHER_BROADCAST, mCountAndPassLabel); - maybeSetupCounter(gen, Counter.DROPPED_IPV4_L2_BROADCAST); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_IPV4_L2_BROADCAST); } // Otherwise, pass - maybeSetupCounter(gen, Counter.PASSED_IPV4); - gen.addJump(mCountAndPassLabel); + gen.addCountAndPass(Counter.PASSED_IPV4); } @GuardedBy("this") - private void generateKeepaliveFilters(ApfV4Generator gen, Class<?> filterType, int proto, + private void generateKeepaliveFilters(ApfV4GeneratorBase<?> gen, Class<?> filterType, int proto, int offset, String label) throws IllegalInstructionException { final boolean haveKeepaliveResponses = CollectionUtils.any(mKeepalivePackets, - ack -> filterType.isInstance(ack)); + filterType::isInstance); // If no keepalive packets of this type if (!haveKeepaliveResponses) return; @@ -1674,13 +1663,14 @@ } @GuardedBy("this") - private void generateV4KeepaliveFilters(ApfV4Generator gen) throws IllegalInstructionException { + private void generateV4KeepaliveFilters(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException { generateKeepaliveFilters(gen, TcpKeepaliveAckV4.class, IPPROTO_TCP, IPV4_PROTOCOL_OFFSET, "skip_v4_keepalive_filter"); } @GuardedBy("this") - private void generateV4NattKeepaliveFilters(ApfV4Generator gen) + private void generateV4NattKeepaliveFilters(ApfV4GeneratorBase<?> gen) throws IllegalInstructionException { generateKeepaliveFilters(gen, NattKeepaliveResponse.class, IPPROTO_UDP, IPV4_PROTOCOL_OFFSET, "skip_v4_nattkeepalive_filter"); @@ -1693,7 +1683,8 @@ * - Packet being filtered is IPv6 */ @GuardedBy("this") - private void generateIPv6FilterLocked(ApfV4Generator gen) throws IllegalInstructionException { + private void generateIPv6FilterLocked(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException { // Here's a basic summary of what the IPv6 filter program does: // // if there is a hop-by-hop option present (e.g. MLD query) @@ -1714,7 +1705,7 @@ // MLD packets set the router-alert hop-by-hop option. // TODO: be smarter about not blindly passing every packet with HBH options. - gen.addJumpIfR0Equals(IPPROTO_HOPOPTS, mCountAndPassLabel); + gen.addCountAndPassIfR0Equals(IPPROTO_HOPOPTS, Counter.PASSED_MLD); // Drop multicast if the multicast filter is enabled. if (mMulticastFilter) { @@ -1737,20 +1728,17 @@ // Drop all other packets sent to ff00::/8 (multicast prefix). gen.defineLabel(dropAllIPv6MulticastsLabel); - maybeSetupCounter(gen, Counter.DROPPED_IPV6_NON_ICMP_MULTICAST); gen.addLoad8(R0, IPV6_DEST_ADDR_OFFSET); - gen.addJumpIfR0Equals(0xff, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(0xff, Counter.DROPPED_IPV6_NON_ICMP_MULTICAST); // If any keepalive filter matches, drop generateV6KeepaliveFilters(gen); // Not multicast. Pass. - maybeSetupCounter(gen, Counter.PASSED_IPV6_UNICAST_NON_ICMP); - gen.addJump(mCountAndPassLabel); + gen.addCountAndPass(Counter.PASSED_IPV6_UNICAST_NON_ICMP); gen.defineLabel(skipIPv6MulticastFilterLabel); } else { generateV6KeepaliveFilters(gen); // If not ICMPv6, pass. - maybeSetupCounter(gen, Counter.PASSED_IPV6_NON_ICMP); - gen.addJumpIfR0NotEquals(IPPROTO_ICMPV6, mCountAndPassLabel); + gen.addCountAndPassIfR0NotEquals(IPPROTO_ICMPV6, Counter.PASSED_IPV6_NON_ICMP); } // If we got this far, the packet is ICMPv6. Drop some specific types. @@ -1759,8 +1747,8 @@ String skipUnsolicitedMulticastNALabel = "skipUnsolicitedMulticastNA"; gen.addLoad8(R0, ICMP6_TYPE_OFFSET); // Drop all router solicitations (b/32833400) - maybeSetupCounter(gen, Counter.DROPPED_IPV6_ROUTER_SOLICITATION); - gen.addJumpIfR0Equals(ICMPV6_ROUTER_SOLICITATION, mCountAndDropLabel); + gen.addCountAndDropIfR0Equals(ICMPV6_ROUTER_SOLICITATION, + Counter.DROPPED_IPV6_ROUTER_SOLICITATION); // If not neighbor announcements, skip filter. gen.addJumpIfR0NotEquals(ICMPV6_NEIGHBOR_ADVERTISEMENT, skipUnsolicitedMulticastNALabel); // Drop all multicast NA to ff02::/120. @@ -1770,8 +1758,7 @@ gen.addLoadImmediate(R0, IPV6_DEST_ADDR_OFFSET); gen.addJumpIfBytesAtR0NotEqual(unsolicitedNaDropPrefix, skipUnsolicitedMulticastNALabel); - maybeSetupCounter(gen, Counter.DROPPED_IPV6_MULTICAST_NA); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_IPV6_MULTICAST_NA); gen.defineLabel(skipUnsolicitedMulticastNALabel); // Note that this is immediately followed emitEpilogue which will: @@ -1796,7 +1783,7 @@ * or PASS_LABEL if the packet is mDNS packets. Otherwise, skip this check. */ @GuardedBy("this") - private void generateMdnsFilterLocked(ApfV4Generator gen) + private void generateMdnsFilterLocked(ApfV4GeneratorBase<?> gen) throws IllegalInstructionException { final String skipMdnsv4Filter = "skip_mdns_v4_filter"; final String skipMdnsFilter = "skip_mdns_filter"; @@ -1844,7 +1831,7 @@ gen.addJumpIfR0NotEquals(IPPROTO_UDP, skipMdnsFilter); // Set R1 to IPv4 header. - gen.addLoadFromMemory(R1, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R1, IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addJump(checkMdnsUdpPort); gen.defineLabel(skipMdnsv4Filter); @@ -1876,7 +1863,7 @@ // If QDCOUNT == 1, matches the QNAME with allowlist. // Load offset for the first QNAME. gen.addLoadImmediate(R0, MDNS_QNAME_OFFSET); - gen.addAddR1(); + gen.addAddR1ToR0(); // Check first QNAME against allowlist for (int i = 0; i < mMdnsAllowList.size(); ++i) { @@ -1890,12 +1877,10 @@ } // If QNAME doesn't match any entries in allowlist, drop the packet. gen.defineLabel(mDnsDropPacket); - maybeSetupCounter(gen, Counter.DROPPED_MDNS); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_MDNS); gen.defineLabel(mDnsAcceptPacket); - maybeSetupCounter(gen, Counter.PASSED_MDNS); - gen.addJump(mCountAndPassLabel); + gen.addCountAndPass(Counter.PASSED_MDNS); gen.defineLabel(skipMdnsFilter); @@ -1903,12 +1888,12 @@ /** * Generate filter code to drop IPv4 TCP packets on port 7. - * - * On entry we know it is IPv4 ethertype, but don't know anything else. + * <p> + * On entry, we know it is IPv4 ethertype, but don't know anything else. * R0/R1 have nothing useful in them, and can be clobbered. */ @GuardedBy("this") - private void generateV4TcpPort7FilterLocked(ApfV4Generator gen) + private void generateV4TcpPort7FilterLocked(ApfV4GeneratorBase<?> gen) throws IllegalInstructionException { final String skipPort7V4Filter = "skip_port7_v4_filter"; @@ -1921,20 +1906,20 @@ gen.addJumpIfR0AnyBitsSet(IPV4_FRAGMENT_OFFSET_MASK, skipPort7V4Filter); // Check it's destination port 7. - gen.addLoadFromMemory(R1, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); + gen.addLoadFromMemory(R1, IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addLoad16Indexed(R0, TCP_UDP_DESTINATION_PORT_OFFSET); gen.addJumpIfR0NotEquals(ECHO_PORT, skipPort7V4Filter); // Drop it. - maybeSetupCounter(gen, Counter.DROPPED_IPV4_TCP_PORT7_UNICAST); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_IPV4_TCP_PORT7_UNICAST); // Skip label. gen.defineLabel(skipPort7V4Filter); } @GuardedBy("this") - private void generateV6KeepaliveFilters(ApfV4Generator gen) throws IllegalInstructionException { + private void generateV6KeepaliveFilters(ApfV4GeneratorBase<?> gen) + throws IllegalInstructionException { generateKeepaliveFilters(gen, TcpKeepaliveAckV6.class, IPPROTO_TCP, IPV6_NEXT_HEADER_OFFSET, "skip_v6_keepalive_filter"); } @@ -1961,9 +1946,14 @@ */ @GuardedBy("this") @VisibleForTesting - protected ApfV4Generator emitPrologueLocked() throws IllegalInstructionException { + protected ApfV4GeneratorBase<?> emitPrologueLocked() throws IllegalInstructionException { // This is guaranteed to succeed because of the check in maybeCreate. - ApfV4Generator gen = new ApfV4Generator(mApfCapabilities.apfVersionSupported); + ApfV4GeneratorBase<?> gen; + if (mEnableApfV6 && mApfCapabilities.apfVersionSupported > 4) { + gen = new ApfV6Generator().addData(); + } else { + gen = new ApfV4Generator(mApfCapabilities.apfVersionSupported); + } if (mApfCapabilities.hasDataAccess()) { // Increment TOTAL_PACKETS @@ -1980,6 +1970,16 @@ maybeSetupCounter(gen, Counter.FILTER_AGE_16384THS); gen.addLoadFromMemory(R0, 9); // m[9] is filter age in 16384ths gen.addStoreData(R0, 0); // store 'counter' + + // requires a new enough APFv5+ interpreter, otherwise will be 0 + maybeSetupCounter(gen, Counter.APF_VERSION); + gen.addLoadFromMemory(R0, 8); // m[8] is apf version + gen.addStoreData(R0, 0); // store 'counter' + + // store this program's sequential id, for later comparison + maybeSetupCounter(gen, Counter.APF_PROGRAM_ID); + gen.addLoadImmediate(R0, mNumProgramUpdates); + gen.addStoreData(R0, 0); // store 'counter' } // Here's a basic summary of what the initial program does: @@ -2002,14 +2002,13 @@ if (mDrop802_3Frames) { // drop 802.3 frames (ethtype < 0x0600) - maybeSetupCounter(gen, Counter.DROPPED_802_3_FRAME); - gen.addJumpIfR0LessThan(ETH_TYPE_MIN, mCountAndDropLabel); + gen.addCountAndDropIfR0LessThan(ETH_TYPE_MIN, Counter.DROPPED_802_3_FRAME); } // Handle ether-type black list - maybeSetupCounter(gen, Counter.DROPPED_ETHERTYPE_DENYLISTED); for (int p : mEthTypeBlackList) { - gen.addJumpIfR0Equals(p, mCountAndDropLabel); + // TODO: Refactorings increased APFv4 code size; optimize for reduction. + gen.addCountAndDropIfR0Equals(p, Counter.DROPPED_ETHERTYPE_DENYLISTED); } // Add ARP filters: @@ -2039,8 +2038,7 @@ gen.addLoadImmediate(R0, ETH_DEST_ADDR_OFFSET); maybeSetupCounter(gen, Counter.PASSED_NON_IP_UNICAST); gen.addJumpIfBytesAtR0NotEqual(ETHER_BROADCAST, mCountAndPassLabel); - maybeSetupCounter(gen, Counter.DROPPED_ETH_BROADCAST); - gen.addJump(mCountAndDropLabel); + gen.addCountAndDrop(Counter.DROPPED_ETH_BROADCAST); // Add IPv6 filters: gen.defineLabel(ipv6FilterLabel); @@ -2050,12 +2048,12 @@ /** * Append packet counting epilogue to the APF program. - * + * <p> * Currently, the epilogue consists of two trampolines which count passed and dropped packets * before jumping to the actual PASS and DROP labels. */ @GuardedBy("this") - private void emitEpilogue(ApfV4Generator gen) throws IllegalInstructionException { + private void emitEpilogue(ApfV4GeneratorBase<?> gen) throws IllegalInstructionException { // If APFv4 is unsupported, no epilogue is necessary: if execution reached this far, it // will just fall-through to the PASS label. if (!mApfCapabilities.hasDataAccess()) return; @@ -2064,6 +2062,7 @@ // which will pass the packet to the application processor. maybeSetupCounter(gen, Counter.PASSED_IPV6_ICMP); + // TODO: remove the duplicated trampoline block after fully migrate to addCountAndXXX() API. // Append the count & pass trampoline, which increments the counter at the data address // pointed to by R1, then jumps to the pass label. This saves a few bytes over inserting // the entire sequence inline for every counter. @@ -2071,14 +2070,17 @@ gen.addLoadData(R0, 0); // R0 = *(R1 + 0) gen.addAdd(1); // R0++ gen.addStoreData(R0, 0); // *(R1 + 0) = R0 - gen.addJump(gen.PASS_LABEL); + gen.addJump(PASS_LABEL); // Same as above for the count & drop trampoline. gen.defineLabel(mCountAndDropLabel); gen.addLoadData(R0, 0); // R0 = *(R1 + 0) gen.addAdd(1); // R0++ gen.addStoreData(R0, 0); // *(R1 + 0) = R0 - gen.addJump(gen.DROP_LABEL); + gen.addJump(DROP_LABEL); + + // TODO: merge the addCountTrampoline() into generate() method + gen.addCountTrampoline(); } /** @@ -2097,11 +2099,16 @@ maximumApfProgramSize -= Counter.totalSize(); } + // Prevent generating (and thus installing) larger programs + if (maximumApfProgramSize > mInstallableProgramSizeClamp) { + maximumApfProgramSize = mInstallableProgramSizeClamp; + } + // Ensure the entire APF program uses the same time base. int timeSeconds = secondsSinceBoot(); try { // Step 1: Determine how many RA filters we can fit in the program. - ApfV4Generator gen = emitPrologueLocked(); + ApfV4GeneratorBase<?> gen = emitPrologueLocked(); // The epilogue normally goes after the RA filters, but add it early to include its // length when estimating the total. @@ -2141,10 +2148,12 @@ sendNetworkQuirkMetrics(NetworkQuirkEvent.QE_APF_GENERATE_FILTER_EXCEPTION); return; } - // Update data snapshot every time we install a new program - mIpClientCallback.startReadPacketFilter(); - if (!mIpClientCallback.installPacketFilter(program)) { - sendNetworkQuirkMetrics(NetworkQuirkEvent.QE_APF_INSTALL_FAILURE); + if (mIsRunning) { + // Update data snapshot every time we install a new program + mIpClientCallback.startReadPacketFilter(); + if (!mIpClientCallback.installPacketFilter(program)) { + sendNetworkQuirkMetrics(NetworkQuirkEvent.QE_APF_INSTALL_FAILURE); + } } mLastTimeInstalledProgram = timeSeconds; mLastInstalledProgramMinLifetime = programMinLft; @@ -2262,7 +2271,7 @@ if (context == null || config == null || ifParams == null) return null; ApfCapabilities apfCapabilities = config.apfCapabilities; if (apfCapabilities == null) return null; - if (apfCapabilities.apfVersionSupported == 0) return null; + if (apfCapabilities.apfVersionSupported < 2) return null; if (apfCapabilities.maximumApfProgramSize < 512) { Log.e(TAG, "Unacceptably small APF limit: " + apfCapabilities.maximumApfProgramSize); return null; @@ -2322,9 +2331,6 @@ public synchronized void setMulticastFilter(boolean isEnabled) { if (mMulticastFilter == isEnabled) return; mMulticastFilter = isEnabled; - if (!isEnabled) { - mNumProgramUpdatesAllowingMulticast++; - } installNewProgramLocked(); } @@ -2441,6 +2447,8 @@ public synchronized void dump(IndentingPrintWriter pw) { pw.println("Capabilities: " + mApfCapabilities); + pw.println("InstallableProgramSizeClamp: " + mInstallableProgramSizeClamp); + pw.println("Filter update status: " + (mIsRunning ? "RUNNING" : "PAUSED")); pw.println("Receive thread: " + (mReceiveThread != null ? "RUNNING" : "STOPPED")); pw.println("Multicast: " + (mMulticastFilter ? "DROP" : "ALLOW")); pw.println("Minimum RDNSS lifetime: " + mMinRdnssLifetimeSec); @@ -2546,6 +2554,29 @@ pw.decreaseIndent(); } + /** Return ApfFilter update status for testing purposes. */ + public boolean isRunning() { + return mIsRunning; + } + + /** Pause ApfFilter updates for testing purposes. */ + public void pause() { + mIsRunning = false; + } + + /** Resume ApfFilter updates for testing purposes. */ + public void resume() { + mIsRunning = true; + } + + /** Return data snapshot as hex string for testing purposes. */ + public synchronized @Nullable String getDataSnapshotHexString() { + if (mDataSnapshot == null) { + return null; + } + return HexDump.toHexString(mDataSnapshot, 0, mDataSnapshot.length, false /* lowercase */); + } + // TODO: move to android.net.NetworkUtils @VisibleForTesting public static int ipv4BroadcastAddress(byte[] addrBytes, int prefixLength) {
diff --git a/src/android/net/apf/ApfV4Generator.java b/src/android/net/apf/ApfV4Generator.java index e1b0fc3..96320ae 100644 --- a/src/android/net/apf/ApfV4Generator.java +++ b/src/android/net/apf/ApfV4Generator.java
@@ -1,5 +1,5 @@ /* - * Copyright (C) 2016 The Android Open Source Project + * Copyright (C) 2024 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,25 +13,34 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package android.net.apf; -import static android.net.apf.BaseApfGenerator.Rbit.Rbit0; +import static android.net.apf.BaseApfGenerator.Register.R0; import static android.net.apf.BaseApfGenerator.Register.R1; import com.android.internal.annotations.VisibleForTesting; /** - * APF assembler/generator. A tool for generating an APF program. - * - * Call add*() functions to add instructions to the program, then call - * {@link BaseApfGenerator#generate} to get the APF bytecode for the program. - * - * @param <Type> the generator class + * APFv4 assembler/generator. A tool for generating an APFv4 program. * * @hide */ -public class ApfV4Generator<Type extends BaseApfGenerator> extends BaseApfGenerator { +public final class ApfV4Generator extends ApfV4GeneratorBase<ApfV4Generator> { + + /** + * Jump to this label to terminate the program, increment the counter and indicate the packet + * should be passed to the AP. + */ + private static final String COUNT_AND_PASS_LABEL = "__COUNT_AND_PASS__"; + + /** + * Jump to this label to terminate the program, increment counter, and indicate the packet + * should be dropped. + */ + private static final String COUNT_AND_DROP_LABEL = "__COUNT_AND_DROP__"; + + private final String mCountAndDropLabel; + private final String mCountAndPassLabel; /** * Creates an ApfV4Generator instance which is able to emit instructions for the specified @@ -41,377 +50,118 @@ @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE) public ApfV4Generator(int version) throws IllegalInstructionException { super(version); - requireApfVersion(MIN_APF_VERSION); + mCountAndDropLabel = version >= 4 ? COUNT_AND_DROP_LABEL : DROP_LABEL; + mCountAndPassLabel = version >= 4 ? COUNT_AND_PASS_LABEL : PASS_LABEL; } - Type append(Instruction instruction) { - if (mGenerated) { - throw new IllegalStateException("Program already generated"); + @Override + void addArithR1(Opcodes opcode) { + append(new Instruction(opcode, R1)); + } + + /** + * Generates instructions to prepare to increment the specified counter and jump to the + * "__COUNT_AND_PASS__" label. + * In APFv2, it will directly return PASS. + * + * @param counter The ApfCounterTracker.Counter to increment + * @return Type the generator object + */ + @Override + public ApfV4Generator addCountAndPass(ApfCounterTracker.Counter counter) { + checkPassCounterRange(counter); + return maybeAddLoadR1CounterOffset(counter).addJump(mCountAndPassLabel); + } + + /** + * Generates instructions to prepare to increment the specified counter and jump to the + * "__COUNT_AND_DROP__" label. + * In APFv2, it will directly return DROP. + * + * @param counter The ApfCounterTracker.Counter to increment + * @return Type the generator object + */ + @Override + public ApfV4Generator addCountAndDrop(ApfCounterTracker.Counter counter) { + checkDropCounterRange(counter); + return maybeAddLoadR1CounterOffset(counter).addJump(mCountAndDropLabel); + } + + @Override + public ApfV4Generator addCountAndDropIfR0Equals(long val, ApfCounterTracker.Counter cnt) { + checkDropCounterRange(cnt); + return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0Equals(val, mCountAndDropLabel); + } + + @Override + public ApfV4Generator addCountAndPassIfR0Equals(long val, ApfCounterTracker.Counter cnt) { + checkPassCounterRange(cnt); + return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0Equals(val, mCountAndPassLabel); + } + + @Override + public ApfV4Generator addCountAndDropIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) { + checkDropCounterRange(cnt); + return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0NotEquals(val, mCountAndDropLabel); + } + + @Override + public ApfV4Generator addCountAndPassIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) { + checkPassCounterRange(cnt); + return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0NotEquals(val, mCountAndPassLabel); + } + + @Override + public ApfV4Generator addCountAndDropIfR0LessThan(long val, ApfCounterTracker.Counter cnt) { + checkDropCounterRange(cnt); + if (val <= 0) { + throw new IllegalArgumentException("val must > 0, current val: " + val); } - mInstructions.add(instruction); - return (Type) this; + return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0LessThan(val, mCountAndDropLabel); + } + + @Override + public ApfV4Generator addCountAndPassIfR0LessThan(long val, ApfCounterTracker.Counter cnt) { + checkPassCounterRange(cnt); + if (val <= 0) { + throw new IllegalArgumentException("val must > 0, current val: " + val); + } + return maybeAddLoadR1CounterOffset(cnt).addJumpIfR0LessThan(val, mCountAndPassLabel); + } + + @Override + public ApfV4Generator addCountAndDropIfBytesAtR0NotEqual(byte[] bytes, + ApfCounterTracker.Counter cnt) throws IllegalInstructionException { + checkDropCounterRange(cnt); + return maybeAddLoadR1CounterOffset(cnt).addJumpIfBytesAtR0NotEqual(bytes, + mCountAndDropLabel); } /** - * Define a label at the current end of the program. Jumps can jump to this label. Labels are - * their own separate instructions, though with size 0. This facilitates having labels with - * no corresponding code to execute, for example a label at the end of a program. For example - * an {@link ApfV4Generator} might be passed to a function that adds a filter like so: - * <pre> - * load from packet - * compare loaded data, jump if not equal to "next_filter" - * load from packet - * compare loaded data, jump if not equal to "next_filter" - * jump to drop label - * define "next_filter" here - * </pre> - * In this case "next_filter" may not have any generated code associated with it. + * Append the count & (pass|drop) trampoline, which increments the counter at the data address + * pointed to by R1, then jumps to the (pass|drop) label. This saves a few bytes over inserting + * the entire sequence inline for every counter. + * This instruction is necessary to be called at the end of any APFv4 program in order to make + * counter incrementing logic work. + * In APFv2, it is a noop. */ - public Type defineLabel(String name) throws IllegalInstructionException { - return append(new Instruction(Opcodes.LABEL).setLabel(name)); + @Override + public ApfV4Generator addCountTrampoline() throws IllegalInstructionException { + if (mVersion < 4) return self(); + return defineLabel(COUNT_AND_PASS_LABEL) + .addLoadData(R0, 0) // R0 = *(R1 + 0) + .addAdd(1) // R0++ + .addStoreData(R0, 0) // *(R1 + 0) = R0 + .addJump(PASS_LABEL) + .defineLabel(COUNT_AND_DROP_LABEL) + .addLoadData(R0, 0) // R0 = *(R1 + 0) + .addAdd(1) // R0++ + .addStoreData(R0, 0) // *(R1 + 0) = R0 + .addJump(DROP_LABEL); } - /** - * Add an unconditional jump instruction to the end of the program. - */ - public Type addJump(String target) { - return append(new Instruction(Opcodes.JMP).setTargetLabel(target)); + private ApfV4Generator maybeAddLoadR1CounterOffset(ApfCounterTracker.Counter counter) { + if (mVersion >= 4) return addLoadImmediate(R1, counter.offset()); + return self(); } - - /** - * Add an instruction to the end of the program to load the byte at offset {@code offset} - * bytes from the beginning of the packet into {@code register}. - */ - public Type addLoad8(Register r, int ofs) { - return append(new Instruction(Opcodes.LDB, r).addPacketOffset(ofs)); - } - - /** - * Add an instruction to the end of the program to load 16-bits at offset {@code offset} - * bytes from the beginning of the packet into {@code register}. - */ - public Type addLoad16(Register r, int ofs) { - return append(new Instruction(Opcodes.LDH, r).addPacketOffset(ofs)); - } - - /** - * Add an instruction to the end of the program to load 32-bits at offset {@code offset} - * bytes from the beginning of the packet into {@code register}. - */ - public Type addLoad32(Register r, int ofs) { - return append(new Instruction(Opcodes.LDW, r).addPacketOffset(ofs)); - } - - /** - * Add an instruction to the end of the program to load a byte from the packet into - * {@code register}. The offset of the loaded byte from the beginning of the packet is - * the sum of {@code offset} and the value in register R1. - */ - public Type addLoad8Indexed(Register r, int ofs) { - return append(new Instruction(Opcodes.LDBX, r).addPacketOffset(ofs)); - } - - /** - * Add an instruction to the end of the program to load 16-bits from the packet into - * {@code register}. The offset of the loaded 16-bits from the beginning of the packet is - * the sum of {@code offset} and the value in register R1. - */ - public Type addLoad16Indexed(Register r, int ofs) { - return append(new Instruction(Opcodes.LDHX, r).addPacketOffset(ofs)); - } - - /** - * Add an instruction to the end of the program to load 32-bits from the packet into - * {@code register}. The offset of the loaded 32-bits from the beginning of the packet is - * the sum of {@code offset} and the value in register R1. - */ - public Type addLoad32Indexed(Register r, int ofs) { - return append(new Instruction(Opcodes.LDWX, r).addPacketOffset(ofs)); - } - - /** - * Add an instruction to the end of the program to add {@code value} to register R0. - */ - public Type addAdd(int val) { - return append(new Instruction(Opcodes.ADD).addTwosCompUnsigned(val)); - } - - /** - * Add an instruction to the end of the program to multiply register R0 by {@code value}. - */ - public Type addMul(long val) { - return append(new Instruction(Opcodes.MUL).addUnsigned(val)); - } - - /** - * Add an instruction to the end of the program to divide register R0 by {@code value}. - */ - public Type addDiv(long val) { - return append(new Instruction(Opcodes.DIV).addUnsigned(val)); - } - - /** - * Add an instruction to the end of the program to logically and register R0 with {@code value}. - */ - public Type addAnd(int val) { - return append(new Instruction(Opcodes.AND).addTwosCompUnsigned(val)); - } - - /** - * Add an instruction to the end of the program to logically or register R0 with {@code value}. - */ - public Type addOr(int val) { - return append(new Instruction(Opcodes.OR).addTwosCompUnsigned(val)); - } - - /** - * Add an instruction to the end of the program to shift left register R0 by {@code value} bits. - */ - // TODO: consider whether should change the argument type to byte - public Type addLeftShift(int val) { - return append(new Instruction(Opcodes.SH).addSigned(val)); - } - - /** - * Add an instruction to the end of the program to shift right register R0 by {@code value} - * bits. - */ - // TODO: consider whether should change the argument type to byte - public Type addRightShift(int val) { - return append(new Instruction(Opcodes.SH).addSigned(-val)); - } - - /** - * Add an instruction to the end of the program to add register R1 to register R0. - */ - public Type addAddR1() { - return append(new Instruction(Opcodes.ADD, R1)); - } - - /** - * Add an instruction to the end of the program to multiply register R0 by register R1. - */ - public Type addMulR1() { - return append(new Instruction(Opcodes.MUL, R1)); - } - - /** - * Add an instruction to the end of the program to divide register R0 by register R1. - */ - public Type addDivR1() { - return append(new Instruction(Opcodes.DIV, R1)); - } - - /** - * Add an instruction to the end of the program to logically and register R0 with register R1 - * and store the result back into register R0. - */ - public Type addAndR1() { - return append(new Instruction(Opcodes.AND, R1)); - } - - /** - * Add an instruction to the end of the program to logically or register R0 with register R1 - * and store the result back into register R0. - */ - public Type addOrR1() { - return append(new Instruction(Opcodes.OR, R1)); - } - - /** - * Add an instruction to the end of the program to shift register R0 left by the value in - * register R1. - */ - public Type addLeftShiftR1() { - return append(new Instruction(Opcodes.SH, R1)); - } - - /** - * Add an instruction to the end of the program to move {@code value} into {@code register}. - */ - public Type addLoadImmediate(Register register, int value) { - return append(new Instruction(Opcodes.LI, register).addSigned(value)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value equals {@code value}. - */ - public Type addJumpIfR0Equals(int val, String tgt) { - return append(new Instruction(Opcodes.JEQ).addTwosCompUnsigned(val).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value does not equal {@code value}. - */ - public Type addJumpIfR0NotEquals(int val, String tgt) { - return append(new Instruction(Opcodes.JNE).addTwosCompUnsigned(val).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value is greater than {@code value}. - */ - public Type addJumpIfR0GreaterThan(long val, String tgt) { - return append(new Instruction(Opcodes.JGT).addUnsigned(val).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value is less than {@code value}. - */ - public Type addJumpIfR0LessThan(long val, String tgt) { - return append(new Instruction(Opcodes.JLT).addUnsigned(val).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value has any bits set that are also set in {@code value}. - */ - public Type addJumpIfR0AnyBitsSet(int val, String tgt) { - return append(new Instruction(Opcodes.JSET).addTwosCompUnsigned(val).setTargetLabel(tgt)); - } - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value equals register R1's value. - */ - public Type addJumpIfR0EqualsR1(String tgt) { - return append(new Instruction(Opcodes.JEQ, R1).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value does not equal register R1's value. - */ - public Type addJumpIfR0NotEqualsR1(String tgt) { - return append(new Instruction(Opcodes.JNE, R1).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value is greater than register R1's value. - */ - public Type addJumpIfR0GreaterThanR1(String tgt) { - return append(new Instruction(Opcodes.JGT, R1).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value is less than register R1's value. - */ - public Type addJumpIfR0LessThanR1(String target) { - return append(new Instruction(Opcodes.JLT, R1).setTargetLabel(target)); - } - - /** - * Add an instruction to the end of the program to jump to {@code target} if register R0's - * value has any bits set that are also set in R1's value. - */ - public Type addJumpIfR0AnyBitsSetR1(String tgt) { - return append(new Instruction(Opcodes.JSET, R1).setTargetLabel(tgt)); - } - - /** - * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the - * packet at an offset specified by {@code register} don't match {@code bytes} - * R=0 means check for not equal - */ - public Type addJumpIfBytesAtR0NotEqual(byte[] bytes, String tgt) { - return append(new Instruction(Opcodes.JNEBS).addUnsigned( - bytes.length).setTargetLabel(tgt).setBytesImm(bytes)); - } - - /** - * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the - * packet at an offset specified by {@code register} match {@code bytes} - * R=1 means check for equal. - */ - public Type addJumpIfBytesAtR0Equal(byte[] bytes, String tgt) - throws IllegalInstructionException { - requireApfVersion(MIN_APF_VERSION_IN_DEV); - return append(new Instruction(Opcodes.JNEBS, R1).addUnsigned( - bytes.length).setTargetLabel(tgt).setBytesImm(bytes)); - } - - /** - * Add an instruction to the end of the program to load memory slot {@code slot} into - * {@code register}. - */ - public Type addLoadFromMemory(Register r, int slot) - throws IllegalInstructionException { - return append(new BaseApfGenerator.Instruction(ExtendedOpcodes.LDM, slot, r)); - } - - /** - * Add an instruction to the end of the program to store {@code register} into memory slot - * {@code slot}. - */ - public Type addStoreToMemory(Register r, int slot) - throws IllegalInstructionException { - return append(new Instruction(ExtendedOpcodes.STM, slot, r)); - } - - /** - * Add an instruction to the end of the program to logically not {@code register}. - */ - public Type addNot(Register r) { - return append(new Instruction(ExtendedOpcodes.NOT, r)); - } - - /** - * Add an instruction to the end of the program to negate {@code register}. - */ - public Type addNeg(Register r) { - return append(new Instruction(ExtendedOpcodes.NEG, r)); - } - - /** - * Add an instruction to swap the values in register R0 and register R1. - */ - public Type addSwap() { - return append(new Instruction(ExtendedOpcodes.SWAP)); - } - - /** - * Add an instruction to the end of the program to move the value into - * {@code register} from the other register. - */ - public Type addMove(Register r) { - return append(new Instruction(ExtendedOpcodes.MOVE, r)); - } - - /** - * Add an instruction to the end of the program to let the program immediately return PASS. - */ - public Type addPass() { - // PASS requires using Rbit0 because it shares opcode with DROP - return append(new Instruction(Opcodes.PASSDROP, Rbit0)); - } - - /** - * Add an instruction to the end of the program to load 32 bits from the data memory into - * {@code register}. The source address is computed by adding the signed immediate - * @{code offset} to the other register. - * Requires APF v4 or greater. - */ - public Type addLoadData(Register dst, int ofs) - throws IllegalInstructionException { - requireApfVersion(APF_VERSION_4); - return append(new Instruction(Opcodes.LDDW, dst).addSigned(ofs)); - } - - /** - * Add an instruction to the end of the program to store 32 bits from {@code register} into the - * data memory. The destination address is computed by adding the signed immediate - * @{code offset} to the other register. - * Requires APF v4 or greater. - */ - public Type addStoreData(Register src, int ofs) - throws IllegalInstructionException { - requireApfVersion(APF_VERSION_4); - return append(new Instruction(Opcodes.STDW, src).addSigned(ofs)); - } - } -
diff --git a/src/android/net/apf/ApfV4GeneratorBase.java b/src/android/net/apf/ApfV4GeneratorBase.java new file mode 100644 index 0000000..e27a5a2 --- /dev/null +++ b/src/android/net/apf/ApfV4GeneratorBase.java
@@ -0,0 +1,518 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package android.net.apf; + +import static android.net.apf.BaseApfGenerator.Rbit.Rbit0; +import static android.net.apf.BaseApfGenerator.Register.R0; +import static android.net.apf.BaseApfGenerator.Register.R1; + +import com.android.internal.annotations.VisibleForTesting; + +/** + * APF assembler/generator. A tool for generating an APF program. + * + * Call add*() functions to add instructions to the program, then call + * {@link BaseApfGenerator#generate} to get the APF bytecode for the program. + * <p> + * Choose between these approaches for your instruction helper methods: If the functionality must + * be identical across APF versions, make it a final method within the base class. If it needs + * version-specific adjustments, use an abstract method in the base class with final + * implementations in generator instances. + * + * @param <Type> the generator class + * + * @hide + */ +public abstract class ApfV4GeneratorBase<Type extends ApfV4GeneratorBase<Type>> extends + BaseApfGenerator { + + /** + * Creates an ApfV4GeneratorBase instance which is able to emit instructions for the specified + * {@code version} of the APF interpreter. Throws {@code IllegalInstructionException} if + * the requested version is unsupported. + */ + @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE) + public ApfV4GeneratorBase(int version) throws IllegalInstructionException { + super(version); + requireApfVersion(MIN_APF_VERSION); + } + + final Type self() { + return (Type) this; + } + + final Type append(Instruction instruction) { + if (mGenerated) { + throw new IllegalStateException("Program already generated"); + } + mInstructions.add(instruction); + return self(); + } + + /** + * Define a label at the current end of the program. Jumps can jump to this label. Labels are + * their own separate instructions, though with size 0. This facilitates having labels with + * no corresponding code to execute, for example a label at the end of a program. For example + * an {@link ApfV4GeneratorBase} might be passed to a function that adds a filter like so: + * <pre> + * load from packet + * compare loaded data, jump if not equal to "next_filter" + * load from packet + * compare loaded data, jump if not equal to "next_filter" + * jump to drop label + * define "next_filter" here + * </pre> + * In this case "next_filter" may not have any generated code associated with it. + */ + public final Type defineLabel(String name) throws IllegalInstructionException { + return append(new Instruction(Opcodes.LABEL).setLabel(name)); + } + + /** + * Add an unconditional jump instruction to the end of the program. + */ + public final Type addJump(String target) { + return append(new Instruction(Opcodes.JMP).setTargetLabel(target)); + } + + /** + * Add an unconditional jump instruction to the next instruction - ie. a no-op. + */ + public final Type addNop() { + return append(new Instruction(Opcodes.JMP).addUnsigned(0)); + } + + /** + * Add an instruction to the end of the program to load the byte at offset {@code offset} + * bytes from the beginning of the packet into {@code register}. + */ + public final Type addLoad8(Register r, int ofs) { + return append(new Instruction(Opcodes.LDB, r).addPacketOffset(ofs)); + } + + /** + * Add an instruction to the end of the program to load 16-bits at offset {@code offset} + * bytes from the beginning of the packet into {@code register}. + */ + public final Type addLoad16(Register r, int ofs) { + return append(new Instruction(Opcodes.LDH, r).addPacketOffset(ofs)); + } + + /** + * Add an instruction to the end of the program to load 32-bits at offset {@code offset} + * bytes from the beginning of the packet into {@code register}. + */ + public final Type addLoad32(Register r, int ofs) { + return append(new Instruction(Opcodes.LDW, r).addPacketOffset(ofs)); + } + + /** + * Add an instruction to the end of the program to load a byte from the packet into + * {@code register}. The offset of the loaded byte from the beginning of the packet is + * the sum of {@code offset} and the value in register R1. + */ + public final Type addLoad8Indexed(Register r, int ofs) { + return append(new Instruction(Opcodes.LDBX, r).addPacketOffset(ofs)); + } + + /** + * Add an instruction to the end of the program to load 16-bits from the packet into + * {@code register}. The offset of the loaded 16-bits from the beginning of the packet is + * the sum of {@code offset} and the value in register R1. + */ + public final Type addLoad16Indexed(Register r, int ofs) { + return append(new Instruction(Opcodes.LDHX, r).addPacketOffset(ofs)); + } + + /** + * Add an instruction to the end of the program to load 32-bits from the packet into + * {@code register}. The offset of the loaded 32-bits from the beginning of the packet is + * the sum of {@code offset} and the value in register R1. + */ + public final Type addLoad32Indexed(Register r, int ofs) { + return append(new Instruction(Opcodes.LDWX, r).addPacketOffset(ofs)); + } + + /** + * Add an instruction to the end of the program to add {@code value} to register R0. + */ + public final Type addAdd(long val) { + if (val == 0) return self(); // nop, as APFv6 would '+= R1' + return append(new Instruction(Opcodes.ADD).addTwosCompUnsigned(val)); + } + + /** + * Add an instruction to the end of the program to subtract {@code value} from register R0. + */ + public final Type addSub(long val) { + return addAdd(-val); // note: addSub(4 billion) isn't valid, as addAdd(-4 billion) isn't + } + + /** + * Add an instruction to the end of the program to multiply register R0 by {@code value}. + */ + public final Type addMul(long val) { + if (val == 0) return addLoadImmediate(R0, 0); // equivalent, as APFv6 would '*= R1' + return append(new Instruction(Opcodes.MUL).addUnsigned(val)); + } + + /** + * Add an instruction to the end of the program to divide register R0 by {@code value}. + */ + public final Type addDiv(long val) { + if (val == 0) return addPass(); // equivalent, as APFv6 would '/= R1' + return append(new Instruction(Opcodes.DIV).addUnsigned(val)); + } + + /** + * Add an instruction to the end of the program to logically and register R0 with {@code value}. + */ + public final Type addAnd(long val) { + if (val == 0) return addLoadImmediate(R0, 0); // equivalent, as APFv6 would '+= R1' + return append(new Instruction(Opcodes.AND).addTwosCompUnsigned(val)); + } + + /** + * Add an instruction to the end of the program to logically or register R0 with {@code value}. + */ + public final Type addOr(long val) { + if (val == 0) return self(); // nop, as APFv6 would '|= R1' + return append(new Instruction(Opcodes.OR).addTwosCompUnsigned(val)); + } + + /** + * Add an instruction to the end of the program to shift left register R0 by {@code value} bits. + */ + // TODO: consider whether should change the argument type to byte + public final Type addLeftShift(int val) { + if (val == 0) return self(); // nop, as APFv6 would '<<= R1' + return append(new Instruction(Opcodes.SH).addSigned(val)); + } + + /** + * Add an instruction to the end of the program to shift right register R0 by {@code value} + * bits. + */ + // TODO: consider whether should change the argument type to byte + public final Type addRightShift(int val) { + return addLeftShift(-val); + } + + // Argument should be one of Opcodes.{ADD,MUL,DIV,AND,OR,SH} + abstract void addArithR1(Opcodes opcode); + + /** + * Add an instruction to the end of the program to add register R1 to register R0. + */ + public final Type addAddR1ToR0() { + addArithR1(Opcodes.ADD); + return self(); + } + + /** + * Add an instruction to the end of the program to multiply register R0 by register R1. + */ + public final Type addMulR0ByR1() { + addArithR1(Opcodes.MUL); + return self(); + } + + /** + * Add an instruction to the end of the program to divide register R0 by register R1. + */ + public final Type addDivR0ByR1() { + addArithR1(Opcodes.DIV); + return self(); + } + + /** + * Add an instruction to the end of the program to logically and register R0 with register R1 + * and store the result back into register R0. + */ + public final Type addAndR0WithR1() { + addArithR1(Opcodes.AND); + return self(); + } + + /** + * Add an instruction to the end of the program to logically or register R0 with register R1 + * and store the result back into register R0. + */ + public final Type addOrR0WithR1() { + addArithR1(Opcodes.OR); + return self(); + } + + /** + * Add an instruction to the end of the program to shift register R0 left by the value in + * register R1. + */ + public final Type addLeftShiftR0ByR1() { + addArithR1(Opcodes.SH); + return self(); + } + + /** + * Add an instruction to the end of the program to move {@code value} into {@code register}. + */ + public final Type addLoadImmediate(Register register, int value) { + return append(new Instruction(Opcodes.LI, register).addSigned(value)); + } + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value equals {@code value}. + */ + public final Type addJumpIfR0Equals(long val, String tgt) { + return append(new Instruction(Opcodes.JEQ).addTwosCompUnsigned(val).setTargetLabel(tgt)); + } + + /** + * Add instructions to the end of the program to increase counter and drop packet if R0 equals + * {@code val} + * WARNING: may modify R1 + */ + public abstract Type addCountAndDropIfR0Equals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException; + + /** + * Add instructions to the end of the program to increase counter and pass packet if R0 equals + * {@code val} + * WARNING: may modify R1 + */ + public abstract Type addCountAndPassIfR0Equals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException; + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value does not equal {@code value}. + */ + public final Type addJumpIfR0NotEquals(long val, String tgt) { + return append(new Instruction(Opcodes.JNE).addTwosCompUnsigned(val).setTargetLabel(tgt)); + } + + /** + * Add instructions to the end of the program to increase counter and drop packet if R0 not + * equals {@code val} + * WARNING: may modify R1 + */ + public abstract Type addCountAndDropIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException; + + /** + * Add instructions to the end of the program to increase counter and pass packet if R0 not + * equals {@code val} + * WARNING: may modify R1 + */ + public abstract Type addCountAndPassIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException; + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value is greater than {@code value}. + */ + public final Type addJumpIfR0GreaterThan(long val, String tgt) { + return append(new Instruction(Opcodes.JGT).addUnsigned(val).setTargetLabel(tgt)); + } + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value is less than {@code value}. + */ + public final Type addJumpIfR0LessThan(long val, String tgt) { + return append(new Instruction(Opcodes.JLT).addUnsigned(val).setTargetLabel(tgt)); + } + + /** + * Add instructions to the end of the program to increase counter and drop packet if R0 less + * than {@code val} + * WARNING: may modify R1 + */ + public abstract Type addCountAndDropIfR0LessThan(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException; + + /** + * Add instructions to the end of the program to increase counter and pass packet if R0 less + * than {@code val} + * WARNING: may modify R1 + */ + public abstract Type addCountAndPassIfR0LessThan(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException; + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value has any bits set that are also set in {@code value}. + */ + public final Type addJumpIfR0AnyBitsSet(long val, String tgt) { + return append(new Instruction(Opcodes.JSET).addTwosCompUnsigned(val).setTargetLabel(tgt)); + } + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value equals register R1's value. + */ + public final Type addJumpIfR0EqualsR1(String tgt) { + return append(new Instruction(Opcodes.JEQ, R1).setTargetLabel(tgt)); + } + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value does not equal register R1's value. + */ + public final Type addJumpIfR0NotEqualsR1(String tgt) { + return append(new Instruction(Opcodes.JNE, R1).setTargetLabel(tgt)); + } + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value is greater than register R1's value. + */ + public final Type addJumpIfR0GreaterThanR1(String tgt) { + return append(new Instruction(Opcodes.JGT, R1).setTargetLabel(tgt)); + } + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value is less than register R1's value. + */ + public final Type addJumpIfR0LessThanR1(String target) { + return append(new Instruction(Opcodes.JLT, R1).setTargetLabel(target)); + } + + /** + * Add an instruction to the end of the program to jump to {@code target} if register R0's + * value has any bits set that are also set in R1's value. + */ + public final Type addJumpIfR0AnyBitsSetR1(String tgt) { + return append(new Instruction(Opcodes.JSET, R1).setTargetLabel(tgt)); + } + + /** + * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the + * packet at an offset specified by register0 don't match {@code bytes}. + * R=0 means check for not equal. + */ + public final Type addJumpIfBytesAtR0NotEqual(byte[] bytes, String tgt) { + return append(new Instruction(Opcodes.JNEBS).addUnsigned( + bytes.length).setTargetLabel(tgt).setBytesImm(bytes)); + } + + /** + * Add instructions to the end of the program to increase counter and drop packet if the + * bytes of the packet at an offset specified by register0 don't match {@code bytes}. + * WARNING: may modify R1 + */ + public abstract Type addCountAndDropIfBytesAtR0NotEqual(byte[] bytes, + ApfCounterTracker.Counter cnt) throws IllegalInstructionException; + + /** + * Add an instruction to the end of the program to load memory slot {@code slot} into + * {@code register}. + */ + public final Type addLoadFromMemory(Register r, int slot) + throws IllegalInstructionException { + return append(new BaseApfGenerator.Instruction(ExtendedOpcodes.LDM, slot, r)); + } + + /** + * Add an instruction to the end of the program to store {@code register} into memory slot + * {@code slot}. + */ + public final Type addStoreToMemory(Register r, int slot) + throws IllegalInstructionException { + return append(new Instruction(ExtendedOpcodes.STM, slot, r)); + } + + /** + * Add an instruction to the end of the program to logically not {@code register}. + */ + public final Type addNot(Register r) { + return append(new Instruction(ExtendedOpcodes.NOT, r)); + } + + /** + * Add an instruction to the end of the program to negate {@code register}. + */ + public final Type addNeg(Register r) { + return append(new Instruction(ExtendedOpcodes.NEG, r)); + } + + /** + * Add an instruction to swap the values in register R0 and register R1. + */ + public final Type addSwap() { + return append(new Instruction(ExtendedOpcodes.SWAP)); + } + + /** + * Add an instruction to the end of the program to move the value into + * {@code register} from the other register. + */ + public final Type addMove(Register r) { + return append(new Instruction(ExtendedOpcodes.MOVE, r)); + } + + /** + * Add an instruction to the end of the program to let the program immediately return PASS. + */ + public final Type addPass() { + // PASS requires using Rbit0 because it shares opcode with DROP + return append(new Instruction(Opcodes.PASSDROP, Rbit0)); + } + + /** + * Abstract method for adding instructions to increment the counter and return PASS. + */ + public abstract Type addCountAndPass(ApfCounterTracker.Counter counter); + + /** + * Abstract method for adding instructions to increment the counter and return DROP. + */ + public abstract Type addCountAndDrop(ApfCounterTracker.Counter counter); + + /** + * Add an instruction to the end of the program to load 32 bits from the data memory into + * {@code register}. The source address is computed by adding the signed immediate + * @{code offset} to the other register. + * Requires APF v4 or greater. + */ + public final Type addLoadData(Register dst, int ofs) + throws IllegalInstructionException { + requireApfVersion(APF_VERSION_4); + return append(new Instruction(Opcodes.LDDW, dst).addSigned(ofs)); + } + + /** + * Add an instruction to the end of the program to store 32 bits from {@code register} into the + * data memory. The destination address is computed by adding the signed immediate + * @{code offset} to the other register. + * Requires APF v4 or greater. + */ + public final Type addStoreData(Register src, int ofs) + throws IllegalInstructionException { + requireApfVersion(APF_VERSION_4); + return append(new Instruction(Opcodes.STDW, src).addSigned(ofs)); + } + + + /** + * The abstract method to generate count trampoline instructions. + * @return + * @throws IllegalInstructionException + */ + public abstract Type addCountTrampoline() throws IllegalInstructionException; +} +
diff --git a/src/android/net/apf/ApfV6Generator.java b/src/android/net/apf/ApfV6Generator.java index 40f5778..da624b2 100644 --- a/src/android/net/apf/ApfV6Generator.java +++ b/src/android/net/apf/ApfV6Generator.java
@@ -15,383 +15,131 @@ */ package android.net.apf; -import static android.net.apf.BaseApfGenerator.Rbit.Rbit0; -import static android.net.apf.BaseApfGenerator.Rbit.Rbit1; +import static android.net.apf.BaseApfGenerator.Register.R1; -import androidx.annotation.NonNull; - -import com.android.net.module.util.HexDump; +import com.android.internal.annotations.VisibleForTesting; /** * APFv6 assembler/generator. A tool for generating an APFv6 program. * * @hide */ -public class ApfV6Generator extends ApfV4Generator<ApfV6Generator> { - +public final class ApfV6Generator extends ApfV6GeneratorBase<ApfV6Generator> { /** * Creates an ApfV6Generator instance which is able to emit instructions for the specified * {@code version} of the APF interpreter. Throws {@code IllegalInstructionException} if * the requested version is unsupported. - * */ + @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE) public ApfV6Generator() throws IllegalInstructionException { - super(MIN_APF_VERSION_IN_DEV); + super(); + } + + @Override + void addArithR1(Opcodes opcode) { + append(new Instruction(opcode, R1)); } /** * Add an instruction to the end of the program to increment the counter value and * immediately return PASS. + * + * @param counter the counter enum to be incremented. */ - public ApfV6Generator addCountAndPass(int cnt) { - checkRange("CounterNumber", cnt /* value */, 1 /* lowerBound */, - 1000 /* upperBound */); - // PASS requires using Rbit0 because it shares opcode with DROP - return append(new Instruction(Opcodes.PASSDROP, Rbit0).addUnsigned(cnt)); - } - - /** - * Add an instruction to the end of the program to let the program immediately return DROP. - */ - public ApfV6Generator addDrop() { - // DROP requires using Rbit1 because it shares opcode with PASS - return append(new Instruction(Opcodes.PASSDROP, Rbit1)); + @Override + public ApfV6Generator addCountAndPass(ApfCounterTracker.Counter counter) { + checkPassCounterRange(counter); + return addCountAndPass(counter.value()); } /** * Add an instruction to the end of the program to increment the counter value and * immediately return DROP. - */ - public ApfV6Generator addCountAndDrop(int cnt) { - checkRange("CounterNumber", cnt /* value */, 1 /* lowerBound */, - 1000 /* upperBound */); - // DROP requires using Rbit1 because it shares opcode with PASS - return append(new Instruction(Opcodes.PASSDROP, Rbit1).addUnsigned(cnt)); - } - - /** - * Add an instruction to the end of the program to call the apf_allocate_buffer() function. - * Buffer length to be allocated is stored in register 0. - */ - public ApfV6Generator addAllocateR0() { - return append(new Instruction(ExtendedOpcodes.ALLOCATE)); - } - - /** - * Add an instruction to the end of the program to call the apf_allocate_buffer() function. * - * @param size the buffer length to be allocated. + * @param counter the counter enum to be incremented. */ - public ApfV6Generator addAllocate(int size) { - // Rbit1 means the extra be16 immediate is present - return append(new Instruction(ExtendedOpcodes.ALLOCATE, Rbit1).addU16(size)); + @Override + public ApfV6Generator addCountAndDrop(ApfCounterTracker.Counter counter) { + checkDropCounterRange(counter); + return addCountAndDrop(counter.value()); } - /** - * Add an instruction to the beginning of the program to reserve the data region. - * @param data the actual data byte - */ - public ApfV6Generator addData(byte[] data) throws IllegalInstructionException { - if (!mInstructions.isEmpty()) { - throw new IllegalInstructionException("data instruction has to come first"); + @Override + public ApfV6Generator addCountAndDropIfR0Equals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException { + checkDropCounterRange(cnt); + final String tgt = getUniqueLabel(); + return addJumpIfR0NotEquals(val, tgt).addCountAndDrop(cnt).defineLabel(tgt); + } + + @Override + public ApfV6Generator addCountAndPassIfR0Equals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException { + checkPassCounterRange(cnt); + final String tgt = getUniqueLabel(); + return addJumpIfR0NotEquals(val, tgt).addCountAndPass(cnt).defineLabel(tgt); + } + + @Override + public ApfV6Generator addCountAndDropIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException { + checkDropCounterRange(cnt); + final String tgt = getUniqueLabel(); + return addJumpIfR0Equals(val, tgt).addCountAndDrop(cnt).defineLabel(tgt); + } + + @Override + public ApfV6Generator addCountAndPassIfR0NotEquals(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException { + checkPassCounterRange(cnt); + final String tgt = getUniqueLabel(); + return addJumpIfR0Equals(val, tgt).addCountAndPass(cnt).defineLabel(tgt); + } + + @Override + public ApfV6Generator addCountAndDropIfR0LessThan(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException { + checkDropCounterRange(cnt); + if (val <= 0) { + throw new IllegalArgumentException("val must > 0, current val: " + val); } - return append(new Instruction(Opcodes.JMP, Rbit1).addUnsigned(data.length) - .setBytesImm(data)); + final String tgt = getUniqueLabel(); + return addJumpIfR0GreaterThan(val - 1, tgt).addCountAndDrop(cnt).defineLabel(tgt); } - /** - * Add an instruction to the end of the program to transmit the allocated buffer without - * checksum. - */ - public ApfV6Generator addTransmitWithoutChecksum() { - return addTransmit(-1 /* ipOfs */); - } - - /** - * Add an instruction to the end of the program to transmit the allocated buffer. - */ - public ApfV6Generator addTransmit(int ipOfs) { - if (ipOfs >= 255) { - throw new IllegalArgumentException("IP offset of " + ipOfs + " must be < 255"); + @Override + public ApfV6Generator addCountAndPassIfR0LessThan(long val, ApfCounterTracker.Counter cnt) + throws IllegalInstructionException { + checkPassCounterRange(cnt); + if (val <= 0) { + throw new IllegalArgumentException("val must > 0, current val: " + val); } - if (ipOfs == -1) ipOfs = 255; - return append(new Instruction(ExtendedOpcodes.TRANSMIT, Rbit0).addU8(ipOfs).addU8(255)); + final String tgt = getUniqueLabel(); + return addJumpIfR0GreaterThan(val - 1, tgt).addCountAndPass(cnt).defineLabel(tgt); + } + + @Override + public ApfV6Generator addCountAndDropIfBytesAtR0NotEqual(byte[] bytes, + ApfCounterTracker.Counter cnt) throws IllegalInstructionException { + checkDropCounterRange(cnt); + final String tgt = getUniqueLabel(); + return addJumpIfBytesAtR0Equal(bytes, tgt).addCountAndDrop(cnt).defineLabel(tgt); + } + + private int mLabelCount = 0; + + /** + * Return a unique label string. + */ + private String getUniqueLabel() { + return "LABEL_" + mLabelCount++; } /** - * Add an instruction to the end of the program to transmit the allocated buffer. + * This method is noop in APFv6. */ - public ApfV6Generator addTransmitL4(int ipOfs, int csumOfs, int csumStart, int partialCsum, - boolean isUdp) { - if (ipOfs >= 255) { - throw new IllegalArgumentException("IP offset of " + ipOfs + " must be < 255"); - } - if (ipOfs == -1) ipOfs = 255; - if (csumOfs >= 255) { - throw new IllegalArgumentException("L4 checksum requires csum offset of " - + csumOfs + " < 255"); - } - return append(new Instruction(ExtendedOpcodes.TRANSMIT, isUdp ? Rbit1 : Rbit0) - .addU8(ipOfs).addU8(csumOfs).addU8(csumStart).addU16(partialCsum)); - } - - /** - * Add an instruction to the end of the program to write 1 byte value to output buffer. - */ - public ApfV6Generator addWriteU8(int val) { - return append(new Instruction(Opcodes.WRITE).overrideLenField(1).addU8(val)); - } - - /** - * Add an instruction to the end of the program to write 2 bytes value to output buffer. - */ - public ApfV6Generator addWriteU16(int val) { - return append(new Instruction(Opcodes.WRITE).overrideLenField(2).addU16(val)); - } - - /** - * Add an instruction to the end of the program to write 4 bytes value to output buffer. - */ - public ApfV6Generator addWriteU32(long val) { - return append(new Instruction(Opcodes.WRITE).overrideLenField(4).addU32(val)); - } - - /** - * Add an instruction to the end of the program to write 1 byte value from register to output - * buffer. - */ - public ApfV6Generator addWriteU8(Register reg) { - return append(new Instruction(ExtendedOpcodes.EWRITE1, reg)); - } - - /** - * Add an instruction to the end of the program to write 2 byte value from register to output - * buffer. - */ - public ApfV6Generator addWriteU16(Register reg) { - return append(new Instruction(ExtendedOpcodes.EWRITE2, reg)); - } - - /** - * Add an instruction to the end of the program to write 4 byte value from register to output - * buffer. - */ - public ApfV6Generator addWriteU32(Register reg) { - return append(new Instruction(ExtendedOpcodes.EWRITE4, reg)); - } - - /** - * Add an instruction to the end of the program to copy data from APF program/data region to - * output buffer and auto-increment the output buffer pointer. - * - * @param src the offset inside the APF program/data region for where to start copy. - * @param len the length of bytes needed to be copied, only <= 255 bytes can be copied at - * one time. - * @return the ApfV6Generator object - */ - public ApfV6Generator addDataCopy(int src, int len) { - return append(new Instruction(Opcodes.PKTDATACOPY, Rbit1).addDataOffset(src).addU8(len)); - } - - /** - * Add an instruction to the end of the program to copy data from input packet to output - * buffer and auto-increment the output buffer pointer. - * - * @param src the offset inside the input packet for where to start copy. - * @param len the length of bytes needed to be copied, only <= 255 bytes can be copied at - * one time. - * @return the ApfV6Generator object - */ - public ApfV6Generator addPacketCopy(int src, int len) { - return append(new Instruction(Opcodes.PKTDATACOPY, Rbit0).addPacketOffset(src).addU8(len)); - } - - /** - * Add an instruction to the end of the program to copy data from APF program/data region to - * output buffer and auto-increment the output buffer pointer. - * Source offset is stored in R0. - * - * @param len the number of bytes to be copied, only <= 255 bytes can be copied at once. - * @return the ApfV6Generator object - */ - public ApfV6Generator addDataCopyFromR0(int len) { - return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYIMM, Rbit1).addU8(len)); - } - - /** - * Add an instruction to the end of the program to copy data from input packet to output - * buffer and auto-increment the output buffer pointer. - * Source offset is stored in R0. - * - * @param len the number of bytes to be copied, only <= 255 bytes can be copied at once. - * @return the ApfV6Generator object - */ - public ApfV6Generator addPacketCopyFromR0(int len) { - return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYIMM, Rbit0).addU8(len)); - } - - /** - * Add an instruction to the end of the program to copy data from APF program/data region to - * output buffer and auto-increment the output buffer pointer. - * Source offset is stored in R0. - * Copy length is stored in R1. - * - * @return the ApfV6Generator object - */ - public ApfV6Generator addDataCopyFromR0LenR1() { - return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYR1, Rbit1)); - } - - /** - * Add an instruction to the end of the program to copy data from input packet to output - * buffer and auto-increment the output buffer pointer. - * Source offset is stored in R0. - * Copy length is stored in R1. - * - * @return the ApfV6Generator object - */ - public ApfV6Generator addPacketCopyFromR0LenR1() { - return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYR1, Rbit0)); - } - - /** - * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP - * payload's DNS questions do NOT contain the QNAMEs specified in {@code qnames} and qtype - * equals {@code qtype}. Examines the payload starting at the offset in R0. - * R = 0 means check for "does not contain". - * Drops packets if packets are corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0DoesNotContainDnsQ(@NonNull byte[] qnames, int qtype, - @NonNull String tgt) { - validateNames(qnames); - return append(new Instruction(ExtendedOpcodes.JDNSQMATCH, Rbit0).setTargetLabel(tgt).addU8( - qtype).setBytesImm(qnames)); - } - - /** - * Same as {@link #addJumpIfPktAtR0DoesNotContainDnsQ} except passes packets if packets are - * corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0DoesNotContainDnsQSafe(@NonNull byte[] qnames, int qtype, - @NonNull String tgt) { - validateNames(qnames); - return append(new Instruction(ExtendedOpcodes.JDNSQMATCHSAFE, Rbit0).setTargetLabel( - tgt).addU8(qtype).setBytesImm(qnames)); - } - - /** - * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP - * payload's DNS questions contain the QNAMEs specified in {@code qnames} and qtype - * equals {@code qtype}. Examines the payload starting at the offset in R0. - * R = 1 means check for "contain". - * Drops packets if packets are corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0ContainDnsQ(@NonNull byte[] qnames, int qtype, - @NonNull String tgt) { - validateNames(qnames); - return append(new Instruction(ExtendedOpcodes.JDNSQMATCH, Rbit1).setTargetLabel(tgt).addU8( - qtype).setBytesImm(qnames)); - } - - /** - * Same as {@link #addJumpIfPktAtR0ContainDnsQ} except passes packets if packets are - * corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0ContainDnsQSafe(@NonNull byte[] qnames, int qtype, - @NonNull String tgt) { - validateNames(qnames); - return append(new Instruction(ExtendedOpcodes.JDNSQMATCHSAFE, Rbit1).setTargetLabel( - tgt).addU8(qtype).setBytesImm(qnames)); - } - - /** - * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP - * payload's DNS answers/authority/additional records do NOT contain the NAMEs - * specified in {@code Names}. Examines the payload starting at the offset in R0. - * R = 0 means check for "does not contain". - * Drops packets if packets are corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0DoesNotContainDnsA(@NonNull byte[] names, - @NonNull String tgt) { - validateNames(names); - return append(new Instruction(ExtendedOpcodes.JDNSAMATCH, Rbit0).setTargetLabel(tgt) - .setBytesImm(names)); - } - - /** - * Same as {@link #addJumpIfPktAtR0DoesNotContainDnsA} except passes packets if packets are - * corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0DoesNotContainDnsASafe(@NonNull byte[] names, - @NonNull String tgt) { - validateNames(names); - return append(new Instruction(ExtendedOpcodes.JDNSAMATCHSAFE, Rbit0).setTargetLabel(tgt) - .setBytesImm(names)); - } - - /** - * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP - * payload's DNS answers/authority/additional records contain the NAMEs - * specified in {@code Names}. Examines the payload starting at the offset in R0. - * R = 1 means check for "contain". - * Drops packets if packets are corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0ContainDnsA(@NonNull byte[] names, - @NonNull String tgt) { - validateNames(names); - return append(new Instruction(ExtendedOpcodes.JDNSAMATCH, Rbit1).setTargetLabel( - tgt).setBytesImm(names)); - } - - /** - * Same as {@link #addJumpIfPktAtR0ContainDnsA} except passes packets if packets are - * corrupted. - */ - public ApfV6Generator addJumpIfPktAtR0ContainDnsASafe(@NonNull byte[] names, - @NonNull String tgt) { - validateNames(names); - return append(new Instruction(ExtendedOpcodes.JDNSAMATCHSAFE, Rbit1).setTargetLabel( - tgt).setBytesImm(names)); - } - - /** - * Check if the byte is valid dns character: A-Z,0-9,-,_ - */ - private static boolean isValidDnsCharacter(byte c) { - return (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' || c == '%'; - } - - private static void validateNames(@NonNull byte[] names) { - final int len = names.length; - if (len < 4) { - throw new IllegalArgumentException("qnames must have at least length 4"); - } - final String errorMessage = "qname: " + HexDump.toHexString(names) - + "is not null-terminated list of TLV-encoded names"; - int i = 0; - while (i < len - 1) { - int label_len = names[i++]; - // byte == 0xff means it is a '*' wildcard - if (label_len == -1) continue; - if (label_len < 1 || label_len > 63) { - throw new IllegalArgumentException( - "label len: " + label_len + " must be between 1 and 63"); - } - if (i + label_len >= len - 1) { - throw new IllegalArgumentException(errorMessage); - } - while (label_len-- > 0) { - if (!isValidDnsCharacter(names[i++])) { - throw new IllegalArgumentException("qname: " + HexDump.toHexString(names) - + " contains invalid character"); - } - } - if (names[i] == 0) { - i++; // skip null terminator. - } - } - if (names[len - 1] != 0) { - throw new IllegalArgumentException(errorMessage); - } + @Override + public ApfV6Generator addCountTrampoline() { + return self(); } }
diff --git a/src/android/net/apf/ApfV6GeneratorBase.java b/src/android/net/apf/ApfV6GeneratorBase.java new file mode 100644 index 0000000..99f07c2 --- /dev/null +++ b/src/android/net/apf/ApfV6GeneratorBase.java
@@ -0,0 +1,472 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.net.apf; + +import static android.net.apf.BaseApfGenerator.Rbit.Rbit0; +import static android.net.apf.BaseApfGenerator.Rbit.Rbit1; +import static android.net.apf.BaseApfGenerator.Register.R1; + +import androidx.annotation.NonNull; + +import com.android.net.module.util.HexDump; + +import java.util.Objects; + +/** + * The abstract class for APFv6 assembler/generator. + * + * @param <Type> the generator class + * + * @hide + */ +public abstract class ApfV6GeneratorBase<Type extends ApfV6GeneratorBase<Type>> extends + ApfV4GeneratorBase<Type> { + + // We have not *yet* switched to APFv6 mode (see addData), + // and are thus still in APFv2/4 backward compatibility mode. + boolean mIsV6 = false; + + /** + * Creates an ApfV6GeneratorBase instance which is able to emit instructions for the specified + * {@code version} of the APF interpreter. Throws {@code IllegalInstructionException} if + * the requested version is unsupported. + * + */ + public ApfV6GeneratorBase() throws IllegalInstructionException { + super(MIN_APF_VERSION_IN_DEV); + } + + /** + * Add an instruction to the end of the program to increment the counter value and + * immediately return PASS. + * + * @param cnt the counter number to be incremented. + */ + public final Type addCountAndPass(int cnt) { + checkRange("CounterNumber", cnt /* value */, 1 /* lowerBound */, + 1000 /* upperBound */); + // PASS requires using Rbit0 because it shares opcode with DROP + return append(new Instruction(Opcodes.PASSDROP, Rbit0).addUnsigned(cnt)); + } + + /** + * Add an instruction to the end of the program to let the program immediately return DROP. + */ + public final Type addDrop() { + // DROP requires using Rbit1 because it shares opcode with PASS + return append(new Instruction(Opcodes.PASSDROP, Rbit1)); + } + + /** + * Add an instruction to the end of the program to increment the counter value and + * immediately return DROP. + * + * @param cnt the counter number to be incremented. + */ + public final Type addCountAndDrop(int cnt) { + checkRange("CounterNumber", cnt /* value */, 1 /* lowerBound */, + 1000 /* upperBound */); + // DROP requires using Rbit1 because it shares opcode with PASS + return append(new Instruction(Opcodes.PASSDROP, Rbit1).addUnsigned(cnt)); + } + + /** + * Add an instruction to the end of the program to call the apf_allocate_buffer() function. + * Buffer length to be allocated is stored in register 0. + */ + public final Type addAllocateR0() { + return append(new Instruction(ExtendedOpcodes.ALLOCATE)); + } + + /** + * Add an instruction to the end of the program to call the apf_allocate_buffer() function. + * + * @param size the buffer length to be allocated. + */ + public final Type addAllocate(int size) { + // Rbit1 means the extra be16 immediate is present + return append(new Instruction(ExtendedOpcodes.ALLOCATE, Rbit1).addU16(size)); + } + + /** + * Add an instruction to the beginning of the program to reserve the empty data region. + */ + public final Type addData() throws IllegalInstructionException { + return addData(new byte[0]); + } + + /** + * Add an instruction to the beginning of the program to reserve the data region. + * @param data the actual data byte + */ + public final Type addData(byte[] data) throws IllegalInstructionException { + if (!mInstructions.isEmpty()) { + throw new IllegalInstructionException("data instruction has to come first"); + } + if (data.length > 65535) { + throw new IllegalArgumentException("data size larger than 65535"); + } + mIsV6 = true; + return append(new Instruction(Opcodes.JMP, Rbit1).addUnsigned(data.length) + .setBytesImm(data).overrideImmSize(2)); + } + + /** + * Add an instruction to the end of the program to transmit the allocated buffer without + * checksum. + */ + public final Type addTransmitWithoutChecksum() { + return addTransmit(-1 /* ipOfs */); + } + + /** + * Add an instruction to the end of the program to transmit the allocated buffer. + */ + public final Type addTransmit(int ipOfs) { + if (ipOfs >= 255) { + throw new IllegalArgumentException("IP offset of " + ipOfs + " must be < 255"); + } + if (ipOfs == -1) ipOfs = 255; + return append(new Instruction(ExtendedOpcodes.TRANSMIT, Rbit0).addU8(ipOfs).addU8(255)); + } + + /** + * Add an instruction to the end of the program to transmit the allocated buffer. + */ + public final Type addTransmitL4(int ipOfs, int csumOfs, int csumStart, int partialCsum, + boolean isUdp) { + if (ipOfs >= 255) { + throw new IllegalArgumentException("IP offset of " + ipOfs + " must be < 255"); + } + if (ipOfs == -1) ipOfs = 255; + if (csumOfs >= 255) { + throw new IllegalArgumentException("L4 checksum requires csum offset of " + + csumOfs + " < 255"); + } + return append(new Instruction(ExtendedOpcodes.TRANSMIT, isUdp ? Rbit1 : Rbit0) + .addU8(ipOfs).addU8(csumOfs).addU8(csumStart).addU16(partialCsum)); + } + + /** + * Add an instruction to the end of the program to write 1 byte value to output buffer. + */ + public final Type addWriteU8(int val) { + return append(new Instruction(Opcodes.WRITE).overrideImmSize(1).addU8(val)); + } + + /** + * Add an instruction to the end of the program to write 2 bytes value to output buffer. + */ + public final Type addWriteU16(int val) { + return append(new Instruction(Opcodes.WRITE).overrideImmSize(2).addU16(val)); + } + + /** + * Add an instruction to the end of the program to write 4 bytes value to output buffer. + */ + public final Type addWriteU32(long val) { + return append(new Instruction(Opcodes.WRITE).overrideImmSize(4).addU32(val)); + } + + /** + * Add an instruction to the end of the program to encode int value as 4 bytes to output buffer. + */ + public final Type addWrite32(int val) { + return addWriteU32((long) val & 0xffffffffL); + } + + /** + * Add an instruction to the end of the program to write 4 bytes array to output buffer. + */ + public final Type addWrite32(@NonNull byte[] bytes) { + Objects.requireNonNull(bytes); + if (bytes.length != 4) { + throw new IllegalArgumentException( + "bytes array size must be 4, current size: " + bytes.length); + } + return addWrite32(((bytes[0] & 0xff) << 24) + | ((bytes[1] & 0xff) << 16) + | ((bytes[2] & 0xff) << 8) + | (bytes[3] & 0xff)); + } + + /** + * Add an instruction to the end of the program to write 1 byte value from register to output + * buffer. + */ + public final Type addWriteU8(Register reg) { + return append(new Instruction(ExtendedOpcodes.EWRITE1, reg)); + } + + /** + * Add an instruction to the end of the program to write 2 byte value from register to output + * buffer. + */ + public final Type addWriteU16(Register reg) { + return append(new Instruction(ExtendedOpcodes.EWRITE2, reg)); + } + + /** + * Add an instruction to the end of the program to write 4 byte value from register to output + * buffer. + */ + public final Type addWriteU32(Register reg) { + return append(new Instruction(ExtendedOpcodes.EWRITE4, reg)); + } + + /** + * Add an instruction to the end of the program to copy data from APF program/data region to + * output buffer and auto-increment the output buffer pointer. + * This method requires the {@code addData} method to be called beforehand. + * It will first attempt to match {@code content} with existing data bytes. If not exist, then + * append the {@code content} to the data bytes. + */ + public final Type addDataCopy(@NonNull byte[] content) throws IllegalInstructionException { + if (mInstructions.isEmpty()) { + throw new IllegalInstructionException("There is no instructions"); + } + Objects.requireNonNull(content); + int copySrc = mInstructions.get(0).maybeUpdateBytesImm(content); + return addDataCopy(copySrc, content.length); + } + + /** + * Add an instruction to the end of the program to copy data from APF program/data region to + * output buffer and auto-increment the output buffer pointer. + * + * @param src the offset inside the APF program/data region for where to start copy. + * @param len the length of bytes needed to be copied, only <= 255 bytes can be copied at + * one time. + * @return the Type object + */ + public final Type addDataCopy(int src, int len) { + return append(new Instruction(Opcodes.PKTDATACOPY, Rbit1).addDataOffset(src).addU8(len)); + } + + /** + * Add an instruction to the end of the program to copy data from input packet to output + * buffer and auto-increment the output buffer pointer. + * + * @param src the offset inside the input packet for where to start copy. + * @param len the length of bytes needed to be copied, only <= 255 bytes can be copied at + * one time. + * @return the Type object + */ + public final Type addPacketCopy(int src, int len) { + return append(new Instruction(Opcodes.PKTDATACOPY, Rbit0).addPacketOffset(src).addU8(len)); + } + + /** + * Add an instruction to the end of the program to copy data from APF program/data region to + * output buffer and auto-increment the output buffer pointer. + * Source offset is stored in R0. + * + * @param len the number of bytes to be copied, only <= 255 bytes can be copied at once. + * @return the Type object + */ + public final Type addDataCopyFromR0(int len) { + return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYIMM, Rbit1).addU8(len)); + } + + /** + * Add an instruction to the end of the program to copy data from input packet to output + * buffer and auto-increment the output buffer pointer. + * Source offset is stored in R0. + * + * @param len the number of bytes to be copied, only <= 255 bytes can be copied at once. + * @return the Type object + */ + public final Type addPacketCopyFromR0(int len) { + return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYIMM, Rbit0).addU8(len)); + } + + /** + * Add an instruction to the end of the program to copy data from APF program/data region to + * output buffer and auto-increment the output buffer pointer. + * Source offset is stored in R0. + * Copy length is stored in R1. + * + * @return the Type object + */ + public final Type addDataCopyFromR0LenR1() { + return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYR1, Rbit1)); + } + + /** + * Add an instruction to the end of the program to copy data from input packet to output + * buffer and auto-increment the output buffer pointer. + * Source offset is stored in R0. + * Copy length is stored in R1. + * + * @return the Type object + */ + public final Type addPacketCopyFromR0LenR1() { + return append(new Instruction(ExtendedOpcodes.EPKTDATACOPYR1, Rbit0)); + } + + /** + * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP + * payload's DNS questions do NOT contain the QNAMEs specified in {@code qnames} and qtype + * equals {@code qtype}. Examines the payload starting at the offset in R0. + * R = 0 means check for "does not contain". + * Drops packets if packets are corrupted. + */ + public final Type addJumpIfPktAtR0DoesNotContainDnsQ(@NonNull byte[] qnames, int qtype, + @NonNull String tgt) { + validateNames(qnames); + return append(new Instruction(ExtendedOpcodes.JDNSQMATCH, Rbit0).setTargetLabel(tgt).addU8( + qtype).setBytesImm(qnames)); + } + + /** + * Same as {@link #addJumpIfPktAtR0DoesNotContainDnsQ} except passes packets if packets are + * corrupted. + */ + public final Type addJumpIfPktAtR0DoesNotContainDnsQSafe(@NonNull byte[] qnames, int qtype, + @NonNull String tgt) { + validateNames(qnames); + return append(new Instruction(ExtendedOpcodes.JDNSQMATCHSAFE, Rbit0).setTargetLabel( + tgt).addU8(qtype).setBytesImm(qnames)); + } + + /** + * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP + * payload's DNS questions contain the QNAMEs specified in {@code qnames} and qtype + * equals {@code qtype}. Examines the payload starting at the offset in R0. + * R = 1 means check for "contain". + * Drops packets if packets are corrupted. + */ + public final Type addJumpIfPktAtR0ContainDnsQ(@NonNull byte[] qnames, int qtype, + @NonNull String tgt) { + validateNames(qnames); + return append(new Instruction(ExtendedOpcodes.JDNSQMATCH, Rbit1).setTargetLabel(tgt).addU8( + qtype).setBytesImm(qnames)); + } + + /** + * Same as {@link #addJumpIfPktAtR0ContainDnsQ} except passes packets if packets are + * corrupted. + */ + public final Type addJumpIfPktAtR0ContainDnsQSafe(@NonNull byte[] qnames, int qtype, + @NonNull String tgt) { + validateNames(qnames); + return append(new Instruction(ExtendedOpcodes.JDNSQMATCHSAFE, Rbit1).setTargetLabel( + tgt).addU8(qtype).setBytesImm(qnames)); + } + + /** + * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP + * payload's DNS answers/authority/additional records do NOT contain the NAMEs + * specified in {@code Names}. Examines the payload starting at the offset in R0. + * R = 0 means check for "does not contain". + * Drops packets if packets are corrupted. + */ + public final Type addJumpIfPktAtR0DoesNotContainDnsA(@NonNull byte[] names, + @NonNull String tgt) { + validateNames(names); + return append(new Instruction(ExtendedOpcodes.JDNSAMATCH, Rbit0).setTargetLabel(tgt) + .setBytesImm(names)); + } + + /** + * Same as {@link #addJumpIfPktAtR0DoesNotContainDnsA} except passes packets if packets are + * corrupted. + */ + public final Type addJumpIfPktAtR0DoesNotContainDnsASafe(@NonNull byte[] names, + @NonNull String tgt) { + validateNames(names); + return append(new Instruction(ExtendedOpcodes.JDNSAMATCHSAFE, Rbit0).setTargetLabel(tgt) + .setBytesImm(names)); + } + + /** + * Appends a conditional jump instruction to the program: Jumps to {@code tgt} if the UDP + * payload's DNS answers/authority/additional records contain the NAMEs + * specified in {@code Names}. Examines the payload starting at the offset in R0. + * R = 1 means check for "contain". + * Drops packets if packets are corrupted. + */ + public final Type addJumpIfPktAtR0ContainDnsA(@NonNull byte[] names, + @NonNull String tgt) { + validateNames(names); + return append(new Instruction(ExtendedOpcodes.JDNSAMATCH, Rbit1).setTargetLabel( + tgt).setBytesImm(names)); + } + + /** + * Same as {@link #addJumpIfPktAtR0ContainDnsA} except passes packets if packets are + * corrupted. + */ + public final Type addJumpIfPktAtR0ContainDnsASafe(@NonNull byte[] names, + @NonNull String tgt) { + validateNames(names); + return append(new Instruction(ExtendedOpcodes.JDNSAMATCHSAFE, Rbit1).setTargetLabel( + tgt).setBytesImm(names)); + } + + /** + * Add an instruction to the end of the program to jump to {@code tgt} if the bytes of the + * packet at an offset specified by register0 match {@code bytes}. + * R=1 means check for equal. + */ + public final Type addJumpIfBytesAtR0Equal(byte[] bytes, String tgt) + throws IllegalInstructionException { + return append(new Instruction(Opcodes.JNEBS, R1).addUnsigned( + bytes.length).setTargetLabel(tgt).setBytesImm(bytes)); + } + + + /** + * Check if the byte is valid dns character: A-Z,0-9,-,_ + */ + private static boolean isValidDnsCharacter(byte c) { + return (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' || c == '%'; + } + + private static void validateNames(@NonNull byte[] names) { + final int len = names.length; + if (len < 4) { + throw new IllegalArgumentException("qnames must have at least length 4"); + } + final String errorMessage = "qname: " + HexDump.toHexString(names) + + "is not null-terminated list of TLV-encoded names"; + int i = 0; + while (i < len - 1) { + int label_len = names[i++]; + // byte == 0xff means it is a '*' wildcard + if (label_len == -1) continue; + if (label_len < 1 || label_len > 63) { + throw new IllegalArgumentException( + "label len: " + label_len + " must be between 1 and 63"); + } + if (i + label_len >= len - 1) { + throw new IllegalArgumentException(errorMessage); + } + while (label_len-- > 0) { + if (!isValidDnsCharacter(names[i++])) { + throw new IllegalArgumentException("qname: " + HexDump.toHexString(names) + + " contains invalid character"); + } + } + if (names[i] == 0) { + i++; // skip null terminator. + } + } + if (names[len - 1] != 0) { + throw new IllegalArgumentException(errorMessage); + } + } +}
diff --git a/src/android/net/apf/BaseApfGenerator.java b/src/android/net/apf/BaseApfGenerator.java index 75ef639..859f80b 100644 --- a/src/android/net/apf/BaseApfGenerator.java +++ b/src/android/net/apf/BaseApfGenerator.java
@@ -22,6 +22,10 @@ import androidx.annotation.NonNull; +import com.android.net.module.util.ByteUtils; +import com.android.net.module.util.CollectionUtils; +import com.android.net.module.util.HexDump; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -313,16 +317,16 @@ } class Instruction { - private final Opcodes mOpcode; + public final Opcodes mOpcode; private final Rbit mRbit; public final List<IntImmediate> mIntImms = new ArrayList<>(); // When mOpcode is a jump: private int mTargetLabelSize; - private int mLenFieldOverride = -1; + private int mImmSizeOverride = -1; private String mTargetLabel; // When mOpcode == Opcodes.LABEL: private String mLabel; - private byte[] mBytesImm; + public byte[] mBytesImm; // Offset in bytes from the beginning of this program. // Set by {@link BaseApfGenerator#generate}. int offset; @@ -383,13 +387,12 @@ return addUnsigned(imm); } - Instruction addTwosCompSigned(int imm) { + Instruction addTwosCompSigned(long imm) { mIntImms.add(IntImmediate.newTwosComplementSigned(imm)); return this; } - - Instruction addTwosCompUnsigned(int imm) { + Instruction addTwosCompUnsigned(long imm) { mIntImms.add(IntImmediate.newTwosComplementUnsigned(imm)); return this; } @@ -442,8 +445,8 @@ return this; } - Instruction overrideLenField(int size) { - mLenFieldOverride = size; + Instruction overrideImmSize(int size) { + mImmSizeOverride = size; return this; } @@ -453,6 +456,36 @@ } /** + * Attempts to match {@code content} with existing data bytes. If not exist, then + * append the {@code content} to the data bytes. + * Returns the start offset of the content from the beginning of the program. + */ + int maybeUpdateBytesImm(byte[] content) throws IllegalInstructionException { + if (mOpcode != Opcodes.JMP || mBytesImm == null) { + throw new IllegalInstructionException(String.format( + "maybeUpdateBytesImm() is only valid for jump data instruction, mOpcode " + + ":%s, mBytesImm: %s", Opcodes.JMP, + mBytesImm == null ? "(empty)" : HexDump.toHexString(mBytesImm))); + } + if (mImmSizeOverride != 2) { + throw new IllegalInstructionException( + "mImmSizeOverride must be 2, mImmSizeOverride: " + mImmSizeOverride); + } + int offsetInDataBytes = CollectionUtils.indexOfSubArray(mBytesImm, content); + if (offsetInDataBytes == -1) { + offsetInDataBytes = mBytesImm.length; + mBytesImm = ByteUtils.concat(mBytesImm, content); + // Update the length immediate (first imm) value. Due to mValue within + // IntImmediate being final, we must remove and re-add the value to apply changes. + mIntImms.remove(0); + addDataOffset(mBytesImm.length); + } + // Note that the data instruction encoding consumes 1 byte and the data length + // encoding consumes 2 bytes. + return 1 + mImmSizeOverride + offsetInDataBytes; + } + + /** * @return size of instruction in bytes. */ int size() { @@ -494,21 +527,6 @@ * Assemble value for instruction size field. */ private int generateImmSizeField() { - // If we already know the size the length field, just use it - switch (mLenFieldOverride) { - case -1: - break; - case 1: - return 1; - case 2: - return 2; - case 4: - return 3; - default: - throw new IllegalStateException( - "mLenFieldOverride has invalid value: " + mLenFieldOverride); - } - // Otherwise, calculate int immSize = calculateRequiredIndeterminateSize(); // Encode size field to fit in 2 bits: 0->0, 1->1, 2->2, 3->4. return immSize == 4 ? 3 : immSize; @@ -583,7 +601,23 @@ for (IntImmediate imm : mIntImms) { maxSize = Math.max(maxSize, imm.calculateIndeterminateSize()); } - return maxSize; + if (mImmSizeOverride != -1 && maxSize > mImmSizeOverride) { + throw new IllegalStateException(String.format( + "maxSize: %d should not be greater than mImmSizeOverride: %d", maxSize, + mImmSizeOverride)); + } + // If we already know the size the length field, just use it + switch (mImmSizeOverride) { + case -1: + return maxSize; + case 1: + case 2: + case 4: + return mImmSizeOverride; + default: + throw new IllegalStateException( + "mImmSizeOverride has invalid value: " + mImmSizeOverride); + } } private int calculateTargetLabelOffset() throws IllegalInstructionException { @@ -643,6 +677,26 @@ upperBound)); } + static void checkPassCounterRange(ApfCounterTracker.Counter cnt) { + if (cnt.value() < ApfCounterTracker.MIN_PASS_COUNTER.value() + || cnt.value() > ApfCounterTracker.MAX_PASS_COUNTER.value()) { + throw new IllegalArgumentException( + String.format("Counter %s, is not in range [%s, %s]", cnt, + ApfCounterTracker.MIN_PASS_COUNTER, + ApfCounterTracker.MAX_PASS_COUNTER)); + } + } + + static void checkDropCounterRange(ApfCounterTracker.Counter cnt) { + if (cnt.value() < ApfCounterTracker.MIN_DROP_COUNTER.value() + || cnt.value() > ApfCounterTracker.MAX_DROP_COUNTER.value()) { + throw new IllegalArgumentException( + String.format("Counter %s, is not in range [%s, %s]", cnt, + ApfCounterTracker.MIN_DROP_COUNTER, + ApfCounterTracker.MAX_DROP_COUNTER)); + } + } + /** * Returns an overestimate of the size of the generated program. {@link #generate} may return * a program that is smaller. @@ -773,6 +827,6 @@ private final HashMap<String, Instruction> mLabels = new HashMap<String, Instruction>(); private final Instruction mDropLabel = new Instruction(Opcodes.LABEL); private final Instruction mPassLabel = new Instruction(Opcodes.LABEL); - private final int mVersion; + public final int mVersion; public boolean mGenerated; }
diff --git a/src/android/net/apf/DnsUtils.java b/src/android/net/apf/DnsUtils.java index 4fa02be..5afe1d5 100644 --- a/src/android/net/apf/DnsUtils.java +++ b/src/android/net/apf/DnsUtils.java
@@ -145,7 +145,7 @@ gen.addLoad16Indexed(R0, 0); gen.addAnd(0x3ff); gen.addLoadFromMemory(R1, SLOT_DNS_HEADER_OFFSET); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addLoadFromMemory(R1, SLOT_CURRENT_PARSE_OFFSET); gen.addJumpIfR0EqualsR1(ApfV4Generator.DROP_LABEL); gen.addJumpIfR0GreaterThanR1(ApfV4Generator.DROP_LABEL); @@ -228,7 +228,7 @@ gen.addJumpIfR0Equals(0, labelFindNextDnsQuestionNoPointer); // It's a pointer. Skip the pointer and question, and return. gen.addLoadImmediate(R0, POINTER_AND_QUESTION_HEADER_SIZE); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); gen.addJump(labelFindNextDnsQuestionReturn); @@ -240,14 +240,14 @@ // Skip the label (1 byte) and query (2 bytes qtype, 2 bytes qclass) and return. gen.addJumpIfR0NotEquals(0, labelFindNextDnsQuestionLabel); gen.addLoadImmediate(R0, LABEL_AND_QUESTION_HEADER_SIZE); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addStoreToMemory(R0, SLOT_CURRENT_PARSE_OFFSET); gen.addJump(labelFindNextDnsQuestionReturn); // Non-zero length label. Consume it and continue. gen.defineLabel(labelFindNextDnsQuestionLabel); gen.addAdd(1); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addMove(R1); gen.addJump(labelFindNextDnsQuestionLoop);
diff --git a/src/android/net/apf/LegacyApfFilter.java b/src/android/net/apf/LegacyApfFilter.java index 20a9e33..072e467 100644 --- a/src/android/net/apf/LegacyApfFilter.java +++ b/src/android/net/apf/LegacyApfFilter.java
@@ -424,6 +424,10 @@ // Listen for doze-mode transition changes to enable/disable the IPv6 multicast filter. mContext.registerReceiver(mDeviceIdleReceiver, new IntentFilter(PowerManager.ACTION_DEVICE_IDLE_MODE_CHANGED)); + + mDependencies.onApfFilterCreated(this); + // mReceiveThread is created in maybeStartFilter() and halted in shutdown(). + mDependencies.onThreadCreated(mReceiveThread); } public synchronized void setDataSnapshot(byte[] data) { @@ -1075,7 +1079,7 @@ gen.addSwap(); gen.addLoad16(R0, IPV4_TOTAL_LENGTH_OFFSET); gen.addNeg(R1); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0NotEquals(1, nextFilterLabel); // Check that the ports match @@ -1194,16 +1198,16 @@ // top bits of the low nibble are guaranteed to be zeroes. Right-shift R0 by 2. gen.addRightShift(2); // R0 += R1 -> R0 contains TCP + IP headers length - gen.addAddR1(); + gen.addAddR1ToR0(); // Load IPv4 total length gen.addLoad16(R1, IPV4_TOTAL_LENGTH_OFFSET); gen.addNeg(R0); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0NotEquals(0, nextFilterLabel); // Add IPv4 header length gen.addLoadFromMemory(R1, gen.IPV4_HEADER_SIZE_MEMORY_SLOT); gen.addLoadImmediate(R0, ETH_HEADER_LEN); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(mPortSeqAckFingerprint, nextFilterLabel); maybeSetupCounter(gen, Counter.DROPPED_IPV4_KEEPALIVE_ACK); @@ -1395,7 +1399,7 @@ // Check it's DHCP to our MAC address. gen.addLoadImmediate(R0, DHCP_CLIENT_MAC_OFFSET); // NOTE: Relies on R1 containing IPv4 header offset. - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(mHardwareAddress, skipDhcpv4Filter); maybeSetupCounter(gen, Counter.PASSED_DHCP); gen.addJump(mCountAndPassLabel); @@ -1660,7 +1664,7 @@ // If QDCOUNT == 1, matches the QNAME with allowlist. // Load offset for the first QNAME. gen.addLoadImmediate(R0, MDNS_QNAME_OFFSET); - gen.addAddR1(); + gen.addAddR1ToR0(); // Check first QNAME against allowlist for (int i = 0; i < mMdnsAllowList.size(); ++i) { @@ -1731,6 +1735,16 @@ maybeSetupCounter(gen, Counter.FILTER_AGE_16384THS); gen.addLoadFromMemory(R0, 9); // m[9] is filter age in 16384ths gen.addStoreData(R0, 0); // store 'counter' + + // requires a new enough APFv5+ interpreter, otherwise will be 0 + maybeSetupCounter(gen, Counter.APF_VERSION); + gen.addLoadFromMemory(R0, 8); // m[8] is apf version + gen.addStoreData(R0, 0); // store 'counter' + + // store this program's sequential id, for later comparison + maybeSetupCounter(gen, Counter.APF_PROGRAM_ID); + gen.addLoadImmediate(R0, mNumProgramUpdates); + gen.addStoreData(R0, 0); // store 'counter' } // Here's a basic summary of what the initial program does:
diff --git a/src/android/net/dhcp/DhcpClient.java b/src/android/net/dhcp/DhcpClient.java index e415698..4b94968 100644 --- a/src/android/net/dhcp/DhcpClient.java +++ b/src/android/net/dhcp/DhcpClient.java
@@ -16,6 +16,8 @@ package android.net.dhcp; +import static android.net.dhcp.DhcpPacket.CONFIG_MINIMUM_LEASE; +import static android.net.dhcp.DhcpPacket.DEFAULT_MINIMUM_LEASE; import static android.net.dhcp.DhcpPacket.DHCP_BROADCAST_ADDRESS; import static android.net.dhcp.DhcpPacket.DHCP_CAPTIVE_PORTAL; import static android.net.dhcp.DhcpPacket.DHCP_DNS_SERVER; @@ -477,6 +479,14 @@ } /** + * Get the Integer value of relevant DeviceConfig properties of Connectivity namespace. + */ + public int getIntDeviceConfig(final String name, int defaultValue) { + return DeviceConfigUtils.getDeviceConfigPropertyInt(NAMESPACE_CONNECTIVITY, + name, defaultValue); + } + + /** * Get a new wake lock to force CPU keeping awake when transmitting packets or waiting * for timeout. */ @@ -1006,16 +1016,19 @@ public final List<DhcpOption> options; public final boolean isWifiManagedProfile; public final int hostnameSetting; + public final boolean populateLinkAddressLifetime; public Configuration(@Nullable final String l2Key, final boolean isPreconnectionEnabled, @NonNull final List<DhcpOption> options, final boolean isWifiManagedProfile, - final int hostnameSetting) { + final int hostnameSetting, + final boolean populateLinkAddressLifetime) { this.l2Key = l2Key; this.isPreconnectionEnabled = isPreconnectionEnabled; this.options = options; this.isWifiManagedProfile = isWifiManagedProfile; this.hostnameSetting = hostnameSetting; + this.populateLinkAddressLifetime = populateLinkAddressLifetime; } } @@ -1112,7 +1125,9 @@ } public void setDhcpLeaseExpiry(DhcpPacket packet) { - long leaseTimeMillis = packet.getLeaseTimeMillis(); + final int defaultMinimumLease = + mDependencies.getIntDeviceConfig(CONFIG_MINIMUM_LEASE, DEFAULT_MINIMUM_LEASE); + long leaseTimeMillis = packet.getLeaseTimeMillis(defaultMinimumLease); mDhcpLeaseExpiry = (leaseTimeMillis > 0) ? SystemClock.elapsedRealtime() + leaseTimeMillis : 0; } @@ -1882,8 +1897,23 @@ // the registered IpManager.Callback. IP address changes // are not supported here. acceptDhcpResults(results, mLeaseMsg); - notifySuccess(); - transitionTo(mDhcpBoundState); + if (mConfiguration.populateLinkAddressLifetime) { + // Transit to ConfiguringInterfaceState and notify address renew + // or rebind with success, and refresh the IPv4 address lifetime + // via netlink message there. Otherwise, the IPv4 address will end + // up being deleted from the interface when the address lifetime + // expires. Transit back to BoundState later and schedule new lease + // expiry once the address lifetime is successfully updated. + // This change is required since the user space updates the deprecationTime + // and expirationTime of IPv4 link address when it receives the netlink + // message from kernel. Previously the lifetime of an IPv4 address was + // always permanent, so we don't need to maintain lifetime updates in + // user space. + transitionTo(mConfiguringInterfaceState); + } else { + notifySuccess(); + transitionTo(mDhcpBoundState); + } } } else if (packet instanceof DhcpNakPacket) { Log.d(TAG, "Received NAK, returning to StoppedState");
diff --git a/src/android/net/dhcp/DhcpPacket.java b/src/android/net/dhcp/DhcpPacket.java index 2649851..595c63a 100644 --- a/src/android/net/dhcp/DhcpPacket.java +++ b/src/android/net/dhcp/DhcpPacket.java
@@ -64,7 +64,8 @@ // dhcpcd has a minimum lease of 20 seconds, but DhcpStateMachine would refuse to wake up the // CPU for anything shorter than 5 minutes. For sanity's sake, this must be higher than the // DHCP client timeout. - public static final int MINIMUM_LEASE = 60; + public static final String CONFIG_MINIMUM_LEASE = "dhcp_minimum_lease"; + public static final int DEFAULT_MINIMUM_LEASE = 60; public static final int INFINITE_LEASE = (int) 0xffffffff; public static final Inet4Address INADDR_ANY = IPV4_ADDR_ANY; @@ -1490,12 +1491,12 @@ /** * Returns the parsed lease time, in milliseconds, or 0 for infinite. */ - public long getLeaseTimeMillis() { + public long getLeaseTimeMillis(int defaultMinimumLease) { // dhcpcd treats the lack of a lease time option as an infinite lease. if (mLeaseTime == null || mLeaseTime == INFINITE_LEASE) { return 0; - } else if (0 <= mLeaseTime && mLeaseTime < MINIMUM_LEASE) { - return MINIMUM_LEASE * 1000; + } else if (0 <= mLeaseTime && mLeaseTime < defaultMinimumLease) { + return defaultMinimumLease * 1000L; } else { return (mLeaseTime & 0xffffffffL) * 1000; }
diff --git a/src/android/net/ip/IpClient.java b/src/android/net/ip/IpClient.java index c1d18ae..aba4f4e 100644 --- a/src/android/net/ip/IpClient.java +++ b/src/android/net/ip/IpClient.java
@@ -24,6 +24,8 @@ import static android.net.ip.IIpClient.PROV_IPV6_LINKLOCAL; import static android.net.ip.IIpClient.PROV_IPV6_SLAAC; import static android.net.ip.IIpClientCallbacks.DTIM_MULTIPLIER_RESET; +import static android.net.ip.IpClientLinkObserver.IpClientNetlinkMonitor; +import static android.net.ip.IpClientLinkObserver.IpClientNetlinkMonitor.INetlinkMessageProcessor; import static android.net.ip.IpReachabilityMonitor.INVALID_REACHABILITY_LOSS_TYPE; import static android.net.ip.IpReachabilityMonitor.nudEventTypeToInt; import static android.net.util.SocketUtils.makePacketSocketAddress; @@ -49,7 +51,6 @@ import static com.android.networkstack.util.NetworkStackUtils.APF_POLLING_COUNTERS_VERSION; import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_DHCPV6_PREFIX_DELEGATION_VERSION; import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_GARP_NA_ROAMING_VERSION; -import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_GRATUITOUS_NA_VERSION; import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_IGNORE_LOW_RA_LIFETIME_VERSION; import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_POPULATE_LINK_ADDRESS_LIFETIME_VERSION; import static com.android.networkstack.util.NetworkStackUtils.createInet6AddressFromEui64; @@ -149,7 +150,6 @@ import com.android.networkstack.packets.NeighborAdvertisement; import com.android.networkstack.packets.NeighborSolicitation; import com.android.networkstack.util.NetworkStackUtils; -import com.android.server.NetworkObserverRegistry; import com.android.server.NetworkStackService.NetworkStackServiceManager; import java.io.File; @@ -173,8 +173,12 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.StringJoiner; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -670,10 +674,8 @@ @VisibleForTesting protected final IpClientCallbacksWrapper mCallback; private final Dependencies mDependencies; - private final CountDownLatch mShutdownLatch; private final ConnectivityManager mCm; private final INetd mNetd; - private final NetworkObserverRegistry mObserverRegistry; private final IpClientLinkObserver mLinkObserver; private final WakeupMessage mProvisioningTimeoutAlarm; private final WakeupMessage mDhcpActionTimeoutAlarm; @@ -879,7 +881,8 @@ final File sysctl = new File(path); return sysctl.exists(); } - /** + + /** * Get the configuration from RRO to check whether or not to send domain search list * option in DHCPDISCOVER/DHCPREQUEST message. */ @@ -887,17 +890,23 @@ return context.getResources().getBoolean(R.bool.config_dhcp_client_domain_search_list); } + /** + * Create an IpClientNetlinkMonitor instance. + */ + public IpClientNetlinkMonitor makeIpClientNetlinkMonitor(Handler h, SharedLog log, + String tag, int sockRcvbufSize, INetlinkMessageProcessor p) { + return new IpClientNetlinkMonitor(h, log, tag, sockRcvbufSize, p); + } } public IpClient(Context context, String ifName, IIpClientCallbacks callback, - NetworkObserverRegistry observerRegistry, NetworkStackServiceManager nssManager) { - this(context, ifName, callback, observerRegistry, nssManager, new Dependencies()); + NetworkStackServiceManager nssManager) { + this(context, ifName, callback, nssManager, new Dependencies()); } @VisibleForTesting public IpClient(Context context, String ifName, IIpClientCallbacks callback, - NetworkObserverRegistry observerRegistry, NetworkStackServiceManager nssManager, - Dependencies deps) { + NetworkStackServiceManager nssManager, Dependencies deps) { super(IpClient.class.getSimpleName() + "." + ifName); Objects.requireNonNull(ifName); Objects.requireNonNull(callback); @@ -911,9 +920,7 @@ mDependencies = deps; mMetricsLog = deps.getIpConnectivityLog(); mNetworkQuirkMetrics = deps.getNetworkQuirkMetrics(); - mShutdownLatch = new CountDownLatch(1); mCm = mContext.getSystemService(ConnectivityManager.class); - mObserverRegistry = observerRegistry; mIpMemoryStore = deps.getIpMemoryStore(context, nssManager); sSmLogs.putIfAbsent(mInterfaceName, new SharedLog(MAX_LOG_RECORDS, mTag)); @@ -969,10 +976,8 @@ public void onIpv6AddressRemoved(final Inet6Address address) { // The update of Gratuitous NA target addresses set or unsolicited // multicast NS source addresses set should be only accessed from the - // handler thread of IpClient StateMachine, keeping the behaviour - // consistent with relying on the non-blocking NetworkObserver callbacks, - // see {@link registerObserverForNonblockingCallback}. This can be done - // by either sending a message to StateMachine or posting a handler. + // handler thread of IpClient StateMachine. This can be done by either + // sending a message to StateMachine or posting a handler. if (address.isLinkLocalAddress()) return; getHandler().post(() -> { mLog.log("Remove IPv6 GUA " + address @@ -1133,19 +1138,13 @@ } private void startStateMachineUpdaters() { - mObserverRegistry.registerObserverForNonblockingCallback(mLinkObserver); } private void stopStateMachineUpdaters() { - mObserverRegistry.unregisterObserver(mLinkObserver); mLinkObserver.clearInterfaceParams(); mLinkObserver.shutdown(); } - private boolean isGratuitousNaEnabled() { - return mDependencies.isFeatureNotChickenedOut(mContext, IPCLIENT_GRATUITOUS_NA_VERSION); - } - private boolean isGratuitousArpNaRoamingEnabled() { return mDependencies.isFeatureEnabled(mContext, IPCLIENT_GARP_NA_ROAMING_VERSION); } @@ -1177,7 +1176,6 @@ @Override protected void onQuitting() { mCallback.onQuit(); - mShutdownLatch.countDown(); } /** @@ -1414,6 +1412,86 @@ pw.decreaseIndent(); } + /** + * Handle "adb shell cmd apf" command. + */ + public String apfShellCommand(String cmd, @Nullable String optarg) { + final long oneDayInMs = 86400 * 1000; + if (SystemClock.elapsedRealtime() >= oneDayInMs) { + throw new IllegalStateException("Error: This test interface requires uptime < 24h"); + } + + // Waiting for a "read" result cannot block the handler thread, since the result gets + // processed on it. This is test only code, so mApfFilter going away is not a concern. + if (cmd.equals("read")) { + if (mApfFilter == null) { + throw new IllegalStateException("Error: No active APF filter"); + } + // Request a new snapshot, then wait for it. + mApfDataSnapshotComplete.close(); + mCallback.startReadPacketFilter(); + if (!mApfDataSnapshotComplete.block(5000 /* ms */)) { + throw new RuntimeException("Error: Failed to read APF program"); + } + } + + final CompletableFuture<String> result = new CompletableFuture<>(); + + getHandler().post(() -> { + try { + if (mApfFilter == null) { + // IpClient has either stopped or the interface does not support APF. + throw new IllegalStateException("No active APF filter."); + } + switch (cmd) { + case "status": + result.complete(mApfFilter.isRunning() ? "running" : "paused"); + break; + case "pause": + mApfFilter.pause(); + result.complete("success"); + break; + case "resume": + mApfFilter.resume(); + result.complete("success"); + break; + case "install": + Objects.requireNonNull(optarg, "No program provided"); + if (mApfFilter.isRunning()) { + throw new IllegalStateException("APF filter must first be paused"); + } + mCallback.installPacketFilter(HexDump.hexStringToByteArray(optarg)); + result.complete("success"); + break; + case "capabilities": + final StringJoiner joiner = new StringJoiner(","); + joiner.add(Integer.toString(mCurrentApfCapabilities.apfVersionSupported)); + joiner.add(Integer.toString(mCurrentApfCapabilities.maximumApfProgramSize)); + joiner.add(Integer.toString(mCurrentApfCapabilities.apfPacketFormat)); + result.complete(joiner.toString()); + break; + case "read": + final String snapshot = mApfFilter.getDataSnapshotHexString(); + Objects.requireNonNull(snapshot, "No data snapshot recorded."); + result.complete(snapshot); + break; + default: + throw new IllegalArgumentException("Invalid apf command: " + cmd); + } + } catch (Exception e) { + result.completeExceptionally(e); + } + }); + + try { + return result.get(30, TimeUnit.SECONDS); + } catch (ExecutionException | InterruptedException | TimeoutException e) { + // completeExceptionally is solely used to return error messages back to the user, so + // the stack trace is not all that interesting. (A similar argument can be made for + // InterruptedException). Only extract the message from the checked exception. + throw new RuntimeException(e.getMessage()); + } + } /** * Internals. @@ -2046,8 +2124,17 @@ // Returns false if we have lost provisioning, true otherwise. private boolean handleLinkPropertiesUpdate(boolean sendCallbacks) { final LinkProperties newLp = assembleLinkProperties(); + // LinkProperties.equals just compares if the interface addresses are identical, + // it doesn't compare the LinkAddress objects, so it considers two LinkProperties + // objects are identical even with different address lifetime. However, we may want + // to notify the caller whenever the link address lifetime is updated, especially + // after we enable populating the deprecationTime/expirationTime fields. The caller + // can get the latest address lifetime from the onLinkPropertiesChange callback. if (Objects.equals(newLp, mLinkProperties)) { - return true; + if (!mPopulateLinkAddressLifetime) return true; + if (LinkPropertiesUtils.isIdenticalAllLinkAddresses(newLp, mLinkProperties)) { + return true; + } } // Set an alarm to wait for IPv6 autoconf via SLAAC to succeed after receiving an RA, @@ -2073,9 +2160,7 @@ // Check if new assigned IPv6 GUA is available in the LinkProperties now. If so, initiate // gratuitous multicast unsolicited Neighbor Advertisements as soon as possible to inform // first-hop routers that the new GUA host is goning to use. - if (isGratuitousNaEnabled()) { - maybeSendGratuitousNAs(newLp, false /* isGratuitousNaAfterRoaming */); - } + maybeSendGratuitousNAs(newLp, false /* isGratuitousNaAfterRoaming */); // Sending multicast NS from each new assigned IPv6 GUAs to the solicited-node multicast // address based on the default router's IPv6 link-local address should trigger default @@ -2436,6 +2521,10 @@ private AndroidPacketFilter maybeCreateApfFilter(final ApfCapabilities apfCapabilities) { ApfFilter.ApfConfiguration apfConfig = new ApfFilter.ApfConfiguration(); apfConfig.apfCapabilities = apfCapabilities; + if (apfCapabilities != null && !SdkLevel.isAtLeastV() + && apfCapabilities.apfVersionSupported <= 4) { + apfConfig.installableProgramSizeClamp = 1024; + } apfConfig.multicastFilter = mMulticastFiltering; // Get the Configuration for ApfFilter from Context // Resource settings were moved from ApfCapabilities APIs to NetworkStack resources in S @@ -2675,7 +2764,7 @@ } mDhcpClient.sendMessage(DhcpClient.CMD_START_DHCP, new DhcpClient.Configuration(mL2Key, isUsingPreconnection(), options, isManagedWifiProfile, - mConfiguration.mHostnameSetting)); + mConfiguration.mHostnameSetting, mPopulateLinkAddressLifetime)); } private boolean hasPermission(String permissionName) { @@ -3050,14 +3139,33 @@ } } + private void deleteInterfaceAddress(final LinkAddress address) { + if (!address.isIpv6()) { + // NetlinkUtils.sendRtmDelAddressRequest does not support deleting IPv4 addresses. + Log.wtf(TAG, "Deleting IPv4 address not supported " + address); + return; + } + final Inet6Address in6addr = (Inet6Address) address.getAddress(); + final short plen = (short) address.getPrefixLength(); + if (!NetlinkUtils.sendRtmDelAddressRequest(mInterfaceParams.index, in6addr, plen)) { + Log.e(TAG, "Failed to delete IPv6 address " + address); + } + } + private void deleteIpv6PrefixDelegationAddresses(final IpPrefix prefix) { - for (LinkAddress la : mLinkProperties.getLinkAddresses()) { - final InetAddress address = la.getAddress(); - if (prefix.contains(address)) { - if (!NetlinkUtils.sendRtmDelAddressRequest(mInterfaceParams.index, - (Inet6Address) address, (short) la.getPrefixLength())) { - Log.e(TAG, "Failed to delete IPv6 address " + address.getHostAddress()); - } + // b/290747921: some kernels require the mngtmpaddr to be deleted first, to prevent the + // creation of a new tempaddr. + final List<LinkAddress> linkAddresses = mLinkProperties.getLinkAddresses(); + // delete addresses with IFA_F_MANAGETEMPADDR contained in the prefix. + for (LinkAddress la : linkAddresses) { + if (hasFlag(la, IFA_F_MANAGETEMPADDR) && prefix.contains(la.getAddress())) { + deleteInterfaceAddress(la); + } + } + // delete all other addresses contained in the prefix. + for (LinkAddress la : linkAddresses) { + if (!hasFlag(la, IFA_F_MANAGETEMPADDR) && prefix.contains(la.getAddress())) { + deleteInterfaceAddress(la); } } }
diff --git a/src/android/net/ip/IpClientLinkObserver.java b/src/android/net/ip/IpClientLinkObserver.java index 738a50b..624143b 100644 --- a/src/android/net/ip/IpClientLinkObserver.java +++ b/src/android/net/ip/IpClientLinkObserver.java
@@ -28,8 +28,6 @@ import static com.android.net.module.util.netlink.NetlinkConstants.RTPROT_KERNEL; import static com.android.net.module.util.netlink.NetlinkConstants.RTPROT_RA; import static com.android.net.module.util.netlink.NetlinkConstants.RT_SCOPE_UNIVERSE; -import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_ACCEPT_IPV6_LINK_LOCAL_DNS_VERSION; -import static com.android.networkstack.util.NetworkStackUtils.IPCLIENT_PARSE_NETLINK_EVENTS_FORCE_DISABLE; import android.app.AlarmManager; import android.content.Context; @@ -65,7 +63,6 @@ import com.android.net.module.util.netlink.StructNdOptRdnss; import com.android.networkstack.apishim.NetworkInformationShimImpl; import com.android.networkstack.apishim.common.NetworkInformationShim; -import com.android.server.NetworkObserver; import java.net.Inet6Address; import java.net.InetAddress; @@ -80,10 +77,9 @@ /** * Keeps track of link configuration received from Netd. * - * An instance of this class is constructed by passing in an interface name and a callback. The - * owner is then responsible for registering the tracker with NetworkObserverRegistry. When the - * class receives update notifications, it applies the update to its local LinkProperties, and if - * something has changed, notifies its owner of the update via the callback. + * An instance of this class is constructed by passing in an interface name and a callback. When + * the class receives update notifications, it applies the update to its local LinkProperties, and + * if something has changed, notifies its owner of the update via the callback. * * The owner can then call {@code getLinkProperties()} in order to find out * what changed. If in the meantime the LinkProperties stored here have changed, @@ -96,18 +92,15 @@ * * - The owner of this class is expected to create it, register it, and call * getLinkProperties or clearLinkProperties on its thread. - * - Most of the methods in the class are implementing NetworkObserver and are called - * on the handler used to register the observer. * - All accesses to mLinkProperties must be synchronized(this). All the other * member variables are immutable once the object is constructed. * * TODO: Now that all the methods are called on the handler thread, remove synchronization and * pass the LinkProperties to the update() callback. - * TODO: Stop extending NetworkObserver and get events from netlink directly. * * @hide */ -public class IpClientLinkObserver implements NetworkObserver { +public class IpClientLinkObserver { private final String mTag; /** @@ -163,9 +156,20 @@ private final IpClient.Dependencies mDependencies; private final String mClatInterfaceName; private final IpClientNetlinkMonitor mNetlinkMonitor; - private final boolean mNetlinkEventParsingEnabled; + private final NetworkInformationShim mShim; + private final AlarmManager.OnAlarmListener mExpirePref64Alarm; private boolean mClatInterfaceExists; + private long mNat64PrefixExpiry; + + /** + * Current interface index. Most of this class (and of IpClient), only uses interface names, + * not interface indices. This means that the interface index can in theory change, and that + * it's not necessarily correct to get the interface name at object creation time (and in + * fact, when the object is created, the interface might not even exist). + * TODO: once all netlink events pass through this class, stop depending on interface names. + */ + private int mIfindex; // This must match the interface prefix in clatd.c. // TODO: Revert this hack once IpClient and Nat464Xlat work in concert. @@ -177,7 +181,8 @@ // recv buffer size to avoid the ENOBUFS as much as possible. @VisibleForTesting static final String CONFIG_SOCKET_RECV_BUFSIZE = "ipclient_netlink_sock_recv_buf_size"; - private static final int SOCKET_RECV_BUFSIZE = 4 * 1024 * 1024; + @VisibleForTesting + static final int SOCKET_RECV_BUFSIZE = 4 * 1024 * 1024; public IpClientLinkObserver(Context context, Handler h, String iface, Callback callback, Configuration config, SharedLog log, IpClient.Dependencies deps) { @@ -194,9 +199,11 @@ mDnsServerRepository = new DnsServerRepository(config.minRdnssLifetime); mAlarmManager = (AlarmManager) context.getSystemService(Context.ALARM_SERVICE); mDependencies = deps; - mNetlinkEventParsingEnabled = deps.isFeatureNotChickenedOut(context, - IPCLIENT_PARSE_NETLINK_EVENTS_FORCE_DISABLE); - mNetlinkMonitor = new IpClientNetlinkMonitor(h, log, mTag); + mNetlinkMonitor = deps.makeIpClientNetlinkMonitor(h, log, mTag, + getSocketReceiveBufferSize(), + (nlMsg, whenMs) -> processNetlinkMessage(nlMsg, whenMs)); + mShim = NetworkInformationShimImpl.newInstance(); + mExpirePref64Alarm = new IpClientObserverAlarmListener(); mHandler.post(() -> { if (!mNetlinkMonitor.start()) { Log.wtf(mTag, "Fail to start NetlinkMonitor."); @@ -208,11 +215,6 @@ mHandler.post(mNetlinkMonitor::stop); } - private boolean isIpv6LinkLocalDnsAccepted() { - return mDependencies.isFeatureNotChickenedOut(mContext, - IPCLIENT_ACCEPT_IPV6_LINK_LOCAL_DNS_VERSION); - } - private void maybeLog(String operation, String iface, LinkAddress address) { if (DBG) { Log.d(mTag, operation + ": " + address + " on " + iface @@ -240,73 +242,6 @@ return size; } - @Override - public void onInterfaceAdded(String iface) { - if (mNetlinkEventParsingEnabled) return; - maybeLog("interfaceAdded", iface); - if (mClatInterfaceName.equals(iface)) { - mCallback.onClatInterfaceStateUpdate(true /* add interface */); - } - } - - @Override - public void onInterfaceRemoved(String iface) { - if (mNetlinkEventParsingEnabled) return; - maybeLog("interfaceRemoved", iface); - if (mClatInterfaceName.equals(iface)) { - mCallback.onClatInterfaceStateUpdate(false /* remove interface */); - } else if (mInterfaceName.equals(iface)) { - updateInterfaceRemoved(); - } - } - - @Override - public void onInterfaceLinkStateChanged(String iface, boolean state) { - if (mNetlinkEventParsingEnabled) return; - if (!mInterfaceName.equals(iface)) return; - maybeLog("interfaceLinkStateChanged", iface + (state ? " up" : " down")); - updateInterfaceLinkStateChanged(state); - } - - @Override - public void onInterfaceAddressUpdated(LinkAddress address, String iface) { - if (mNetlinkEventParsingEnabled) return; - if (!mInterfaceName.equals(iface)) return; - maybeLog("addressUpdated", iface, address); - updateInterfaceAddress(address, true /* add address */); - } - - @Override - public void onInterfaceAddressRemoved(LinkAddress address, String iface) { - if (mNetlinkEventParsingEnabled) return; - if (!mInterfaceName.equals(iface)) return; - maybeLog("addressRemoved", iface, address); - updateInterfaceAddress(address, false /* remove address */); - } - - @Override - public void onRouteUpdated(RouteInfo route) { - if (mNetlinkEventParsingEnabled) return; - if (!mInterfaceName.equals(route.getInterface())) return; - maybeLog("routeUpdated", route); - updateInterfaceRoute(route, true /* add route */); - } - - @Override - public void onRouteRemoved(RouteInfo route) { - if (mNetlinkEventParsingEnabled) return; - if (!mInterfaceName.equals(route.getInterface())) return; - maybeLog("routeRemoved", route); - updateInterfaceRoute(route, false /* remove route */); - } - - @Override - public void onInterfaceDnsServerInfo(String iface, long lifetime, String[] addresses) { - if (mNetlinkEventParsingEnabled) return; - if (!mInterfaceName.equals(iface)) return; - updateInterfaceDnsServerInfo(lifetime, addresses); - } - private synchronized void updateInterfaceLinkStateChanged(boolean state) { setInterfaceLinkStateLocked(state); } @@ -390,7 +325,7 @@ // while interfaceDnsServerInfo() is being called, we'll end up with no DNS servers in // mLinkProperties, as desired. mDnsServerRepository = new DnsServerRepository(mConfig.minRdnssLifetime); - mNetlinkMonitor.clearAlarms(); + cancelPref64Alarm(); mLinkProperties.clear(); mLinkProperties.setInterfaceName(mInterfaceName); } @@ -405,12 +340,19 @@ /** Notifies this object of new interface parameters. */ public void setInterfaceParams(InterfaceParams params) { - mNetlinkMonitor.setIfindex(params.index); + setIfindex(params.index); } /** Notifies this object not to listen on any interface. */ public void clearInterfaceParams() { - mNetlinkMonitor.setIfindex(0); // 0 is never a valid ifindex. + setIfindex(0); // 0 is never a valid ifindex. + } + + private void setIfindex(int ifindex) { + if (!mNetlinkMonitor.isRunning()) { + Log.wtf(mTag, "NetlinkMonitor is not running when setting interface parameter!"); + } + mIfindex = ifindex; } private static boolean isSupportedRouteProtocol(RtNetlinkRouteMessage msg) { @@ -429,47 +371,44 @@ * Simple NetlinkMonitor. Listen for netlink events from kernel. * All methods except the constructor must be called on the handler thread. */ - private class IpClientNetlinkMonitor extends NetlinkMonitor { - private final Handler mHandler; - - IpClientNetlinkMonitor(Handler h, SharedLog log, String tag) { - super(h, log, tag, OsConstants.NETLINK_ROUTE, - !mNetlinkEventParsingEnabled - ? NetlinkConstants.RTMGRP_ND_USEROPT - : (NetlinkConstants.RTMGRP_ND_USEROPT | NetlinkConstants.RTMGRP_LINK - | NetlinkConstants.RTMGRP_IPV4_IFADDR - | NetlinkConstants.RTMGRP_IPV6_IFADDR - | NetlinkConstants.RTMGRP_IPV6_ROUTE), - getSocketReceiveBufferSize()); - - mHandler = h; - } - - private final NetworkInformationShim mShim = NetworkInformationShimImpl.newInstance(); - - private long mNat64PrefixExpiry; - + static class IpClientNetlinkMonitor extends NetlinkMonitor { /** - * Current interface index. Most of this class (and of IpClient), only uses interface names, - * not interface indices. This means that the interface index can in theory change, and that - * it's not necessarily correct to get the interface name at object creation time (and in - * fact, when the object is created, the interface might not even exist). - * TODO: once all netlink events pass through this class, stop depending on interface names. + * An interface used to process the received netlink messages, which is easiler to inject + * the function in unit test. */ - private int mIfindex; - - void setIfindex(int ifindex) { - if (!isRunning()) { - Log.wtf(mTag, "NetlinkMonitor is not running when setting interface parameter!"); - } - mIfindex = ifindex; + public interface INetlinkMessageProcessor { + void processNetlinkMessage(@NonNull NetlinkMessage nlMsg, long whenMs); } - void clearAlarms() { - cancelPref64Alarm(); + private final Handler mHandler; + private final INetlinkMessageProcessor mNetlinkMessageProcessor; + + IpClientNetlinkMonitor(Handler h, SharedLog log, String tag, int sockRcvbufSize, + INetlinkMessageProcessor p) { + super(h, log, tag, OsConstants.NETLINK_ROUTE, + (NetlinkConstants.RTMGRP_ND_USEROPT + | NetlinkConstants.RTMGRP_LINK + | NetlinkConstants.RTMGRP_IPV4_IFADDR + | NetlinkConstants.RTMGRP_IPV6_IFADDR + | NetlinkConstants.RTMGRP_IPV6_ROUTE), + sockRcvbufSize); + mHandler = h; + mNetlinkMessageProcessor = p; } - private final AlarmManager.OnAlarmListener mExpirePref64Alarm = () -> { + @Override + protected void processNetlinkMessage(NetlinkMessage nlMsg, long whenMs) { + mNetlinkMessageProcessor.processNetlinkMessage(nlMsg, whenMs); + } + + protected boolean isRunning() { + return super.isRunning(); + } + } + + private class IpClientObserverAlarmListener implements AlarmManager.OnAlarmListener { + @Override + public void onAlarm() { // Ignore the alarm if cancelPref64Alarm has already been called. // // TODO: in the rare case where the alarm fires and posts the lambda to the handler @@ -478,254 +417,237 @@ // lifetime in the RA is zero this code will correctly do nothing, but if the lifetime // is nonzero then the prefix will be added and immediately removed by this code. if (mNat64PrefixExpiry == 0) return; - updatePref64(mShim.getNat64Prefix(mLinkProperties), - mNat64PrefixExpiry, mNat64PrefixExpiry); - }; + updatePref64(mShim.getNat64Prefix(mLinkProperties), mNat64PrefixExpiry, + mNat64PrefixExpiry); + } + } - private void cancelPref64Alarm() { - // Clear the expiry in case the alarm just fired and has not been processed yet. - if (mNat64PrefixExpiry == 0) return; - mNat64PrefixExpiry = 0; - mAlarmManager.cancel(mExpirePref64Alarm); + private void cancelPref64Alarm() { + // Clear the expiry in case the alarm just fired and has not been processed yet. + if (mNat64PrefixExpiry == 0) return; + mNat64PrefixExpiry = 0; + mAlarmManager.cancel(mExpirePref64Alarm); + } + + private void schedulePref64Alarm() { + // There is no need to cancel any existing alarms, because we are using the same + // OnAlarmListener object, and each such listener can only have at most one alarm. + final String tag = mTag + ".PREF64"; + mAlarmManager.setExact(AlarmManager.ELAPSED_REALTIME_WAKEUP, mNat64PrefixExpiry, tag, + mExpirePref64Alarm, mHandler); + } + + /** + * Processes a PREF64 ND option. + * + * @param prefix The NAT64 prefix. + * @param now The time (as determined by SystemClock.elapsedRealtime) when the event + * that triggered this method was received. + * @param expiry The time (as determined by SystemClock.elapsedRealtime) when the option + * expires. + */ + private synchronized void updatePref64(IpPrefix prefix, final long now, + final long expiry) { + final IpPrefix currentPrefix = mShim.getNat64Prefix(mLinkProperties); + + // If the prefix matches the current prefix, refresh its lifetime. + if (prefix.equals(currentPrefix)) { + mNat64PrefixExpiry = expiry; + if (expiry > now) schedulePref64Alarm(); } - private void schedulePref64Alarm() { - // There is no need to cancel any existing alarms, because we are using the same - // OnAlarmListener object, and each such listener can only have at most one alarm. - final String tag = mTag + ".PREF64"; - mAlarmManager.setExact(AlarmManager.ELAPSED_REALTIME_WAKEUP, mNat64PrefixExpiry, tag, - mExpirePref64Alarm, mHandler); + // If we already have a prefix, continue using it and ignore the new one. Stopping and + // restarting clatd is disruptive because it will break existing IPv4 connections. + // Note: this means that if we receive an RA that adds a new prefix and deletes the old + // prefix, we might receive and ignore the new prefix, then delete the old prefix, and + // have no prefix until the next RA is received. This is because the kernel returns ND + // user options one at a time even if they are in the same RA. + // TODO: keep track of the last few prefixes seen, like DnsServerRepository does. + if (mNat64PrefixExpiry > now) return; + + // The current prefix has expired. Either replace it with the new one or delete it. + if (expiry > now) { + // If expiry > now, then prefix != currentPrefix (due to the return statement above) + mShim.setNat64Prefix(mLinkProperties, prefix); + mNat64PrefixExpiry = expiry; + schedulePref64Alarm(); + } else { + mShim.setNat64Prefix(mLinkProperties, null); + cancelPref64Alarm(); } - /** - * Processes a PREF64 ND option. - * - * @param prefix The NAT64 prefix. - * @param now The time (as determined by SystemClock.elapsedRealtime) when the event - * that triggered this method was received. - * @param expiry The time (as determined by SystemClock.elapsedRealtime) when the option - * expires. - */ - private synchronized void updatePref64(IpPrefix prefix, final long now, - final long expiry) { - final IpPrefix currentPrefix = mShim.getNat64Prefix(mLinkProperties); + mCallback.update(getInterfaceLinkStateLocked()); + } - // If the prefix matches the current prefix, refresh its lifetime. - if (prefix.equals(currentPrefix)) { - mNat64PrefixExpiry = expiry; - if (expiry > now) { - schedulePref64Alarm(); - } - } + private void processPref64Option(StructNdOptPref64 opt, final long now) { + final long expiry = now + TimeUnit.SECONDS.toMillis(opt.lifetime); + updatePref64(opt.prefix, now, expiry); + } - // If we already have a prefix, continue using it and ignore the new one. Stopping and - // restarting clatd is disruptive because it will break existing IPv4 connections. - // Note: this means that if we receive an RA that adds a new prefix and deletes the old - // prefix, we might receive and ignore the new prefix, then delete the old prefix, and - // have no prefix until the next RA is received. This is because the kernel returns ND - // user options one at a time even if they are in the same RA. - // TODO: keep track of the last few prefixes seen, like DnsServerRepository does. - if (mNat64PrefixExpiry > now) return; + private void processRdnssOption(StructNdOptRdnss opt) { + final String[] addresses = new String[opt.servers.length]; + for (int i = 0; i < opt.servers.length; i++) { + final Inet6Address addr = InetAddressUtils.withScopeId(opt.servers[i], mIfindex); + addresses[i] = addr.getHostAddress(); + } + updateInterfaceDnsServerInfo(opt.header.lifetime, addresses); + } - // The current prefix has expired. Either replace it with the new one or delete it. - if (expiry > now) { - // If expiry > now, then prefix != currentPrefix (due to the return statement above) - mShim.setNat64Prefix(mLinkProperties, prefix); - mNat64PrefixExpiry = expiry; - schedulePref64Alarm(); - } else { - mShim.setNat64Prefix(mLinkProperties, null); - cancelPref64Alarm(); - } + private void processNduseroptMessage(NduseroptMessage msg, final long whenMs) { + if (msg.family != AF_INET6 || msg.option == null || msg.ifindex != mIfindex) return; + if (msg.icmp_type != (byte) ICMPV6_ROUTER_ADVERTISEMENT) return; - mCallback.update(getInterfaceLinkStateLocked()); + switch (msg.option.type) { + case StructNdOptPref64.TYPE: + processPref64Option((StructNdOptPref64) msg.option, whenMs); + break; + + case StructNdOptRdnss.TYPE: + processRdnssOption((StructNdOptRdnss) msg.option); + break; + + default: + // TODO: implement DNSSL. + break; + } + } + + private void updateClatInterfaceLinkState(@Nullable final String ifname, short nlMsgType) { + switch (nlMsgType) { + case NetlinkConstants.RTM_NEWLINK: + if (mClatInterfaceExists) break; + maybeLog("clatInterfaceAdded", ifname); + mCallback.onClatInterfaceStateUpdate(true /* add interface */); + mClatInterfaceExists = true; + break; + + case NetlinkConstants.RTM_DELLINK: + if (!mClatInterfaceExists) break; + maybeLog("clatInterfaceRemoved", ifname); + mCallback.onClatInterfaceStateUpdate(false /* remove interface */); + mClatInterfaceExists = false; + break; + + default: + Log.e(mTag, "unsupported rtnetlink link msg type " + nlMsgType); + break; + } + } + + private void processRtNetlinkLinkMessage(RtNetlinkLinkMessage msg) { + // Check if receiving netlink link state update for clat interface. + final String ifname = msg.getInterfaceName(); + final short nlMsgType = msg.getHeader().nlmsg_type; + final StructIfinfoMsg ifinfoMsg = msg.getIfinfoHeader(); + if (mClatInterfaceName.equals(ifname)) { + updateClatInterfaceLinkState(ifname, nlMsgType); + return; } - private void processPref64Option(StructNdOptPref64 opt, final long now) { - final long expiry = now + TimeUnit.SECONDS.toMillis(opt.lifetime); - updatePref64(opt.prefix, now, expiry); + if (ifinfoMsg.family != AF_UNSPEC || ifinfoMsg.index != mIfindex) return; + if ((ifinfoMsg.flags & IFF_LOOPBACK) != 0) return; + + switch (nlMsgType) { + case NetlinkConstants.RTM_NEWLINK: + final boolean state = (ifinfoMsg.flags & IFF_LOWER_UP) != 0; + maybeLog("interfaceLinkStateChanged", "ifindex " + mIfindex + + (state ? " up" : " down")); + updateInterfaceLinkStateChanged(state); + break; + + case NetlinkConstants.RTM_DELLINK: + maybeLog("interfaceRemoved", ifname); + updateInterfaceRemoved(); + break; + + default: + Log.e(mTag, "Unknown rtnetlink link msg type " + nlMsgType); + break; } + } - private void processRdnssOption(StructNdOptRdnss opt) { - if (!mNetlinkEventParsingEnabled) return; - final String[] addresses = new String[opt.servers.length]; - for (int i = 0; i < opt.servers.length; i++) { - final Inet6Address addr = isIpv6LinkLocalDnsAccepted() - ? InetAddressUtils.withScopeId(opt.servers[i], mIfindex) - : opt.servers[i]; - addresses[i] = addr.getHostAddress(); - } - updateInterfaceDnsServerInfo(opt.header.lifetime, addresses); - } + private void processRtNetlinkAddressMessage(RtNetlinkAddressMessage msg) { + final StructIfaddrMsg ifaddrMsg = msg.getIfaddrHeader(); + if (ifaddrMsg.index != mIfindex) return; - private void processNduseroptMessage(NduseroptMessage msg, final long whenMs) { - if (msg.family != AF_INET6 || msg.option == null || msg.ifindex != mIfindex) return; - if (msg.icmp_type != (byte) ICMPV6_ROUTER_ADVERTISEMENT) return; + final StructIfacacheInfo cacheInfo = msg.getIfacacheInfo(); + long deprecationTime = LinkAddress.LIFETIME_UNKNOWN; + long expirationTime = LinkAddress.LIFETIME_UNKNOWN; + if (cacheInfo != null && mConfig.populateLinkAddressLifetime) { + deprecationTime = LinkAddress.LIFETIME_PERMANENT; + expirationTime = LinkAddress.LIFETIME_PERMANENT; - switch (msg.option.type) { - case StructNdOptPref64.TYPE: - processPref64Option((StructNdOptPref64) msg.option, whenMs); - break; - - case StructNdOptRdnss.TYPE: - processRdnssOption((StructNdOptRdnss) msg.option); - break; - - default: - // TODO: implement DNSSL. - break; - } - } - - private void updateClatInterfaceLinkState(@NonNull final StructIfinfoMsg ifinfoMsg, - @Nullable final String ifname, short nlMsgType) { - switch (nlMsgType) { - case NetlinkConstants.RTM_NEWLINK: - if (mClatInterfaceExists) break; - maybeLog("clatInterfaceAdded", ifname); - mCallback.onClatInterfaceStateUpdate(true /* add interface */); - mClatInterfaceExists = true; - break; - - case NetlinkConstants.RTM_DELLINK: - if (!mClatInterfaceExists) break; - maybeLog("clatInterfaceRemoved", ifname); - mCallback.onClatInterfaceStateUpdate(false /* remove interface */); - mClatInterfaceExists = false; - break; - - default: - Log.e(mTag, "unsupported rtnetlink link msg type " + nlMsgType); - break; - } - } - - private void processRtNetlinkLinkMessage(RtNetlinkLinkMessage msg) { - if (!mNetlinkEventParsingEnabled) return; - - // Check if receiving netlink link state update for clat interface. - final String ifname = msg.getInterfaceName(); - final short nlMsgType = msg.getHeader().nlmsg_type; - final StructIfinfoMsg ifinfoMsg = msg.getIfinfoHeader(); - if (mClatInterfaceName.equals(ifname)) { - updateClatInterfaceLinkState(ifinfoMsg, ifname, nlMsgType); - return; - } - - if (ifinfoMsg.family != AF_UNSPEC || ifinfoMsg.index != mIfindex) return; - if ((ifinfoMsg.flags & IFF_LOOPBACK) != 0) return; - - switch (nlMsgType) { - case NetlinkConstants.RTM_NEWLINK: - final boolean state = (ifinfoMsg.flags & IFF_LOWER_UP) != 0; - maybeLog("interfaceLinkStateChanged", "ifindex " + mIfindex - + (state ? " up" : " down")); - updateInterfaceLinkStateChanged(state); - break; - - case NetlinkConstants.RTM_DELLINK: - maybeLog("interfaceRemoved", ifname); - updateInterfaceRemoved(); - break; - - default: - Log.e(mTag, "Unknown rtnetlink link msg type " + nlMsgType); - break; - } - } - - // The preferred/valid in ifa_cacheinfo expressed in units of seconds, convert - // it to milliseconds for deprecationTime or expirationTime used in LinkAddress. - // If the experiment flag is not enabled, LinkAddress.LIFETIME_UNKNOWN is retuend, - // the same as before. - private long getDeprecationOrExpirationTime(@Nullable final StructIfacacheInfo cacheInfo, - long now, boolean deprecationTime) { - if (!mConfig.populateLinkAddressLifetime || (cacheInfo == null)) { - return LinkAddress.LIFETIME_UNKNOWN; - } - final long lifetime = deprecationTime ? cacheInfo.preferred : cacheInfo.valid; - return (lifetime == Integer.toUnsignedLong(INFINITE_LEASE)) - ? LinkAddress.LIFETIME_PERMANENT - : now + lifetime * 1000; - } - - private void processRtNetlinkAddressMessage(RtNetlinkAddressMessage msg) { - if (!mNetlinkEventParsingEnabled) return; - - final StructIfaddrMsg ifaddrMsg = msg.getIfaddrHeader(); - if (ifaddrMsg.index != mIfindex) return; - - final StructIfacacheInfo cacheInfo = msg.getIfacacheInfo(); final long now = SystemClock.elapsedRealtime(); - final long deprecationTime = - getDeprecationOrExpirationTime(cacheInfo, now, true /* deprecationTime */); - final long expirationTime = - getDeprecationOrExpirationTime(cacheInfo, now, false /* deprecationTime */); - final LinkAddress la = new LinkAddress(msg.getIpAddress(), ifaddrMsg.prefixLen, - msg.getFlags(), ifaddrMsg.scope, deprecationTime, expirationTime); - - switch (msg.getHeader().nlmsg_type) { - case NetlinkConstants.RTM_NEWADDR: - if (updateInterfaceAddress(la, true /* add address */)) { - maybeLog("addressUpdated", mIfindex, la); - } - break; - case NetlinkConstants.RTM_DELADDR: - if (updateInterfaceAddress(la, false /* remove address */)) { - maybeLog("addressRemoved", mIfindex, la); - } - break; - default: - Log.e(mTag, "Unknown rtnetlink address msg type " + msg.getHeader().nlmsg_type); - return; + // TODO: change INFINITE_LEASE to long so the pesky conversions can be removed. + if (cacheInfo.preferred < Integer.toUnsignedLong(INFINITE_LEASE)) { + deprecationTime = now + (cacheInfo.preferred /* seconds */ * 1000); + } + if (cacheInfo.valid < Integer.toUnsignedLong(INFINITE_LEASE)) { + expirationTime = now + (cacheInfo.valid /* seconds */ * 1000); } } - private void processRtNetlinkRouteMessage(RtNetlinkRouteMessage msg) { - if (!mNetlinkEventParsingEnabled) return; - if (msg.getInterfaceIndex() != mIfindex) return; - // Ignore the unsupported route protocol and non-global unicast routes. - if (!isSupportedRouteProtocol(msg) - || !isGlobalUnicastRoute(msg) - // don't support source routing - || (msg.getRtMsgHeader().srcLen != 0) - // don't support cloned routes - || ((msg.getRtMsgHeader().flags & RTM_F_CLONED) != 0)) { - return; - } + final LinkAddress la = new LinkAddress(msg.getIpAddress(), ifaddrMsg.prefixLen, + msg.getFlags(), ifaddrMsg.scope, deprecationTime, expirationTime); - final RouteInfo route = new RouteInfo(msg.getDestination(), msg.getGateway(), - mInterfaceName, msg.getRtMsgHeader().type); - switch (msg.getHeader().nlmsg_type) { - case NetlinkConstants.RTM_NEWROUTE: - if (updateInterfaceRoute(route, true /* add route */)) { - maybeLog("routeUpdated", route); - } - break; - case NetlinkConstants.RTM_DELROUTE: - if (updateInterfaceRoute(route, false /* remove route */)) { - maybeLog("routeRemoved", route); - } - break; - default: - Log.e(mTag, "Unknown rtnetlink route msg type " + msg.getHeader().nlmsg_type); - break; - } + switch (msg.getHeader().nlmsg_type) { + case NetlinkConstants.RTM_NEWADDR: + if (updateInterfaceAddress(la, true /* add address */)) { + maybeLog("addressUpdated", mIfindex, la); + } + break; + case NetlinkConstants.RTM_DELADDR: + if (updateInterfaceAddress(la, false /* remove address */)) { + maybeLog("addressRemoved", mIfindex, la); + } + break; + default: + Log.e(mTag, "Unknown rtnetlink address msg type " + msg.getHeader().nlmsg_type); + } + } + + private void processRtNetlinkRouteMessage(RtNetlinkRouteMessage msg) { + if (msg.getInterfaceIndex() != mIfindex) return; + // Ignore the unsupported route protocol and non-global unicast routes. + if (!isSupportedRouteProtocol(msg) + || !isGlobalUnicastRoute(msg) + // don't support source routing + || (msg.getRtMsgHeader().srcLen != 0) + // don't support cloned routes + || ((msg.getRtMsgHeader().flags & RTM_F_CLONED) != 0)) { + return; } - @Override - protected void processNetlinkMessage(NetlinkMessage nlMsg, long whenMs) { - if (nlMsg instanceof NduseroptMessage) { - processNduseroptMessage((NduseroptMessage) nlMsg, whenMs); - } else if (nlMsg instanceof RtNetlinkLinkMessage) { - processRtNetlinkLinkMessage((RtNetlinkLinkMessage) nlMsg); - } else if (nlMsg instanceof RtNetlinkAddressMessage) { - processRtNetlinkAddressMessage((RtNetlinkAddressMessage) nlMsg); - } else if (nlMsg instanceof RtNetlinkRouteMessage) { - processRtNetlinkRouteMessage((RtNetlinkRouteMessage) nlMsg); - } else { - Log.e(mTag, "Unknown netlink message: " + nlMsg); - } + final RouteInfo route = new RouteInfo(msg.getDestination(), msg.getGateway(), + mInterfaceName, msg.getRtMsgHeader().type); + switch (msg.getHeader().nlmsg_type) { + case NetlinkConstants.RTM_NEWROUTE: + if (updateInterfaceRoute(route, true /* add route */)) { + maybeLog("routeUpdated", route); + } + break; + case NetlinkConstants.RTM_DELROUTE: + if (updateInterfaceRoute(route, false /* remove route */)) { + maybeLog("routeRemoved", route); + } + break; + default: + Log.e(mTag, "Unknown rtnetlink route msg type " + msg.getHeader().nlmsg_type); + break; + } + } + + private void processNetlinkMessage(NetlinkMessage nlMsg, long whenMs) { + if (nlMsg instanceof NduseroptMessage) { + processNduseroptMessage((NduseroptMessage) nlMsg, whenMs); + } else if (nlMsg instanceof RtNetlinkLinkMessage) { + processRtNetlinkLinkMessage((RtNetlinkLinkMessage) nlMsg); + } else if (nlMsg instanceof RtNetlinkAddressMessage) { + processRtNetlinkAddressMessage((RtNetlinkAddressMessage) nlMsg); + } else if (nlMsg instanceof RtNetlinkRouteMessage) { + processRtNetlinkRouteMessage((RtNetlinkRouteMessage) nlMsg); + } else { + Log.e(mTag, "Unknown netlink message: " + nlMsg); } }
diff --git a/src/android/net/ip/IpReachabilityMonitor.java b/src/android/net/ip/IpReachabilityMonitor.java index daf9e51..fd16784 100644 --- a/src/android/net/ip/IpReachabilityMonitor.java +++ b/src/android/net/ip/IpReachabilityMonitor.java
@@ -460,7 +460,7 @@ // For on-link IPv6 DNS server or default router that never ever responds to address // resolution, kernel will send RTM_NEWNEIGH with NUD_FAILED to user space directly, // and there is no netlink neighbor events related to this neighbor received before. - return (prev == null || event.nudState == StructNdMsg.NUD_FAILED); + return (prev == null && event.nudState == StructNdMsg.NUD_FAILED); } private void handleNeighborLost(@Nullable final NeighborEvent prev, @@ -488,7 +488,9 @@ } } - if (avoidingBadLinks() || !(ip instanceof Inet6Address)) { + final boolean avoidingBadLinks = avoidingBadLinks(); + Log.d(TAG, "avoidingBadLinks: " + avoidingBadLinks); + if (avoidingBadLinks || !(ip instanceof Inet6Address)) { // We should do this unconditionally, but alas we cannot: b/31827713. whatIfLp.removeDnsServer(ip); } @@ -527,6 +529,15 @@ (mLinkProperties.isIpv4Provisioned() && !whatIfLp.isIpv4Provisioned()) || (mLinkProperties.isIpv6Provisioned() && !whatIfLp.isIpv6Provisioned() && !ignoreIncompleteIpv6Neighbor); + // TODO: for debugging flaky test only, delete it later. + Log.d(TAG, "lostProvisioning: " + lostProvisioning); + Log.d(TAG, "mLinkProperties.isIpv4Provisioned(): " + mLinkProperties.isIpv4Provisioned()); + Log.d(TAG, "mLinkProperties.isIpv6Provisioned(): " + mLinkProperties.isIpv6Provisioned()); + Log.d(TAG, "whatIfLp.isIpv6Provisioned(): " + whatIfLp.isIpv6Provisioned()); + Log.d(TAG, "whatIfLp.isIpv4Provisioned(): " + whatIfLp.isIpv4Provisioned()); + Log.d(TAG, "ignoreIncompleteIpv6Neighbor: " + ignoreIncompleteIpv6Neighbor); + Log.d(TAG, "IP address: " + ip); + final NudEventType type = getNudFailureEventType(isFromProbe(), isNudFailureDueToRoam(), lostProvisioning);
diff --git a/src/com/android/networkstack/netlink/TcpSocketTracker.java b/src/com/android/networkstack/netlink/TcpSocketTracker.java index 0d77dca..4140e64 100644 --- a/src/com/android/networkstack/netlink/TcpSocketTracker.java +++ b/src/com/android/networkstack/netlink/TcpSocketTracker.java
@@ -50,6 +50,7 @@ import android.net.NetworkCapabilities; import android.os.AsyncTask; import android.os.Build; +import android.os.Handler; import android.os.IBinder; import android.os.PowerManager; import android.os.RemoteException; @@ -66,7 +67,6 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; -import com.android.internal.annotations.GuardedBy; import com.android.internal.annotations.VisibleForTesting; import com.android.modules.utils.build.SdkLevel; import com.android.net.module.util.DeviceConfigUtils; @@ -133,10 +133,11 @@ private int mMinPacketsThreshold = DEFAULT_DATA_STALL_MIN_PACKETS_THRESHOLD; private int mTcpPacketsFailRateThreshold = DEFAULT_TCP_PACKETS_FAIL_PERCENTAGE; + // These variables are initialized when the NetworkMonitor enters DefaultState, + // and can only be accessed on the NetworkMonitor state machine thread after + // the NetworkMonitor state machine has been started. // TODO: Remove doze mode solution since uid networking blocked traffic is filtered out by // the info provided by bpf maps. - private final Object mDozeModeLock = new Object(); - @GuardedBy("mDozeModeLock") private boolean mInDozeMode = false; // These variables are initialized when the NetworkMonitor enters DefaultState, @@ -226,11 +227,23 @@ family, InetDiagMessage.buildInetDiagReqForAliveTcpSockets(family)); } mDependencies.addDeviceConfigChangedListener(mConfigListener); - mDependencies.addDeviceIdleReceiver(mDeviceIdleReceiver, mShouldDisableInDeepDoze, - mShouldDisableInLightDoze); + mCm = mDependencies.getContext().getSystemService(ConnectivityManager.class); } + /** + * Called from NetworkMonitor to notify NetworkMonitor is created. + * This is for initializing TcpSocketTracker from default state. + */ + public void init(@NonNull final Handler handler, @NonNull LinkProperties lp, + @NonNull NetworkCapabilities nc) { + mDependencies.addDeviceIdleReceiver(mDeviceIdleReceiver, mShouldDisableInDeepDoze, + mShouldDisableInLightDoze, handler); + setOpportunisticMode(false); + setLinkProperties(lp); + setNetworkCapabilities(nc); + } + @Nullable private MarkMaskParcel getNetworkMarkMask() { try { @@ -254,9 +267,7 @@ // Traffic will be restricted in doze mode. TCP info may not reflect the correct network // behavior. // TODO: Traffic may be restricted by other reason. Get the restriction info from bpf in T+. - synchronized (mDozeModeLock) { - if (mInDozeMode) return false; - } + if (mInDozeMode) return false; FileDescriptor fd = null; @@ -464,10 +475,8 @@ public boolean isDataStallSuspected() { // Skip checking data stall since the traffic will be restricted and it will not be real // network stall. - // TODO: Traffic may be restricted by other reason. Get the restriction info from bpf in T+. - synchronized (mDozeModeLock) { - if (mInDozeMode) return false; - } + if (mInDozeMode) return false; + final boolean ret = (getLatestPacketFailPercentage() >= getTcpPacketsFailRateThreshold()); if (ret) { log("data stall suspected, uids: " + mLatestReportedUids.toString()); @@ -640,11 +649,9 @@ } private void setDozeMode(boolean isEnabled) { - synchronized (mDozeModeLock) { - if (mInDozeMode == isEnabled) return; - mInDozeMode = isEnabled; - logd("Doze mode enabled=" + mInDozeMode); - } + if (mInDozeMode == isEnabled) return; + mInDozeMode = isEnabled; + logd("Doze mode enabled=" + mInDozeMode); } public void setOpportunisticMode(boolean isEnabled) { @@ -748,7 +755,8 @@ /** Add receiver for detecting doze mode change to control TCP detection. */ @TargetApi(Build.VERSION_CODES.TIRAMISU) public void addDeviceIdleReceiver(@NonNull final BroadcastReceiver receiver, - boolean shouldDisableInDeepDoze, boolean shouldDisableInLightDoze) { + boolean shouldDisableInDeepDoze, boolean shouldDisableInLightDoze, + @NonNull final Handler handler) { // No need to register receiver if no related feature is enabled. if (!shouldDisableInDeepDoze && !shouldDisableInLightDoze) return; @@ -759,7 +767,8 @@ if (shouldDisableInLightDoze) { intentFilter.addAction(ACTION_DEVICE_LIGHT_IDLE_MODE_CHANGED); } - mContext.registerReceiver(receiver, intentFilter); + mContext.registerReceiver(receiver, intentFilter, null /* broadcastPermission */, + handler); } /** Remove broadcast receiver. */
diff --git a/src/com/android/networkstack/util/NetworkStackUtils.java b/src/com/android/networkstack/util/NetworkStackUtils.java index 0cb31fe..4e3a8fd 100755 --- a/src/com/android/networkstack/util/NetworkStackUtils.java +++ b/src/com/android/networkstack/util/NetworkStackUtils.java
@@ -186,12 +186,6 @@ public static final String VALIDATION_METRICS_VERSION = "validation_metrics_version"; /** - * Experiment flag to enable sending gratuitous multicast unsolicited Neighbor Advertisements - * to propagate new assigned IPv6 GUA as quickly as possible. - */ - public static final String IPCLIENT_GRATUITOUS_NA_VERSION = "ipclient_gratuitous_na_version"; - - /** * Experiment flag to enable sending Gratuitous APR and Gratuitous Neighbor Advertisement for * all assigned IPv4 and IPv6 GUAs after completing L2 roaming. */ @@ -199,14 +193,6 @@ "ipclient_garp_na_roaming_version"; /** - * Experiment flag to check if an on-link IPv6 link local DNS is acceptable. The default flag - * value is true, just add this flag for A/B testing to see if this fix works as expected via - * experiment rollout. - */ - public static final String IPCLIENT_ACCEPT_IPV6_LINK_LOCAL_DNS_VERSION = - "ipclient_accept_ipv6_link_local_dns_version"; - - /** * Experiment flag to enable "mcast_resolicit" neighbor parameter in IpReachabilityMonitor, * set it to 3 by default. */ @@ -279,13 +265,6 @@ /**** BEGIN Feature Kill Switch Flags ****/ /** - * Kill switch flag to disable the feature of parsing netlink events from kernel directly - * instead from netd aidl interface by flag push. - */ - public static final String IPCLIENT_PARSE_NETLINK_EVENTS_FORCE_DISABLE = - "ipclient_parse_netlink_events_force_disable"; - - /** * Kill switch flag to disable the feature of handle light doze mode in Apf. */ public static final String APF_HANDLE_LIGHT_DOZE_FORCE_DISABLE =
diff --git a/src/com/android/server/NetworkObserver.java b/src/com/android/server/NetworkObserver.java deleted file mode 100644 index cccec0b..0000000 --- a/src/com/android/server/NetworkObserver.java +++ /dev/null
@@ -1,88 +0,0 @@ -/* - * Copyright (C) 2019 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.android.server; - -import android.net.LinkAddress; -import android.net.RouteInfo; - -/** - * Observer for network events, to use with {@link NetworkObserverRegistry}. - */ -public interface NetworkObserver { - - /** - * @see android.net.INetdUnsolicitedEventListener#onInterfaceChanged(java.lang.String, boolean) - */ - default void onInterfaceChanged(String ifName, boolean up) {} - - /** - * @see android.net.INetdUnsolicitedEventListener#onInterfaceRemoved(String) - */ - default void onInterfaceRemoved(String ifName) {} - - /** - * @see android.net.INetdUnsolicitedEventListener - * #onInterfaceAddressUpdated(String, String, int, int) - */ - default void onInterfaceAddressUpdated(LinkAddress address, String ifName) {} - - /** - * @see android.net.INetdUnsolicitedEventListener - * #onInterfaceAddressRemoved(String, String, int, int) - */ - default void onInterfaceAddressRemoved(LinkAddress address, String ifName) {} - - /** - * @see android.net.INetdUnsolicitedEventListener#onInterfaceLinkStateChanged(String, boolean) - */ - default void onInterfaceLinkStateChanged(String ifName, boolean up) {} - - /** - * @see android.net.INetdUnsolicitedEventListener#onInterfaceAdded(String) - */ - default void onInterfaceAdded(String ifName) {} - - /** - * @see android.net.INetdUnsolicitedEventListener - * #onInterfaceClassActivityChanged(boolean, int, long, int) - */ - default void onInterfaceClassActivityChanged( - boolean isActive, int label, long timestamp, int uid) {} - - /** - * @see android.net.INetdUnsolicitedEventListener#onQuotaLimitReached(String, String) - */ - default void onQuotaLimitReached(String alertName, String ifName) {} - - /** - * @see android.net.INetdUnsolicitedEventListener - * #onInterfaceDnsServerInfo(String, long, String[]) - */ - default void onInterfaceDnsServerInfo(String ifName, long lifetime, String[] servers) {} - - /** - * @see android.net.INetdUnsolicitedEventListener - * #onRouteChanged(boolean, String, String, String) - */ - default void onRouteUpdated(RouteInfo route) {} - - /** - * @see android.net.INetdUnsolicitedEventListener - * #onRouteChanged(boolean, String, String, String) - */ - default void onRouteRemoved(RouteInfo route) {} -}
diff --git a/src/com/android/server/NetworkObserverRegistry.java b/src/com/android/server/NetworkObserverRegistry.java deleted file mode 100644 index 38a0008..0000000 --- a/src/com/android/server/NetworkObserverRegistry.java +++ /dev/null
@@ -1,189 +0,0 @@ -/* - * Copyright (C) 2019 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.android.server; - -import static android.net.RouteInfo.RTN_UNICAST; - -import android.net.INetd; -import android.net.INetdUnsolicitedEventListener; -import android.net.InetAddresses; -import android.net.IpPrefix; -import android.net.LinkAddress; -import android.net.RouteInfo; -import android.os.Handler; -import android.os.RemoteException; -import android.util.Log; - -import androidx.annotation.NonNull; - -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; - -/** - * A class for reporting network events to clients. - * - * Implements INetdUnsolicitedEventListener and registers with netd, and relays those events to - * all INetworkManagementEventObserver objects that have registered with it. - */ -public class NetworkObserverRegistry extends INetdUnsolicitedEventListener.Stub { - private static final String TAG = NetworkObserverRegistry.class.getSimpleName(); - - /** - * Start listening for Netd events. - * - * <p>This should be called before allowing any observer to be registered. - * Note there is no unregister method. The only way to unregister is when the process - * terminates. - */ - public void register(@NonNull INetd netd) throws RemoteException { - netd.registerUnsolicitedEventListener(this); - } - - private final ConcurrentHashMap<NetworkObserver, Optional<Handler>> mObservers = - new ConcurrentHashMap<>(); - - /** - * Registers the specified observer and start sending callbacks to it. - * This method may be called on any thread. - */ - public void registerObserver(@NonNull NetworkObserver observer, @NonNull Handler handler) { - if (handler == null) { - throw new IllegalArgumentException("handler must be non-null"); - } - mObservers.put(observer, Optional.of(handler)); - } - - /** - * Registers the specified observer, and start sending callbacks to it. - * - * <p>This method must only be called with callbacks that are nonblocking, such as callbacks - * that only send a message to a StateMachine. - */ - public void registerObserverForNonblockingCallback(@NonNull NetworkObserver observer) { - mObservers.put(observer, Optional.empty()); - } - - /** - * Unregisters the specified observer and stop sending callbacks to it. - * This method may be called on any thread. - */ - public void unregisterObserver(@NonNull NetworkObserver observer) { - mObservers.remove(observer); - } - - @FunctionalInterface - private interface NetworkObserverEventCallback { - void sendCallback(NetworkObserver o); - } - - private void invokeForAllObservers(@NonNull final NetworkObserverEventCallback callback) { - // ConcurrentHashMap#entrySet is weakly consistent: observers that were in the map before - // creation will be processed, those added during traversal may or may not. - for (Map.Entry<NetworkObserver, Optional<Handler>> entry : mObservers.entrySet()) { - final NetworkObserver observer = entry.getKey(); - final Optional<Handler> handler = entry.getValue(); - if (handler.isPresent()) { - handler.get().post(() -> callback.sendCallback(observer)); - return; - } - - try { - callback.sendCallback(observer); - } catch (RuntimeException e) { - Log.e(TAG, "Error sending callback to observer", e); - } - } - } - - @Override - public void onInterfaceClassActivityChanged(boolean isActive, - int label, long timestamp, int uid) { - invokeForAllObservers(o -> o.onInterfaceClassActivityChanged( - isActive, label, timestamp, uid)); - } - - /** - * Notify our observers of a limit reached. - */ - @Override - public void onQuotaLimitReached(String alertName, String ifName) { - invokeForAllObservers(o -> o.onQuotaLimitReached(alertName, ifName)); - } - - @Override - public void onInterfaceDnsServerInfo(String ifName, long lifetime, String[] servers) { - invokeForAllObservers(o -> o.onInterfaceDnsServerInfo(ifName, lifetime, servers)); - } - - @Override - public void onInterfaceAddressUpdated(String addr, String ifName, int flags, int scope) { - final LinkAddress address = new LinkAddress(addr, flags, scope); - invokeForAllObservers(o -> o.onInterfaceAddressUpdated(address, ifName)); - } - - @Override - public void onInterfaceAddressRemoved(String addr, - String ifName, int flags, int scope) { - final LinkAddress address = new LinkAddress(addr, flags, scope); - invokeForAllObservers(o -> o.onInterfaceAddressRemoved(address, ifName)); - } - - @Override - public void onInterfaceAdded(String ifName) { - invokeForAllObservers(o -> o.onInterfaceAdded(ifName)); - } - - @Override - public void onInterfaceRemoved(String ifName) { - invokeForAllObservers(o -> o.onInterfaceRemoved(ifName)); - } - - @Override - public void onInterfaceChanged(String ifName, boolean up) { - invokeForAllObservers(o -> o.onInterfaceChanged(ifName, up)); - } - - @Override - public void onInterfaceLinkStateChanged(String ifName, boolean up) { - invokeForAllObservers(o -> o.onInterfaceLinkStateChanged(ifName, up)); - } - - @Override - public void onRouteChanged(boolean updated, String route, String gateway, String ifName) { - final RouteInfo processRoute = new RouteInfo(new IpPrefix(route), - ("".equals(gateway)) ? null : InetAddresses.parseNumericAddress(gateway), - ifName, RTN_UNICAST); - if (updated) { - invokeForAllObservers(o -> o.onRouteUpdated(processRoute)); - } else { - invokeForAllObservers(o -> o.onRouteRemoved(processRoute)); - } - } - - @Override - public void onStrictCleartextDetected(int uid, String hex) {} - - @Override - public int getInterfaceVersion() { - return INetdUnsolicitedEventListener.VERSION; - } - - @Override - public String getInterfaceHash() { - return INetdUnsolicitedEventListener.HASH; - } -}
diff --git a/src/com/android/server/NetworkStackService.java b/src/com/android/server/NetworkStackService.java index 40aee28..aa8f3fa 100644 --- a/src/com/android/server/NetworkStackService.java +++ b/src/com/android/server/NetworkStackService.java
@@ -85,6 +85,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; +import java.util.ListIterator; import java.util.Objects; import java.util.SortedSet; import java.util.TreeSet; @@ -179,9 +180,9 @@ /** @see IpClient */ @NonNull public IpClient makeIpClient(@NonNull Context context, @NonNull String ifName, - @NonNull IIpClientCallbacks cb, @NonNull NetworkObserverRegistry observerRegistry, + @NonNull IIpClientCallbacks cb, @NonNull NetworkStackServiceManager nsServiceManager) { - return new IpClient(context, ifName, cb, observerRegistry, nsServiceManager); + return new IpClient(context, ifName, cb, nsServiceManager); } } @@ -196,7 +197,6 @@ private final PermissionChecker mPermChecker; private final Dependencies mDeps; private final INetd mNetd; - private final NetworkObserverRegistry mObserverRegistry; @GuardedBy("mIpClients") private final ArrayList<WeakReference<IpClient>> mIpClients = new ArrayList<>(); private final IpMemoryStoreService mIpMemoryStoreService; @@ -294,7 +294,6 @@ mDeps = deps; mNetd = INetd.Stub.asInterface( (IBinder) context.getSystemService(Context.NETD_SERVICE)); - mObserverRegistry = new NetworkObserverRegistry(); mIpMemoryStoreService = mDeps.makeIpMemoryStoreService(context); // NetworkStackNotifier only shows notifications relevant for API level > Q if (ShimUtils.isReleaseOrDevelopmentApiAbove(Build.VERSION_CODES.Q)) { @@ -317,12 +316,6 @@ netdHash = HASH_UNKNOWN; } updateNetdAidlVersion(netdVersion, netdHash); - - try { - mObserverRegistry.register(mNetd); - } catch (RemoteException e) { - mLog.e("Error registering observer on Netd", e); - } } private void updateNetdAidlVersion(final int version, final String hash) { @@ -385,7 +378,7 @@ mPermChecker.enforceNetworkStackCallingPermission(); updateNetworkStackAidlVersion(cb.getInterfaceVersion(), cb.getInterfaceHash()); final IpClient ipClient = mDeps.makeIpClient( - mContext, ifName, cb, mObserverRegistry, this); + mContext, ifName, cb, this); synchronized (mIpClients) { final Iterator<WeakReference<IpClient>> it = mIpClients.iterator(); @@ -511,6 +504,23 @@ err.getFileDescriptor(), args); } + private String apfShellCommand(String iface, String cmd, @Nullable String optarg) { + synchronized (mIpClients) { + // HACK: An old IpClient serving the given interface name might not have been + // garbage collected. Since new IpClients are always appended to the list, iterate + // through it in reverse order to get the most up-to-date IpClient instance. + // Create a ListIterator at the end of the list. + final ListIterator it = mIpClients.listIterator(mIpClients.size()); + while (it.hasPrevious()) { + final IpClient ipClient = ((WeakReference<IpClient>) it.previous()).get(); + if (ipClient != null && ipClient.getInterfaceName().equals(iface)) { + return ipClient.apfShellCommand(cmd, optarg); + } + } + } + throw new IllegalArgumentException("No active IpClient found for interface " + iface); + } + private class ShellCmd extends BasicShellCommandHandler { @Override public int onCommand(String cmd) { @@ -518,39 +528,51 @@ return handleDefaultCommands(cmd); } final PrintWriter pw = getOutPrintWriter(); - try { - switch (cmd) { - case "is-uid-networking-blocked": - if (!DeviceConfigUtils.isFeatureSupported(mContext, - FEATURE_IS_UID_NETWORKING_BLOCKED)) { - pw.println("API is unsupported"); - return -1; - } + switch (cmd) { + case "is-uid-networking-blocked": + if (!DeviceConfigUtils.isFeatureSupported(mContext, + FEATURE_IS_UID_NETWORKING_BLOCKED)) { + throw new IllegalStateException("API is unsupported"); + } - // Usage : cmd network_stack is-uid-networking-blocked <uid> <metered> - // If no argument, get and display the usage help. - if (getRemainingArgsCount() != 2) { - onHelp(); - return -1; - } - final int uid; - final boolean metered; - // If any fail, throws and output to the stdout. - // Let the caller handle it. - uid = Integer.parseInt(getNextArg()); - metered = Boolean.parseBoolean(getNextArg()); - final ConnectivityManager cm = - mContext.getSystemService(ConnectivityManager.class); - pw.println(cm.isUidNetworkingBlocked( - uid, metered /* isNetworkMetered */)); - return 0; - default: - return handleDefaultCommands(cmd); - } - } catch (Exception e) { - pw.println(e); + // Usage : cmd network_stack is-uid-networking-blocked <uid> <metered> + // If no argument, get and display the usage help. + if (getRemainingArgsCount() != 2) { + onHelp(); + throw new IllegalArgumentException("Incorrect number of arguments"); + } + final int uid; + final boolean metered; + uid = Integer.parseInt(getNextArg()); + metered = Boolean.parseBoolean(getNextArg()); + final ConnectivityManager cm = + mContext.getSystemService(ConnectivityManager.class); + pw.println(cm.isUidNetworkingBlocked(uid, metered /* isNetworkMetered */)); + return 0; + case "apf": + // Usage: cmd network_stack apf <iface> <cmd> + final String iface = getNextArg(); + if (iface == null) { + throw new IllegalArgumentException("No <iface> specified"); + } + + final String subcmd = getNextArg(); + if (subcmd == null) { + throw new IllegalArgumentException("No <cmd> specified"); + } + + final String optarg = getNextArg(); + if (getRemainingArgsCount() != 0) { + throw new IllegalArgumentException("Too many arguments passed"); + } + + final String result = apfShellCommand(iface, subcmd, optarg); + pw.println(result); + return 0; + + default: + return handleDefaultCommands(cmd); } - return -1; } @Override @@ -563,6 +585,24 @@ pw.println(" Get whether the networking is blocked for given uid and metered."); pw.println(" <uid>: The target uid."); pw.println(" <metered>: [true|false], Whether the target network is metered."); + pw.println(" apf <iface> <cmd>"); + pw.println(" APF utility commands for integration tests."); + pw.println(" <iface>: the network interface the provided command operates on."); + pw.println(" <cmd>: [status]"); + pw.println(" status"); + pw.println(" returns whether the APF filter is \"running\" or \"paused\"."); + pw.println(" pause"); + pw.println(" pause APF filter generation."); + pw.println(" resume"); + pw.println(" resume APF filter generation."); + pw.println(" install <program-hex-string>"); + pw.println(" install the APF program contained in <program-hex-string>."); + pw.println(" The filter must be paused before installing a new program."); + pw.println(" capabilities"); + pw.println(" return the reported APF capabilities."); + pw.println(" Format: <apfVersion>,<maxProgramSize>,<packetFormat>"); + pw.println(" read"); + pw.println(" reads and returns the current state of APF memory."); } }
diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java index c62fb90..e564cd7 100755 --- a/src/com/android/server/connectivity/NetworkMonitor.java +++ b/src/com/android/server/connectivity/NetworkMonitor.java
@@ -987,9 +987,7 @@ final TcpSocketTracker tst = getTcpSocketTracker(); if (tst != null) { // Initialization. - tst.setOpportunisticMode(false); - tst.setLinkProperties(mLinkProperties); - tst.setNetworkCapabilities(mNetworkCapabilities); + tst.init(getHandler(), mLinkProperties, mNetworkCapabilities); } Log.d(TAG, "Starting on network " + mNetwork + " with capport HTTPS URL " + Arrays.toString(mCaptivePortalHttpsUrls) @@ -2022,7 +2020,8 @@ recordProbeEventMetrics(ProbeType.PT_PRIVDNS, elapsedNanos, success ? ProbeResult.PR_SUCCESS : ProbeResult.PR_FAILURE, null /* capportData */); - logValidationProbe(elapsedNanos, PROBE_PRIVDNS, success ? DNS_SUCCESS : DNS_FAILURE); + logValidationProbe(elapsedNanos / 1000, PROBE_PRIVDNS, + success ? DNS_SUCCESS : DNS_FAILURE); final String strIps = Objects.toString(answer); validationLog(PROBE_PRIVDNS, queryName,
diff --git a/tests/integration/AndroidManifest.xml b/tests/integration/AndroidManifest.xml index 85c971a..ca5382e 100644 --- a/tests/integration/AndroidManifest.xml +++ b/tests/integration/AndroidManifest.xml
@@ -23,10 +23,13 @@ 05-14 00:41:02.723 18330 18330 E AndroidRuntime: java.lang.IllegalStateException: Signature|privileged permissions not in privapp-permissions whitelist: {com.android.server.networkstack.integrationtests: android.permission.CONNECTIVITY_INTERNAL} --> - <!-- Used by creating test network --> - <uses-permission android:name="android.permission.MANAGE_TEST_NETWORKS" /> - <uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" /> - <uses-permission android:name="android.permission.WRITE_SETTINGS" /> + <!-- Used by creating test network. + comment out the permission needed by NetworkStatsIntegrationTest. + The test case migrated to Connectivity + --> + <!-- uses-permission android:name="android.permission.MANAGE_TEST_NETWORKS" /--> + <!-- uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" /--> + <!-- uses-permission android:name="android.permission.WRITE_SETTINGS" /--> <application android:debuggable="true"> <uses-library android:name="android.test.runner" /> </application>
diff --git a/tests/integration/common/android/net/ip/IpClientIntegrationTestCommon.java b/tests/integration/common/android/net/ip/IpClientIntegrationTestCommon.java index 06b0ca2..d310b9c 100644 --- a/tests/integration/common/android/net/ip/IpClientIntegrationTestCommon.java +++ b/tests/integration/common/android/net/ip/IpClientIntegrationTestCommon.java
@@ -26,6 +26,7 @@ import static android.net.NetworkCapabilities.TRANSPORT_TEST; import static android.net.RouteInfo.RTN_UNICAST; import static android.net.dhcp.DhcpClient.EXPIRED_LEASE; +import static android.net.dhcp.DhcpPacket.CONFIG_MINIMUM_LEASE; import static android.net.dhcp.DhcpPacket.DHCP_BOOTREQUEST; import static android.net.dhcp.DhcpPacket.DHCP_CLIENT; import static android.net.dhcp.DhcpPacket.DHCP_IPV6_ONLY_PREFERRED; @@ -220,8 +221,6 @@ import com.android.networkstack.packets.NeighborAdvertisement; import com.android.networkstack.packets.NeighborSolicitation; import com.android.networkstack.util.NetworkStackUtils; -import com.android.server.NetworkObserver; -import com.android.server.NetworkObserverRegistry; import com.android.server.NetworkStackService.NetworkStackServiceManager; import com.android.testutils.CompatUtil; import com.android.testutils.DevSdkIgnoreRule; @@ -239,8 +238,6 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; import org.mockito.ArgumentCaptor; import org.mockito.InOrder; import org.mockito.Mock; @@ -282,7 +279,6 @@ * * Tests in this class can either be run with signature permissions, or with root access. */ -@RunWith(Parameterized.class) @SmallTest public abstract class IpClientIntegrationTestCommon { private static final String TAG = IpClientIntegrationTestCommon.class.getSimpleName(); @@ -318,17 +314,6 @@ @Rule public final TestName mTestNameRule = new TestName(); - // Indicate whether the flag of parsing netlink event is enabled or not. If it's disabled, - // integration test still covers the old codepath(i.e. using NetworkObserver), otherwise, - // test goes through the new codepath(i.e. processRtNetlinkxxx). - @Parameterized.Parameter(0) - public boolean mIsNetlinkEventParseEnabled; - - @Parameterized.Parameters - public static Iterable<? extends Object> data() { - return Arrays.asList(Boolean.FALSE, Boolean.TRUE); - } - /** * Indicates that a test requires signature permissions to run. * @@ -364,7 +349,6 @@ @Mock private DevicePolicyManager mDevicePolicyManager; @Mock private PackageManager mPackageManager; @Spy private INetd mNetd; - private NetworkObserverRegistry mNetworkObserverRegistry; protected IpClient mIpc; protected Dependencies mDependencies; @@ -595,6 +579,11 @@ } @Override + public int getIntDeviceConfig(final String name, int defaultValue) { + return Dependencies.this.getDeviceConfigPropertyInt(name, defaultValue); + } + + @Override public PowerManager.WakeLock getWakeLock(final PowerManager powerManager) { return mTimeoutWakeLock; } @@ -718,32 +707,11 @@ @Before public void setUp() throws Exception { - // Suffix "[0]" or "[1]" is added to the end of test method name after running with - // Parameterized.class, that's intended behavior, to iterate each test method with the - // parameterize value. However, Class#getMethod() throws NoSuchMethodException when - // searching the target test method name due to this change. Just keep the original test - // method name to fix NoSuchMethodException, and find the correct annotation associated - // to test method. - final String testMethodName = mTestNameRule.getMethodName().split("\\[")[0]; + final String testMethodName = mTestNameRule.getMethodName(); final Method testMethod = IpClientIntegrationTestCommon.class.getMethod(testMethodName); mIsSignatureRequiredTest = testMethod.getAnnotation(SignatureRequiredTest.class) != null; assumeFalse(testSkipped()); - // Depend on the parameterized value to enable/disable netlink message refactor flag. - // Make sure both of the old codepath(rely on the INetdUnsolicitedEventListener aidl) - // and new codepath(parse netlink event from kernel) will be executed. - // - // Note this must be called before making IpClient instance since MyNetlinkMontior ctor - // in IpClientLinkObserver will use mIsNetlinkEventParseEnabled to decide the proper - // bindGroups, otherwise, the parameterized value got from ArrayMap(integration test) is - // always false. - // - // Set feature kill switch flag with the parameterized value to keep running test cases on - // both code paths. Once we clean up the old code path (i.e.when the parameterized variable - // is false), then we can also delete this code. - setFeatureChickenedOut(NetworkStackUtils.IPCLIENT_PARSE_NETLINK_EVENTS_FORCE_DISABLE, - !mIsNetlinkEventParseEnabled); - // Enable DHCPv6 Prefix Delegation. setFeatureEnabled(NetworkStackUtils.IPCLIENT_DHCPV6_PREFIX_DELEGATION_VERSION, true /* isDhcp6PrefixDelegationEnabled */); @@ -752,6 +720,18 @@ setFeatureEnabled(NetworkStackUtils.IPCLIENT_POPULATE_LINK_ADDRESS_LIFETIME_VERSION, true /* enabled */); + // Disable the experiment flag IP_REACHABILITY_IGNORE_INCOMPLETE_IPV6_DNS_SERVER_VERSION + // for testIpReachabilityMonitor_incompleteIpv6DnsServerInDualStack_flagoff testcase, given + // the experiment flag is read at IpReachabilityMonitor constructor so we have to turn it + // off before creating the IpClient instance. + // TODO: cleanup this code as well when cleaning up the experiment flag. + if (testMethodName.equals( + "testIpReachabilityMonitor_incompleteIpv6DnsServerInDualStack_flagoff")) { + setFeatureEnabled( + NetworkStackUtils.IP_REACHABILITY_IGNORE_INCOMPLETE_IPV6_DNS_SERVER_VERSION, + false /* enabled */); + } + setUpTapInterface(); // It turns out that Router Solicitation will also be sent out even after the tap interface // is brought up, however, we want to wait for RS which is sent due to IPv6 stack is enabled @@ -767,6 +747,7 @@ setUpIpClient(); // Enable packet retransmit alarm in DhcpClient. enableRealAlarm("DhcpClient." + mIfaceName + ".KICK"); + enableRealAlarm("DhcpClient." + mIfaceName + ".RENEW"); // Enable alarm for IPv6 autoconf via SLAAC in IpClient. enableRealAlarm("IpClient." + mIfaceName + ".EVENT_IPV6_AUTOCONF_TIMEOUT"); // Enable packet retransmit alarm in Dhcp6Client. @@ -786,6 +767,8 @@ // in this case and start DHCPv6 Prefix Delegation then. final int timeout = useNetworkStackSignature() ? 500 : (int) TEST_TIMEOUT_MS; setDeviceConfigProperty(IpClient.CONFIG_IPV6_AUTOCONF_TIMEOUT, timeout /* default value */); + // Set DHCP minimum lease. + setDeviceConfigProperty(DhcpPacket.CONFIG_MINIMUM_LEASE, DhcpPacket.DEFAULT_MINIMUM_LEASE); } protected void setUpMocks() throws Exception { @@ -945,8 +928,8 @@ } private IpClient makeIpClient() throws Exception { - IpClient ipc = new IpClient(mContext, mIfaceName, mCb, mNetworkObserverRegistry, - mNetworkStackServiceManager, mDependencies); + IpClient ipc = + new IpClient(mContext, mIfaceName, mCb, mNetworkStackServiceManager, mDependencies); // Wait for IpClient to enter its initial state. Otherwise, additional setup steps or tests // that mock IpClient's dependencies might interact with those mocks while IpClient is // starting. This would cause UnfinishedStubbingExceptions as mocks cannot be interacted @@ -963,8 +946,6 @@ when(mContext.getSystemService(eq(Context.NETD_SERVICE))).thenReturn(netdIBinder); assertNotNull(mNetd); - mNetworkObserverRegistry = new NetworkObserverRegistry(); - mNetworkObserverRegistry.register(mNetd); mIpc = makeIpClient(); // Tell the IpMemoryStore immediately to answer any question about network attributes with a @@ -2164,11 +2145,7 @@ } private boolean isStablePrivacyAddress(LinkAddress addr) { - // The Q netd does not understand the IFA_F_STABLE_PRIVACY flag. - // See r.android.com/1295670. - final int flag = (mIsNetlinkEventParseEnabled || ShimUtils.isAtLeastR()) - ? IFA_F_STABLE_PRIVACY : 0; - return addr.isGlobalPreferred() && hasFlag(addr, flag); + return addr.isGlobalPreferred() && hasFlag(addr, IFA_F_STABLE_PRIVACY); } private LinkProperties doIpv6OnlyProvisioning() throws Exception { @@ -2263,14 +2240,11 @@ reset(mCb); } - private void runRaRdnssIpv6LinkLocalDnsTest(boolean isIpv6LinkLocalDnsAccepted) - throws Exception { + private void runRaRdnssIpv6LinkLocalDnsTest() throws Exception { ProvisioningConfiguration config = new ProvisioningConfiguration.Builder() .withoutIpReachabilityMonitor() .withoutIPv4() .build(); - setFeatureEnabled(NetworkStackUtils.IPCLIENT_ACCEPT_IPV6_LINK_LOCAL_DNS_VERSION, - isIpv6LinkLocalDnsAccepted /* default value */); startIpClientProvisioning(config); final ByteBuffer pio = buildPioOption(600, 300, "2001:db8:1::/64"); @@ -2286,7 +2260,7 @@ @Test public void testRaRdnss_Ipv6LinkLocalDns() throws Exception { - runRaRdnssIpv6LinkLocalDnsTest(true /* isIpv6LinkLocalDnsAccepted */); + runRaRdnssIpv6LinkLocalDnsTest(); final ArgumentCaptor<LinkProperties> captor = ArgumentCaptor.forClass(LinkProperties.class); verify(mCb, timeout(TEST_TIMEOUT_MS)).onProvisioningSuccess(captor.capture()); final LinkProperties lp = captor.getValue(); @@ -2296,20 +2270,6 @@ assertTrue(lp.isIpv6Provisioned()); } - @Test - public void testRaRdnss_disableIpv6LinkLocalDns() throws Exception { - // Only run the test when the flag of parsing netlink events is enabled, feature flag - // "ipclient_accept_ipv6_link_local_dns" doesn't affect the legacy code. - assumeTrue(mIsNetlinkEventParseEnabled); - runRaRdnssIpv6LinkLocalDnsTest(false /* isIpv6LinkLocalDnsAccepted */); - verify(mCb, timeout(TEST_TIMEOUT_MS)).onLinkPropertiesChange(argThat(lp -> { - return lp.hasGlobalIpv6Address() - && lp.hasIpv6DefaultRoute() - && !lp.hasIpv6DnsServer(); - })); - verify(mCb, never()).onProvisioningSuccess(any()); - } - private void expectNat64PrefixUpdate(InOrder inOrder, IpPrefix expected) throws Exception { inOrder.verify(mCb, timeout(TEST_TIMEOUT_MS)).onLinkPropertiesChange( argThat(lp -> Objects.equals(expected, lp.getNat64Prefix()))); @@ -2438,49 +2398,15 @@ HandlerUtils.waitForIdle(mIpc.getHandler(), TEST_TIMEOUT_MS); } - private void waitForAddressViaNetworkObserver(final String iface, final String addr1, - final String addr2, int prefixLength) throws Exception { - final CountDownLatch latch = new CountDownLatch(1); - - // Add two IPv4 addresses to the specified interface, and proceed when the NetworkObserver - // has seen the second one. This ensures that every other NetworkObserver registered with - // mNetworkObserverRegistry - in particular, IpClient's - has seen the addition of the first - // address. - final LinkAddress trigger = new LinkAddress(addr2 + "/" + prefixLength); - NetworkObserver observer = new NetworkObserver() { - @Override - public void onInterfaceAddressUpdated(LinkAddress address, String ifName) { - if (ifName.equals(iface) && address.isSameAddressAs(trigger)) { - latch.countDown(); - } - } - }; - - mNetworkObserverRegistry.registerObserverForNonblockingCallback(observer); - try { - mNetd.interfaceAddAddress(iface, addr1, prefixLength); - mNetd.interfaceAddAddress(iface, addr2, prefixLength); - assertTrue("Trigger IP address " + addr2 + " not seen after " + TEST_TIMEOUT_MS + "ms", - latch.await(TEST_TIMEOUT_MS, TimeUnit.MILLISECONDS)); - } finally { - mNetworkObserverRegistry.unregisterObserver(observer); - } - } - private void addIpAddressAndWaitForIt(final String iface) throws Exception { final String addr1 = "192.0.2.99"; final String addr2 = "192.0.2.3"; final int prefixLength = 26; - if (!mIsNetlinkEventParseEnabled) { - waitForAddressViaNetworkObserver(iface, addr1, addr2, prefixLength); - } else { - // IpClient gets IP addresses directly from netlink instead of from netd, unnecessary - // to rely on the NetworkObserver callbacks to confirm new added address update. Just - // add the addresses directly and wait to see if IpClient has seen the address - mNetd.interfaceAddAddress(iface, addr1, prefixLength); - mNetd.interfaceAddAddress(iface, addr2, prefixLength); - } + // IpClient gets IP addresses directly from netlink instead of from netd, just + // add the addresses directly and wait to see if IpClient has seen the address. + mNetd.interfaceAddAddress(iface, addr1, prefixLength); + mNetd.interfaceAddAddress(iface, addr2, prefixLength); // Wait for IpClient to process the addition of the address. HandlerUtils.waitForIdle(mIpc.getHandler(), TEST_TIMEOUT_MS); @@ -2912,6 +2838,17 @@ (byte) 0x06, data); } + private void assertDhcpResultsParcelable(final DhcpResultsParcelable lease) { + assertNotNull(lease); + assertEquals(CLIENT_ADDR, lease.baseConfiguration.getIpAddress().getAddress()); + assertEquals(SERVER_ADDR, lease.baseConfiguration.getGateway()); + assertEquals(1, lease.baseConfiguration.getDnsServers().size()); + assertTrue(lease.baseConfiguration.getDnsServers().contains(SERVER_ADDR)); + assertEquals(SERVER_ADDR, InetAddresses.parseNumericAddress(lease.serverAddress)); + assertEquals(TEST_DEFAULT_MTU, lease.mtu); + assertEquals(TEST_LEASE_DURATION_S, lease.leaseDuration); + } + private void doUpstreamHotspotDetectionTest(final int id, final String displayName, final String ssid, final byte[] oui, final byte type, final byte[] data, final boolean expectMetered) throws Exception { @@ -2930,13 +2867,7 @@ ArgumentCaptor.forClass(DhcpResultsParcelable.class); verify(mCb, timeout(TEST_TIMEOUT_MS)).onNewDhcpResults(captor.capture()); final DhcpResultsParcelable lease = captor.getValue(); - assertNotNull(lease); - assertEquals(CLIENT_ADDR, lease.baseConfiguration.getIpAddress().getAddress()); - assertEquals(SERVER_ADDR, lease.baseConfiguration.getGateway()); - assertEquals(1, lease.baseConfiguration.getDnsServers().size()); - assertTrue(lease.baseConfiguration.getDnsServers().contains(SERVER_ADDR)); - assertEquals(SERVER_ADDR, InetAddresses.parseNumericAddress(lease.serverAddress)); - assertEquals(TEST_DEFAULT_MTU, lease.mtu); + assertDhcpResultsParcelable(lease); if (expectMetered) { assertEquals(lease.vendorInfo, DhcpPacket.VENDOR_INFO_ANDROID_METERED); @@ -3864,8 +3795,6 @@ .withoutIPv4() .build(); - setFeatureEnabled(NetworkStackUtils.IPCLIENT_GRATUITOUS_NA_VERSION, - true /* isGratuitousNaEnabled */); startIpClientProvisioning(config); doIpv6OnlyProvisioning(); @@ -3898,11 +3827,6 @@ setDhcpFeatures(true /* isRapidCommitEnabled */, false /* isDhcpIpConflictDetectEnabled */); - // Disable gratuitious neighbor discovery feature manually, if the feature is enabled on - // the DUT during experiment launch, that will send another two duplicate NA packets and - // mess up the assert of received NA packets. - setFeatureEnabled(NetworkStackUtils.IPCLIENT_GRATUITOUS_NA_VERSION, - false /* isGratuitousNaEnabled */); if (isGratuitousArpNaRoamingEnabled) { setFeatureEnabled(NetworkStackUtils.IPCLIENT_GARP_NA_ROAMING_VERSION, true); } else { @@ -3939,7 +3863,8 @@ final List<ArpPacket> arpList = new ArrayList<>(); final List<NeighborAdvertisement> naList = new ArrayList<>(); waitForGratuitousArpAndNaPacket(arpList, naList); - assertEquals(2, naList.size()); // privacy address and stable privacy address + // 2 NAs sent due to RFC9131 implement and 2 NAs sent after roam + assertEquals(4, naList.size()); // privacy address and stable privacy address assertEquals(1, arpList.size()); // IPv4 address } @@ -3953,7 +3878,7 @@ final List<ArpPacket> arpList = new ArrayList<>(); final List<NeighborAdvertisement> naList = new ArrayList<>(); waitForGratuitousArpAndNaPacket(arpList, naList); - assertEquals(0, naList.size()); + assertEquals(2, naList.size()); // NAs sent due to RFC9131 implement, not from roam assertEquals(0, arpList.size()); } @@ -3967,7 +3892,8 @@ final List<ArpPacket> arpList = new ArrayList<>(); final List<NeighborAdvertisement> naList = new ArrayList<>(); waitForGratuitousArpAndNaPacket(arpList, naList); - assertEquals(2, naList.size()); + // 2 NAs sent due to RFC9131 implement and 2 NAs sent after roam + assertEquals(4, naList.size()); assertEquals(0, arpList.size()); } @@ -4688,9 +4614,6 @@ @Test @SignatureRequiredTest(reason = "requires mock callback object") public void testNetlinkSocketReceiveENOBUFS() throws Exception { - // Only run the test when the flag of parsing netlink events is enabled. - assumeTrue(mIsNetlinkEventParseEnabled); - ProvisioningConfiguration config = new ProvisioningConfiguration.Builder() .withoutIPv4() .build(); @@ -4995,20 +4918,16 @@ final LinkProperties lp = captor.getValue(); assertTrue(hasIpv6AddressPrefixedWith(lp, prefix)); - // Only run the test when the flag of parsing netlink events is enabled, where the - // deprecationTime and expirationTime is set. - if (mIsNetlinkEventParseEnabled) { - final long now = SystemClock.elapsedRealtime(); - long when = 0; - for (LinkAddress la : lp.getLinkAddresses()) { - if (la.getAddress().isLinkLocalAddress()) { - assertLinkAddressPermanentLifetime(la); - } else if (la.isGlobalPreferred()) { - when = now + 4500 * 1000; // preferred=4500s - assertLinkAddressDeprecationTime(la, when); - when = now + 7200 * 1000; // valid=7200s - assertLinkAddressExpirationTime(la, when); - } + final long now = SystemClock.elapsedRealtime(); + long when = 0; + for (LinkAddress la : lp.getLinkAddresses()) { + if (la.getAddress().isLinkLocalAddress()) { + assertLinkAddressPermanentLifetime(la); + } else if (la.isGlobalPreferred()) { + when = now + 4500 * 1000; // preferred=4500s + assertLinkAddressDeprecationTime(la, when); + when = now + 7200 * 1000; // valid=7200s + assertLinkAddressExpirationTime(la, when); } } } @@ -5777,11 +5696,6 @@ @Test public void testPopulateLinkAddressLifetime() throws Exception { - // Only run the test when the flag of parsing netlink events is enabled to verify the - // code of setting deprecationTime/expirationTime added when IpClientLinkObserver sees - // the RTM_NEWADDR, and we are going to delete the dead old code path completely soon. - assumeTrue(mIsNetlinkEventParseEnabled); - final LinkProperties lp = doDualStackProvisioning(); final long now = SystemClock.elapsedRealtime(); long when = 0; @@ -5803,9 +5717,6 @@ @Test public void testPopulateLinkAddressLifetime_infiniteLeaseDuration() throws Exception { - // Only run the test when the flag of parsing netlink events is enabled. - assumeTrue(mIsNetlinkEventParseEnabled); - final ProvisioningConfiguration cfg = new ProvisioningConfiguration.Builder() .withoutIPv6() .build(); @@ -5829,9 +5740,6 @@ @Test public void testPopulateLinkAddressLifetime_minimalLeaseDuration() throws Exception { - // Only run the test when the flag of parsing netlink events is enabled. - assumeTrue(mIsNetlinkEventParseEnabled); - final ProvisioningConfiguration cfg = new ProvisioningConfiguration.Builder() .withoutIPv6() .build(); @@ -5856,6 +5764,65 @@ } } + @Test + public void testPopulateLinkAddressLifetime_onDhcpRenew() throws Exception { + final ProvisioningConfiguration cfg = new ProvisioningConfiguration.Builder() + .withoutIPv6() + .build(); + setDeviceConfigProperty(CONFIG_MINIMUM_LEASE, 5 /* default minimum lease */); + startIpClientProvisioning(cfg); + handleDhcpPackets(true /* isSuccessLease */, 4 /* lease duration */, + false /* shouldReplyRapidCommitAck */, TEST_DEFAULT_MTU, + null /* captivePortalApiUrl */, null /* ipv6OnlyWaitTime */, + null /* domainName */, null /* domainSearchList */); + + verify(mCb, timeout(TEST_TIMEOUT_MS)).onProvisioningSuccess(any()); + + // Device sends ARP request for address resolution of default gateway first. + final ArpPacket request = getNextArpPacket(); + assertArpRequest(request, SERVER_ADDR); + sendArpReply(request.senderHwAddress.toByteArray() /* dst */, ROUTER_MAC_BYTES /* srcMac */, + request.senderIp /* target IP */, SERVER_ADDR /* sender IP */); + + clearInvocations(mCb); + + // Then client sends unicast DHCPREQUEST to extend the IPv4 address lifetime, and we reply + // with DHCPACK to refresh the DHCP lease. + final DhcpPacket packet = getNextDhcpPacket(); + assertTrue(packet instanceof DhcpRequestPacket); + assertDhcpRequestForReacquire(packet); + mPacketReader.sendResponse(buildDhcpAckPacket(packet, CLIENT_ADDR, + TEST_LEASE_DURATION_S, (short) TEST_DEFAULT_MTU, + false /* rapidCommit */, null /* captivePortalApiUrl */)); + + // The IPv4 link address lifetime should be also updated after a success DHCP renew, check + // that we should never see provisioning failure. + verify(mCb, after(100).never()).onProvisioningFailure(any()); + + final ArgumentCaptor<DhcpResultsParcelable> dhcpResultsCaptor = + ArgumentCaptor.forClass(DhcpResultsParcelable.class); + verify(mCb, timeout(TEST_TIMEOUT_MS)).onNewDhcpResults(dhcpResultsCaptor.capture()); + final DhcpResultsParcelable lease = dhcpResultsCaptor.getValue(); + assertDhcpResultsParcelable(lease); + + // Check if the IPv4 address lifetime has updated along with a success DHCP renew. + verify(mCb, timeout(TEST_TIMEOUT_MS)).onLinkPropertiesChange(argThat(x -> { + for (LinkAddress la : x.getLinkAddresses()) { + if (la.isIpv4()) { + final long now = SystemClock.elapsedRealtime(); + final long when = now + 3600 * 1000; + return (la.getDeprecationTime() != LinkAddress.LIFETIME_UNKNOWN) + && (la.getExpirationTime() != LinkAddress.LIFETIME_UNKNOWN) + && (la.getDeprecationTime() < when + TEST_LIFETIME_TOLERANCE_MS) + && (la.getDeprecationTime() > when - TEST_LIFETIME_TOLERANCE_MS) + && (la.getExpirationTime() < when + TEST_LIFETIME_TOLERANCE_MS) + && (la.getExpirationTime() > when - TEST_LIFETIME_TOLERANCE_MS); + } + } + return false; + })); + } + private void doDhcpHostnameSettingTest(int hostnameSetting, boolean isHostnameConfigurationEnabled, boolean expectSendHostname) throws Exception { final ProvisioningConfiguration cfg = new ProvisioningConfiguration.Builder()
diff --git a/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt b/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt deleted file mode 100644 index 6c56add..0000000 --- a/tests/integration/signature/android/net/NetworkStatsIntegrationTest.kt +++ /dev/null
@@ -1,587 +0,0 @@ -/* - * Copyright (C) 2023 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License - */ - -package android.net - -import android.Manifest.permission.MANAGE_TEST_NETWORKS -import android.annotation.TargetApi -import android.app.usage.NetworkStats -import android.app.usage.NetworkStats.Bucket -import android.app.usage.NetworkStats.Bucket.TAG_NONE -import android.app.usage.NetworkStatsManager -import android.content.Context -import android.net.ConnectivityManager.TYPE_TEST -import android.net.NetworkStatsIntegrationTest.Direction.DOWNLOAD -import android.net.NetworkStatsIntegrationTest.Direction.UPLOAD -import android.net.NetworkTemplate.MATCH_TEST -import android.os.Build -import android.os.Process -import androidx.test.platform.app.InstrumentationRegistry -import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo -import com.android.testutils.DevSdkIgnoreRunner -import com.android.testutils.PacketBridge -import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged -import com.android.testutils.TestDnsServer -import com.android.testutils.TestHttpServer -import com.android.testutils.TestableNetworkCallback -import com.android.testutils.runAsShell -import fi.iki.elonen.NanoHTTPD -import java.io.BufferedInputStream -import java.io.BufferedOutputStream -import java.net.HttpURLConnection -import java.net.HttpURLConnection.HTTP_OK -import java.net.InetSocketAddress -import java.net.URL -import java.nio.charset.Charset -import kotlin.math.ceil -import kotlin.test.assertEquals -import kotlin.test.assertTrue -import org.junit.After -import org.junit.Assume.assumeTrue -import org.junit.Before -import org.junit.Test -import org.junit.runner.RunWith - -private const val TEST_TAG = 0xF00D - -@RunWith(DevSdkIgnoreRunner::class) -@TargetApi(Build.VERSION_CODES.S) -@IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) -class NetworkStatsIntegrationTest { - private val TAG = NetworkStatsIntegrationTest::class.java.simpleName - private val LOCAL_V6ADDR = - LinkAddress(InetAddresses.parseNumericAddress("2001:db8::1234"), 64) - - // Remote address, both the client and server will have a hallucination that - // they are talking to this address. - private val REMOTE_V6ADDR = - LinkAddress(InetAddresses.parseNumericAddress("dead:beef::808:808"), 64) - private val REMOTE_V4ADDR = - LinkAddress(InetAddresses.parseNumericAddress("8.8.8.8"), 32) - private val DEFAULT_MTU = 1500 - private val DEFAULT_BUFFER_SIZE = 1500 // Any size greater than or equal to mtu - private val CONNECTION_TIMEOUT_MILLIS = 15000 - private val TEST_DOWNLOAD_SIZE = 10000L - private val TEST_UPLOAD_SIZE = 20000L - private val HTTP_SERVER_NAME = "test.com" - private val HTTP_SERVER_PORT = 8080 // Use port > 1024 to avoid restrictions on system ports - private val DNS_INTERNAL_SERVER_PORT = 53 - private val DNS_EXTERNAL_SERVER_PORT = 1053 - private val TCP_ACK_SIZE = 72 - - // Packet overheads that are not part of the actual data transmission, these - // include DNS packets, TCP handshake/termination packets, and HTTP header - // packets. These overheads were gathered from real samples and may not - // be perfectly accurate because of DNS caches and TCP retransmissions, etc. - private val CONSTANT_PACKET_OVERHEAD = 8 - - // 130 is an observed average. - private val CONSTANT_BYTES_OVERHEAD = 130 * CONSTANT_PACKET_OVERHEAD - private val TOLERANCE = 1.3 - - // Set up the packet bridge with two IPv6 address only test networks. - private val inst = InstrumentationRegistry.getInstrumentation() - private val context = inst.getContext() - private val packetBridge = runAsShell(MANAGE_TEST_NETWORKS) { - PacketBridge( - context, - listOf(LOCAL_V6ADDR), - REMOTE_V6ADDR.address, - listOf( - Pair(DNS_INTERNAL_SERVER_PORT, DNS_EXTERNAL_SERVER_PORT) - ) - ) - } - private val cm = context.getSystemService(ConnectivityManager::class.java)!! - - // Set up DNS server for testing server and DNS64. - private val fakeDns = TestDnsServer( - packetBridge.externalNetwork, - InetSocketAddress(LOCAL_V6ADDR.address, DNS_EXTERNAL_SERVER_PORT) - ).apply { - start() - setAnswer( - "ipv4only.arpa", - listOf(IpPrefix(REMOTE_V6ADDR.address, REMOTE_V6ADDR.prefixLength).address) - ) - setAnswer(HTTP_SERVER_NAME, listOf(REMOTE_V4ADDR.address)) - } - - // Start up test http server. - private val httpServer = TestHttpServer( - LOCAL_V6ADDR.address.hostAddress, - HTTP_SERVER_PORT - ).apply { - start() - } - - @Before - fun setUp() { - assumeTrue(shouldRunTests()) - packetBridge.start() - } - - // For networkstack tests, it is not guaranteed that the tethering module will be - // updated at the same time. If the tethering module is not new enough, it may not contain - // the necessary abilities to run these tests. For example, The tests depends on test - // network stats being counted, which can only be achieved when they are marked as TYPE_TEST. - // If the tethering module does not support TYPE_TEST stats, then these tests will need - // to be skipped. - fun shouldRunTests() = cm.getNetworkInfo(packetBridge.internalNetwork)!!.type == TYPE_TEST - - @After - fun tearDown() { - packetBridge.stop() - fakeDns.stop() - httpServer.stop() - } - - private fun waitFor464XlatReady(network: Network): String { - val iface = cm.getLinkProperties(network)!!.interfaceName!! - - // Make a network request to listen to the specific test network. - val nr = NetworkRequest.Builder() - .clearCapabilities() - .addTransportType(NetworkCapabilities.TRANSPORT_TEST) - .setNetworkSpecifier(TestNetworkSpecifier(iface)) - .build() - val testCb = TestableNetworkCallback() - cm.registerNetworkCallback(nr, testCb) - - // Wait for the stacked address to be available. - testCb.eventuallyExpect<LinkPropertiesChanged> { - it.lp.stackedLinks.getOrNull(0)?.linkAddresses?.getOrNull(0) != null - } - - return iface - } - - private val Network.mtu: Int get() { - val lp = cm.getLinkProperties(this)!! - val mtuStacked = if (lp.stackedLinks[0]?.mtu != 0) lp.stackedLinks[0].mtu else DEFAULT_MTU - val mtuInterface = if (lp.mtu != 0) lp.mtu else DEFAULT_MTU - return mtuInterface.coerceAtMost(mtuStacked) - } - - /** - * Verify data usage download stats with test 464xlat networks. - * - * This test starts two test networks and binds them together, the internal one is for the - * client to make http traffic on the test network, and the external one is for the mocked - * http and dns server to bind to and provide responses. - * - * After Clat setup, the client will use clat v4 address to send packets to the mocked - * server v4 address, which will be translated into a v6 packet by the clat daemon with - * NAT64 prefix learned from the mocked DNS64 response. And send to the interface. - * - * While the packets are being forwarded to the external interface, the servers will see - * the packets originated from the mocked v6 address, and destined to a local v6 address. - */ - @Test - fun test464XlatTcpStats() { - // Wait for 464Xlat to be ready. - val internalInterfaceName = waitFor464XlatReady(packetBridge.internalNetwork) - val mtu = packetBridge.internalNetwork.mtu - - val snapshotBeforeTest = StatsSnapshot(context, internalInterfaceName) - - // Generate the download traffic. - genHttpTraffic(packetBridge.internalNetwork, uploadSize = 0L, TEST_DOWNLOAD_SIZE) - - // In practice, for one way 10k download payload, the download usage is about - // 11222~12880 bytes, with 14~17 packets. And the upload usage is about 1279~1626 bytes - // with 14~17 packets, which is majorly contributed by TCP ACK packets. - val snapshotAfterDownload = StatsSnapshot(context, internalInterfaceName) - val (expectedDownloadLower, expectedDownloadUpper) = getExpectedStatsBounds( - TEST_DOWNLOAD_SIZE, - mtu, - DOWNLOAD - ) - assertOnlyNonTaggedStatsIncreases( - snapshotBeforeTest, - snapshotAfterDownload, - expectedDownloadLower, - expectedDownloadUpper - ) - - // Generate upload traffic with tag to verify tagged data accounting as well. - genHttpTrafficWithTag( - packetBridge.internalNetwork, - TEST_UPLOAD_SIZE, - downloadSize = 0L, - TEST_TAG - ) - - // Verify upload data usage accounting. - val snapshotAfterUpload = StatsSnapshot(context, internalInterfaceName) - val (expectedUploadLower, expectedUploadUpper) = getExpectedStatsBounds( - TEST_UPLOAD_SIZE, - mtu, - UPLOAD - ) - assertAllStatsIncreases( - snapshotAfterDownload, - snapshotAfterUpload, - expectedUploadLower, - expectedUploadUpper - ) - } - - private enum class Direction { - DOWNLOAD, - UPLOAD - } - - private fun getExpectedStatsBounds( - transmittedSize: Long, - mtu: Int, - direction: Direction - ): Pair<BareStats, BareStats> { - // This is already an underestimated value since the input doesn't include TCP/IP - // layer overhead. - val txBytesLower = transmittedSize - // Include TCP/IP header overheads and retransmissions in the upper bound. - val txBytesUpper = (transmittedSize * TOLERANCE).toLong() - val txPacketsLower = txBytesLower / mtu + (CONSTANT_PACKET_OVERHEAD / TOLERANCE).toLong() - val estTransmissionPacketsUpper = ceil(txBytesUpper / mtu.toDouble()).toLong() - val txPacketsUpper = estTransmissionPacketsUpper + - (CONSTANT_PACKET_OVERHEAD * TOLERANCE).toLong() - // Assume ACK only sent once for the entire transmission. - val rxPacketsLower = 1L + (CONSTANT_PACKET_OVERHEAD / TOLERANCE).toLong() - // Assume ACK sent for every RX packet. - val rxPacketsUpper = txPacketsUpper - val rxBytesLower = 1L * TCP_ACK_SIZE + (CONSTANT_BYTES_OVERHEAD / TOLERANCE).toLong() - val rxBytesUpper = estTransmissionPacketsUpper * TCP_ACK_SIZE + - (CONSTANT_BYTES_OVERHEAD * TOLERANCE).toLong() - - return if (direction == UPLOAD) { - BareStats(rxBytesLower, rxPacketsLower, txBytesLower, txPacketsLower) to - BareStats(rxBytesUpper, rxPacketsUpper, txBytesUpper, txPacketsUpper) - } else { - BareStats(txBytesLower, txPacketsLower, rxBytesLower, rxPacketsLower) to - BareStats(txBytesUpper, txPacketsUpper, rxBytesUpper, rxPacketsUpper) - } - } - - private fun genHttpTraffic(network: Network, uploadSize: Long, downloadSize: Long) = - genHttpTrafficWithTag(network, uploadSize, downloadSize, NetworkStats.Bucket.TAG_NONE) - - private fun genHttpTrafficWithTag( - network: Network, - uploadSize: Long, - downloadSize: Long, - tag: Int - ) { - val path = "/test_upload_download" - val buf = ByteArray(DEFAULT_BUFFER_SIZE) - - httpServer.addResponse( - TestHttpServer.Request(path, NanoHTTPD.Method.POST), NanoHTTPD.Response.Status.OK, - content = getRandomString(downloadSize) - ) - var httpConnection: HttpURLConnection? = null - try { - TrafficStats.setThreadStatsTag(tag) - val spec = "http://$HTTP_SERVER_NAME:${httpServer.listeningPort}$path" - val url = URL(spec) - httpConnection = network.openConnection(url) as HttpURLConnection - httpConnection.connectTimeout = CONNECTION_TIMEOUT_MILLIS - httpConnection.requestMethod = "POST" - httpConnection.doOutput = true - // Tell the server that the response should not be compressed. Otherwise, the data usage - // accounted will be less than expected. - httpConnection.setRequestProperty("Accept-Encoding", "identity") - // Tell the server that to close connection after this request, this is needed to - // prevent from reusing the same socket that has different tagging requirement. - httpConnection.setRequestProperty("Connection", "close") - - // Send http body. - val outputStream = BufferedOutputStream(httpConnection.outputStream) - outputStream.write(getRandomString(uploadSize).toByteArray(Charset.forName("UTF-8"))) - outputStream.close() - assertEquals(HTTP_OK, httpConnection.responseCode) - - // Receive response from the server. - val inputStream = BufferedInputStream(httpConnection.getInputStream()) - var total = 0L - while (true) { - val count = inputStream.read(buf) - if (count == -1) break // End-of-Stream - total += count - } - assertEquals(downloadSize, total) - } finally { - httpConnection?.inputStream?.close() - TrafficStats.clearThreadStatsTag() - } - } - - // NetworkStats.Bucket cannot be written. So another class is needed to - // perform arithmetic operations. - data class BareStats( - val rxBytes: Long, - val rxPackets: Long, - val txBytes: Long, - val txPackets: Long - ) { - operator fun plus(other: BareStats): BareStats { - return BareStats( - this.rxBytes + other.rxBytes, this.rxPackets + other.rxPackets, - this.txBytes + other.txBytes, this.txPackets + other.txPackets - ) - } - - operator fun minus(other: BareStats): BareStats { - return BareStats( - this.rxBytes - other.rxBytes, this.rxPackets - other.rxPackets, - this.txBytes - other.txBytes, this.txPackets - other.txPackets - ) - } - - fun reverse(): BareStats = - BareStats( - rxBytes = txBytes, - rxPackets = txPackets, - txBytes = rxBytes, - txPackets = rxPackets - ) - - override fun toString(): String { - return "BareStats{rx/txBytes=$rxBytes/$txBytes, rx/txPackets=$rxPackets/$txPackets}" - } - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is BareStats) return false - - if (rxBytes != other.rxBytes) return false - if (rxPackets != other.rxPackets) return false - if (txBytes != other.txBytes) return false - if (txPackets != other.txPackets) return false - - return true - } - - override fun hashCode(): Int { - return (rxBytes * 11 + rxPackets * 13 + txBytes * 17 + txPackets * 19).toInt() - } - - companion object { - val EMPTY = BareStats(0L, 0L, 0L, 0L) - } - } - - data class StatsSnapshot(val context: Context, val iface: String) { - val statsSummary = getNetworkSummary(iface) - val statsUid = getUidDetail(iface, TAG_NONE) - val taggedSummary = getTaggedNetworkSummary(iface, TEST_TAG) - val taggedUid = getUidDetail(iface, TEST_TAG) - val trafficStatsIface = getTrafficStatsIface(iface) - val trafficStatsUid = getTrafficStatsUid(Process.myUid()) - - private fun getUidDetail(iface: String, tag: Int): BareStats { - return getNetworkStatsThat(iface, tag) { nsm, template -> - nsm.queryDetailsForUidTagState( - template, Long.MIN_VALUE, Long.MAX_VALUE, - Process.myUid(), tag, Bucket.STATE_ALL - ) - } - } - - private fun getNetworkSummary(iface: String): BareStats { - return getNetworkStatsThat(iface, TAG_NONE) { nsm, template -> - nsm.querySummary(template, Long.MIN_VALUE, Long.MAX_VALUE) - } - } - - private fun getTaggedNetworkSummary(iface: String, tag: Int): BareStats { - return getNetworkStatsThat(iface, tag) { nsm, template -> - nsm.queryTaggedSummary(template, Long.MIN_VALUE, Long.MAX_VALUE) - } - } - - private fun getNetworkStatsThat( - iface: String, - tag: Int, - queryApi: (nsm: NetworkStatsManager, template: NetworkTemplate) -> NetworkStats - ): BareStats { - val nsm = context.getSystemService(NetworkStatsManager::class.java)!! - nsm.forceUpdate() - val testTemplate = NetworkTemplate.Builder(MATCH_TEST) - .setWifiNetworkKeys(setOf(iface)).build() - val stats = queryApi.invoke(nsm, testTemplate) - val filteredBuckets = - stats.buckets().filter { it.uid == Process.myUid() && it.tag == tag } - return filteredBuckets.fold(BareStats.EMPTY) { acc, it -> - acc + BareStats( - it.rxBytes, - it.rxPackets, - it.txBytes, - it.txPackets - ) - } - } - - // Helper function to iterate buckets in app.usage.NetworkStats. - private fun NetworkStats.buckets() = object : Iterable<NetworkStats.Bucket> { - override fun iterator() = object : Iterator<NetworkStats.Bucket> { - override operator fun hasNext() = hasNextBucket() - override operator fun next() = - NetworkStats.Bucket().also { assertTrue(getNextBucket(it)) } - } - } - - private fun getTrafficStatsIface(iface: String): BareStats = BareStats( - TrafficStats.getRxBytes(iface), - TrafficStats.getRxPackets(iface), - TrafficStats.getTxBytes(iface), - TrafficStats.getTxPackets(iface) - ) - - private fun getTrafficStatsUid(uid: Int): BareStats = BareStats( - TrafficStats.getUidRxBytes(uid), - TrafficStats.getUidRxPackets(uid), - TrafficStats.getUidTxBytes(uid), - TrafficStats.getUidTxPackets(uid) - ) - } - - private fun assertAllStatsIncreases( - before: StatsSnapshot, - after: StatsSnapshot, - lower: BareStats, - upper: BareStats - ) { - assertNonTaggedStatsIncreases(before, after, lower, upper) - assertTaggedStatsIncreases(before, after, lower, upper) - } - - private fun assertOnlyNonTaggedStatsIncreases( - before: StatsSnapshot, - after: StatsSnapshot, - lower: BareStats, - upper: BareStats - ) { - assertNonTaggedStatsIncreases(before, after, lower, upper) - assertTaggedStatsEquals(before, after) - } - - private fun assertNonTaggedStatsIncreases( - before: StatsSnapshot, - after: StatsSnapshot, - lower: BareStats, - upper: BareStats - ) { - assertInRange( - "Unexpected iface traffic stats", - after.iface, - before.trafficStatsIface, after.trafficStatsIface, - lower, upper - ) - // Uid traffic stats are counted in both direction because the external network - // traffic is also attributed to the test uid. - assertInRange( - "Unexpected uid traffic stats", - after.iface, - before.trafficStatsUid, after.trafficStatsUid, - lower + lower.reverse(), upper + upper.reverse() - ) - assertInRange( - "Unexpected non-tagged summary stats", - after.iface, - before.statsSummary, after.statsSummary, - lower, upper - ) - assertInRange( - "Unexpected non-tagged uid stats", - after.iface, - before.statsUid, after.statsUid, - lower, upper - ) - } - - private fun assertTaggedStatsEquals(before: StatsSnapshot, after: StatsSnapshot) { - // Increment of tagged data should be zero since no tagged traffic was generated. - assertEquals( - before.taggedSummary, - after.taggedSummary, - "Unexpected tagged summary stats: ${after.iface}" - ) - assertEquals( - before.taggedUid, - after.taggedUid, - "Unexpected tagged uid stats: ${Process.myUid()} on ${after.iface}" - ) - } - - private fun assertTaggedStatsIncreases( - before: StatsSnapshot, - after: StatsSnapshot, - lower: BareStats, - upper: BareStats - ) { - assertInRange( - "Unexpected tagged summary stats", - after.iface, - before.taggedSummary, after.taggedSummary, - lower, - upper - ) - assertInRange( - "Unexpected tagged uid stats: ${Process.myUid()}", - after.iface, - before.taggedUid, after.taggedUid, - lower, - upper - ) - } - - /** Verify the given BareStats is in range [lower, upper] */ - private fun assertInRange( - tag: String, - iface: String, - before: BareStats, - after: BareStats, - lower: BareStats, - upper: BareStats - ) { - // Passing the value after operation and the value before operation to dump the actual - // numbers if it fails. - assertTrue(checkInRange(before, after, lower, upper), - "$tag on $iface: $after - $before is not within range [$lower, $upper]" - ) - } - - private fun checkInRange( - before: BareStats, - after: BareStats, - lower: BareStats, - upper: BareStats - ): Boolean { - val value = after - before - return value.rxBytes in lower.rxBytes..upper.rxBytes && - value.rxPackets in lower.rxPackets..upper.rxPackets && - value.txBytes in lower.txBytes..upper.txBytes && - value.txPackets in lower.txPackets..upper.txPackets - } - - fun getRandomString(length: Long): String { - val allowedChars = ('A'..'Z') + ('a'..'z') + ('0'..'9') - return (1..length) - .map { allowedChars.random() } - .joinToString("") - } -}
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp index 88d995f..91e94a8 100644 --- a/tests/unit/Android.bp +++ b/tests/unit/Android.bp
@@ -21,7 +21,6 @@ java_defaults { name: "NetworkStackTestsDefaults", - platform_apis: true, srcs: [ "src/**/*.java", "src/**/*.kt", @@ -58,6 +57,7 @@ // Tests for NetworkStackNext. android_test { name: "NetworkStackNextTests", + platform_apis: true, target_sdk_version: "current", min_sdk_version: "30", srcs: [], // TODO: tests that only apply to the current, non-stable API can be added here @@ -93,6 +93,7 @@ android_test { name: "NetworkStackTests", + platform_apis: true, min_sdk_version: "30", test_suites: [ "general-tests",
diff --git a/tests/unit/AndroidManifest.xml b/tests/unit/AndroidManifest.xml index 000863a..b55e6b6 100644 --- a/tests/unit/AndroidManifest.xml +++ b/tests/unit/AndroidManifest.xml
@@ -26,6 +26,8 @@ <uses-permission android:name="android.permission.ACCESS_WIFI_STATE" /> <uses-permission android:name="android.permission.INTERNET" /> <uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" /> + <!-- Needed to check tethering module version. --> + <uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" /> <application android:debuggable="true"> <uses-library android:name="android.test.runner" />
diff --git a/tests/unit/src/android/net/apf/ApfStandaloneTest.kt b/tests/unit/src/android/net/apf/ApfStandaloneTest.kt new file mode 100644 index 0000000..7a58ea6 --- /dev/null +++ b/tests/unit/src/android/net/apf/ApfStandaloneTest.kt
@@ -0,0 +1,389 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package android.net.apf + +import android.net.apf.ApfConstant.DHCP_SERVER_PORT +import android.net.apf.ApfConstant.ETH_HEADER_LEN +import android.net.apf.ApfConstant.ICMP6_TYPE_OFFSET +import android.net.apf.ApfConstant.IPV4_BROADCAST_ADDRESS +import android.net.apf.ApfConstant.IPV4_DEST_ADDR_OFFSET +import android.net.apf.ApfConstant.IPV4_PROTOCOL_OFFSET +import android.net.apf.ApfConstant.IPV4_SRC_ADDR_OFFSET +import android.net.apf.ApfConstant.IPV6_NEXT_HEADER_OFFSET +import android.net.apf.ApfConstant.TCP_UDP_DESTINATION_PORT_OFFSET +import android.net.apf.BaseApfGenerator.APF_VERSION_4 +import android.net.apf.BaseApfGenerator.Register.R0 +import android.net.apf.BaseApfGenerator.Register.R1 +import android.system.OsConstants +import android.system.OsConstants.ETH_P_IP +import android.system.OsConstants.IPPROTO_ICMPV6 +import android.util.Log +import androidx.test.filters.SmallTest +import com.android.net.module.util.HexDump +import com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_OFFSET +import com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_SOLICITATION +import com.android.testutils.DevSdkIgnoreRunner +import kotlin.test.assertEquals +import org.junit.Test +import org.junit.runner.RunWith + +/** + * This class generate ApfStandaloneTest programs for side-loading into firmware without needing the + * ApfFilter.java dependency. Its bytecode facilitates Wi-Fi chipset vendor regression tests, + * preventing issues caused by APF interpreter integration. + * + * Note: Code size optimization is not a priority for these test programs, so some redundancy may + * exist. + */ +@RunWith(DevSdkIgnoreRunner::class) +@SmallTest +class ApfStandaloneTest { + + private val etherTypeDenyList = listOf(0x88A2, 0x88A4, 0x88B8, 0x88CD, 0x88E1, 0x88E3) + + fun runApfTest(isSuspendMode: Boolean) { + val program = generateApfV4Program(isSuspendMode) + Log.w(TAG, "Program should be run in SETSUSPENDMODE $isSuspendMode: " + + HexDump.toHexString(program)) + // packet that in ethertype denylist: + // ###[ Ethernet ]### + // dst = ff:ff:ff:ff:ff:ff + // src = 04:7b:cb:46:3f:b5 + // type = 0x88a2 + // ###[ Raw ]### + // load = '01' + // + // raw bytes: + // ffffffffffff047bcb463fb588a21 + + val packetBadEtherType = + HexDump.hexStringToByteArray("ffffffffffff047bcb463fb588a201") + val dataRegion = ByteArray(Counter.totalSize()) { 0 } + ApfTestUtils.assertVerdict(APF_VERSION_4, ApfTestUtils.DROP, + program, packetBadEtherType, dataRegion) + assertEquals(mapOf<Counter, Long>( + Counter.TOTAL_PACKETS to 1, + Counter.DROPPED_ETHERTYPE_DENYLISTED to 1), decodeCountersIntoMap(dataRegion)) + + // dhcp request packet. + // ###[ Ethernet ]### + // dst = ff:ff:ff:ff:ff:ff + // src = 04:7b:cb:46:3f:b5 + // type = IPv4 + // ###[ IP ]### + // version = 4 + // ihl = None + // tos = 0x0 + // len = None + // id = 1 + // flags = + // frag = 0 + // ttl = 64 + // proto = udp + // chksum = None + // src = 0.0.0.0 + // dst = 255.255.255.255 + // \options \ + // ###[ UDP ]### + // sport = bootpc + // dport = bootps + // len = None + // chksum = None + // ###[ BOOTP ]### + // op = BOOTREQUEST + // htype = Ethernet (10Mb) + // hlen = 6 + // hops = 0 + // xid = 0x1020304 + // secs = 0 + // flags = + // ciaddr = 0.0.0.0 + // yiaddr = 0.0.0.0 + // siaddr = 0.0.0.0 + // giaddr = 0.0.0.0 + // chaddr = 30:34:3a:37:62:3a (pad: b'cb:46:3f:b5') + // sname = '' + // file = '' + // options = b'c\x82Sc' (DHCP magic) + // ###[ DHCP options ]### + // options = [message-type='request' server_id=192.168.1.1 requested_addr=192.168.1.100 end] + // + // raw bytes: + // ffffffffffff047bcb463fb508004500011c00010000401179d100000000ffffffff004400430108393b010106000000000b000000000000000000000000000000000000000030343a37623a63623a34363a33663a62000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000638253633501033604c0a801013204c0a80164ff + + val dhcpRequestPkt = HexDump.hexStringToByteArray("ffffffffffff047bcb463fb508004500011c00010000401179d100000000ffffffff004400430108393b010106000000000b000000000000000000000000000000000000000030343a37623a63623a34363a33663a62000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000638253633501033604c0a801013204c0a80164ff") + ApfTestUtils.assertVerdict(APF_VERSION_4, ApfTestUtils.DROP, + program, dhcpRequestPkt, dataRegion) + assertEquals(mapOf<Counter, Long>( + Counter.TOTAL_PACKETS to 2, + Counter.DROPPED_ETHERTYPE_DENYLISTED to 1, + Counter.DROPPED_DHCP_REQUEST_DISCOVERY to 1), decodeCountersIntoMap(dataRegion)) + + // RS packet: + // ###[ Ethernet ]### + // dst = ff:ff:ff:ff:ff:ff + // src = 04:7b:cb:46:3f:b5 + // type = IPv6 + // ###[ IPv6 ]### + // version = 6 + // tc = 0 + // fl = 0 + // plen = None + // nh = ICMPv6 + // hlim = 255 + // src = fe80::30b4:5e42:ef3d:36e5 + // dst = ff02::2 + // ###[ ICMPv6 Neighbor Discovery - Router Solicitation ]### + // type = Router Solicitation + // code = 0 + // cksum = None + // res = 0 + // + // raw bytes: + // ffffffffffff047bcb463fb586dd6000000000083afffe8000000000000030b45e42ef3d36e5ff0200000000000000000000000000028500c81d00000000 + val rsPkt = HexDump.hexStringToByteArray("ffffffffffff047bcb463fb586dd6000000000083afffe8000000000000030b45e42ef3d36e5ff0200000000000000000000000000028500c81d00000000") + ApfTestUtils.assertVerdict(APF_VERSION_4, ApfTestUtils.DROP, program, rsPkt, dataRegion) + assertEquals(mapOf<Counter, Long>( + Counter.TOTAL_PACKETS to 3, + Counter.DROPPED_RS to 1, + Counter.DROPPED_ETHERTYPE_DENYLISTED to 1, + Counter.DROPPED_DHCP_REQUEST_DISCOVERY to 1), decodeCountersIntoMap(dataRegion)) + if (isSuspendMode) { + // Ping request packet + // ###[ Ethernet ]### + // dst = ff:ff:ff:ff:ff:ff + // src = 04:7b:cb:46:3f:b5 + // type = IPv4 + // ###[ IP ]### + // version = 4 + // ihl = None + // tos = 0x0 + // len = None + // id = 1 + // flags = + // frag = 0 + // ttl = 64 + // proto = icmp + // chksum = None + // src = 100.79.97.84 + // dst = 8.8.8.8 + // \options \ + // ###[ ICMP ]### + // type = echo-request + // code = 0 + // chksum = None + // id = 0x0 + // seq = 0x0 + // unused = '' + // + // raw bytes: 84 + // ffffffffffff047bcb463fb508004500001c000100004001a52d644f6154080808080800f7ff00000000 + val pingRequestPkt = HexDump.hexStringToByteArray("ffffffffffff047bcb463fb508004500001c000100004001a52d644f6154080808080800f7ff00000000") + ApfTestUtils.assertVerdict(APF_VERSION_4, ApfTestUtils.DROP, program, pingRequestPkt, dataRegion) + assertEquals(mapOf<Counter, Long>( + Counter.TOTAL_PACKETS to 4, + Counter.DROPPED_RS to 1, + Counter.DROPPED_ICMP4_ECHO_REQUEST to 1, + Counter.DROPPED_ETHERTYPE_DENYLISTED to 1, + Counter.DROPPED_DHCP_REQUEST_DISCOVERY to 1), decodeCountersIntoMap(dataRegion)) + } + } + + @Test + fun testApfProgramInNormalMode() { + runApfTest(isSuspendMode = false) + } + + @Test + fun testApfProgramInSuspendMode() { + runApfTest(isSuspendMode = true) + } + + private fun generateApfV4Program(isDeviceIdle: Boolean): ByteArray { + val countAndPassLabel = "countAndPass" + val countAndDropLabel = "countAndDrop" + val endOfDhcpFilter = "endOfDhcpFilter" + val endOfRsFilter = "endOfRsFiler" + val endOfPingFilter = "endOfPingFilter" + val gen = ApfV4Generator(APF_VERSION_4) + + maybeSetupCounter(gen, Counter.TOTAL_PACKETS) + gen.addLoadData(R0, 0) + gen.addAdd(1) + gen.addStoreData(R0, 0) + + maybeSetupCounter(gen, Counter.FILTER_AGE_SECONDS) + gen.addLoadFromMemory(R0, 15) + gen.addStoreData(R0, 0) + + maybeSetupCounter(gen, Counter.FILTER_AGE_16384THS) + gen.addLoadFromMemory(R0, 9) + gen.addStoreData(R0, 0) + + // ethtype filter + gen.addLoad16(R0, ETHER_TYPE_OFFSET) + maybeSetupCounter(gen, Counter.DROPPED_ETHERTYPE_DENYLISTED) + for (p in etherTypeDenyList) { + gen.addJumpIfR0Equals(p.toLong(), countAndDropLabel) + } + + // dhcp request filters + + // Check IPv4 + gen.addLoad16(R0, ETHER_TYPE_OFFSET) + gen.addJumpIfR0NotEquals(ETH_P_IP.toLong(), endOfDhcpFilter) + + // Pass DHCP addressed to us. + // Check src is IP is 0.0.0.0 + gen.addLoad32(R0, IPV4_SRC_ADDR_OFFSET) + gen.addJumpIfR0NotEquals(0, endOfDhcpFilter) + // Check dst ip is 255.255.255.255 + gen.addLoad32(R0, IPV4_DEST_ADDR_OFFSET) + gen.addJumpIfR0NotEquals(IPV4_BROADCAST_ADDRESS.toLong(), endOfDhcpFilter) + // Check it's UDP. + gen.addLoad8(R0, IPV4_PROTOCOL_OFFSET) + gen.addJumpIfR0NotEquals(OsConstants.IPPROTO_UDP.toLong(), endOfDhcpFilter) + // Check it's addressed to DHCP client port. + gen.addLoadFromMemory(R1, BaseApfGenerator.IPV4_HEADER_SIZE_MEMORY_SLOT) + gen.addLoad16Indexed(R0, TCP_UDP_DESTINATION_PORT_OFFSET) + gen.addJumpIfR0NotEquals(DHCP_SERVER_PORT.toLong(), endOfDhcpFilter) + // drop dhcp the discovery and request + maybeSetupCounter(gen, Counter.DROPPED_DHCP_REQUEST_DISCOVERY) + gen.addJump(countAndDropLabel) + + gen.defineLabel(endOfDhcpFilter) + + // rs filters + + // check IPv6 + gen.addLoad16(R0, ETHER_TYPE_OFFSET) + gen.addJumpIfR0NotEquals(OsConstants.ETH_P_IPV6.toLong(), endOfRsFilter) + // check ICMP6 packet + gen.addLoad8(R0, IPV6_NEXT_HEADER_OFFSET) + gen.addJumpIfR0NotEquals(IPPROTO_ICMPV6.toLong(), endOfRsFilter) + // check type it is RS + gen.addLoad8(R0, ICMP6_TYPE_OFFSET) + gen.addJumpIfR0NotEquals(ICMPV6_ROUTER_SOLICITATION.toLong(), endOfRsFilter) + // drop rs packet + maybeSetupCounter(gen, Counter.DROPPED_RS) + gen.addJump(countAndDropLabel) + + gen.defineLabel(endOfRsFilter) + + if (isDeviceIdle) { + // ping filter + + // Check IPv4 + gen.addLoad16(R0, ETHER_TYPE_OFFSET) + gen.addJumpIfR0NotEquals(ETH_P_IP.toLong(), endOfPingFilter) + // Check it's ICMP. + gen.addLoad8(R0, IPV4_PROTOCOL_OFFSET) + gen.addJumpIfR0NotEquals(OsConstants.IPPROTO_ICMP.toLong(), endOfPingFilter) + // Check if it is echo request + gen.addLoadFromMemory(R1, BaseApfGenerator.IPV4_HEADER_SIZE_MEMORY_SLOT) + gen.addLoad8Indexed(R0, ETH_HEADER_LEN) + gen.addJumpIfR0NotEquals(8, endOfPingFilter) + // drop ping request + maybeSetupCounter(gen, Counter.DROPPED_ICMP4_ECHO_REQUEST) + gen.addJump(countAndDropLabel) + + gen.defineLabel(endOfPingFilter) + } + + // end of filters. + maybeSetupCounter(gen, Counter.PASSED_PACKET) + + gen.defineLabel(countAndPassLabel) + gen.addLoadData(BaseApfGenerator.Register.R0, 0) // R0 = *(R1 + 0) + gen.addAdd(1) // R0++ + gen.addStoreData(BaseApfGenerator.Register.R0, 0) // *(R1 + 0) = R0 + gen.addJump(BaseApfGenerator.PASS_LABEL) + + gen.defineLabel(countAndDropLabel) + gen.addLoadData(BaseApfGenerator.Register.R0, 0) // R0 = *(R1 + 0) + gen.addAdd(1) // R0++ + gen.addStoreData(BaseApfGenerator.Register.R0, 0) // *(R1 + 0) = R0 + gen.addJump(BaseApfGenerator.DROP_LABEL) + + return gen.generate() + } + + enum class Counter { + RESERVED, + ENDIANNESS, + FILTER_AGE_SECONDS, + FILTER_AGE_16384THS, + TOTAL_PACKETS, + DROPPED_ETHERTYPE_DENYLISTED, + DROPPED_DHCP_REQUEST_DISCOVERY, + DROPPED_ICMP4_ECHO_REQUEST, + DROPPED_RS, + PASSED_PACKET; + + fun offset(): Int { + return -4 * this.ordinal + } + + companion object { + fun totalSize(): Int { + return (Counter::class.java.enumConstants.size - 1) * 4 + } + } + } + + private fun maybeSetupCounter(gen: ApfV4Generator, c: Counter) { + gen.addLoadImmediate(R1, c.offset()) + } + + private fun decodeCountersIntoMap(counterBytes: ByteArray): Map<Counter, Long> { + val counters = Counter::class.java.enumConstants + val ret = HashMap<Counter, Long>() + // starting from index 2 to skip the endianness mark + for (c in listOf(*counters).subList(2, counters.size)) { + val value = getCounterValue(counterBytes, c) + if (value != 0L) { + ret[c] = value + } + } + return ret + } + + private fun getCounterValue(data: ByteArray, counter: Counter): Long { + var offset = data.size + Counter.ENDIANNESS.offset() + var endianness = 0 + for (i in 0..3) { + endianness = endianness shl 8 or (data[offset + i].toInt() and 0xff) + } + // Follow the same wrap-around addressing scheme of the interpreter. + offset = data.size + counter.offset() + var isBe = true + when (endianness) { + 0, 0x12345678 -> isBe = true + 0x78563412 -> isBe = false + } + + var value: Long = 0 + for (i in 0..3) { + value = value shl 8 or + (data[offset + (if (isBe) i else 3 - i)].toInt() and 0xff).toLong() + } + return value + } + + companion object { + const val TAG = "ApfStandaloneTest" + } +}
diff --git a/tests/unit/src/android/net/apf/ApfTest.java b/tests/unit/src/android/net/apf/ApfTest.java index 29a5045..ab90e2d 100644 --- a/tests/unit/src/android/net/apf/ApfTest.java +++ b/tests/unit/src/android/net/apf/ApfTest.java
@@ -48,6 +48,8 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -85,6 +87,7 @@ import androidx.test.InstrumentationRegistry; import androidx.test.filters.SmallTest; +import com.android.internal.annotations.GuardedBy; import com.android.internal.util.HexDump; import com.android.net.module.util.DnsPacket; import com.android.net.module.util.Inet4AddressUtils; @@ -94,11 +97,13 @@ import com.android.networkstack.metrics.IpClientRaInfoMetrics; import com.android.networkstack.metrics.NetworkQuirkMetrics; import com.android.server.networkstack.tests.R; +import com.android.testutils.ConcurrentUtils; import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRunner; import libcore.io.Streams; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -107,6 +112,7 @@ import org.junit.runners.Parameterized; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import java.io.ByteArrayOutputStream; @@ -132,6 +138,7 @@ * * The test cases will be executed by both APFv4 and APFv6 interpreter. */ +@DevSdkIgnoreRunner.MonitorThreadLeak @RunWith(DevSdkIgnoreRunner.class) @SmallTest public class ApfTest { @@ -157,12 +164,71 @@ @Mock private ApfSessionInfoMetrics mApfSessionInfoMetrics; @Mock private IpClientRaInfoMetrics mIpClientRaInfoMetrics; @Mock private ApfFilter.Clock mClock; + @GuardedBy("mApfFilterCreated") + private final ArrayList<AndroidPacketFilter> mApfFilterCreated = new ArrayList<>(); + @GuardedBy("mThreadsToBeCleared") + private final ArrayList<Thread> mThreadsToBeCleared = new ArrayList<>(); + @Before public void setUp() throws Exception { MockitoAnnotations.initMocks(this); doReturn(mPowerManager).when(mContext).getSystemService(PowerManager.class); doReturn(mApfSessionInfoMetrics).when(mDependencies).getApfSessionInfoMetrics(); doReturn(mIpClientRaInfoMetrics).when(mDependencies).getIpClientRaInfoMetrics(); + doAnswer((invocation) -> { + synchronized (mApfFilterCreated) { + mApfFilterCreated.add(invocation.getArgument(0)); + } + return null; + }).when(mDependencies).onApfFilterCreated(any()); + doAnswer((invocation) -> { + synchronized (mThreadsToBeCleared) { + mThreadsToBeCleared.add(invocation.getArgument(0)); + } + return null; + }).when(mDependencies).onThreadCreated(any()); + } + + private void quitThreads() throws Exception { + ConcurrentUtils.quitThreads( + THREAD_QUIT_MAX_RETRY_COUNT, + false /* interrupt */, + HANDLER_TIMEOUT_MS, + () -> { + synchronized (mThreadsToBeCleared) { + final ArrayList<Thread> ret = new ArrayList<>(mThreadsToBeCleared); + mThreadsToBeCleared.clear(); + return ret; + } + }); + } + + private void shutdownApfFilters() throws Exception { + ConcurrentUtils.quitResources(THREAD_QUIT_MAX_RETRY_COUNT, () -> { + synchronized (mApfFilterCreated) { + final ArrayList<AndroidPacketFilter> ret = + new ArrayList<>(mApfFilterCreated); + mApfFilterCreated.clear(); + return ret; + } + }, (apf) -> { + apf.shutdown(); + }); + synchronized (mApfFilterCreated) { + assertEquals("ApfFilters did not fully shutdown.", + 0, mApfFilterCreated.size()); + } + // It's necessary to wait until all ReceiveThreads have finished running because + // clearInlineMocks clears all Mock objects, including some privilege frameworks + // required by logStats, at the end of ReceiveThread#run. + quitThreads(); + } + + @After + public void tearDown() throws Exception { + shutdownApfFilters(); + // Clear mocks to prevent from stubs holding instances and cause memory leaks. + Mockito.framework().clearInlineMocks(); } private static final String TAG = "ApfTest"; @@ -179,6 +245,9 @@ private static final int MIN_RDNSS_LIFETIME_SEC = 0; private static final int MIN_METRICS_SESSION_DURATIONS_MS = 300_000; + private static final int HANDLER_TIMEOUT_MS = 1000; + private static final int THREAD_QUIT_MAX_RETRY_COUNT = 3; + // Constants for opcode encoding private static final byte LI_OP = (byte)(13 << 3); private static final byte LDDW_OP = (byte)(22 << 3); @@ -238,7 +307,13 @@ private void assertDataMemoryContents(int expected, byte[] program, byte[] packet, byte[] data, byte[] expectedData) throws Exception { ApfTestUtils.assertDataMemoryContents(mApfVersion, expected, program, packet, data, - expectedData); + expectedData, false /* ignoreInterpreterVersion */); + } + + private void assertDataMemoryContentsIgnoreVersion(int expected, byte[] program, + byte[] packet, byte[] data, byte[] expectedData) throws Exception { + ApfTestUtils.assertDataMemoryContents(mApfVersion, expected, program, packet, data, + expectedData, true /* ignoreInterpreterVersion */); } private void assertVerdict(String msg, int expected, byte[] program, @@ -388,21 +463,21 @@ // Test add. gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R1, 1234567890); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0Equals(1234567890, DROP_LABEL); assertDrop(gen); // Test subtract. gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R1, -1234567890); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0Equals(-1234567890, DROP_LABEL); assertDrop(gen); // Test or. gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R1, 1234567890); - gen.addOrR1(); + gen.addOrR0WithR1(); gen.addJumpIfR0Equals(1234567890, DROP_LABEL); assertDrop(gen); @@ -410,7 +485,7 @@ gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R0, 1234567890); gen.addLoadImmediate(R1, 123456789); - gen.addAndR1(); + gen.addAndR0WithR1(); gen.addJumpIfR0Equals(1234567890 & 123456789, DROP_LABEL); assertDrop(gen); @@ -418,7 +493,7 @@ gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R0, 1234567890); gen.addLoadImmediate(R1, 1); - gen.addLeftShiftR1(); + gen.addLeftShiftR0ByR1(); gen.addJumpIfR0Equals(1234567890 << 1, DROP_LABEL); assertDrop(gen); @@ -426,7 +501,7 @@ gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R0, 1234567890); gen.addLoadImmediate(R1, -1); - gen.addLeftShiftR1(); + gen.addLeftShiftR0ByR1(); gen.addJumpIfR0Equals(1234567890 >> 1, DROP_LABEL); assertDrop(gen); @@ -434,7 +509,7 @@ gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R0, 123456789); gen.addLoadImmediate(R1, 2); - gen.addMulR1(); + gen.addMulR0ByR1(); gen.addJumpIfR0Equals(123456789 * 2, DROP_LABEL); assertDrop(gen); @@ -442,13 +517,13 @@ gen = new ApfV4Generator(MIN_APF_VERSION); gen.addLoadImmediate(R0, 1234567890); gen.addLoadImmediate(R1, 2); - gen.addDivR1(); + gen.addDivR0ByR1(); gen.addJumpIfR0Equals(1234567890 / 2, DROP_LABEL); assertDrop(gen); // Test divide by zero. gen = new ApfV4Generator(MIN_APF_VERSION); - gen.addDivR1(); + gen.addDivR0ByR1(); gen.addJump(DROP_LABEL); assertPass(gen); @@ -949,7 +1024,7 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); apfFilter.setLinkProperties(lp); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); byte[] data = new byte[Counter.totalSize()]; @@ -960,7 +1035,6 @@ assertTrue("Failed to drop all packets by filter. \nAPF counters:" + HexDump.toHexString(data, false), result); - apfFilter.shutdown(); } private static final int ETH_HEADER_LEN = 14; @@ -1126,7 +1200,7 @@ ApfConfiguration config = getDefaultConfig(); config.multicastFilter = DROP_MULTICAST; TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); apfFilter.setLinkProperties(lp); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); @@ -1170,8 +1244,6 @@ // Verify unicast IPv4 DHCP to us is passed put(packet, ETH_DEST_ADDR_OFFSET, TestApfFilter.MOCK_MAC_ADDR); assertPass(program, packet.array()); - - apfFilter.shutdown(); } @Test @@ -1179,7 +1251,7 @@ MockIpClientCallback ipClientCallback = new MockIpClientCallback(); ApfConfiguration config = getDefaultConfig(); TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); // Verify empty IPv6 packet is passed @@ -1211,8 +1283,6 @@ assertDrop(program, packet.array()); put(packet, IPV6_DEST_ADDR_OFFSET, IPV6_ALL_ROUTERS_ADDRESS); assertDrop(program, packet.array()); - - apfFilter.shutdown(); } private static void fillQuestionSection(ByteBuffer buf, String... qnames) throws IOException { @@ -1395,21 +1465,16 @@ } - /** Adds to the program a no-op instruction that is one byte long. */ - private void addOneByteNoop(ApfV4Generator gen) { - gen.addLeftShift(0); - } - @Test - public void testAddOneByteNoopAddsOneByte() throws Exception { + public void testAddNopAddsOneByte() throws Exception { ApfV4Generator gen = new ApfV4Generator(MIN_APF_VERSION); - addOneByteNoop(gen); + gen.addNop(); assertEquals(1, gen.generate().length); final int count = 42; gen = new ApfV4Generator(MIN_APF_VERSION); for (int i = 0; i < count; i++) { - addOneByteNoop(gen); + gen.addNop(); } assertEquals(count, gen.generate().length); } @@ -1435,7 +1500,7 @@ ApfConfiguration config = getDefaultConfig(); TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); apfFilter.setLinkProperties(lp); // Construct IPv4 mDNS packet @@ -1482,8 +1547,6 @@ mdnsv6packet = makeMdnsV6Packet("abcd.local"); assertDrop(program, mdnsv4packet); assertDrop(program, mdnsv6packet); - - apfFilter.shutdown(); } private ApfV4Generator generateDnsFilter(boolean ipv6, String... labels) throws Exception { @@ -1500,7 +1563,7 @@ // Hack to prevent the APF instruction limit triggering. for (int i = 0; i < 500; i++) { - addOneByteNoop(gen); + gen.addNop(); } byte[] program = gen.generate(); @@ -1586,7 +1649,7 @@ // bytes, is capable of dropping the packet. ApfV4Generator gen = generateDnsFilter(/*ipv6=*/ true, labels); for (int i = 0; i < expectedNecessaryOverhead; i++) { - addOneByteNoop(gen); + gen.addNop(); } final byte[] programWithJustEnoughOverhead = gen.generate(); assertVerdict( @@ -1600,7 +1663,7 @@ // cannot correctly drop the packet because it hits the interpreter instruction limit. gen = generateDnsFilter(/*ipv6=*/ true, labels); for (int i = 0; i < expectedNecessaryOverhead - 1; i++) { - addOneByteNoop(gen); + gen.addNop(); } final byte[] programWithNotEnoughOverhead = gen.generate(); @@ -1673,7 +1736,7 @@ ApfConfiguration config = getDefaultConfig(); config.ieee802_3Filter = DROP_802_3_FRAMES; TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); apfFilter.setLinkProperties(lp); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); @@ -1731,10 +1794,10 @@ // Verify it can be initialized to on ipClientCallback.resetApfProgramWait(); - apfFilter.shutdown(); config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; - apfFilter = new TestApfFilter(mContext, config, ipClientCallback, mNetworkQuirkMetrics); + apfFilter = new TestApfFilter(mContext, config, ipClientCallback, mNetworkQuirkMetrics, + mDependencies); apfFilter.setLinkProperties(lp); program = ipClientCallback.assertProgramUpdateAndGet(); assertDrop(program, mcastv4packet.array()); @@ -1745,8 +1808,6 @@ // Verify that ICMPv6 multicast is not dropped. mcastv6packet.put(IPV6_NEXT_HEADER_OFFSET, (byte)IPPROTO_ICMPV6); assertPass(program, mcastv6packet.array()); - - apfFilter.shutdown(); } @Test @@ -1828,8 +1889,6 @@ receiver.onReceive(mContext, new Intent(ACTION_DEVICE_IDLE_MODE_CHANGED)); } assertPass(ipClientCallback.assertProgramUpdateAndGet(), packet.array()); - - apfFilter.shutdown(); } @Test @@ -1855,7 +1914,6 @@ // Now turn on the filter ipClientCallback.resetApfProgramWait(); - apfFilter.shutdown(); config.ieee802_3Filter = DROP_802_3_FRAMES; apfFilter = TestApfFilter.createTestApfFilter(mContext, ipClientCallback, config, mNetworkQuirkMetrics, mDependencies); @@ -1873,8 +1931,6 @@ // Verify that IPv6 (as example of Ethernet II) frame will pass setIpv6VersionFields(packet); assertPass(program, packet.array()); - - apfFilter.shutdown(); } @Test @@ -1904,7 +1960,6 @@ // Now add IPv4 to the black list ipClientCallback.resetApfProgramWait(); - apfFilter.shutdown(); config.ethTypeBlackList = ipv4BlackList; apfFilter = TestApfFilter.createTestApfFilter(mContext, ipClientCallback, config, mNetworkQuirkMetrics, mDependencies); @@ -1920,7 +1975,6 @@ // Now let us have both IPv4 and IPv6 in the black list ipClientCallback.resetApfProgramWait(); - apfFilter.shutdown(); config.ethTypeBlackList = ipv4Ipv6BlackList; apfFilter = TestApfFilter.createTestApfFilter(mContext, ipClientCallback, config, mNetworkQuirkMetrics, mDependencies); @@ -1933,8 +1987,6 @@ // Verify that IPv6 frame will be dropped setIpv6VersionFields(packet); assertDrop(program, packet.array()); - - apfFilter.shutdown(); } private byte[] getProgram(MockIpClientCallback cb, ApfFilter filter, LinkProperties lp) { @@ -1945,8 +1997,8 @@ private void verifyArpFilter(byte[] program, int filterResult) { // Verify ARP request packet - assertPass(program, arpRequestBroadcast(MOCK_IPV4_ADDR)); - assertVerdict(filterResult, program, arpRequestBroadcast(ANOTHER_IPV4_ADDR)); + assertVerdict(filterResult, program, arpRequestBroadcast(MOCK_IPV4_ADDR)); + assertDrop(program, arpRequestBroadcast(ANOTHER_IPV4_ADDR)); assertDrop(program, arpRequestBroadcast(IPV4_ANY_HOST_ADDR)); // Verify ARP reply packets from different source ip @@ -1971,21 +2023,19 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); - // Verify initially ARP request filter is off, and GARP filter is on. - verifyArpFilter(ipClientCallback.assertProgramUpdateAndGet(), PASS); + // Verify initially ARP request filter and GARP filter are on. + verifyArpFilter(ipClientCallback.assertProgramUpdateAndGet(), DROP); // Inform ApfFilter of our address and verify ARP filtering is on LinkAddress linkAddress = new LinkAddress(InetAddress.getByAddress(MOCK_IPV4_ADDR), 24); LinkProperties lp = new LinkProperties(); assertTrue(lp.addLinkAddress(linkAddress)); - verifyArpFilter(getProgram(ipClientCallback, apfFilter, lp), DROP); + verifyArpFilter(getProgram(ipClientCallback, apfFilter, lp), PASS); - // Inform ApfFilter of loss of IP and verify ARP filtering is off - verifyArpFilter(getProgram(ipClientCallback, apfFilter, new LinkProperties()), PASS); - - apfFilter.shutdown(); + // Inform ApfFilter of loss of IP and verify ARP filtering is on + verifyArpFilter(getProgram(ipClientCallback, apfFilter, new LinkProperties()), DROP); } private static byte[] arpReply(byte[] sip, byte[] tip) { @@ -2032,7 +2082,7 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); byte[] program; final int srcPort = 12345; final int dstPort = 54321; @@ -2177,8 +2227,6 @@ assertPass(program, ipv6TcpPacket(IPV6_ANOTHER_ADDR, IPV6_KEEPALIVE_SRC_ADDR, srcPort, dstPort, anotherSeqNum, anotherAckNum)); - - apfFilter.shutdown(); } private static byte[] ipv4TcpPacket(byte[] sip, byte[] dip, int sport, @@ -2226,7 +2274,7 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); byte[] program; final int srcPort = 1024; final int dstPort = 4500; @@ -2272,7 +2320,6 @@ dstPort, srcPort, 10 /* dataLength */)); apfFilter.removeKeepalivePacketFilter(slot1); - apfFilter.shutdown(); } private static byte[] ipv4UdpPacket(byte[] sip, byte[] dip, int sport, @@ -2391,7 +2438,7 @@ public RaPacketBuilder addDnsslOption(int lifetime, String... domains) { ByteArrayOutputStream dnssl = new ByteArrayOutputStream(); for (String domain : domains) { - for (String label : domain.split(".")) { + for (String label : domain.split("\\.")) { final byte[] bytes = label.getBytes(StandardCharsets.UTF_8); dnssl.write((byte) bytes.length); dnssl.write(bytes, 0, bytes.length); @@ -2475,7 +2522,8 @@ public void testRaToString() throws Exception { MockIpClientCallback cb = new MockIpClientCallback(); ApfConfiguration config = getDefaultConfig(); - TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, mNetworkQuirkMetrics); + TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, mNetworkQuirkMetrics, + mDependencies); byte[] packet = buildLargeRa(); ApfFilter.Ra ra = apfFilter.new Ra(packet, packet.length); @@ -2546,7 +2594,7 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); final int ROUTER_LIFETIME = 1000; @@ -2627,8 +2675,6 @@ verifyRaLifetime(program, routeInfoOptionPacket, ROUTE_LIFETIME); verifyRaLifetime(program, dnsslOptionPacket, ROUTER_LIFETIME); verifyRaLifetime(program, largeRaPacket, 300); - - apfFilter.shutdown(); } @Test @@ -2638,7 +2684,7 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); final int RA_REACHABLE_TIME = 1800; final int RA_RETRANSMISSION_TIMER = 1234; @@ -2680,7 +2726,7 @@ config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); byte[] program = ipClientCallback.assertProgramUpdateAndGet(); final int routerLifetime = 1000; @@ -2710,8 +2756,6 @@ } program = ipClientCallback.assertProgramUpdateAndGet(); assertPass(program, basePacket.array()); - - apfFilter.shutdown(); } /** @@ -2749,7 +2793,8 @@ ApfConfiguration config = getDefaultConfig(); config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; - TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, mNetworkQuirkMetrics); + TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, mNetworkQuirkMetrics, + mDependencies); for (int i = 0; i < 1000; i++) { byte[] packet = new byte[r.nextInt(maxRandomPacketSize + 1)]; r.nextBytes(packet); @@ -2760,7 +2805,6 @@ throw new Exception("bad packet: " + HexDump.toHexString(packet), e); } } - apfFilter.shutdown(); } @Test @@ -2771,7 +2815,8 @@ ApfConfiguration config = getDefaultConfig(); config.multicastFilter = DROP_MULTICAST; config.ieee802_3Filter = DROP_802_3_FRAMES; - TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, mNetworkQuirkMetrics); + TestApfFilter apfFilter = new TestApfFilter(mContext, config, cb, mNetworkQuirkMetrics, + mDependencies); for (int i = 0; i < 1000; i++) { byte[] packet = new byte[r.nextInt(maxRandomPacketSize + 1)]; r.nextBytes(packet); @@ -2781,14 +2826,13 @@ throw new Exception("bad packet: " + HexDump.toHexString(packet), e); } } - apfFilter.shutdown(); } @Test public void testMatchedRaUpdatesLifetime() throws Exception { final MockIpClientCallback ipClientCallback = new MockIpClientCallback(); final TestApfFilter apfFilter = new TestApfFilter(mContext, getDefaultConfig(), - ipClientCallback, mNetworkQuirkMetrics); + ipClientCallback, mNetworkQuirkMetrics, mDependencies); // Create an RA and build an APF program byte[] ra = new RaPacketBuilder(1800 /* router lifetime */).build(); @@ -2805,7 +2849,6 @@ // assert program was updated and new lifetimes were taken into account. assertDrop(program, ra); - apfFilter.shutdown(); } @Test @@ -2862,13 +2905,13 @@ final String packetStringFmt = "33330000000128C68E23672C86DD60054C6B00603AFFFE800000000000002AC68EFFFE23672CFF02000000000000000000000000000186000ACD40C01B580000000000000000010128C68E23672C05010000000005DC030440C0%s000000002401FA000480F00000000000000000001903000000001B582401FA000480F000000000000000000107010000000927C0"; final List<String> lifetimes = List.of("FFFFFFFF", "00000000", "00000001", "00001B58"); for (String lifetime : lifetimes) { - apfFilter = new TestApfFilter(mContext, config, ipClientCallback, mNetworkQuirkMetrics); + apfFilter = new TestApfFilter(mContext, config, ipClientCallback, mNetworkQuirkMetrics, + mDependencies); final byte[] ra = hexStringToByteArray( String.format(packetStringFmt, lifetime + lifetime)); // feed the RA into APF and generate the filter, the filter shouldn't crash. apfFilter.pretendPacketReceived(ra); ipClientCallback.assertProgramUpdateAndGet(); - apfFilter.shutdown(); } } @@ -2881,7 +2924,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(1800 /* router lifetime */) @@ -2899,7 +2942,6 @@ .addPioOption(1800 /*valid*/, 1 /*preferred*/, "2001:db8::/64") .build(); assertPass(program, ra); - apfFilter.shutdown(); } // Test for go/apf-ra-filter Case 2a. @@ -2911,7 +2953,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(1800 /* router lifetime */) @@ -2935,7 +2977,6 @@ .addPioOption(1800 /*valid*/, 33 /*preferred*/, "2001:db8::/64") .build(); assertPass(program, ra); - apfFilter.shutdown(); } @@ -2948,7 +2989,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(0 /* router lifetime */).build(); @@ -2966,7 +3007,6 @@ // lifetime increases to accept_ra_min_lft ra = new RaPacketBuilder(180 /* router lifetime */).build(); assertPass(program, ra); - apfFilter.shutdown(); } @@ -2979,7 +3019,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(100 /* router lifetime */).build(); @@ -3005,7 +3045,6 @@ // lifetime is 0 ra = new RaPacketBuilder(0 /* router lifetime */).build(); assertPass(program, ra); - apfFilter.shutdown(); } // Test for go/apf-ra-filter Case 3b. @@ -3017,7 +3056,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(200 /* router lifetime */).build(); @@ -3039,7 +3078,6 @@ // lifetime is 0 ra = new RaPacketBuilder(0 /* router lifetime */).build(); assertPass(program, ra); - apfFilter.shutdown(); } // Test for go/apf-ra-filter Case 4b. @@ -3051,7 +3089,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(1800 /* router lifetime */).build(); @@ -3081,7 +3119,6 @@ // lifetime is 0 ra = new RaPacketBuilder(0 /* router lifetime */).build(); assertPass(program, ra); - apfFilter.shutdown(); } @Test @@ -3091,7 +3128,7 @@ final ApfConfiguration config = getDefaultConfig(); config.acceptRaMinLft = 180; final TestApfFilter apfFilter = new TestApfFilter(mContext, config, ipClientCallback, - mNetworkQuirkMetrics); + mNetworkQuirkMetrics, mDependencies); // Create an initial RA and build an APF program byte[] ra = new RaPacketBuilder(1800 /* router lifetime */).build(); @@ -3137,7 +3174,6 @@ apfFilter.pretendPacketReceived(ra); program = ipClientCallback.assertProgramUpdateAndGet(); assertDrop(program, ra); - apfFilter.shutdown(); } @Test @@ -3195,7 +3231,6 @@ } verify(mNetworkQuirkMetrics).setEvent(NetworkQuirkEvent.QE_APF_INSTALL_FAILURE); verify(mNetworkQuirkMetrics).statsWrite(); - apfFilter.shutdown(); } @Test @@ -3222,7 +3257,6 @@ program = ipClientCallback.assertProgramUpdateAndGet(); verify(mNetworkQuirkMetrics).setEvent(NetworkQuirkEvent.QE_APF_OVER_SIZE_FAILURE); verify(mNetworkQuirkMetrics).statsWrite(); - apfFilter.shutdown(); } @Test @@ -3252,7 +3286,6 @@ } verify(mNetworkQuirkMetrics).setEvent(NetworkQuirkEvent.QE_APF_GENERATE_FILTER_EXCEPTION); verify(mNetworkQuirkMetrics).statsWrite(); - apfFilter.shutdown(); } @Test @@ -3293,7 +3326,7 @@ final byte[] ra = buildLargeRa(); expectedData[totalPacketsCounterIdx + 3] += 1; expectedData[passedIpv6IcmpCounterIdx + 3] += 1; - assertDataMemoryContents(PASS, program, ra, data, expectedData); + assertDataMemoryContentsIgnoreVersion(PASS, program, ra, data, expectedData); apfFilter.pretendPacketReceived(ra); program = ipClientCallback.assertProgramUpdateAndGet(); maxProgramSize = Math.max(maxProgramSize, program.length); @@ -3311,7 +3344,8 @@ put(mcastv4packet, IPV4_DEST_ADDR_OFFSET, multicastIpv4Addr); expectedData[totalPacketsCounterIdx + 3] += 1; expectedData[droppedIpv4MulticastIdx + 3] += 1; - assertDataMemoryContents(DROP, program, mcastv4packet.array(), data, expectedData); + assertDataMemoryContentsIgnoreVersion(DROP, program, mcastv4packet.array(), data, + expectedData); // Set data snapshot and update counters. apfFilter.setDataSnapshot(data); @@ -3483,6 +3517,23 @@ verifyNoMetricsWrittenForShortDuration(true /* isLegacy */); } + private int deriveApfGeneratorVersion(ApfV4GeneratorBase<?> gen) { + if (gen instanceof ApfV4Generator) { + return 4; + } else if (gen instanceof ApfV6Generator) { + return 6; + } + return -1; + } + + @Test + public void testApfGeneratorPropagation() throws IllegalInstructionException { + ApfV4Generator v4Gen = new ApfV4Generator(APF_VERSION_4); + ApfV6Generator v6Gen = new ApfV6Generator(); + assertEquals(4, deriveApfGeneratorVersion(v4Gen)); + assertEquals(6, deriveApfGeneratorVersion(v6Gen)); + } + @Test public void testFullApfV4ProgramGenerationIPV6() throws IllegalInstructionException { ApfV4Generator gen = new ApfV4Generator(APF_VERSION_4); @@ -3533,7 +3584,7 @@ gen.addLoad16Indexed(R0, 16); gen.addJumpIfR0NotEquals(0x44, "LABEL_159"); gen.addLoadImmediate(R0, 50); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(hexStringToByteArray("e212507c6345"), "LABEL_159"); gen.addLoadImmediate(R1, -12); gen.addJump("LABEL_498"); @@ -3686,7 +3737,7 @@ gen.addLoad16Indexed(R0, 16); gen.addJumpIfR0NotEquals(0x44, "LABEL_151"); gen.addLoadImmediate(R0, 50); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(hexStringToByteArray("f683d58f832b"), "LABEL_151"); gen.addLoadImmediate(R1, -12); gen.addJump("LABEL_277"); @@ -3807,7 +3858,7 @@ gen.addLoad16Indexed(R0, 16); gen.addJumpIfR0NotEquals(0x44, "LABEL_157"); gen.addLoadImmediate(R0, 50); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(hexStringToByteArray("ea42226789c0"), "LABEL_157"); gen.addLoadImmediate(R1, -12); gen.addJump("LABEL_339"); @@ -3831,7 +3882,7 @@ gen.addSwap(); gen.addLoad16(R0, 16); gen.addNeg(R1); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfR0NotEquals(0x1, "LABEL_243"); gen.addLoadFromMemory(R0, 13); gen.addAdd(14); @@ -3949,7 +4000,7 @@ gen.addLoad16Indexed(R0, 16); gen.addJumpIfR0NotEquals(0x44, "LABEL_165"); gen.addLoadImmediate(R0, 50); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addJumpIfBytesAtR0NotEqual(hexStringToByteArray("7e9046bc7008"), "LABEL_165"); gen.addLoadImmediate(R1, -24); gen.addJump("LABEL_576");
diff --git a/tests/unit/src/android/net/apf/ApfTestUtils.java b/tests/unit/src/android/net/apf/ApfTestUtils.java index 3fcb8d7..79143eb 100644 --- a/tests/unit/src/android/net/apf/ApfTestUtils.java +++ b/tests/unit/src/android/net/apf/ApfTestUtils.java
@@ -87,9 +87,12 @@ private static void assertVerdict(int apfVersion, int expected, byte[] program, byte[] packet, int filterAge) { - final String msg = "Unexpected APF verdict. To debug:\n" + " apf_run --program " - + HexDump.toHexString(program) + " --packet " + HexDump.toHexString(packet) - + " --trace | less\n "; + final String msg = "Unexpected APF verdict. To debug:\n" + + " apf_run --program " + HexDump.toHexString(program) + + " --packet " + HexDump.toHexString(packet) + + " --age " + filterAge + + (apfVersion > 4 ? " --v6" : "") + + " --trace " + " | less\n "; assertReturnCodesEqual(msg, expected, apfSimulate(apfVersion, program, packet, null, filterAge)); } @@ -154,11 +157,21 @@ * Runs the APF program and checks the return code and data regions equals to expected value. */ public static void assertDataMemoryContents(int apfVersion, int expected, byte[] program, - byte[] packet, byte[] data, byte[] expectedData) + byte[] packet, byte[] data, byte[] expectedData, boolean ignoreInterpreterVersion) throws ApfV4Generator.IllegalInstructionException, Exception { assertReturnCodesEqual(expected, apfSimulate(apfVersion, program, packet, data, 0 /* filterAge */)); + if (ignoreInterpreterVersion) { + final int apfVersionIdx = ApfCounterTracker.Counter.totalSize() + + ApfCounterTracker.Counter.APF_VERSION.offset(); + final int apfProgramIdIdx = ApfCounterTracker.Counter.totalSize() + + ApfCounterTracker.Counter.APF_PROGRAM_ID.offset(); + for (int i = 0; i < 4; ++i) { + data[apfVersionIdx + i] = 0; + data[apfProgramIdIdx + i] = 0; + } + } // assertArrayEquals() would only print one byte, making debugging difficult. if (!Arrays.equals(expectedData, data)) { throw new Exception("\nprogram: " + HexDump.toHexString(program) + "\ndata memory: " @@ -172,14 +185,25 @@ */ public static void assertVerdict(int apfVersion, int expected, byte[] program, byte[] packet, byte[] data) { - assertReturnCodesEqual(expected, - apfSimulate(apfVersion, program, packet, data, 0 /* filterAge */)); + assertVerdict(apfVersion, expected, program, packet, data, 0 /* filterAge */); } private static void assertVerdict(int apfVersion, int expected, ApfV4Generator gen, byte[] packet, int filterAge) throws ApfV4Generator.IllegalInstructionException { - assertReturnCodesEqual(expected, - apfSimulate(apfVersion, gen.generate(), packet, null, filterAge)); + assertVerdict(apfVersion, expected, gen.generate(), packet, null, filterAge); + } + + private static void assertVerdict(int apfVersion, int expected, byte[] program, byte[] packet, + byte[] data, int filterAge) { + final String msg = "Unexpected APF verdict. To debug:\n" + + " apf_run --program " + HexDump.toHexString(program) + + " --packet " + HexDump.toHexString(packet) + + (data != null ? " --data " + HexDump.toHexString(data) : "") + + " --age " + filterAge + + (apfVersion > 4 ? " --v6" : "") + + " --trace " + " | less\n "; + assertReturnCodesEqual(msg, expected, + apfSimulate(apfVersion, program, packet, data, filterAge)); } /** @@ -276,13 +300,6 @@ private final boolean mThrowsExceptionWhenGeneratesProgram; public TestApfFilter(Context context, ApfConfiguration config, - MockIpClientCallback ipClientCallback, NetworkQuirkMetrics networkQuirkMetrics) - throws Exception { - this(context, config, ipClientCallback, networkQuirkMetrics, new Dependencies(context), - false /* throwsExceptionWhenGeneratesProgram */, new ApfFilter.Clock()); - } - - public TestApfFilter(Context context, ApfConfiguration config, MockIpClientCallback ipClientCallback, NetworkQuirkMetrics networkQuirkMetrics, Dependencies dependencies) throws Exception { this(context, config, ipClientCallback, networkQuirkMetrics, dependencies, @@ -383,7 +400,7 @@ @Override @GuardedBy("this") - protected ApfV4Generator emitPrologueLocked() throws IllegalInstructionException { + protected ApfV4GeneratorBase<?> emitPrologueLocked() throws IllegalInstructionException { if (mThrowsExceptionWhenGeneratesProgram) { throw new IllegalStateException(); }
diff --git a/tests/unit/src/android/net/apf/ApfV5Test.kt b/tests/unit/src/android/net/apf/ApfV5Test.kt index 421ed5b..416321a 100644 --- a/tests/unit/src/android/net/apf/ApfV5Test.kt +++ b/tests/unit/src/android/net/apf/ApfV5Test.kt
@@ -16,12 +16,16 @@ package android.net.apf import android.net.apf.ApfCounterTracker.Counter +import android.net.apf.ApfCounterTracker.Counter.DROPPED_ETHERTYPE_DENYLISTED +import android.net.apf.ApfCounterTracker.Counter.DROPPED_ETH_BROADCAST +import android.net.apf.ApfCounterTracker.Counter.PASSED_ARP import android.net.apf.ApfTestUtils.DROP import android.net.apf.ApfTestUtils.MIN_PKT_SIZE import android.net.apf.ApfTestUtils.PASS import android.net.apf.ApfTestUtils.assertDrop import android.net.apf.ApfTestUtils.assertPass import android.net.apf.ApfTestUtils.assertVerdict +import android.net.apf.BaseApfGenerator.APF_VERSION_4 import android.net.apf.BaseApfGenerator.DROP_LABEL import android.net.apf.BaseApfGenerator.IllegalInstructionException import android.net.apf.BaseApfGenerator.MIN_APF_VERSION @@ -30,6 +34,7 @@ import android.net.apf.BaseApfGenerator.Register.R1 import androidx.test.filters.SmallTest import androidx.test.runner.AndroidJUnit4 +import com.android.net.module.util.HexDump import com.android.net.module.util.Struct import com.android.net.module.util.structs.EthernetHeader import com.android.net.module.util.structs.Ipv4Header @@ -38,6 +43,7 @@ import kotlin.test.assertContentEquals import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import org.junit.After import org.junit.Test import org.junit.runner.RunWith @@ -51,6 +57,11 @@ private val testPacket = byteArrayOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + @After + fun tearDown() { + ApfJniUtils.resetTransmittedPacketMemory() + } + @Test fun testDataInstructionMustComeFirst() { var gen = ApfV6Generator() @@ -61,6 +72,7 @@ @Test fun testApfInstructionEncodingSizeCheck() { var gen = ApfV6Generator() + assertFailsWith<IllegalArgumentException> { gen.addData(ByteArray(65536) { 0x01 }) } assertFailsWith<IllegalArgumentException> { gen.addAllocate(65536) } assertFailsWith<IllegalArgumentException> { gen.addAllocate(-1) } assertFailsWith<IllegalArgumentException> { gen.addDataCopy(-1, 1) } @@ -74,77 +86,192 @@ assertFailsWith<IllegalArgumentException> { gen.addPacketCopyFromR0(-1) } assertFailsWith<IllegalArgumentException> { gen.addDataCopyFromR0(-1) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( - byteArrayOf(1, 'A'.code.toByte(), 0, 0), 256, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'A'.code.toByte(), 0, 0), + 256, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( - byteArrayOf(1, 'a'.code.toByte(), 0, 0), 0x0c, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'a'.code.toByte(), 0, 0), + 0x0c, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( - byteArrayOf(1, '.'.code.toByte(), 0, 0), 0x0c, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, '.'.code.toByte(), 0, 0), + 0x0c, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( - byteArrayOf(0, 0), 0xc0, ApfV4Generator.DROP_LABEL) } + byteArrayOf(0, 0), + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( - byteArrayOf(1, 'A'.code.toByte()), 0xc0, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'A'.code.toByte()), + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( byteArrayOf(64) + ByteArray(64) { 'A'.code.toByte() } + byteArrayOf(0, 0), - 0xc0, ApfV4Generator.DROP_LABEL) } + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0), - 0xc0, ApfV4Generator.DROP_LABEL) } + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsQ( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()), - 0xc0, ApfV4Generator.DROP_LABEL) } + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( - byteArrayOf(1, 'A'.code.toByte(), 0, 0), 256, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'A'.code.toByte(), 0, 0), + 256, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( - byteArrayOf(1, 'a'.code.toByte(), 0, 0), 0x0c, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'a'.code.toByte(), 0, 0), + 0x0c, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( - byteArrayOf(1, '.'.code.toByte(), 0, 0), 0x0c, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, '.'.code.toByte(), 0, 0), + 0x0c, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( - byteArrayOf(0, 0), 0xc0, ApfV4Generator.DROP_LABEL) } + byteArrayOf(0, 0), + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( - byteArrayOf(1, 'A'.code.toByte()), 0xc0, ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'A'.code.toByte()), + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( byteArrayOf(64) + ByteArray(64) { 'A'.code.toByte() } + byteArrayOf(0, 0), - 0xc0, ApfV4Generator.DROP_LABEL) } + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0), - 0xc0, ApfV4Generator.DROP_LABEL) } + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsQ( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()), - 0xc0, ApfV4Generator.DROP_LABEL) } + 0xc0, + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( - byteArrayOf(1, 'a'.code.toByte(), 0, 0), ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'a'.code.toByte(), 0, 0), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( - byteArrayOf(1, '.'.code.toByte(), 0, 0), ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, '.'.code.toByte(), 0, 0), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( - byteArrayOf(0, 0), ApfV4Generator.DROP_LABEL) } + byteArrayOf(0, 0), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( - byteArrayOf(1, 'A'.code.toByte()), ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'A'.code.toByte()), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( byteArrayOf(64) + ByteArray(64) { 'A'.code.toByte() } + byteArrayOf(0, 0), - ApfV4Generator.DROP_LABEL) } + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0), - ApfV4Generator.DROP_LABEL) } + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0DoesNotContainDnsA( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()), - ApfV4Generator.DROP_LABEL) } + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( - byteArrayOf(1, 'a'.code.toByte(), 0, 0), ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'a'.code.toByte(), 0, 0), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( - byteArrayOf(1, '.'.code.toByte(), 0, 0), ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, '.'.code.toByte(), 0, 0), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( - byteArrayOf(0, 0), ApfV4Generator.DROP_LABEL) } + byteArrayOf(0, 0), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( - byteArrayOf(1, 'A'.code.toByte()), ApfV4Generator.DROP_LABEL) } + byteArrayOf(1, 'A'.code.toByte()), + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( byteArrayOf(64) + ByteArray(64) { 'A'.code.toByte() } + byteArrayOf(0, 0), - ApfV4Generator.DROP_LABEL) } + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0), - ApfV4Generator.DROP_LABEL) } + ApfV4Generator.DROP_LABEL + ) } assertFailsWith<IllegalArgumentException> { gen.addJumpIfPktAtR0ContainDnsA( byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte()), - ApfV4Generator.DROP_LABEL) } + ApfV4Generator.DROP_LABEL + ) } + assertFailsWith<IllegalArgumentException> { gen.addCountAndDrop(PASSED_ARP) } + assertFailsWith<IllegalArgumentException> { gen.addCountAndPass(DROPPED_ETH_BROADCAST) } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndDropIfR0Equals(3, PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndPassIfR0Equals(3, DROPPED_ETH_BROADCAST) + } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndDropIfR0NotEquals(3, PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndPassIfR0NotEquals(3, DROPPED_ETH_BROADCAST) + } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndDropIfR0LessThan(3, PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndPassIfR0LessThan(3, DROPPED_ETH_BROADCAST) + } + assertFailsWith<IllegalArgumentException> { + gen.addCountAndDropIfBytesAtR0NotEqual(byteArrayOf(1), PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + gen.addWrite32(byteArrayOf()) + } + + val v4gen = ApfV4Generator(APF_VERSION_4) + assertFailsWith<IllegalArgumentException> { v4gen.addCountAndDrop(PASSED_ARP) } + assertFailsWith<IllegalArgumentException> { v4gen.addCountAndPass(DROPPED_ETH_BROADCAST) } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndDropIfR0Equals(3, PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndPassIfR0Equals(3, DROPPED_ETH_BROADCAST) + } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndDropIfR0NotEquals(3, PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndPassIfR0NotEquals(3, DROPPED_ETH_BROADCAST) + } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndDropIfR0LessThan(3, PASSED_ARP) + } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndPassIfR0LessThan(3, DROPPED_ETH_BROADCAST) + } + assertFailsWith<IllegalArgumentException> { + v4gen.addCountAndDropIfBytesAtR0NotEqual(byteArrayOf(1), PASSED_ARP) + } } @Test @@ -154,57 +281,105 @@ val program = ApfV6Generator().addJumpIfPktAtR0ContainDnsQ( byteArrayOf(1, '%'.code.toByte(), 0, 0), 1, - DROP_LABEL) - .addJumpIfPktAtR0ContainDnsA( - byteArrayOf(0xff.toByte(), 1, 'B'.code.toByte(), 0, 0), - DROP_LABEL - ) - .generate() + DROP_LABEL + ).addJumpIfPktAtR0ContainDnsA( + byteArrayOf(0xff.toByte(), 1, 'B'.code.toByte(), 0, 0), + DROP_LABEL + ).generate() } @Test fun testApfInstructionsEncoding() { - val v4gen = ApfV4Generator<ApfV4Generator<BaseApfGenerator>>(MIN_APF_VERSION) + val v4gen = ApfV4Generator(MIN_APF_VERSION) v4gen.addPass() var program = v4gen.generate() // encoding PASS opcode: opcode=0, imm_len=0, R=0 assertContentEquals( - byteArrayOf(encodeInstruction(opcode = 0, immLength = 0, register = 0)), program) + byteArrayOf(encodeInstruction(opcode = 0, immLength = 0, register = 0)), + program + ) assertContentEquals( - listOf("0: pass"), - ApfJniUtils.disassembleApf(program).map { it.trim() } ) + listOf("0: pass"), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) var gen = ApfV6Generator() gen.addDrop() program = gen.generate() // encoding DROP opcode: opcode=0, imm_len=0, R=1 assertContentEquals( - byteArrayOf(encodeInstruction(opcode = 0, immLength = 0, register = 1)), program) + byteArrayOf(encodeInstruction(opcode = 0, immLength = 0, register = 1)), + program + ) assertContentEquals( - listOf("0: drop"), - ApfJniUtils.disassembleApf(program).map { it.trim() } ) + listOf("0: drop"), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) gen = ApfV6Generator() gen.addCountAndPass(129) program = gen.generate() // encoding COUNT(PASS) opcode: opcode=0, imm_len=size_of(imm), R=0, imm=counterNumber assertContentEquals( - byteArrayOf(encodeInstruction(opcode = 0, immLength = 1, register = 0), - 0x81.toByte()), program) + byteArrayOf( + encodeInstruction(opcode = 0, immLength = 1, register = 0), + 0x81.toByte() + ), + program + ) assertContentEquals( - listOf("0: pass 129"), - ApfJniUtils.disassembleApf(program).map { it.trim() } ) + listOf("0: pass 129"), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) gen = ApfV6Generator() gen.addCountAndDrop(1000) program = gen.generate() // encoding COUNT(DROP) opcode: opcode=0, imm_len=size_of(imm), R=1, imm=counterNumber assertContentEquals( - byteArrayOf(encodeInstruction(opcode = 0, immLength = 2, register = 1), - 0x03, 0xe8.toByte()), program) + byteArrayOf( + encodeInstruction(opcode = 0, immLength = 2, register = 1), + 0x03, + 0xe8.toByte() + ), + program + ) assertContentEquals( - listOf("0: drop 1000"), - ApfJniUtils.disassembleApf(program).map { it.trim() } ) + listOf("0: drop 1000"), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) + + gen = ApfV6Generator() + gen.addCountAndPass(PASSED_ARP) + program = gen.generate() + // encoding COUNT(PASS) opcode: opcode=0, imm_len=size_of(imm), R=0, imm=counterNumber + assertContentEquals( + byteArrayOf( + encodeInstruction(opcode = 0, immLength = 1, register = 0), + PASSED_ARP.value().toByte() + ), + program + ) + assertContentEquals( + listOf("0: pass 10"), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) + + gen = ApfV6Generator() + gen.addCountAndDrop(DROPPED_ETHERTYPE_DENYLISTED) + program = gen.generate() + // encoding COUNT(DROP) opcode: opcode=0, imm_len=size_of(imm), R=1, imm=counterNumber + assertContentEquals( + byteArrayOf( + encodeInstruction(opcode = 0, immLength = 1, register = 1), + DROPPED_ETHERTYPE_DENYLISTED.value().toByte() + ), + program + ) + assertContentEquals( + listOf("0: drop 38"), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) gen = ApfV6Generator() gen.addAllocateR0() @@ -212,13 +387,21 @@ program = gen.generate() // encoding ALLOC opcode: opcode=21(EXT opcode number), imm=36(TRANS opcode number). // R=0 means length stored in R0. R=1 means the length stored in imm1. - assertContentEquals(byteArrayOf( - encodeInstruction(opcode = 21, immLength = 1, register = 0), 36, - encodeInstruction(opcode = 21, immLength = 1, register = 1), 36, 0x05, - 0xDC.toByte()), - program) - assertContentEquals(listOf("0: allocate r0", "2: allocate 1500"), - ApfJniUtils.disassembleApf(program).map { it.trim() }) + assertContentEquals( + byteArrayOf( + encodeInstruction(opcode = 21, immLength = 1, register = 0), + 36, + encodeInstruction(opcode = 21, immLength = 1, register = 1), + 36, + 0x05, + 0xDC.toByte() + ), + program + ) + assertContentEquals(listOf( + "0: allocate r0", + "2: allocate 1500" + ), ApfJniUtils.disassembleApf(program).map { it.trim() }) gen = ApfV6Generator() gen.addTransmitWithoutChecksum() @@ -231,21 +414,28 @@ 37, 255.toByte(), 255.toByte(), encodeInstruction(opcode = 21, immLength = 1, register = 1), 37, 30, 40, 50, 1, 0 ), program) - assertContentEquals(listOf( - "0: transmit ip_ofs=255", - "4: transmitudp ip_ofs=30, csum_ofs=40, csum_start=50, partial_csum=0x0100", - ), ApfJniUtils.disassembleApf(program).map { it.trim() }) + assertContentEquals(listOf( + "0: transmit ip_ofs=255", + "4: transmitudp ip_ofs=30, csum_ofs=40, csum_start=50, partial_csum=0x0100", + ), ApfJniUtils.disassembleApf(program).map { it.trim() }) gen = ApfV6Generator() val largeByteArray = ByteArray(256) { 0x01 } gen.addData(largeByteArray) program = gen.generate() // encoding DATA opcode: opcode=14(JMP), R=1 - assertContentEquals(byteArrayOf( - encodeInstruction(opcode = 14, immLength = 2, register = 1), 0x01, 0x00) + - largeByteArray, program) - assertContentEquals(listOf("0: data 256, " + "01".repeat(256) ), - ApfJniUtils.disassembleApf(program).map { it.trim() }) + assertContentEquals( + byteArrayOf( + encodeInstruction(opcode = 14, immLength = 2, register = 1), + 0x01, + 0x00 + ) + largeByteArray, + program + ) + assertContentEquals( + listOf("0: data 256, " + "01".repeat(256) ), + ApfJniUtils.disassembleApf(program).map { it.trim() } + ) gen = ApfV6Generator() gen.addWriteU8(0x01) @@ -257,6 +447,8 @@ gen.addWriteU16(0x8000) gen.addWriteU32(0x00000000) gen.addWriteU32(0x80000000) + gen.addWrite32(-2) + gen.addWrite32(byteArrayOf(0xff.toByte(), 0xfe.toByte(), 0xfd.toByte(), 0xfc.toByte())) program = gen.generate() assertContentEquals(byteArrayOf( encodeInstruction(24, 1, 0), 0x01, @@ -267,20 +459,24 @@ encodeInstruction(24, 2, 0), 0x00, 0x00, encodeInstruction(24, 2, 0), 0x80.toByte(), 0x00, encodeInstruction(24, 4, 0), 0x00, 0x00, 0x00, 0x00, - encodeInstruction(24, 4, 0), 0x80.toByte(), 0x00, 0x00, - 0x00), program) + encodeInstruction(24, 4, 0), 0x80.toByte(), 0x00, 0x00, 0x00, + encodeInstruction(24, 4, 0), 0xff.toByte(), 0xff.toByte(), + 0xff.toByte(), 0xfe.toByte(), + encodeInstruction(24, 4, 0), 0xff.toByte(), 0xfe.toByte(), + 0xfd.toByte(), 0xfc.toByte()), program) assertContentEquals(listOf( - "0: write 0x01", - "2: write 0x0102", - "5: write 0x01020304", - "10: write 0x00", - "12: write 0x80", - "14: write 0x0000", - "17: write 0x8000", - "20: write 0x00000000", - "25: write 0x80000000" - ), - ApfJniUtils.disassembleApf(program).map { it.trim() }) + "0: write 0x01", + "2: write 0x0102", + "5: write 0x01020304", + "10: write 0x00", + "12: write 0x80", + "14: write 0x0000", + "17: write 0x8000", + "20: write 0x00000000", + "25: write 0x80000000", + "30: write 0xfffffffe", + "35: write 0xfffefdfc" + ), ApfJniUtils.disassembleApf(program).map { it.trim() }) gen = ApfV6Generator() gen.addWriteU8(R0) @@ -304,7 +500,8 @@ "4: ewrite4 r0", "6: ewrite1 r1", "8: ewrite2 r1", - "10: ewrite4 r1"), ApfJniUtils.disassembleApf(program).map { it.trim() }) + "10: ewrite4 r1" + ), ApfJniUtils.disassembleApf(program).map { it.trim() }) gen = ApfV6Generator() gen.addDataCopy(0, 10) @@ -321,8 +518,7 @@ "0: datacopy src=0, len=10", "2: datacopy src=1, len=5", "5: pktcopy src=1000, len=255" - ), - ApfJniUtils.disassembleApf(program).map { it.trim() }) + ), ApfJniUtils.disassembleApf(program).map { it.trim() }) gen = ApfV6Generator() gen.addDataCopyFromR0(5) @@ -337,20 +533,24 @@ encodeInstruction(21, 1, 0), 42, ), program) assertContentEquals(listOf( - "0: edatacopy src=r0, len=5", - "3: epktcopy src=r0, len=5", - "6: edatacopy src=r0, len=r1", - "8: epktcopy src=r0, len=r1"), ApfJniUtils.disassembleApf(program).map{ it.trim() }) + "0: edatacopy src=r0, len=5", + "3: epktcopy src=r0, len=5", + "6: edatacopy src=r0, len=r1", + "8: epktcopy src=r0, len=r1" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) gen = ApfV6Generator() gen.addJumpIfBytesAtR0Equal(byteArrayOf('a'.code.toByte()), ApfV4Generator.DROP_LABEL) program = gen.generate() - assertContentEquals( - byteArrayOf(encodeInstruction(opcode = 20, immLength = 1, register = 1), - 1, 1, 'a'.code.toByte()), program) + assertContentEquals(byteArrayOf( + encodeInstruction(opcode = 20, immLength = 1, register = 1), + 1, + 1, + 'a'.code.toByte() + ), program) assertContentEquals(listOf( - "0: jbseq r0, 0x1, DROP, 61"), - ApfJniUtils.disassembleApf(program).map{ it.trim() }) + "0: jbseq r0, 0x1, DROP, 61" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) val qnames = byteArrayOf(1, 'A'.code.toByte(), 1, 'B'.code.toByte(), 0, 0) gen = ApfV6Generator() @@ -363,9 +563,9 @@ encodeInstruction(21, 1, 1), 43, 1, 0x0c.toByte(), ) + qnames, program) assertContentEquals(listOf( - "0: jdnsqne r0, DROP, 12, (1)A(1)B(0)(0)", - "10: jdnsqeq r0, DROP, 12, (1)A(1)B(0)(0)"), - ApfJniUtils.disassembleApf(program).map{ it.trim() }) + "0: jdnsqne r0, DROP, 12, (1)A(1)B(0)(0)", + "10: jdnsqeq r0, DROP, 12, (1)A(1)B(0)(0)" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) gen = ApfV6Generator() gen.addJumpIfPktAtR0DoesNotContainDnsQSafe(qnames, 0x0c, ApfV4Generator.DROP_LABEL) @@ -378,8 +578,8 @@ ) + qnames, program) assertContentEquals(listOf( "0: jdnsqnesafe r0, DROP, 12, (1)A(1)B(0)(0)", - "10: jdnsqeqsafe r0, DROP, 12, (1)A(1)B(0)(0)"), - ApfJniUtils.disassembleApf(program).map{ it.trim() }) + "10: jdnsqeqsafe r0, DROP, 12, (1)A(1)B(0)(0)" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) gen = ApfV6Generator() gen.addJumpIfPktAtR0DoesNotContainDnsA(qnames, ApfV4Generator.DROP_LABEL) @@ -391,9 +591,9 @@ encodeInstruction(21, 1, 1), 44, 1, ) + qnames, program) assertContentEquals(listOf( - "0: jdnsane r0, DROP, (1)A(1)B(0)(0)", - "9: jdnsaeq r0, DROP, (1)A(1)B(0)(0)"), - ApfJniUtils.disassembleApf(program).map{ it.trim() }) + "0: jdnsane r0, DROP, (1)A(1)B(0)(0)", + "9: jdnsaeq r0, DROP, (1)A(1)B(0)(0)" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) gen = ApfV6Generator() gen.addJumpIfPktAtR0DoesNotContainDnsASafe(qnames, ApfV4Generator.DROP_LABEL) @@ -406,8 +606,8 @@ ) + qnames, program) assertContentEquals(listOf( "0: jdnsanesafe r0, DROP, (1)A(1)B(0)(0)", - "9: jdnsaeqsafe r0, DROP, (1)A(1)B(0)(0)"), - ApfJniUtils.disassembleApf(program).map{ it.trim() }) + "9: jdnsaeqsafe r0, DROP, (1)A(1)B(0)(0)" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) } @Test @@ -417,6 +617,8 @@ .addWriteU8(0x01) .addWriteU16(0x0203) .addWriteU32(0x04050607) + .addWrite32(-2) + .addWrite32(byteArrayOf(0xff.toByte(), 0xfe.toByte(), 0xfd.toByte(), 0xfc.toByte())) .addLoadImmediate(R0, 1) .addWriteU8(R0) .addLoadImmediate(R0, 0x0203) @@ -426,34 +628,66 @@ .addTransmitWithoutChecksum() .generate() assertPass(MIN_APF_VERSION_IN_DEV, program, ByteArray(MIN_PKT_SIZE)) - assertContentEquals(byteArrayOf(0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x01, 0x02, 0x03, - 0x04, 0x05, 0x06, 0x07), ApfJniUtils.getTransmittedPacket()) + assertContentEquals( + byteArrayOf( + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xff.toByte(), + 0xff.toByte(), 0xff.toByte(), 0xfe.toByte(), 0xff.toByte(), 0xfe.toByte(), + 0xfd.toByte(), 0xfc.toByte(), 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07), + ApfJniUtils.getTransmittedPacket() + ) } @Test fun testCopyToTxBuffer() { var program = ApfV6Generator() - .addData(byteArrayOf(33, 34, 35)) - .addAllocate(14) - .addDataCopy(2 /* src */, 2 /* len */) - .addDataCopy(4 /* src */, 1 /* len */) - .addPacketCopy(0 /* src */, 1 /* len */) - .addPacketCopy(1 /* src */, 3 /* len */) - .addLoadImmediate(R0, 2) // data copy offset - .addDataCopyFromR0(2 /* len */) - .addLoadImmediate(R0, 4) // data copy offset - .addLoadImmediate(R1, 1) // len - .addDataCopyFromR0LenR1() - .addLoadImmediate(R0, 0) // packet copy offset - .addPacketCopyFromR0(1 /* len */) - .addLoadImmediate(R0, 1) // packet copy offset - .addLoadImmediate(R1, 3) // len - .addPacketCopyFromR0LenR1() - .addTransmitWithoutChecksum() - .generate() + .addData(byteArrayOf(33, 34, 35)) + .addAllocate(14) + .addDataCopy(3, 2) // arg1=src, arg2=len + .addDataCopy(5, 1) // arg1=src, arg2=len + .addPacketCopy(0, 1) // arg1=src, arg2=len + .addPacketCopy(1, 3) // arg1=src, arg2=len + .addLoadImmediate(R0, 3) // data copy offset + .addDataCopyFromR0(2) // len + .addLoadImmediate(R0, 5) // data copy offset + .addLoadImmediate(R1, 1) // len + .addDataCopyFromR0LenR1() + .addLoadImmediate(R0, 0) // packet copy offset + .addPacketCopyFromR0(1) // len + .addLoadImmediate(R0, 1) // packet copy offset + .addLoadImmediate(R1, 3) // len + .addPacketCopyFromR0LenR1() + .addTransmitWithoutChecksum() + .generate() assertPass(MIN_APF_VERSION_IN_DEV, program, testPacket) - assertContentEquals(byteArrayOf(33, 34, 35, 1, 2, 3, 4, 33, 34, 35, 1, 2, 3, 4), - ApfJniUtils.getTransmittedPacket()) + assertContentEquals( + byteArrayOf(33, 34, 35, 1, 2, 3, 4, 33, 34, 35, 1, 2, 3, 4), + ApfJniUtils.getTransmittedPacket() + ) + } + + @Test + fun testCopyContentToTxBuffer() { + val program = ApfV6Generator() + .addData() + .addAllocate(18) + .addDataCopy(HexDump.hexStringToByteArray("112233445566")) + .addDataCopy(HexDump.hexStringToByteArray("223344")) + .addDataCopy(HexDump.hexStringToByteArray("778899")) + .addDataCopy(HexDump.hexStringToByteArray("112233445566")) + .addTransmitWithoutChecksum() + .generate() + assertContentEquals(listOf( + "0: data 9, 112233445566778899", + "12: allocate 18", + "16: datacopy src=3, len=6", + "19: datacopy src=4, len=3", + "22: datacopy src=9, len=3", + "25: datacopy src=3, len=6", + "28: transmit ip_ofs=255" + ), ApfJniUtils.disassembleApf(program).map{ it.trim() }) + assertPass(MIN_APF_VERSION_IN_DEV, program, testPacket) + val transmitPkt = HexDump.toHexString(ApfJniUtils.getTransmittedPacket()) + assertEquals("112233445566223344778899112233445566", transmitPkt) } @Test @@ -467,24 +701,179 @@ var dataRegion = ByteArray(Counter.totalSize()) { 0 } program = ApfV6Generator() .addData(byteArrayOf()) - .addCountAndDrop(Counter.DROPPED_ETH_BROADCAST.value()) + .addCountAndDrop(Counter.DROPPED_ETH_BROADCAST) .generate() assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) var counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.DROPPED_ETH_BROADCAST to 1), counterMap) + Counter.DROPPED_ETH_BROADCAST to 1 + ), counterMap) dataRegion = ByteArray(Counter.totalSize()) { 0 } program = ApfV6Generator() .addData(byteArrayOf()) - .addCountAndPass(Counter.PASSED_ARP.value()) + .addCountAndPass(Counter.PASSED_ARP) .generate() assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion) counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.PASSED_ARP to 1), counterMap) + Counter.PASSED_ARP to 1 + ), counterMap) + } + + @Test + fun testCountAndPassDropCompareR0() { + doTestCountAndPassDropCompareR0( + { mutableMapOf() }, + { ApfV4Generator(APF_VERSION_4) } + ) + doTestCountAndPassDropCompareR0( + { mutableMapOf(Counter.TOTAL_PACKETS to 1) }, + { ApfV6Generator().addData(byteArrayOf()) } + ) + } + + private fun doTestCountAndPassDropCompareR0( + getInitialMap: () -> MutableMap<Counter, Long>, + getGenerator: () -> ApfV4GeneratorBase<*> + ) { + var program = getGenerator() + .addLoadImmediate(R0, 123) + .addCountAndDropIfR0Equals(123, Counter.DROPPED_ETH_BROADCAST) + .addPass() + .addCountTrampoline() + .generate() + var dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) + var counterMap = decodeCountersIntoMap(dataRegion) + var expectedMap = getInitialMap() + expectedMap[Counter.DROPPED_ETH_BROADCAST] = 1 + assertEquals(expectedMap, counterMap) + + program = getGenerator() + .addLoadImmediate(R0, 123) + .addCountAndPassIfR0Equals(123, Counter.PASSED_ARP) + .addPass() + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + expectedMap = getInitialMap() + expectedMap[Counter.PASSED_ARP] = 1 + assertEquals(expectedMap, counterMap) + + program = getGenerator() + .addLoadImmediate(R0, 123) + .addCountAndDropIfR0NotEquals(124, Counter.DROPPED_ETH_BROADCAST) + .addPass() + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + expectedMap = getInitialMap() + expectedMap[Counter.DROPPED_ETH_BROADCAST] = 1 + assertEquals(expectedMap, counterMap) + + program = getGenerator() + .addLoadImmediate(R0, 123) + .addCountAndPassIfR0NotEquals(124, Counter.PASSED_ARP) + .addPass() + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + expectedMap = getInitialMap() + expectedMap[Counter.PASSED_ARP] = 1 + assertEquals(expectedMap, counterMap) + + program = getGenerator() + .addLoadImmediate(R0, 123) + .addCountAndDropIfR0LessThan(124, Counter.DROPPED_ETH_BROADCAST) + .addPass() + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + expectedMap = getInitialMap() + expectedMap[Counter.DROPPED_ETH_BROADCAST] = 1 + assertEquals(expectedMap, counterMap) + + program = getGenerator() + .addLoadImmediate(R0, 123) + .addCountAndPassIfR0LessThan(124, Counter.PASSED_ARP) + .addPass() + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + expectedMap = getInitialMap() + expectedMap[Counter.PASSED_ARP] = 1 + assertEquals(expectedMap, counterMap) + + program = getGenerator() + .addLoadImmediate(R0, 1) + .addCountAndDropIfBytesAtR0NotEqual( + byteArrayOf(5, 5), DROPPED_ETH_BROADCAST) + .addPass() + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + expectedMap = getInitialMap() + expectedMap[DROPPED_ETH_BROADCAST] = 1 + assertEquals(expectedMap, counterMap) + } + + @Test + fun testV4CountAndPassDrop() { + var program = ApfV4Generator(APF_VERSION_4) + .addCountAndDrop(Counter.DROPPED_ETH_BROADCAST) + .addCountTrampoline() + .generate() + var dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) + var counterMap = decodeCountersIntoMap(dataRegion) + assertEquals(mapOf<Counter, Long>( + Counter.DROPPED_ETH_BROADCAST to 1 + ), counterMap) + + program = ApfV4Generator(APF_VERSION_4) + .addCountAndPass(Counter.PASSED_ARP) + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion) + counterMap = decodeCountersIntoMap(dataRegion) + assertEquals(mapOf<Counter, Long>( + Counter.PASSED_ARP to 1 + ), counterMap) + } + + @Test + fun testV2CountAndPassDrop() { + var program = ApfV4Generator(MIN_APF_VERSION) + .addCountAndDrop(Counter.DROPPED_ETH_BROADCAST) + .addCountTrampoline() + .generate() + var dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, DROP, program, testPacket, dataRegion) + assertContentEquals(ByteArray(Counter.totalSize()) { 0 }, dataRegion) + + program = ApfV4Generator(MIN_APF_VERSION) + .addCountAndPass(PASSED_ARP) + .addCountTrampoline() + .generate() + dataRegion = ByteArray(Counter.totalSize()) { 0 } + assertVerdict(MIN_APF_VERSION_IN_DEV, PASS, program, testPacket, dataRegion) + assertContentEquals(ByteArray(Counter.totalSize()) { 0 }, dataRegion) } @Test @@ -500,7 +889,8 @@ val counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.PASSED_ALLOCATE_FAILURE to 1), counterMap) + Counter.PASSED_ALLOCATE_FAILURE to 1 + ), counterMap) } @Test @@ -519,43 +909,46 @@ val counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.PASSED_TRANSMIT_FAILURE to 1), counterMap) + Counter.PASSED_TRANSMIT_FAILURE to 1 + ), counterMap) } @Test fun testTransmitL4() { val etherIpv4UdpPacket = intArrayOf( - 0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb, - 0x38, 0xca, 0x84, 0xb7, 0x7f, 0x16, - 0x08, 0x00, // end of ethernet header - 0x45, - 0x04, - 0x00, 0x3f, - 0x43, 0xcd, - 0x40, 0x00, - 0xff, - 0x11, - 0x00, 0x00, // ipv4 checksum set to 0 - 0xc0, 0xa8, 0x01, 0x03, - 0xe0, 0x00, 0x00, 0xfb, // end of ipv4 header - 0x14, 0xe9, - 0x14, 0xe9, - 0x00, 0x2b, - 0x00, 0x2b, // end of udp header. udp checksum set to udp (header + payload) size - 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x62, 0x05, 0x6c, 0x6f, 0x63, 0x61, 0x6c, - 0x00, 0x00, 0x01, 0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x04, 0xc0, 0xa8, 0x01, - 0x09, + 0x01, 0x00, 0x5e, 0x00, 0x00, 0xfb, + 0x38, 0xca, 0x84, 0xb7, 0x7f, 0x16, + 0x08, 0x00, // end of ethernet header + 0x45, + 0x04, + 0x00, 0x3f, + 0x43, 0xcd, + 0x40, 0x00, + 0xff, + 0x11, + 0x00, 0x00, // ipv4 checksum set to 0 + 0xc0, 0xa8, 0x01, 0x03, + 0xe0, 0x00, 0x00, 0xfb, // end of ipv4 header + 0x14, 0xe9, + 0x14, 0xe9, + 0x00, 0x2b, + 0x00, 0x2b, // end of udp header. udp checksum set to udp (header + payload) size + 0x00, 0x00, 0x84, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x62, 0x05, 0x6c, 0x6f, 0x63, 0x61, 0x6c, + 0x00, 0x00, 0x01, 0x80, 0x01, 0x00, 0x00, 0x00, 0x78, 0x00, 0x04, 0xc0, 0xa8, 0x01, + 0x09, ).map { it.toByte() }.toByteArray() val program = ApfV6Generator() .addData(etherIpv4UdpPacket) .addAllocate(etherIpv4UdpPacket.size) - .addDataCopy(2 /* src */, etherIpv4UdpPacket.size /* len */) - .addTransmitL4(ETH_HLEN /* ipOfs */, - ETH_HLEN + IPV4_HLEN + 6 /* csumOfs */, - ETH_HLEN + IPV4_HLEN - 8 /* csumStart */, - IPPROTO_UDP /* partialCsum */, - true /* isUdp */) + .addDataCopy(3, etherIpv4UdpPacket.size) // arg1=src, arg2=len + .addTransmitL4( + ETH_HLEN, // ipOfs, + ETH_HLEN + IPV4_HLEN + 6, // csumOfs + ETH_HLEN + IPV4_HLEN - 8, // csumStart + IPPROTO_UDP, // partialCsum + true // isUdp + ) .generate() assertPass(MIN_APF_VERSION_IN_DEV, program, testPacket) val txBuf = ByteBuffer.wrap(ApfJniUtils.getTransmittedPacket()) @@ -570,32 +963,32 @@ fun testDnsQuestionMatch() { // needles = { A, B.LOCAL } val needlesMatch = intArrayOf( - 0x01, 'A'.code, - 0x00, - 0x01, 'B'.code, - 0x05, 'L'.code, 'O'.code, 'C'.code, 'A'.code, 'L'.code, - 0x00, - 0x00 + 0x01, 'A'.code, + 0x00, + 0x01, 'B'.code, + 0x05, 'L'.code, 'O'.code, 'C'.code, 'A'.code, 'L'.code, + 0x00, + 0x00 ).map { it.toByte() }.toByteArray() val udpPayload = intArrayOf( - 0x00, 0x00, 0x00, 0x00, // tid = 0x00, flags = 0x00, - 0x00, 0x02, // qdcount = 2 - 0x00, 0x00, // ancount = 0 - 0x00, 0x00, // nscount = 0 - 0x00, 0x00, // arcount = 0 - 0x01, 'a'.code, - 0x01, 'b'.code, - 0x05, 'l'.code, 'o'.code, 'c'.code, 'a'.code, 'l'.code, - 0x00, // qname1 = a.b.local - 0x00, 0x01, 0x00, 0x01, // type = A, class = 0x0001 - 0xc0, 0x0e, // qname2 = b.local (name compression) - 0x00, 0x01, 0x00, 0x01 // type = A, class = 0x0001 + 0x00, 0x00, 0x00, 0x00, // tid = 0x00, flags = 0x00, + 0x00, 0x02, // qdcount = 2 + 0x00, 0x00, // ancount = 0 + 0x00, 0x00, // nscount = 0 + 0x00, 0x00, // arcount = 0 + 0x01, 'a'.code, + 0x01, 'b'.code, + 0x05, 'l'.code, 'o'.code, 'c'.code, 'a'.code, 'l'.code, + 0x00, // qname1 = a.b.local + 0x00, 0x01, 0x00, 0x01, // type = A, class = 0x0001 + 0xc0, 0x0e, // qname2 = b.local (name compression) + 0x00, 0x01, 0x00, 0x01 // type = A, class = 0x0001 ).map { it.toByte() }.toByteArray() var program = ApfV6Generator() .addData(byteArrayOf()) .addLoadImmediate(R0, 0) - .addJumpIfPktAtR0ContainDnsQ(needlesMatch, 0x01 /* qtype */, DROP_LABEL) + .addJumpIfPktAtR0ContainDnsQ(needlesMatch, 0x01, DROP_LABEL) // arg2=qtype .addPass() .generate() assertDrop(MIN_APF_VERSION_IN_DEV, program, udpPayload) @@ -603,7 +996,7 @@ program = ApfV6Generator() .addData(byteArrayOf()) .addLoadImmediate(R0, 0) - .addJumpIfPktAtR0ContainDnsQSafe(needlesMatch, 0x01 /* qtype */, DROP_LABEL) + .addJumpIfPktAtR0ContainDnsQSafe(needlesMatch, 0x01, DROP_LABEL) .addPass() .generate() assertDrop(MIN_APF_VERSION_IN_DEV, program, udpPayload) @@ -611,7 +1004,7 @@ program = ApfV6Generator() .addData(byteArrayOf()) .addLoadImmediate(R0, 0) - .addJumpIfPktAtR0DoesNotContainDnsQ(needlesMatch, 0x01 /* qtype */, DROP_LABEL) + .addJumpIfPktAtR0DoesNotContainDnsQ(needlesMatch, 0x01, DROP_LABEL) // arg2=qtype .addPass() .generate() assertPass(MIN_APF_VERSION_IN_DEV, program, udpPayload) @@ -619,7 +1012,7 @@ program = ApfV6Generator() .addData(byteArrayOf()) .addLoadImmediate(R0, 0) - .addJumpIfPktAtR0DoesNotContainDnsQSafe(needlesMatch, 0x01 /* qtype */, DROP_LABEL) + .addJumpIfPktAtR0DoesNotContainDnsQSafe(needlesMatch, 0x01, DROP_LABEL) // arg2=qtype .addPass() .generate() assertPass(MIN_APF_VERSION_IN_DEV, program, udpPayload) @@ -642,7 +1035,7 @@ program = ApfV6Generator() .addData(byteArrayOf()) .addLoadImmediate(R0, 0) - .addJumpIfPktAtR0ContainDnsQ(needlesMatch, 0x01 /* qtype */, DROP_LABEL) + .addJumpIfPktAtR0ContainDnsQ(needlesMatch, 0x01, DROP_LABEL) // arg2=qtype .addPass() .generate() var dataRegion = ByteArray(Counter.totalSize()) { 0 } @@ -650,12 +1043,13 @@ var counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.CORRUPT_DNS_PACKET to 1), counterMap) + Counter.CORRUPT_DNS_PACKET to 1 + ), counterMap) program = ApfV6Generator() .addData(byteArrayOf()) .addLoadImmediate(R0, 0) - .addJumpIfPktAtR0ContainDnsQSafe(needlesMatch, 0x01 /* qtype */, DROP_LABEL) + .addJumpIfPktAtR0ContainDnsQSafe(needlesMatch, 0x01, DROP_LABEL) // arg2=qtype .addPass() .generate() dataRegion = ByteArray(Counter.totalSize()) { 0 } @@ -663,7 +1057,8 @@ counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.CORRUPT_DNS_PACKET to 1), counterMap) + Counter.CORRUPT_DNS_PACKET to 1 + ), counterMap) } @Test @@ -759,7 +1154,8 @@ var counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.CORRUPT_DNS_PACKET to 1), counterMap) + Counter.CORRUPT_DNS_PACKET to 1 + ), counterMap) program = ApfV6Generator() .addData(byteArrayOf()) @@ -772,7 +1168,8 @@ counterMap = decodeCountersIntoMap(dataRegion) assertEquals(mapOf<Counter, Long>( Counter.TOTAL_PACKETS to 1, - Counter.CORRUPT_DNS_PACKET to 1), counterMap) + Counter.CORRUPT_DNS_PACKET to 1 + ), counterMap) } @Test
diff --git a/tests/unit/src/android/net/apf/Bpf2Apf.java b/tests/unit/src/android/net/apf/Bpf2Apf.java index 5d2f9a9..57c560e 100644 --- a/tests/unit/src/android/net/apf/Bpf2Apf.java +++ b/tests/unit/src/android/net/apf/Bpf2Apf.java
@@ -163,17 +163,17 @@ if (arg.equals("x")) { switch(opcode) { case "add": - gen.addAddR1(); + gen.addAddR1ToR0(); break; case "and": - gen.addAndR1(); + gen.addAndR0WithR1(); break; case "or": - gen.addOrR1(); + gen.addOrR0WithR1(); break; case "sub": gen.addNeg(R1); - gen.addAddR1(); + gen.addAddR1ToR0(); gen.addNeg(R1); break; }
diff --git a/tests/unit/src/android/net/apf/JumpTableTest.kt b/tests/unit/src/android/net/apf/JumpTableTest.kt index 2c48e38..f2f0015 100644 --- a/tests/unit/src/android/net/apf/JumpTableTest.kt +++ b/tests/unit/src/android/net/apf/JumpTableTest.kt
@@ -35,7 +35,7 @@ class JumpTableTest { @Mock - lateinit var gen: ApfV4Generator<ApfV4Generator<BaseApfGenerator>> + lateinit var gen: ApfV4Generator @Before fun setUp() {
diff --git a/tests/unit/src/android/net/dhcp/DhcpPacketTest.java b/tests/unit/src/android/net/dhcp/DhcpPacketTest.java index 42ea54b..2eedbfb 100644 --- a/tests/unit/src/android/net/dhcp/DhcpPacketTest.java +++ b/tests/unit/src/android/net/dhcp/DhcpPacketTest.java
@@ -239,7 +239,8 @@ assertNotNull(offerPacket); assertEquals(rawLeaseTime, offerPacket.mLeaseTime); DhcpResults dhcpResults = offerPacket.toDhcpResults(); // Just check this doesn't crash. - assertEquals(leaseTimeMillis, offerPacket.getLeaseTimeMillis()); + assertEquals(leaseTimeMillis, + offerPacket.getLeaseTimeMillis(DhcpPacket.DEFAULT_MINIMUM_LEASE)); } @Test
diff --git a/tests/unit/src/android/net/ip/IpClientTest.java b/tests/unit/src/android/net/ip/IpClientTest.java index 1849776..8d99b11 100644 --- a/tests/unit/src/android/net/ip/IpClientTest.java +++ b/tests/unit/src/android/net/ip/IpClientTest.java
@@ -16,8 +16,20 @@ package android.net.ip; +import static android.net.ip.IpClientLinkObserver.CONFIG_SOCKET_RECV_BUFSIZE; +import static android.net.ip.IpClientLinkObserver.SOCKET_RECV_BUFSIZE; import static android.system.OsConstants.RT_SCOPE_UNIVERSE; +import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT; +import static com.android.net.module.util.netlink.NetlinkConstants.RTPROT_KERNEL; +import static com.android.net.module.util.netlink.NetlinkConstants.RTM_DELROUTE; +import static com.android.net.module.util.netlink.NetlinkConstants.RTM_NEWADDR; +import static com.android.net.module.util.netlink.NetlinkConstants.RTM_NEWNDUSEROPT; +import static com.android.net.module.util.netlink.NetlinkConstants.RTM_NEWROUTE; +import static com.android.net.module.util.netlink.NetlinkConstants.RTN_UNICAST; +import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_ACK; +import static com.android.net.module.util.netlink.StructNlMsgHdr.NLM_F_REQUEST; + import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -27,6 +39,7 @@ import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.clearInvocations; import static org.mockito.Mockito.doReturn; @@ -34,7 +47,6 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -57,6 +69,8 @@ import android.net.RouteInfo; import android.net.apf.ApfCapabilities; import android.net.apf.ApfFilter.ApfConfiguration; +import android.net.ip.IpClientLinkObserver.IpClientNetlinkMonitor; +import android.net.ip.IpClientLinkObserver.IpClientNetlinkMonitor.INetlinkMessageProcessor; import android.net.ipmemorystore.NetworkAttributes; import android.net.metrics.IpConnectivityLog; import android.net.shared.InitialConfiguration; @@ -64,15 +78,21 @@ import android.net.shared.ProvisioningConfiguration; import android.net.shared.ProvisioningConfiguration.ScanResultInfo; import android.os.Build; +import android.system.OsConstants; import androidx.test.filters.SmallTest; import androidx.test.runner.AndroidJUnit4; import com.android.net.module.util.InterfaceParams; +import com.android.net.module.util.netlink.NduseroptMessage; +import com.android.net.module.util.netlink.RtNetlinkAddressMessage; +import com.android.net.module.util.netlink.RtNetlinkRouteMessage; +import com.android.net.module.util.netlink.StructIfaddrMsg; +import com.android.net.module.util.netlink.StructNdOptRdnss; +import com.android.net.module.util.netlink.StructNlMsgHdr; +import com.android.net.module.util.netlink.StructRtMsg; import com.android.networkstack.R; import com.android.networkstack.ipmemorystore.IpMemoryStoreService; -import com.android.server.NetworkObserver; -import com.android.server.NetworkObserverRegistry; import com.android.server.NetworkStackService; import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter; @@ -135,10 +155,12 @@ private static final String TEST_IPV6_GATEWAY = "fd2c:4e57:8e3c::43"; private static final String TEST_IPV4_GATEWAY = "192.168.42.11"; private static final long TEST_DNS_LIFETIME = 3600; + // `whenMs` param in processNetlinkMessage is only used to process PREF64 option in RA, which + // is not used for RTM_NEWADDR, RTM_NEWROUTE and RDNSS option. + private static final long TEST_UNUSED_REAL_TIME = 0; @Mock private Context mContext; @Mock private ConnectivityManager mCm; - @Mock private NetworkObserverRegistry mObserverRegistry; @Mock private INetd mNetd; @Mock private Resources mResources; @Mock private IIpClientCallbacks mCb; @@ -152,9 +174,10 @@ @Mock private IpConnectivityLog mMetricsLog; @Mock private FileDescriptor mFd; @Mock private PrintWriter mWriter; + @Mock private IpClientNetlinkMonitor mNetlinkMonitor; - private NetworkObserver mObserver; private InterfaceParams mIfParams; + private INetlinkMessageProcessor mNetlinkMessageProcessor; @Before public void setUp() throws Exception { @@ -172,6 +195,11 @@ when(mDependencies.getIpMemoryStore(mContext, mNetworkStackServiceManager)) .thenReturn(mIpMemoryStore); when(mDependencies.getIpConnectivityLog()).thenReturn(mMetricsLog); + when(mDependencies.getDeviceConfigPropertyInt(eq(CONFIG_SOCKET_RECV_BUFSIZE), anyInt())) + .thenReturn(SOCKET_RECV_BUFSIZE); + when(mDependencies.makeIpClientNetlinkMonitor( + any(), any(), any(), anyInt(), any())).thenReturn(mNetlinkMonitor); + when(mNetlinkMonitor.start()).thenReturn(true); mIfParams = null; } @@ -185,14 +213,15 @@ private IpClient makeIpClient(String ifname) throws Exception { setTestInterfaceParams(ifname); - final IpClient ipc = new IpClient(mContext, ifname, mCb, mObserverRegistry, - mNetworkStackServiceManager, mDependencies); + final IpClient ipc = + new IpClient(mContext, ifname, mCb, mNetworkStackServiceManager, mDependencies); verify(mNetd, timeout(TEST_TIMEOUT_MS).times(1)).interfaceSetEnableIPv6(ifname, false); verify(mNetd, timeout(TEST_TIMEOUT_MS).times(1)).interfaceClearAddrs(ifname); - ArgumentCaptor<NetworkObserver> arg = ArgumentCaptor.forClass(NetworkObserver.class); - verify(mObserverRegistry, times(1)).registerObserverForNonblockingCallback(arg.capture()); - mObserver = arg.getValue(); - reset(mObserverRegistry); + final ArgumentCaptor<INetlinkMessageProcessor> processorCaptor = + ArgumentCaptor.forClass(INetlinkMessageProcessor.class); + verify(mDependencies).makeIpClientNetlinkMonitor(any(), any(), any(), anyInt(), + processorCaptor.capture()); + mNetlinkMessageProcessor = processorCaptor.getValue(); reset(mNetd); // Verify IpClient doesn't call onLinkPropertiesChange() when it starts. verify(mCb, never()).onLinkPropertiesChange(any()); @@ -212,12 +241,87 @@ // verify(mIpMemoryStore).storeNetworkAttributes(eq(l2Key), eq(attributes), any()); } + private static StructNlMsgHdr makeNetlinkMessageHeader(short type, short flags) { + final StructNlMsgHdr nlmsghdr = new StructNlMsgHdr(); + nlmsghdr.nlmsg_type = type; + nlmsghdr.nlmsg_flags = flags; + nlmsghdr.nlmsg_seq = 1; + return nlmsghdr; + } + + private static RtNetlinkAddressMessage buildRtmAddressMessage(short type, final LinkAddress la, + int ifindex, int flags) { + final StructNlMsgHdr nlmsghdr = + makeNetlinkMessageHeader(type, (short) (NLM_F_REQUEST | NLM_F_ACK)); + InetAddress ip = la.getAddress(); + final byte family = + (byte) ((ip instanceof Inet6Address) ? OsConstants.AF_INET6 : OsConstants.AF_INET); + StructIfaddrMsg ifaddrMsg = new StructIfaddrMsg(family, + (short) la.getPrefixLength(), + (short) la.getFlags(), (short) la.getScope(), ifindex); + + return new RtNetlinkAddressMessage(nlmsghdr, ifaddrMsg, ip, + null /* structIfacacheInfo */, flags); + } + + private static RtNetlinkRouteMessage buildRtmRouteMessage(short type, final RouteInfo route, + int ifindex) { + final StructNlMsgHdr nlmsghdr = + makeNetlinkMessageHeader(type, (short) (NLM_F_REQUEST | NLM_F_ACK)); + final IpPrefix destination = route.getDestination(); + final byte family = (byte) ((destination.getAddress() instanceof Inet6Address) + ? OsConstants.AF_INET6 + : OsConstants.AF_INET); + + final StructRtMsg rtMsg = new StructRtMsg(family, + (short) destination.getPrefixLength() /* dstLen */, (short) 0 /* srcLen */, + (short) 0 /* tos */, (short) 0xfd /* main table */, RTPROT_KERNEL /* protocol */, + (short) RT_SCOPE_UNIVERSE /* scope */, RTN_UNICAST /* type */, 0 /* flags */); + return new RtNetlinkRouteMessage(nlmsghdr, rtMsg, null /* source */, route.getDestination(), + route.getGateway(), 0 /* iif */, ifindex /* oif */, null /* cacheInfo */); + } + + private static NduseroptMessage buildNduseroptMessage(int ifindex, long lifetime, + final String[] servers) { + final StructNlMsgHdr nlmsghdr = + makeNetlinkMessageHeader(RTM_NEWNDUSEROPT, (short) (NLM_F_REQUEST | NLM_F_ACK)); + final Inet6Address[] serverArray = new Inet6Address[servers.length]; + for (int i = 0; i < servers.length; i++) { + serverArray[i] = (Inet6Address) InetAddresses.parseNumericAddress(servers[i]); + } + final StructNdOptRdnss option = new StructNdOptRdnss(serverArray, lifetime); + return new NduseroptMessage(nlmsghdr, (byte) OsConstants.AF_INET6 /* family */, + 0 /* opts_len */, ifindex, (byte) ICMPV6_ROUTER_ADVERTISEMENT /* icmp_type */, + (byte) 0 /* icmp_code */, option, null /* srcaddr */); + } + + private void onInterfaceAddressUpdated(final LinkAddress la, int flags) { + final RtNetlinkAddressMessage msg = + buildRtmAddressMessage(RTM_NEWADDR, la, TEST_IFINDEX, flags); + mNetlinkMessageProcessor.processNetlinkMessage(msg, TEST_UNUSED_REAL_TIME /* whenMs */); + } + + private void onRouteUpdated(final RouteInfo route) { + final RtNetlinkRouteMessage msg = buildRtmRouteMessage(RTM_NEWROUTE, route, TEST_IFINDEX); + mNetlinkMessageProcessor.processNetlinkMessage(msg, TEST_UNUSED_REAL_TIME /* whenMs */); + } + + private void onRouteRemoved(final RouteInfo route) { + final RtNetlinkRouteMessage msg = buildRtmRouteMessage(RTM_DELROUTE, route, TEST_IFINDEX); + mNetlinkMessageProcessor.processNetlinkMessage(msg, TEST_UNUSED_REAL_TIME /* whenMs */); + } + + private void onInterfaceDnsServerInfo(long lifetime, final String[] dnsServers) { + final NduseroptMessage msg = buildNduseroptMessage(TEST_IFINDEX, lifetime, dnsServers); + mNetlinkMessageProcessor.processNetlinkMessage(msg, TEST_UNUSED_REAL_TIME /* whenMs */); + } + @Test public void testNullInterfaceNameMostDefinitelyThrows() throws Exception { setTestInterfaceParams(null); try { - final IpClient ipc = new IpClient(mContext, null, mCb, mObserverRegistry, - mNetworkStackServiceManager, mDependencies); + final IpClient ipc = new IpClient(mContext, null, mCb, mNetworkStackServiceManager, + mDependencies); ipc.shutdown(); fail(); } catch (NullPointerException npe) { @@ -230,8 +334,8 @@ final String ifname = "lo"; setTestInterfaceParams(ifname); try { - final IpClient ipc = new IpClient(mContext, ifname, null, mObserverRegistry, - mNetworkStackServiceManager, mDependencies); + final IpClient ipc = new IpClient(mContext, ifname, null, mNetworkStackServiceManager, + mDependencies); ipc.shutdown(); fail(); } catch (NullPointerException npe) { @@ -242,8 +346,8 @@ @Test public void testInvalidInterfaceDoesNotThrow() throws Exception { setTestInterfaceParams(TEST_IFNAME); - final IpClient ipc = new IpClient(mContext, TEST_IFNAME, mCb, mObserverRegistry, - mNetworkStackServiceManager, mDependencies); + final IpClient ipc = new IpClient(mContext, TEST_IFNAME, mCb, mNetworkStackServiceManager, + mDependencies); verifyNoMoreInteractions(mIpMemoryStore); ipc.shutdown(); } @@ -251,8 +355,8 @@ @Test public void testInterfaceNotFoundFailsImmediately() throws Exception { setTestInterfaceParams(null); - final IpClient ipc = new IpClient(mContext, TEST_IFNAME, mCb, mObserverRegistry, - mNetworkStackServiceManager, mDependencies); + final IpClient ipc = new IpClient(mContext, TEST_IFNAME, mCb, mNetworkStackServiceManager, + mDependencies); ipc.startProvisioning(new ProvisioningConfiguration()); verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).onProvisioningFailure(any()); verify(mIpMemoryStore, never()).storeNetworkAttributes(any(), any(), any()); @@ -286,9 +390,10 @@ verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).setFallbackMulticastFilter(false); final LinkProperties lp = makeIPv6ProvisionedLinkProperties(); - lp.getRoutes().forEach(mObserver::onRouteUpdated); - lp.getLinkAddresses().forEach(la -> mObserver.onInterfaceAddressUpdated(la, TEST_IFNAME)); - mObserver.onInterfaceDnsServerInfo(TEST_IFNAME, TEST_DNS_LIFETIME, + lp.getRoutes().forEach(route -> onRouteUpdated(route)); + lp.getLinkAddresses().forEach( + la -> onInterfaceAddressUpdated(la, la.getFlags())); + onInterfaceDnsServerInfo(TEST_DNS_LIFETIME, lp.getDnsServers().stream().map(InetAddress::getHostAddress) .toArray(String[]::new)); @@ -305,8 +410,8 @@ final LinkAddress la = new LinkAddress(TEST_IPV4_LINKADDRESS); final RouteInfo defaultRoute = new RouteInfo(new IpPrefix(Inet4Address.ANY, 0), InetAddresses.parseNumericAddress(TEST_IPV4_GATEWAY), TEST_IFNAME); - mObserver.onInterfaceAddressUpdated(la, TEST_IFNAME); - mObserver.onRouteUpdated(defaultRoute); + onInterfaceAddressUpdated(la, la.getFlags()); + onRouteUpdated(defaultRoute); lp.addLinkAddress(la); lp.addRoute(defaultRoute); @@ -319,7 +424,7 @@ */ private void doIPv6ProvisioningLoss(LinkProperties lp) { final RouteInfo defaultRoute = defaultIPV6Route(TEST_IPV6_GATEWAY); - mObserver.onRouteRemoved(defaultRoute); + onRouteRemoved(defaultRoute); lp.removeRoute(defaultRoute); } @@ -413,13 +518,13 @@ // Add N - 1 addresses for (int i = 0; i < lastAddr; i++) { - mObserver.onInterfaceAddressUpdated(new LinkAddress(TEST_LOCAL_ADDRESSES[i]), iface); + onInterfaceAddressUpdated(new LinkAddress(TEST_LOCAL_ADDRESSES[i]), 0 /* flags */); verify(mCb, timeout(TEST_TIMEOUT_MS)).onLinkPropertiesChange(any()); reset(mCb); } // Add Nth address - mObserver.onInterfaceAddressUpdated(new LinkAddress(TEST_LOCAL_ADDRESSES[lastAddr]), iface); + onInterfaceAddressUpdated(new LinkAddress(TEST_LOCAL_ADDRESSES[lastAddr]), 0 /* flags */); LinkProperties want = linkproperties(links(TEST_LOCAL_ADDRESSES), routes(TEST_PREFIXES), emptySet() /* dnses */); want.setInterfaceName(iface);
diff --git a/tests/unit/src/com/android/networkstack/NetworkStackServiceTest.kt b/tests/unit/src/com/android/networkstack/NetworkStackServiceTest.kt index 4c4864b..7770eca 100644 --- a/tests/unit/src/com/android/networkstack/NetworkStackServiceTest.kt +++ b/tests/unit/src/com/android/networkstack/NetworkStackServiceTest.kt
@@ -36,15 +36,22 @@ import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.filters.SmallTest import com.android.net.module.util.Inet4AddressUtils.inet4AddressToIntHTH +import com.android.networkstack.ipmemorystore.IpMemoryStoreService import com.android.server.NetworkStackService.Dependencies import com.android.server.NetworkStackService.NetworkStackConnector import com.android.server.NetworkStackService.PermissionChecker import com.android.server.connectivity.NetworkMonitor -import com.android.networkstack.ipmemorystore.IpMemoryStoreService import com.android.testutils.DevSdkIgnoreRule import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo import com.android.testutils.assertThrows +import java.io.FileDescriptor +import java.io.PrintWriter +import java.io.StringWriter +import java.net.Inet4Address +import kotlin.reflect.KVisibility +import kotlin.reflect.full.declaredMemberFunctions +import kotlin.test.assertEquals import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith @@ -56,13 +63,6 @@ import org.mockito.Mockito.spy import org.mockito.Mockito.times import org.mockito.Mockito.verify -import java.io.FileDescriptor -import java.io.PrintWriter -import java.io.StringWriter -import java.net.Inet4Address -import kotlin.reflect.KVisibility -import kotlin.reflect.full.declaredMemberFunctions -import kotlin.test.assertEquals private val TEST_NETD_VERSION = 9991001 private val TEST_NETD_HASH = "test_netd_hash" @@ -85,7 +85,7 @@ doReturn(mockDhcpServer).`when`(this).makeDhcpServer(any(), any(), any(), any()) doReturn(mockNetworkMonitor).`when`(this).makeNetworkMonitor(any(), any(), any(), any(), any()) - doReturn(mockIpClient).`when`(this).makeIpClient(any(), any(), any(), any(), any()) + doReturn(mockIpClient).`when`(this).makeIpClient(any(), any(), any(), any()) } private val netd = mock(INetd::class.java).apply { doReturn(TEST_NETD_VERSION).`when`(this).interfaceVersion @@ -195,7 +195,7 @@ connector.makeIpClient(TEST_IFACE, mockIpClientCb) - verify(deps).makeIpClient(any(), eq(TEST_IFACE), any(), any(), any()) + verify(deps).makeIpClient(any(), eq(TEST_IFACE), any(), any()) verify(mockIpClientCb).onIpClientCreated(any()) // Call some methods one more time with a shared version number and hash to verify no
diff --git a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java index 4944812..e8f14b5 100644 --- a/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java +++ b/tests/unit/src/com/android/networkstack/netlink/TcpSocketTrackerTest.java
@@ -32,6 +32,7 @@ import static junit.framework.Assert.assertFalse; import static junit.framework.Assert.assertTrue; +import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; @@ -54,14 +55,19 @@ import android.net.Network; import android.net.NetworkCapabilities; import android.os.Build; +import android.os.Handler; +import android.os.Looper; import android.os.PowerManager; import android.util.Log; import android.util.Log.TerribleFailureHandler; import androidx.test.filters.SmallTest; +import androidx.test.platform.app.InstrumentationRegistry; import androidx.test.runner.AndroidJUnit4; import com.android.modules.utils.build.SdkLevel; +import com.android.net.module.util.DeviceConfigUtils; +import com.android.net.module.util.FeatureVersions; import com.android.net.module.util.netlink.NetlinkUtils; import com.android.net.module.util.netlink.StructNlMsgHdr; import com.android.testutils.DevSdkIgnoreRule; @@ -151,6 +157,7 @@ private final Network mOtherNetwork = new Network(TEST_NETID2); private TerribleFailureHandler mOldWtfHandler; @Mock private Context mContext; + private final Context mRealContext = InstrumentationRegistry.getInstrumentation().getContext(); @Mock private PowerManager mPowerManager; @Mock private ConnectivityManager mCm; @@ -343,6 +350,9 @@ @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @Test public void testPollSocketsInfo_ignoreBlockedUid_featureDisabled_UOrAbove() throws Exception { + // Test only if the Tethering module is new enough to support the API. + assumeTrue(DeviceConfigUtils.isFeatureSupported(mRealContext, + FeatureVersions.FEATURE_IS_UID_NETWORKING_BLOCKED)); doTestPollSocketsInfo_ignoreBlockedUid_featureDisabled(); verify(mCm, never()).isUidNetworkingBlocked(anyInt(), anyBoolean()); } @@ -377,6 +387,9 @@ @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @Test public void testPollSocketsInfo_ignoreBlockedUid_featureEnabled() throws Exception { + // Test only if the Tethering module is new enough to support the API. + assumeTrue(DeviceConfigUtils.isFeatureSupported(mRealContext, + FeatureVersions.FEATURE_IS_UID_NETWORKING_BLOCKED)); doReturn(true).when(mDependencies).shouldIgnoreTcpInfoForBlockedUids(); final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); tst.setNetworkCapabilities(CELL_NOT_METERED_CAPABILITIES); @@ -409,6 +422,9 @@ @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @Test public void testPollSocketsInfo_ignoreBlockedUid_featureEnabled_dataSaver() throws Exception { + // Test only if the Tethering module is new enough to support the API. + assumeTrue(DeviceConfigUtils.isFeatureSupported(mRealContext, + FeatureVersions.FEATURE_IS_UID_NETWORKING_BLOCKED)); doReturn(true).when(mDependencies).shouldIgnoreTcpInfoForBlockedUids(); final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); @@ -694,11 +710,14 @@ doTestTcpInfoDisableParsingWithDozeMode(DEEP_DOZE, true /* featureEnabled */); } - // Ignore blocked uids is supported on T. Thus, for pre-T device this feature is always + // Ignore blocked uids is supported on U. Thus, for pre-U device this feature is always // needed since there is no replacement. - @IgnoreUpTo(Build.VERSION_CODES.S_V2) + @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU) @Test public void testTcpInfoParsingWithDozeMode_disabled() throws Exception { + // Test only if the Tethering module is new enough to support the API. + assumeTrue(DeviceConfigUtils.isFeatureSupported(mRealContext, + FeatureVersions.FEATURE_IS_UID_NETWORKING_BLOCKED)); doReturn(true).when(mDependencies).shouldIgnoreTcpInfoForBlockedUids(); doReturn(false).when(mDependencies).shouldDisableInLightDoze(anyBoolean()); doTestTcpInfoDisableParsingWithDozeMode(DEEP_DOZE, false /* featureEnabled */); @@ -720,12 +739,20 @@ boolean featureEnabled) throws Exception { final TcpSocketTracker tst = new TcpSocketTracker(mDependencies, mNetwork); tst.setNetworkCapabilities(CELL_NOT_METERED_CAPABILITIES); + + // Verify that device idle mode receiver does not register as the event for NM creation + // is not yet received. + verify(mDependencies, never()).addDeviceIdleReceiver(any(), + anyBoolean(), anyBoolean(), any()); + + final Handler nmHandler = new Handler(Looper.getMainLooper()); + tst.init(nmHandler, new LinkProperties(), CELL_NOT_METERED_CAPABILITIES); final ArgumentCaptor<BroadcastReceiver> receiverCaptor = ArgumentCaptor.forClass(BroadcastReceiver.class); // Enable doze mode with 1 netlink message. verify(mDependencies).addDeviceIdleReceiver(receiverCaptor.capture(), - anyBoolean(), anyBoolean()); + anyBoolean(), anyBoolean(), eq(nmHandler)); final BroadcastReceiver receiver = receiverCaptor.getValue(); if (dozeModeType == DEEP_DOZE) { doReturn(true).when(mPowerManager).isDeviceIdleMode();
diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java index d6e9c8e..a949f80 100644 --- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java +++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -81,6 +81,7 @@ import static org.junit.Assert.fail; import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeTrue; +import static org.mockito.AdditionalMatchers.aryEq; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -130,6 +131,8 @@ import android.net.Uri; import android.net.captiveportal.CaptivePortalProbeResult; import android.net.metrics.IpConnectivityLog; +import android.net.metrics.NetworkEvent; +import android.net.metrics.ValidationProbeEvent; import android.net.networkstack.aidl.NetworkMonitorParameters; import android.net.shared.PrivateDnsConfig; import android.net.wifi.WifiInfo; @@ -177,11 +180,11 @@ import com.android.server.connectivity.nano.CellularData; import com.android.server.connectivity.nano.DnsEvent; import com.android.server.connectivity.nano.WifiData; +import com.android.testutils.ConcurrentUtils; import com.android.testutils.DevSdkIgnoreRule; import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter; import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo; import com.android.testutils.DevSdkIgnoreRunner; -import com.android.testutils.FunctionalUtils.ThrowingConsumer; import com.android.testutils.HandlerUtils; import com.google.protobuf.nano.MessageNano; @@ -229,7 +232,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.Supplier; +import java.util.function.Predicate; import javax.net.ssl.SSLHandshakeException; @@ -329,6 +332,7 @@ private static final NetworkAgentConfigShim TEST_AGENT_CONFIG = NetworkAgentConfigShimImpl.newInstance(null); private static final LinkProperties TEST_LINK_PROPERTIES = new LinkProperties(); + // Each thread that runs isCaptivePortal could generate 2 more probing threads. private static final int THREAD_QUIT_MAX_RETRY_COUNT = 3; // Cannot have a static member for the LinkProperties with captive portal API information, as @@ -629,6 +633,7 @@ doReturn(0).when(mRandom).nextInt(); doReturn(mNetd).when(mTstDependencies).getNetd(); + doNothing().when(mTst).init(any(), any(), any()); // DNS probe timeout should not be defined more than half of HANDLER_TIMEOUT_MS. Otherwise, // it will fail the test because of timeout expired for querying AAAA and A sequentially. doReturn(200).when(mResources) @@ -710,32 +715,15 @@ setConsecutiveDnsTimeoutThreshold(5); } - private static <T> void quitResourcesThat(Supplier<List<T>> supplier, - ThrowingConsumer terminator) throws Exception { - // Run it multiple times since new threads might be generated in a thread - // that is about to be terminated, e.g. each thread that runs - // isCaptivePortal could generate 2 more probing threads. - for (int retryCount = 0; retryCount < THREAD_QUIT_MAX_RETRY_COUNT; retryCount++) { - final List<T> resourcesToBeCleared = supplier.get(); - if (resourcesToBeCleared.isEmpty()) return; - for (final T resource : resourcesToBeCleared) { - terminator.accept(resource); - } - } - - assertEquals(Collections.emptyList(), supplier.get()); - } - private void quitNetworkMonitors() throws Exception { - quitResourcesThat(() -> { + ConcurrentUtils.quitResources(THREAD_QUIT_MAX_RETRY_COUNT, () -> { synchronized (mCreatedNetworkMonitors) { final ArrayList<WrappedNetworkMonitor> ret = new ArrayList<>(mCreatedNetworkMonitors); mCreatedNetworkMonitors.clear(); return ret; } - }, (it) -> { - final WrappedNetworkMonitor nm = (WrappedNetworkMonitor) it; + }, nm -> { nm.notifyNetworkDisconnected(); nm.awaitQuit(); }); @@ -748,31 +736,33 @@ } private void quitExecutorServices() throws Exception { - quitResourcesThat(() -> { - synchronized (mExecutorServiceToBeCleared) { - final ArrayList<ExecutorService> ret = new ArrayList<>(mExecutorServiceToBeCleared); - mExecutorServiceToBeCleared.clear(); - return ret; - } - }, (it) -> { - final ExecutorService ecs = (ExecutorService) it; - ecs.awaitTermination(HANDLER_TIMEOUT_MS, TimeUnit.MILLISECONDS); - }); + ConcurrentUtils.quitExecutorServices( + THREAD_QUIT_MAX_RETRY_COUNT, + // ExecutorService should already have been terminated by NetworkMonitor. + false /* interrupt */, + HANDLER_TIMEOUT_MS, + () -> { + synchronized (mExecutorServiceToBeCleared) { + final ArrayList<ExecutorService> ret = + new ArrayList<>(mExecutorServiceToBeCleared); + mExecutorServiceToBeCleared.clear(); + return ret; + } + }); } private void quitThreads() throws Exception { - quitResourcesThat(() -> { - synchronized (mThreadsToBeCleared) { - final ArrayList<Thread> ret = new ArrayList<>(mThreadsToBeCleared); - mThreadsToBeCleared.clear(); - return ret; - } - }, (it) -> { - final Thread th = (Thread) it; - th.interrupt(); - th.join(HANDLER_TIMEOUT_MS); - if (th.isAlive()) fail("Threads did not terminate within timeout."); - }); + ConcurrentUtils.quitThreads( + THREAD_QUIT_MAX_RETRY_COUNT, + true /* interrupt */, + HANDLER_TIMEOUT_MS, + () -> { + synchronized (mThreadsToBeCleared) { + final ArrayList<Thread> ret = new ArrayList<>(mThreadsToBeCleared); + mThreadsToBeCleared.clear(); + return ret; + } + }); } @After @@ -2741,12 +2731,30 @@ } @Test + public void testTcpSocketTracker_init() throws Exception { + setDataStallEvaluationType(DATA_STALL_EVALUATION_TYPE_TCP); + final WrappedNetworkMonitor wnm = makeCellMeteredNetworkMonitor(); + // makeCellMeteredNetworkMonitor() creates the NM first and then assign + // new NetworkCapabilities, so notifyNMCreated() will start with a empty NC + // then update CELL_METERED_CAPABILITIES in the follow up call. + final InOrder inOrder = inOrder(mTst); + inOrder.verify(mTst).init( + eq(wnm.getHandler()), + eq(new LinkProperties()), + eq(new NetworkCapabilities(null))); + inOrder.verify(mTst).setNetworkCapabilities(eq(CELL_METERED_CAPABILITIES)); + } + + @Test public void testDataStall_setOpportunisticMode() { setDataStallEvaluationType(DATA_STALL_EVALUATION_TYPE_TCP); WrappedNetworkMonitor wnm = makeCellNotMeteredNetworkMonitor(); InOrder inOrder = inOrder(mTst); - // Initialized with default value. - inOrder.verify(mTst).setOpportunisticMode(false); + // Initialized. + inOrder.verify(mTst).init( + eq(wnm.getHandler()), + eq(new LinkProperties()), + eq(new NetworkCapabilities(null))); // Strict mode. wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0])); @@ -3269,6 +3277,101 @@ TEST_REDIRECT_URL, 1 /* interactions */); } + private void doLegacyConnectivityLogTest() throws Exception { + mFakeDns.setAnswer("www.google.com", () -> { + // Make sure the DNS probes take at least 1ms + SystemClock.sleep(1); + return List.of(parseNumericAddress("2001:db8::443")); + }, TYPE_AAAA); + mFakeDns.setAnswer(PRIVATE_DNS_PROBE_HOST_SUFFIX, () -> { + SystemClock.sleep(1); + return List.of(parseNumericAddress("2001:db8::444")); + }, TYPE_AAAA); + setStatus(mHttpsConnection, 204); + setStatus(mHttpConnection, 204); + + mFakeDns.setAnswer("dns6.google", new String[]{"2001:db8::53"}, TYPE_AAAA); + WrappedNetworkMonitor wnm = makeCellNotMeteredNetworkMonitor(); + wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns6.google", + new InetAddress[0])); + notifyNetworkConnected(wnm, CELL_NOT_METERED_CAPABILITIES); + verifyNetworkTestedValidFromPrivateDns(1 /* interactions */); + + + final ArgumentCaptor<IpConnectivityLog.Event> eventCaptor = + ArgumentCaptor.forClass(IpConnectivityLog.Event.class); + verify(mLogger, atLeastOnce()).log(eq(mCleartextDnsNetwork), + aryEq(CELL_METERED_CAPABILITIES.getTransportTypes()), + eventCaptor.capture()); + + final List<IpConnectivityLog.Event> events = eventCaptor.getAllValues(); + final String msg = "Did not find the expected event; events are " + events; + + final int firstValidation = 1 << 8; + + assertHasEvent(msg, NetworkEvent.class, events, 0, + e -> e.eventType == NetworkEvent.NETWORK_CONNECTED); + + final int probesStartIndex = 1; + assertHasEvent(msg, ValidationProbeEvent.class, events, probesStartIndex, + e -> e.probeType == (firstValidation | ValidationProbeEvent.PROBE_DNS) + && e.returnCode == ValidationProbeEvent.DNS_SUCCESS + && e.durationMs >= 1L && e.durationMs < 1000L); + // The first probe has to be DNS, but then the order of the next DNS probe and HTTP/HTTPS + // probes is unknown. + final int httpProbesStartIndex = 2; + assertHasEvent(msg, ValidationProbeEvent.class, events, httpProbesStartIndex, + e -> e.probeType == (firstValidation | ValidationProbeEvent.PROBE_DNS) + && e.returnCode == ValidationProbeEvent.DNS_SUCCESS + && e.durationMs >= 1L && e.durationMs < 1000L); + + final int httpsProbeIndex = assertHasEvent(msg, ValidationProbeEvent.class, events, + httpProbesStartIndex, + e -> e.probeType == (firstValidation | ValidationProbeEvent.PROBE_HTTPS) + && e.returnCode == 204); + + assertHasEvent(msg, ValidationProbeEvent.class, events, httpProbesStartIndex, + e -> e.probeType == (firstValidation | ValidationProbeEvent.PROBE_HTTP) + && e.returnCode == 204); + + // Private DNS starts after validation, so at least after the HTTPS probe + final int privDnsIndex = assertHasEvent(msg, ValidationProbeEvent.class, events, + httpsProbeIndex + 1, + e -> e.probeType == (firstValidation | ValidationProbeEvent.PROBE_PRIVDNS) + && e.returnCode == ValidationProbeEvent.DNS_SUCCESS + && e.durationMs >= 1L && e.durationMs < 1000L); + + assertHasEvent(msg, NetworkEvent.class, events, privDnsIndex + 1, + e -> e.eventType == NetworkEvent.NETWORK_FIRST_VALIDATION_SUCCESS); + } + + private static <T> int assertHasEvent(String msg, Class<T> clazz, + List<IpConnectivityLog.Event> events, int startIdx, + Predicate<T> predicate) { + for (int i = startIdx; i < events.size(); i++) { + if (events.get(i).getClass().isAssignableFrom(clazz) + && predicate.test((T) events.get(i))) { + return i; + } + } + fail(msg + " at startIdx " + startIdx); + return -1; + } + + @Test + public void testLegacyConnectivityLog_SyncDns() throws Exception { + doReturn(false).when(mDependencies).isFeatureEnabled( + any(), eq(NetworkStackUtils.NETWORKMONITOR_ASYNC_PRIVDNS_RESOLUTION)); + doLegacyConnectivityLogTest(); + } + + @Test + public void testLegacyConnectivityLog_AsyncDns() throws Exception { + doReturn(true).when(mDependencies).isFeatureEnabled( + any(), eq(NetworkStackUtils.NETWORKMONITOR_ASYNC_PRIVDNS_RESOLUTION)); + doLegacyConnectivityLogTest(); + } + @Test public void testExtractCharset() { assertEquals(StandardCharsets.UTF_8, extractCharset(null));