Snap for 7277907 from ffc11490cad94d75d6823ea854f5242fda7ecf64 to mainline-permission-release

Change-Id: I8d3f17b6d5c0b87e5d89f69ebbddb5d0fddd154a
diff --git a/Tethering/apishim/30/com/android/networkstack/tethering/apishim/api30/BpfCoordinatorShimImpl.java b/Tethering/apishim/30/com/android/networkstack/tethering/apishim/api30/BpfCoordinatorShimImpl.java
index f27c831..e310fb6 100644
--- a/Tethering/apishim/30/com/android/networkstack/tethering/apishim/api30/BpfCoordinatorShimImpl.java
+++ b/Tethering/apishim/30/com/android/networkstack/tethering/apishim/api30/BpfCoordinatorShimImpl.java
@@ -80,12 +80,14 @@
 
     @Override
     public boolean startUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex,
-            MacAddress srcMac, MacAddress dstMac, int mtu) {
+            @NonNull MacAddress inDstMac, @NonNull MacAddress outSrcMac,
+            @NonNull MacAddress outDstMac, int mtu) {
         return true;
     }
 
     @Override
-    public boolean stopUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex) {
+    public boolean stopUpstreamIpv6Forwarding(int downstreamIfindex,
+            int upstreamIfindex, @NonNull MacAddress inDstMac) {
         return true;
     }
 
@@ -171,6 +173,12 @@
     }
 
     @Override
+    public boolean isAnyIpv4RuleOnUpstream(int ifIndex) {
+        /* no op */
+        return false;
+    }
+
+    @Override
     public String toString() {
         return "Netd used";
     }
diff --git a/Tethering/apishim/31/com/android/networkstack/tethering/apishim/api31/BpfCoordinatorShimImpl.java b/Tethering/apishim/31/com/android/networkstack/tethering/apishim/api31/BpfCoordinatorShimImpl.java
index 4f7fe65..d7ce139 100644
--- a/Tethering/apishim/31/com/android/networkstack/tethering/apishim/api31/BpfCoordinatorShimImpl.java
+++ b/Tethering/apishim/31/com/android/networkstack/tethering/apishim/api31/BpfCoordinatorShimImpl.java
@@ -23,6 +23,7 @@
 import android.system.ErrnoException;
 import android.system.Os;
 import android.system.OsConstants;
+import android.util.Log;
 import android.util.SparseArray;
 
 import androidx.annotation.NonNull;
@@ -84,6 +85,20 @@
     @Nullable
     private final BpfMap<TetherLimitKey, TetherLimitValue> mBpfLimitMap;
 
+    // Tracking IPv4 rule count while any rule is using the given upstream interfaces. Used for
+    // reducing the BPF map iteration query. The count is increased or decreased when the rule is
+    // added or removed successfully on mBpfDownstream4Map. Counting the rules on downstream4 map
+    // is because tetherOffloadRuleRemove can't get upstream interface index from upstream key,
+    // unless pass upstream value which is not required for deleting map entry. The upstream
+    // interface index is the same in Upstream4Value.oif and Downstream4Key.iif. For now, it is
+    // okay to count on Downstream4Key. See BpfConntrackEventConsumer#accept.
+    // Note that except the constructor, any calls to mBpfDownstream4Map.clear() need to clear
+    // this counter as well.
+    // TODO: Count the rule on upstream if multi-upstream is supported and the
+    // packet needs to be sent and responded on different upstream interfaces.
+    // TODO: Add IPv6 rule count.
+    private final SparseArray<Integer> mRule4CountOnUpstream = new SparseArray<>();
+
     public BpfCoordinatorShimImpl(@NonNull final Dependencies deps) {
         mLog = deps.getSharedLog().forSubComponent(TAG);
 
@@ -169,12 +184,13 @@
 
     @Override
     public boolean startUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex,
-            MacAddress srcMac, MacAddress dstMac, int mtu) {
+            @NonNull MacAddress inDstMac, @NonNull MacAddress outSrcMac,
+            @NonNull MacAddress outDstMac, int mtu) {
         if (!isInitialized()) return false;
 
-        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfindex);
-        final Tether6Value value = new Tether6Value(upstreamIfindex, srcMac,
-                dstMac, OsConstants.ETH_P_IPV6, mtu);
+        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfindex, inDstMac);
+        final Tether6Value value = new Tether6Value(upstreamIfindex, outSrcMac,
+                outDstMac, OsConstants.ETH_P_IPV6, mtu);
         try {
             mBpfUpstream6Map.insertEntry(key, value);
         } catch (ErrnoException | IllegalStateException e) {
@@ -185,10 +201,11 @@
     }
 
     @Override
-    public boolean stopUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex) {
+    public boolean stopUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex,
+            @NonNull MacAddress inDstMac) {
         if (!isInitialized()) return false;
 
-        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfindex);
+        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfindex, inDstMac);
         try {
             mBpfUpstream6Map.deleteEntry(key);
         } catch (ErrnoException e) {
@@ -324,18 +341,22 @@
         if (!isInitialized()) return false;
 
         try {
-            // The last used time field of the value is updated by the bpf program. Adding the same
-            // map pair twice causes the unexpected refresh. Must be fixed before starting the
-            // conntrack timeout extension implementation.
-            // TODO: consider using insertEntry.
             if (downstream) {
-                mBpfDownstream4Map.updateEntry(key, value);
+                mBpfDownstream4Map.insertEntry(key, value);
+
+                // Increase the rule count while a adding rule is using a given upstream interface.
+                final int upstreamIfindex = (int) key.iif;
+                int count = mRule4CountOnUpstream.get(upstreamIfindex, 0 /* default */);
+                mRule4CountOnUpstream.put(upstreamIfindex, ++count);
             } else {
-                mBpfUpstream4Map.updateEntry(key, value);
+                mBpfUpstream4Map.insertEntry(key, value);
             }
         } catch (ErrnoException e) {
-            mLog.e("Could not update entry: ", e);
+            mLog.e("Could not insert entry (" + key + ", " + value + "): " + e);
             return false;
+        } catch (IllegalStateException e) {
+            // Silent if the rule already exists. Note that the errno EEXIST was rethrown as
+            // IllegalStateException. See BpfMap#insertEntry.
         }
         return true;
     }
@@ -346,7 +367,26 @@
 
         try {
             if (downstream) {
-                mBpfDownstream4Map.deleteEntry(key);
+                if (!mBpfDownstream4Map.deleteEntry(key)) {
+                    mLog.e("Could not delete entry (key: " + key + ")");
+                    return false;
+                }
+
+                // Decrease the rule count while a deleting rule is not using a given upstream
+                // interface anymore.
+                final int upstreamIfindex = (int) key.iif;
+                Integer count = mRule4CountOnUpstream.get(upstreamIfindex);
+                if (count == null) {
+                    Log.wtf(TAG, "Could not delete count for interface " + upstreamIfindex);
+                    return false;
+                }
+
+                if (--count == 0) {
+                    // Remove the entry if the count decreases to zero.
+                    mRule4CountOnUpstream.remove(upstreamIfindex);
+                } else {
+                    mRule4CountOnUpstream.put(upstreamIfindex, count);
+                }
             } else {
                 mBpfUpstream4Map.deleteEntry(key);
             }
@@ -386,6 +426,12 @@
         return true;
     }
 
+    @Override
+    public boolean isAnyIpv4RuleOnUpstream(int ifIndex) {
+        // No entry means no rule for the given interface because 0 has never been stored.
+        return mRule4CountOnUpstream.get(ifIndex) != null;
+    }
+
     private String mapStatus(BpfMap m, String name) {
         return name + "{" + (m != null ? "OK" : "ERROR") + "}";
     }
diff --git a/Tethering/apishim/common/com/android/networkstack/tethering/apishim/common/BpfCoordinatorShim.java b/Tethering/apishim/common/com/android/networkstack/tethering/apishim/common/BpfCoordinatorShim.java
index b7b4c47..79a628b 100644
--- a/Tethering/apishim/common/com/android/networkstack/tethering/apishim/common/BpfCoordinatorShim.java
+++ b/Tethering/apishim/common/com/android/networkstack/tethering/apishim/common/BpfCoordinatorShim.java
@@ -78,21 +78,25 @@
 
      * @param downstreamIfindex the downstream interface index
      * @param upstreamIfindex the upstream interface index
-     * @param srcMac the source MAC address to use for packets
-     * @oaram dstMac the destination MAC address to use for packets
+     * @param inDstMac the destination MAC address to use for XDP
+     * @param outSrcMac the source MAC address to use for packets
+     * @param outDstMac the destination MAC address to use for packets
      * @return true if operation succeeded or was a no-op, false otherwise
      */
     public abstract boolean startUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex,
-            MacAddress srcMac, MacAddress dstMac, int mtu);
+            @NonNull MacAddress inDstMac, @NonNull MacAddress outSrcMac,
+            @NonNull MacAddress outDstMac, int mtu);
 
     /**
      * Stops IPv6 forwarding between the specified interfaces.
 
      * @param downstreamIfindex the downstream interface index
      * @param upstreamIfindex the upstream interface index
+     * @param inDstMac the destination MAC address to use for XDP
      * @return true if operation succeeded or was a no-op, false otherwise
      */
-    public abstract boolean stopUpstreamIpv6Forwarding(int downstreamIfindex, int upstreamIfindex);
+    public abstract boolean stopUpstreamIpv6Forwarding(int downstreamIfindex,
+            int upstreamIfindex, @NonNull MacAddress inDstMac);
 
     /**
      * Return BPF tethering offload statistics.
@@ -145,6 +149,11 @@
     public abstract boolean tetherOffloadRuleRemove(boolean downstream, @NonNull Tether4Key key);
 
     /**
+     * Whether there is currently any IPv4 rule on the specified upstream.
+     */
+    public abstract boolean isAnyIpv4RuleOnUpstream(int ifIndex);
+
+    /**
      * Attach BPF program.
      *
      * TODO: consider using InterfaceParams to replace interface name.
diff --git a/Tethering/bpf_progs/Android.bp b/Tethering/bpf_progs/Android.bp
index 2b10f89..289d75d 100644
--- a/Tethering/bpf_progs/Android.bp
+++ b/Tethering/bpf_progs/Android.bp
@@ -51,8 +51,6 @@
     include_dirs: [
         // TODO: get rid of system/netd.
         "system/netd/bpf_progs",             // for bpf_net_helpers.h
-        "system/netd/libnetdbpf/include",    // for bpf_shared.h
-        "system/netd/libnetdutils/include",  // for UidConstants.h
     ],
 }
 
@@ -66,7 +64,5 @@
     include_dirs: [
         // TODO: get rid of system/netd.
         "system/netd/bpf_progs",             // for bpf_net_helpers.h
-        "system/netd/libnetdbpf/include",    // for bpf_shared.h
-        "system/netd/libnetdutils/include",  // for UidConstants.h
     ],
 }
diff --git a/Tethering/bpf_progs/bpf_tethering.h b/Tethering/bpf_progs/bpf_tethering.h
index efda228..5fdf8cd 100644
--- a/Tethering/bpf_progs/bpf_tethering.h
+++ b/Tethering/bpf_progs/bpf_tethering.h
@@ -107,11 +107,12 @@
 // Ethernet) have 6-byte MAC addresses.
 
 typedef struct {
-    uint32_t iif;            // The input interface index
-                             // TODO: extend this to include dstMac
-    struct in6_addr neigh6;  // The destination IPv6 address
+    uint32_t iif;              // The input interface index
+    uint8_t dstMac[ETH_ALEN];  // destination ethernet mac address (zeroed iff rawip ingress)
+    uint8_t zero[2];           // zero pad for 8 byte alignment
+    struct in6_addr neigh6;    // The destination IPv6 address
 } TetherDownstream6Key;
-STRUCT_SIZE(TetherDownstream6Key, 4 + 16);  // 20
+STRUCT_SIZE(TetherDownstream6Key, 4 + 6 + 2 + 16);  // 28
 
 typedef struct {
     uint32_t oif;             // The output interface to redirect to
@@ -154,10 +155,12 @@
 #define TETHER_UPSTREAM6_MAP_PATH BPF_PATH_TETHER "map_offload_tether_upstream6_map"
 
 typedef struct {
-    uint32_t iif;  // The input interface index
-                   // TODO: extend this to include dstMac and src ip /64 subnet
+    uint32_t iif;              // The input interface index
+    uint8_t dstMac[ETH_ALEN];  // destination ethernet mac address (zeroed iff rawip ingress)
+    uint8_t zero[2];           // zero pad for 8 byte alignment
+                               // TODO: extend this to include src ip /64 subnet
 } TetherUpstream6Key;
-STRUCT_SIZE(TetherUpstream6Key, 4);
+STRUCT_SIZE(TetherUpstream6Key, 12);
 
 #define TETHER_DOWNSTREAM4_TC_PROG_RAWIP_NAME "prog_offload_schedcls_tether_downstream4_rawip"
 #define TETHER_DOWNSTREAM4_TC_PROG_ETHER_NAME "prog_offload_schedcls_tether_downstream4_ether"
diff --git a/Tethering/bpf_progs/offload.c b/Tethering/bpf_progs/offload.c
index 7f9754d..36f6783 100644
--- a/Tethering/bpf_progs/offload.c
+++ b/Tethering/bpf_progs/offload.c
@@ -72,11 +72,11 @@
 DEFINE_BPF_MAP_GRW(tether_error_map, ARRAY, uint32_t, uint32_t, BPF_TETHER_ERR__MAX,
                    AID_NETWORK_STACK)
 
-#define COUNT_AND_RETURN(counter, ret) do {                  \
+#define COUNT_AND_RETURN(counter, ret) do {                     \
     uint32_t code = BPF_TETHER_ERR_ ## counter;                 \
     uint32_t *count = bpf_tether_error_map_lookup_elem(&code);  \
-    if (count) __sync_fetch_and_add(count, 1);               \
-    return ret;                                              \
+    if (count) __sync_fetch_and_add(count, 1);                  \
+    return ret;                                                 \
 } while(0)
 
 #define TC_DROP(counter) COUNT_AND_RETURN(counter, TC_ACT_SHOT)
@@ -107,17 +107,24 @@
 
 static inline __always_inline int do_forward6(struct __sk_buff* skb, const bool is_ethernet,
         const bool downstream) {
-    const int l2_header_size = is_ethernet ? sizeof(struct ethhdr) : 0;
-    void* data = (void*)(long)skb->data;
-    const void* data_end = (void*)(long)skb->data_end;
-    struct ethhdr* eth = is_ethernet ? data : NULL;  // used iff is_ethernet
-    struct ipv6hdr* ip6 = is_ethernet ? (void*)(eth + 1) : data;
+    // Must be meta-ethernet IPv6 frame
+    if (skb->protocol != htons(ETH_P_IPV6)) return TC_ACT_OK;
 
     // Require ethernet dst mac address to be our unicast address.
     if (is_ethernet && (skb->pkt_type != PACKET_HOST)) return TC_ACT_OK;
 
-    // Must be meta-ethernet IPv6 frame
-    if (skb->protocol != htons(ETH_P_IPV6)) return TC_ACT_OK;
+    const int l2_header_size = is_ethernet ? sizeof(struct ethhdr) : 0;
+
+    // Since the program never writes via DPA (direct packet access) auto-pull/unclone logic does
+    // not trigger and thus we need to manually make sure we can read packet headers via DPA.
+    // Note: this is a blind best effort pull, which may fail or pull less - this doesn't matter.
+    // It has to be done early cause it will invalidate any skb->data/data_end derived pointers.
+    try_make_readable(skb, l2_header_size + IP6_HLEN + TCP_HLEN);
+
+    void* data = (void*)(long)skb->data;
+    const void* data_end = (void*)(long)skb->data_end;
+    struct ethhdr* eth = is_ethernet ? data : NULL;  // used iff is_ethernet
+    struct ipv6hdr* ip6 = is_ethernet ? (void*)(eth + 1) : data;
 
     // Must have (ethernet and) ipv6 header
     if (data + l2_header_size + sizeof(*ip6) > data_end) return TC_ACT_OK;
@@ -169,6 +176,7 @@
     TetherUpstream6Key ku = {
             .iif = skb->ifindex,
     };
+    if (is_ethernet) __builtin_memcpy(downstream ? kd.dstMac : ku.dstMac, eth->h_dest, ETH_ALEN);
 
     Tether6Value* v = downstream ? bpf_tether_downstream6_map_lookup_elem(&kd)
                                  : bpf_tether_upstream6_map_lookup_elem(&ku);
@@ -346,18 +354,25 @@
 
 static inline __always_inline int do_forward4(struct __sk_buff* skb, const bool is_ethernet,
         const bool downstream, const bool updatetime) {
-    const int l2_header_size = is_ethernet ? sizeof(struct ethhdr) : 0;
-    void* data = (void*)(long)skb->data;
-    const void* data_end = (void*)(long)skb->data_end;
-    struct ethhdr* eth = is_ethernet ? data : NULL;  // used iff is_ethernet
-    struct iphdr* ip = is_ethernet ? (void*)(eth + 1) : data;
-
     // Require ethernet dst mac address to be our unicast address.
     if (is_ethernet && (skb->pkt_type != PACKET_HOST)) return TC_ACT_OK;
 
     // Must be meta-ethernet IPv4 frame
     if (skb->protocol != htons(ETH_P_IP)) return TC_ACT_OK;
 
+    const int l2_header_size = is_ethernet ? sizeof(struct ethhdr) : 0;
+
+    // Since the program never writes via DPA (direct packet access) auto-pull/unclone logic does
+    // not trigger and thus we need to manually make sure we can read packet headers via DPA.
+    // Note: this is a blind best effort pull, which may fail or pull less - this doesn't matter.
+    // It has to be done early cause it will invalidate any skb->data/data_end derived pointers.
+    try_make_readable(skb, l2_header_size + IP4_HLEN + TCP_HLEN);
+
+    void* data = (void*)(long)skb->data;
+    const void* data_end = (void*)(long)skb->data_end;
+    struct ethhdr* eth = is_ethernet ? data : NULL;  // used iff is_ethernet
+    struct iphdr* ip = is_ethernet ? (void*)(eth + 1) : data;
+
     // Must have (ethernet and) ipv4 header
     if (data + l2_header_size + sizeof(*ip) > data_end) return TC_ACT_OK;
 
@@ -474,7 +489,7 @@
             .srcPort = is_tcp ? tcph->source : udph->source,
             .dstPort = is_tcp ? tcph->dest : udph->dest,
     };
-    if (is_ethernet) for (int i = 0; i < ETH_ALEN; ++i) k.dstMac[i] = eth->h_dest[i];
+    if (is_ethernet) __builtin_memcpy(k.dstMac, eth->h_dest, ETH_ALEN);
 
     Tether4Value* v = downstream ? bpf_tether_downstream4_map_lookup_elem(&k)
                                  : bpf_tether_upstream4_map_lookup_elem(&k);
diff --git a/Tethering/src/android/net/ip/IpServer.java b/Tethering/src/android/net/ip/IpServer.java
index e5380e0..da15fa8 100644
--- a/Tethering/src/android/net/ip/IpServer.java
+++ b/Tethering/src/android/net/ip/IpServer.java
@@ -742,16 +742,14 @@
                     params.dnses.add(dnsServer);
                 }
             }
-
-            // Add upstream index to name mapping for the tether stats usage in the coordinator.
-            // Although this mapping could be added by both class Tethering and IpServer, adding
-            // mapping from IpServer guarantees that the mapping is added before the adding
-            // forwarding rules. That is because there are different state machines in both
-            // classes. It is hard to guarantee the link property update order between multiple
-            // state machines.
-            mBpfCoordinator.addUpstreamNameToLookupTable(upstreamIfIndex, upstreamIface);
         }
 
+        // Add upstream index to name mapping. See the comment of the interface mapping update in
+        // CMD_TETHER_CONNECTION_CHANGED. Adding the mapping update here to the avoid potential
+        // timing issue. It prevents that the IPv6 capability is updated later than
+        // CMD_TETHER_CONNECTION_CHANGED.
+        mBpfCoordinator.addUpstreamNameToLookupTable(upstreamIfIndex, upstreamIface);
+
         // If v6only is null, we pass in null to setRaParams(), which handles
         // deprecation of any existing RA data.
 
@@ -1335,6 +1333,26 @@
                     mUpstreamIfaceSet = newUpstreamIfaceSet;
 
                     for (String ifname : added) {
+                        // Add upstream index to name mapping for the tether stats usage in the
+                        // coordinator. Although this mapping could be added by both class
+                        // Tethering and IpServer, adding mapping from IpServer guarantees that
+                        // the mapping is added before adding forwarding rules. That is because
+                        // there are different state machines in both classes. It is hard to
+                        // guarantee the link property update order between multiple state machines.
+                        // Note that both IPv4 and IPv6 interface may be added because
+                        // Tethering::setUpstreamNetwork calls getTetheringInterfaces which merges
+                        // IPv4 and IPv6 interface name (if any) into an InterfaceSet. The IPv6
+                        // capability may be updated later. In that case, IPv6 interface mapping is
+                        // updated in updateUpstreamIPv6LinkProperties.
+                        if (!ifname.startsWith("v4-")) {  // ignore clat interfaces
+                            final InterfaceParams upstreamIfaceParams =
+                                    mDeps.getInterfaceParams(ifname);
+                            if (upstreamIfaceParams != null) {
+                                mBpfCoordinator.addUpstreamNameToLookupTable(
+                                        upstreamIfaceParams.index, ifname);
+                            }
+                        }
+
                         mBpfCoordinator.maybeAttachProgram(mIfaceName, ifname);
                         try {
                             mNetd.tetherAddForward(mIfaceName, ifname);
diff --git a/Tethering/src/android/net/ip/RouterAdvertisementDaemon.java b/Tethering/src/android/net/ip/RouterAdvertisementDaemon.java
index 7c0b7cc..543a5c7 100644
--- a/Tethering/src/android/net/ip/RouterAdvertisementDaemon.java
+++ b/Tethering/src/android/net/ip/RouterAdvertisementDaemon.java
@@ -16,7 +16,6 @@
 
 package android.net.ip;
 
-import static android.net.util.NetworkConstants.IPV6_MIN_MTU;
 import static android.net.util.NetworkConstants.RFC7421_PREFIX_LENGTH;
 import static android.net.util.TetheringUtils.getAllNodesForScopeId;
 import static android.system.OsConstants.AF_INET6;
@@ -25,8 +24,18 @@
 import static android.system.OsConstants.SOL_SOCKET;
 import static android.system.OsConstants.SO_SNDTIMEO;
 
+import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_SLLA;
+import static com.android.net.module.util.NetworkStackConstants.ICMPV6_RA_HEADER_LEN;
+import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
+import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_SOLICITATION;
+import static com.android.net.module.util.NetworkStackConstants.IPV6_MIN_MTU;
+import static com.android.net.module.util.NetworkStackConstants.PIO_FLAG_AUTONOMOUS;
+import static com.android.net.module.util.NetworkStackConstants.PIO_FLAG_ON_LINK;
+import static com.android.net.module.util.NetworkStackConstants.TAG_SYSTEM_NEIGHBOR;
+
 import android.net.IpPrefix;
 import android.net.LinkAddress;
+import android.net.MacAddress;
 import android.net.TrafficStats;
 import android.net.util.InterfaceParams;
 import android.net.util.SocketUtils;
@@ -37,7 +46,12 @@
 import android.util.Log;
 
 import com.android.internal.annotations.GuardedBy;
-import com.android.internal.util.TrafficStatsConstants;
+import com.android.net.module.util.structs.Icmpv6Header;
+import com.android.net.module.util.structs.LlaOption;
+import com.android.net.module.util.structs.MtuOption;
+import com.android.net.module.util.structs.PrefixInformationOption;
+import com.android.net.module.util.structs.RaHeader;
+import com.android.net.module.util.structs.RdnssOption;
 
 import java.io.FileDescriptor;
 import java.io.IOException;
@@ -69,9 +83,6 @@
  */
 public class RouterAdvertisementDaemon {
     private static final String TAG = RouterAdvertisementDaemon.class.getSimpleName();
-    private static final byte ICMPV6_ND_ROUTER_SOLICIT = asByte(133);
-    private static final byte ICMPV6_ND_ROUTER_ADVERT  = asByte(134);
-    private static final int MIN_RA_HEADER_SIZE = 16;
 
     // Summary of various timers and lifetimes.
     private static final int MIN_RTR_ADV_INTERVAL_SEC = 300;
@@ -366,54 +377,27 @@
     }
 
     private static void putHeader(ByteBuffer ra, boolean hasDefaultRoute, byte hopLimit) {
-        /**
-            Router Advertisement Message Format
-
-             0                   1                   2                   3
-             0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |     Type      |     Code      |          Checksum             |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            | Cur Hop Limit |M|O|H|Prf|P|R|R|       Router Lifetime         |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                         Reachable Time                        |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                          Retrans Timer                        |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |   Options ...
-            +-+-+-+-+-+-+-+-+-+-+-+-
-        */
-        ra.put(ICMPV6_ND_ROUTER_ADVERT)
-                .put(asByte(0))
-                .putShort(asShort(0))
-                .put(hopLimit)
-                // RFC 4191 "high" preference, iff. advertising a default route.
-                .put(hasDefaultRoute ? asByte(0x08) : asByte(0))
-                .putShort(hasDefaultRoute ? asShort(DEFAULT_LIFETIME) : asShort(0))
-                .putInt(0)
-                .putInt(0);
+        // RFC 4191 "high" preference, iff. advertising a default route.
+        final byte flags = hasDefaultRoute ? asByte(0x08) : asByte(0);
+        final short lifetime = hasDefaultRoute ? asShort(DEFAULT_LIFETIME) : asShort(0);
+        final Icmpv6Header icmpv6Header =
+                new Icmpv6Header(asByte(ICMPV6_ROUTER_ADVERTISEMENT) /* type */,
+                        asByte(0) /* code */, asShort(0) /* checksum */);
+        final RaHeader raHeader = new RaHeader(hopLimit, flags, lifetime, 0 /* reachableTime */,
+                0 /* retransTimer */);
+        icmpv6Header.writeToByteBuffer(ra);
+        raHeader.writeToByteBuffer(ra);
     }
 
     private static void putSlla(ByteBuffer ra, byte[] slla) {
-        /**
-            Source/Target Link-layer Address
-
-             0                   1                   2                   3
-             0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |     Type      |    Length     |    Link-Layer Address ...
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-        */
         if (slla == null || slla.length != 6) {
             // Only IEEE 802.3 6-byte addresses are supported.
             return;
         }
 
-        final byte nd_option_slla = 1;
-        final byte slla_num_8octets = 1;
-        ra.put(nd_option_slla)
-            .put(slla_num_8octets)
-            .put(slla);
+        final ByteBuffer sllaOption = LlaOption.build(asByte(ICMPV6_ND_OPTION_SLLA),
+                MacAddress.fromBytes(slla));
+        ra.put(sllaOption);
     }
 
     private static void putExpandedFlagsOption(ByteBuffer ra) {
@@ -439,70 +423,24 @@
     }
 
     private static void putMtu(ByteBuffer ra, int mtu) {
-        /**
-            MTU
-
-             0                   1                   2                   3
-             0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |     Type      |    Length     |           Reserved            |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                              MTU                              |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-        */
-        final byte nd_option_mtu = 5;
-        final byte mtu_num_8octs = 1;
-        ra.put(nd_option_mtu)
-            .put(mtu_num_8octs)
-            .putShort(asShort(0))
-            .putInt((mtu < IPV6_MIN_MTU) ? IPV6_MIN_MTU : mtu);
+        final ByteBuffer mtuOption = MtuOption.build((mtu < IPV6_MIN_MTU) ? IPV6_MIN_MTU : mtu);
+        ra.put(mtuOption);
     }
 
     private static void putPio(ByteBuffer ra, IpPrefix ipp,
                                int validTime, int preferredTime) {
-        /**
-            Prefix Information
-
-             0                   1                   2                   3
-             0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |     Type      |    Length     | Prefix Length |L|A| Reserved1 |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                         Valid Lifetime                        |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                       Preferred Lifetime                      |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                           Reserved2                           |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                                                               |
-            +                                                               +
-            |                                                               |
-            +                            Prefix                             +
-            |                                                               |
-            +                                                               +
-            |                                                               |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-        */
         final int prefixLength = ipp.getPrefixLength();
         if (prefixLength != 64) {
             return;
         }
-        final byte nd_option_pio = 3;
-        final byte pio_num_8octets = 4;
 
         if (validTime < 0) validTime = 0;
         if (preferredTime < 0) preferredTime = 0;
         if (preferredTime > validTime) preferredTime = validTime;
 
-        final byte[] addr = ipp.getAddress().getAddress();
-        ra.put(nd_option_pio)
-            .put(pio_num_8octets)
-            .put(asByte(prefixLength))
-            .put(asByte(0xc0)) /* L & A set */
-            .putInt(validTime)
-            .putInt(preferredTime)
-            .putInt(0)
-            .put(addr);
+        final ByteBuffer pioOption = PrefixInformationOption.build(ipp,
+                asByte(PIO_FLAG_ON_LINK | PIO_FLAG_AUTONOMOUS), validTime, preferredTime);
+        ra.put(pioOption);
     }
 
     private static void putRio(ByteBuffer ra, IpPrefix ipp) {
@@ -543,22 +481,6 @@
     }
 
     private static void putRdnss(ByteBuffer ra, Set<Inet6Address> dnses, int lifetime) {
-        /**
-            Recursive DNS Server (RDNSS) Option
-
-             0                   1                   2                   3
-             0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |     Type      |     Length    |           Reserved            |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                           Lifetime                            |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-            |                                                               |
-            :            Addresses of IPv6 Recursive DNS Servers            :
-            |                                                               |
-            +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
-         */
-
         final HashSet<Inet6Address> filteredDnses = new HashSet<>();
         for (Inet6Address dns : dnses) {
             if ((new LinkAddress(dns, RFC7421_PREFIX_LENGTH)).isGlobalPreferred()) {
@@ -567,29 +489,22 @@
         }
         if (filteredDnses.isEmpty()) return;
 
-        final byte nd_option_rdnss = 25;
-        final byte rdnss_num_8octets = asByte(dnses.size() * 2 + 1);
-        ra.put(nd_option_rdnss)
-            .put(rdnss_num_8octets)
-            .putShort(asShort(0))
-            .putInt(lifetime);
-
-        for (Inet6Address dns : filteredDnses) {
-            // NOTE: If the full of list DNS servers doesn't fit in the packet,
-            // this code will cause a buffer overflow and the RA won't include
-            // this instance of the option at all.
-            //
-            // TODO: Consider looking at ra.remaining() to determine how many
-            // DNS servers will fit, and adding only those.
-            ra.put(dns.getAddress());
-        }
+        final Inet6Address[] dnsesArray =
+                filteredDnses.toArray(new Inet6Address[filteredDnses.size()]);
+        final ByteBuffer rdnssOption = RdnssOption.build(lifetime, dnsesArray);
+        // NOTE: If the full of list DNS servers doesn't fit in the packet,
+        // this code will cause a buffer overflow and the RA won't include
+        // this instance of the option at all.
+        //
+        // TODO: Consider looking at ra.remaining() to determine how many
+        // DNS servers will fit, and adding only those.
+        ra.put(rdnssOption);
     }
 
     private boolean createSocket() {
         final int send_timout_ms = 300;
 
-        final int oldTag = TrafficStats.getAndSetThreadStatsTag(
-                TrafficStatsConstants.TAG_SYSTEM_NEIGHBOR);
+        final int oldTag = TrafficStats.getAndSetThreadStatsTag(TAG_SYSTEM_NEIGHBOR);
         try {
             mSocket = Os.socket(AF_INET6, SOCK_RAW, IPPROTO_ICMPV6);
             // Setting SNDTIMEO is purely for defensive purposes.
@@ -639,7 +554,7 @@
 
         try {
             synchronized (mLock) {
-                if (mRaLength < MIN_RA_HEADER_SIZE) {
+                if (mRaLength < ICMPV6_RA_HEADER_LEN) {
                     // No actual RA to send.
                     return;
                 }
@@ -668,7 +583,7 @@
                     final int rval = Os.recvfrom(
                             mSocket, mSolicitation, 0, mSolicitation.length, 0, mSolicitor);
                     // Do the least possible amount of validation.
-                    if (rval < 1 || mSolicitation[0] != ICMPV6_ND_ROUTER_SOLICIT) {
+                    if (rval < 1 || mSolicitation[0] != asByte(ICMPV6_ROUTER_SOLICITATION)) {
                         continue;
                     }
                 } catch (ErrnoException | SocketException e) {
@@ -721,7 +636,7 @@
         private int getNextMulticastTransmitDelaySec() {
             boolean deprecationInProgress = false;
             synchronized (mLock) {
-                if (mRaLength < MIN_RA_HEADER_SIZE) {
+                if (mRaLength < ICMPV6_RA_HEADER_LEN) {
                     // No actual RA to send; just sleep for 1 day.
                     return DAY_IN_SECONDS;
                 }
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index 74eb87b..add4f37 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -137,6 +137,8 @@
     private final BpfTetherStatsProvider mStatsProvider;
     @NonNull
     private final BpfCoordinatorShim mBpfCoordinatorShim;
+    @NonNull
+    private final BpfConntrackEventConsumer mBpfConntrackEventConsumer;
 
     // True if BPF offload is supported, false otherwise. The BPF offload could be disabled by
     // a runtime resource overlay package or device configuration. This flag is only initialized
@@ -248,6 +250,11 @@
             return new ConntrackMonitor(getHandler(), getSharedLog(), consumer);
         }
 
+        /** Get interface information for a given interface. */
+        @NonNull public InterfaceParams getInterfaceParams(String ifName) {
+            return InterfaceParams.getByName(ifName);
+        }
+
         /**
          * Check OS Build at least S.
          *
@@ -339,7 +346,14 @@
         mNetd = mDeps.getNetd();
         mLog = mDeps.getSharedLog().forSubComponent(TAG);
         mIsBpfEnabled = isBpfEnabled();
-        mConntrackMonitor = mDeps.getConntrackMonitor(new BpfConntrackEventConsumer());
+
+        // The conntrack consummer needs to be initialized in BpfCoordinator constructor because it
+        // have to access the data members of BpfCoordinator which is not a static class. The
+        // consumer object is also needed for initializing the conntrack monitor which may be
+        // mocked for testing.
+        mBpfConntrackEventConsumer = new BpfConntrackEventConsumer();
+        mConntrackMonitor = mDeps.getConntrackMonitor(mBpfConntrackEventConsumer);
+
         BpfTetherStatsProvider provider = new BpfTetherStatsProvider();
         try {
             mDeps.getNetworkStatsManager().registerNetworkStatsProvider(
@@ -478,25 +492,14 @@
         LinkedHashMap<Inet6Address, Ipv6ForwardingRule> rules = mIpv6ForwardingRules.get(ipServer);
 
         // When the first rule is added to an upstream, setup upstream forwarding and data limit.
-        final int upstreamIfindex = rule.upstreamIfindex;
-        if (!isAnyRuleOnUpstream(upstreamIfindex)) {
-            // If failed to set a data limit, probably should not use this upstream, because
-            // the upstream may not want to blow through the data limit that was told to apply.
-            // TODO: Perhaps stop the coordinator.
-            boolean success = updateDataLimit(upstreamIfindex);
-            if (!success) {
-                final String iface = mInterfaceNames.get(upstreamIfindex);
-                mLog.e("Setting data limit for " + iface + " failed.");
-            }
-
-        }
+        maybeSetLimit(rule.upstreamIfindex);
 
         if (!isAnyRuleFromDownstreamToUpstream(rule.downstreamIfindex, rule.upstreamIfindex)) {
             final int downstream = rule.downstreamIfindex;
             final int upstream = rule.upstreamIfindex;
             // TODO: support upstream forwarding on non-point-to-point interfaces.
             // TODO: get the MTU from LinkProperties and update the rules when it changes.
-            if (!mBpfCoordinatorShim.startUpstreamIpv6Forwarding(downstream, upstream,
+            if (!mBpfCoordinatorShim.startUpstreamIpv6Forwarding(downstream, upstream, rule.srcMac,
                     NULL_MAC_ADDRESS, NULL_MAC_ADDRESS, NetworkStackConstants.ETHER_MTU)) {
                 mLog.e("Failed to enable upstream IPv6 forwarding from "
                         + mInterfaceNames.get(downstream) + " to " + mInterfaceNames.get(upstream));
@@ -537,29 +540,15 @@
         if (!isAnyRuleFromDownstreamToUpstream(rule.downstreamIfindex, rule.upstreamIfindex)) {
             final int downstream = rule.downstreamIfindex;
             final int upstream = rule.upstreamIfindex;
-            if (!mBpfCoordinatorShim.stopUpstreamIpv6Forwarding(downstream, upstream)) {
+            if (!mBpfCoordinatorShim.stopUpstreamIpv6Forwarding(downstream, upstream,
+                    rule.srcMac)) {
                 mLog.e("Failed to disable upstream IPv6 forwarding from "
                         + mInterfaceNames.get(downstream) + " to " + mInterfaceNames.get(upstream));
             }
         }
 
         // Do cleanup functionality if there is no more rule on the given upstream.
-        final int upstreamIfindex = rule.upstreamIfindex;
-        if (!isAnyRuleOnUpstream(upstreamIfindex)) {
-            final TetherStatsValue statsValue =
-                    mBpfCoordinatorShim.tetherOffloadGetAndClearStats(upstreamIfindex);
-            if (statsValue == null) {
-                Log.wtf(TAG, "Fail to cleanup tether stats for upstream index " + upstreamIfindex);
-                return;
-            }
-
-            SparseArray<TetherStatsValue> tetherStatsList = new SparseArray<TetherStatsValue>();
-            tetherStatsList.put(upstreamIfindex, statsValue);
-
-            // Update the last stats delta and delete the local cache for a given upstream.
-            updateQuotaAndStatsFromSnapshot(tetherStatsList);
-            mStats.remove(upstreamIfindex);
-        }
+        maybeClearLimit(rule.upstreamIfindex);
     }
 
     /**
@@ -687,7 +676,7 @@
         if (lp == null || !lp.hasIpv4Address()) return;
 
         // Support raw ip upstream interface only.
-        final InterfaceParams params = InterfaceParams.getByName(lp.getInterfaceName());
+        final InterfaceParams params = mDeps.getInterfaceParams(lp.getInterfaceName());
         if (params == null || params.hasMacAddress) return;
 
         Collection<InetAddress> addresses = lp.getAddresses();
@@ -814,15 +803,15 @@
                 final int upstreamIfindex = rule.upstreamIfindex;
                 pw.println(String.format("%d(%s) %d(%s) %s %s %s", upstreamIfindex,
                         mInterfaceNames.get(upstreamIfindex), rule.downstreamIfindex,
-                        downstreamIface, rule.address, rule.srcMac, rule.dstMac));
+                        downstreamIface, rule.address.getHostAddress(), rule.srcMac, rule.dstMac));
             }
             pw.decreaseIndent();
         }
     }
 
     private String ipv6UpstreamRuletoString(TetherUpstream6Key key, Tether6Value value) {
-        return String.format("%d(%s) -> %d(%s) %04x %s %s",
-                key.iif, getIfName(key.iif), value.oif, getIfName(value.oif),
+        return String.format("%d(%s) %s -> %d(%s) %04x %s %s",
+                key.iif, getIfName(key.iif), key.dstMac, value.oif, getIfName(value.oif),
                 value.ethProto, value.ethSrcMac, value.ethDstMac);
     }
 
@@ -851,9 +840,10 @@
         } catch (UnknownHostException impossible) {
             throw new AssertionError("4-byte array not valid IPv4 address!");
         }
-        return String.format("%d(%s) %d(%s) %s:%d -> %s:%d -> %s:%d",
-                key.iif, getIfName(key.iif), value.oif, getIfName(value.oif),
-                private4, key.srcPort, public4, value.srcPort, dst4, key.dstPort);
+        return String.format("[%s] %d(%s) %s:%d -> %d(%s) %s:%d -> %s:%d",
+                key.dstMac, key.iif, getIfName(key.iif), private4, key.srcPort,
+                value.oif, getIfName(value.oif),
+                public4, value.srcPort, dst4, key.dstPort);
     }
 
     private void dumpIpv4ForwardingRules(IndentingPrintWriter pw) {
@@ -866,7 +856,7 @@
                 pw.println("No IPv4 rules");
                 return;
             }
-            pw.println("[IPv4]: iif(iface) oif(iface) src nat dst");
+            pw.println("IPv4: [inDstMac] iif(iface) src -> nat -> dst");
             pw.increaseIndent();
             map.forEach((k, v) -> pw.println(ipv4RuleToString(k, v)));
         } catch (ErrnoException e) {
@@ -911,6 +901,62 @@
 
     /** IPv6 forwarding rule class. */
     public static class Ipv6ForwardingRule {
+        // The upstream6 and downstream6 rules are built as the following tables. Only raw ip
+        // upstream interface is supported.
+        // TODO: support ether ip upstream interface.
+        //
+        // NAT network topology:
+        //
+        //         public network (rawip)                 private network
+        //                   |                 UE                |
+        // +------------+    V    +------------+------------+    V    +------------+
+        // |   Sever    +---------+  Upstream  | Downstream +---------+   Client   |
+        // +------------+         +------------+------------+         +------------+
+        //
+        // upstream6 key and value:
+        //
+        // +------+-------------+
+        // | TetherUpstream6Key |
+        // +------+------+------+
+        // |field |iif   |dstMac|
+        // |      |      |      |
+        // +------+------+------+
+        // |value |downst|downst|
+        // |      |ream  |ream  |
+        // +------+------+------+
+        //
+        // +------+----------------------------------+
+        // |      |Tether6Value                      |
+        // +------+------+------+------+------+------+
+        // |field |oif   |ethDst|ethSrc|ethPro|pmtu  |
+        // |      |      |mac   |mac   |to    |      |
+        // +------+------+------+------+------+------+
+        // |value |upstre|--    |--    |ETH_P_|1500  |
+        // |      |am    |      |      |IP    |      |
+        // +------+------+------+------+------+------+
+        //
+        // downstream6 key and value:
+        //
+        // +------+--------------------+
+        // |      |TetherDownstream6Key|
+        // +------+------+------+------+
+        // |field |iif   |dstMac|neigh6|
+        // |      |      |      |      |
+        // +------+------+------+------+
+        // |value |upstre|--    |client|
+        // |      |am    |      |      |
+        // +------+------+------+------+
+        //
+        // +------+----------------------------------+
+        // |      |Tether6Value                      |
+        // +------+------+------+------+------+------+
+        // |field |oif   |ethDst|ethSrc|ethPro|pmtu  |
+        // |      |      |mac   |mac   |to    |      |
+        // +------+------+------+------+------+------+
+        // |value |downst|client|downst|ETH_P_|1500  |
+        // |      |ream  |      |ream  |IP    |      |
+        // +------+------+------+------+------+------+
+        //
         public final int upstreamIfindex;
         public final int downstreamIfindex;
 
@@ -960,7 +1006,8 @@
          */
         @NonNull
         public TetherDownstream6Key makeTetherDownstream6Key() {
-            return new TetherDownstream6Key(upstreamIfindex, address.getAddress());
+            return new TetherDownstream6Key(upstreamIfindex, NULL_MAC_ADDRESS,
+                    address.getAddress());
         }
 
         /**
@@ -1113,7 +1160,10 @@
 
     // Support raw ip only.
     // TODO: add ether ip support.
-    private class BpfConntrackEventConsumer implements ConntrackEventConsumer {
+    // TODO: parse CTA_PROTOINFO of conntrack event in ConntrackMonitor. For TCP, only add rules
+    // while TCP status is established.
+    @VisibleForTesting
+    class BpfConntrackEventConsumer implements ConntrackEventConsumer {
         @NonNull
         private Tether4Key makeTetherUpstream4Key(
                 @NonNull ConntrackEvent e, @NonNull ClientInfo c) {
@@ -1178,8 +1228,9 @@
 
             if (e.msgType == (NetlinkConstants.NFNL_SUBSYS_CTNETLINK << 8
                     | NetlinkConstants.IPCTNL_MSG_CT_DELETE)) {
-                mBpfCoordinatorShim.tetherOffloadRuleRemove(false, upstream4Key);
-                mBpfCoordinatorShim.tetherOffloadRuleRemove(true, downstream4Key);
+                mBpfCoordinatorShim.tetherOffloadRuleRemove(UPSTREAM, upstream4Key);
+                mBpfCoordinatorShim.tetherOffloadRuleRemove(DOWNSTREAM, downstream4Key);
+                maybeClearLimit(upstreamIndex);
                 return;
             }
 
@@ -1187,8 +1238,9 @@
             final Tether4Value downstream4Value = makeTetherDownstream4Value(e, tetherClient,
                     upstreamIndex);
 
-            mBpfCoordinatorShim.tetherOffloadRuleAdd(false, upstream4Key, upstream4Value);
-            mBpfCoordinatorShim.tetherOffloadRuleAdd(true, downstream4Key, downstream4Value);
+            maybeSetLimit(upstreamIndex);
+            mBpfCoordinatorShim.tetherOffloadRuleAdd(UPSTREAM, upstream4Key, upstream4Value);
+            mBpfCoordinatorShim.tetherOffloadRuleAdd(DOWNSTREAM, downstream4Key, downstream4Value);
         }
     }
 
@@ -1250,6 +1302,47 @@
         return sendDataLimitToBpfMap(ifIndex, quotaBytes);
     }
 
+    private void maybeSetLimit(int upstreamIfindex) {
+        if (isAnyRuleOnUpstream(upstreamIfindex)
+                || mBpfCoordinatorShim.isAnyIpv4RuleOnUpstream(upstreamIfindex)) {
+            return;
+        }
+
+        // If failed to set a data limit, probably should not use this upstream, because
+        // the upstream may not want to blow through the data limit that was told to apply.
+        // TODO: Perhaps stop the coordinator.
+        boolean success = updateDataLimit(upstreamIfindex);
+        if (!success) {
+            final String iface = mInterfaceNames.get(upstreamIfindex);
+            mLog.e("Setting data limit for " + iface + " failed.");
+        }
+    }
+
+    // TODO: This should be also called while IpServer wants to clear all IPv4 rules. Relying on
+    // conntrack event can't cover this case.
+    private void maybeClearLimit(int upstreamIfindex) {
+        if (isAnyRuleOnUpstream(upstreamIfindex)
+                || mBpfCoordinatorShim.isAnyIpv4RuleOnUpstream(upstreamIfindex)) {
+            return;
+        }
+
+        final TetherStatsValue statsValue =
+                mBpfCoordinatorShim.tetherOffloadGetAndClearStats(upstreamIfindex);
+        if (statsValue == null) {
+            Log.wtf(TAG, "Fail to cleanup tether stats for upstream index " + upstreamIfindex);
+            return;
+        }
+
+        SparseArray<TetherStatsValue> tetherStatsList = new SparseArray<TetherStatsValue>();
+        tetherStatsList.put(upstreamIfindex, statsValue);
+
+        // Update the last stats delta and delete the local cache for a given upstream.
+        updateQuotaAndStatsFromSnapshot(tetherStatsList);
+        mStats.remove(upstreamIfindex);
+    }
+
+    // TODO: Rename to isAnyIpv6RuleOnUpstream and define an isAnyRuleOnUpstream method that called
+    // both isAnyIpv6RuleOnUpstream and mBpfCoordinatorShim.isAnyIpv4RuleOnUpstream.
     private boolean isAnyRuleOnUpstream(int upstreamIfindex) {
         for (LinkedHashMap<Inet6Address, Ipv6ForwardingRule> rules : mIpv6ForwardingRules
                 .values()) {
@@ -1419,5 +1512,13 @@
         return mInterfaceNames;
     }
 
+    // Return BPF conntrack event consumer. This is used for testing only.
+    // Note that this can be only called on handler thread.
+    @NonNull
+    @VisibleForTesting
+    final BpfConntrackEventConsumer getBpfConntrackEventConsumerForTesting() {
+        return mBpfConntrackEventConsumer;
+    }
+
     private static native String[] getBpfCounterNames();
 }
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfMap.java b/Tethering/src/com/android/networkstack/tethering/BpfMap.java
index e9b4ccf..1363dc5 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfMap.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfMap.java
@@ -98,6 +98,7 @@
 
     /**
      * Update an existing or create a new key -> value entry in an eBbpf map.
+     * (use insertOrReplaceEntry() if you need to know whether insert or replace happened)
      */
     public void updateEntry(K key, V value) throws ErrnoException {
         writeToMapEntry(mMapFd, key.writeToBytes(), value.writeToBytes(), BPF_ANY);
@@ -133,6 +134,35 @@
         }
     }
 
+    /**
+     * Update an existing or create a new key -> value entry in an eBbpf map.
+     * Returns true if inserted, false if replaced.
+     * (use updateEntry() if you don't care whether insert or replace happened)
+     * Note: see inline comment below if running concurrently with delete operations.
+     */
+    public boolean insertOrReplaceEntry(K key, V value)
+            throws ErrnoException {
+        try {
+            writeToMapEntry(mMapFd, key.writeToBytes(), value.writeToBytes(), BPF_NOEXIST);
+            return true;   /* insert succeeded */
+        } catch (ErrnoException e) {
+            if (e.errno != EEXIST) throw e;
+        }
+        try {
+            writeToMapEntry(mMapFd, key.writeToBytes(), value.writeToBytes(), BPF_EXIST);
+            return false;   /* replace succeeded */
+        } catch (ErrnoException e) {
+            if (e.errno != ENOENT) throw e;
+        }
+        /* If we reach here somebody deleted after our insert attempt and before our replace:
+         * this implies a race happened.  The kernel bpf delete interface only takes a key,
+         * and not the value, so we can safely pretend the replace actually succeeded and
+         * was immediately followed by the other thread's delete, since the delete cannot
+         * observe the potential change to the value.
+         */
+        return false;   /* pretend replace succeeded */
+    }
+
     /** Remove existing key from eBpf map. Return false if map was not modified. */
     public boolean deleteEntry(K key) throws ErrnoException {
         return deleteMapEntry(mMapFd, key.writeToBytes());
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfUtils.java b/Tethering/src/com/android/networkstack/tethering/BpfUtils.java
index 289452c..0b44249 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfUtils.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfUtils.java
@@ -56,7 +56,7 @@
     // Sync from system/netd/server/OffloadUtils.h.
     static final short PRIO_TETHER6 = 1;
     static final short PRIO_TETHER4 = 2;
-    static final short PRIO_CLAT = 3;
+    // note that the above must be lower than PRIO_CLAT from netd's OffloadUtils.cpp
 
     private static String makeProgPath(boolean downstream, int ipVersion, boolean ether) {
         String path = "/sys/fs/bpf/tethering/prog_offload_schedcls_tether_"
diff --git a/Tethering/src/com/android/networkstack/tethering/TetherDownstream6Key.java b/Tethering/src/com/android/networkstack/tethering/TetherDownstream6Key.java
index 3860cba..a08ad4a 100644
--- a/Tethering/src/com/android/networkstack/tethering/TetherDownstream6Key.java
+++ b/Tethering/src/com/android/networkstack/tethering/TetherDownstream6Key.java
@@ -16,6 +16,10 @@
 
 package com.android.networkstack.tethering;
 
+import android.net.MacAddress;
+
+import androidx.annotation.NonNull;
+
 import com.android.net.module.util.Struct;
 import com.android.net.module.util.Struct.Field;
 import com.android.net.module.util.Struct.Type;
@@ -24,16 +28,23 @@
 import java.net.InetAddress;
 import java.net.UnknownHostException;
 import java.util.Arrays;
+import java.util.Objects;
 
 /** The key of BpfMap which is used for bpf offload. */
 public class TetherDownstream6Key extends Struct {
     @Field(order = 0, type = Type.U32)
     public final long iif; // The input interface index.
 
-    @Field(order = 1, type = Type.ByteArray, arraysize = 16)
+    @Field(order = 1, type = Type.EUI48, padding = 2)
+    public final MacAddress dstMac; // Destination ethernet mac address (zeroed iff rawip ingress).
+
+    @Field(order = 2, type = Type.ByteArray, arraysize = 16)
     public final byte[] neigh6; // The destination IPv6 address.
 
-    public TetherDownstream6Key(final long iif, final byte[] neigh6) {
+    public TetherDownstream6Key(final long iif, @NonNull final MacAddress dstMac,
+            final byte[] neigh6) {
+        Objects.requireNonNull(dstMac);
+
         try {
             final Inet6Address unused = (Inet6Address) InetAddress.getByAddress(neigh6);
         } catch (ClassCastException | UnknownHostException e) {
@@ -41,29 +52,15 @@
                     + Arrays.toString(neigh6));
         }
         this.iif = iif;
+        this.dstMac = dstMac;
         this.neigh6 = neigh6;
     }
 
     @Override
-    public boolean equals(Object obj) {
-        if (this == obj) return true;
-
-        if (!(obj instanceof TetherDownstream6Key)) return false;
-
-        final TetherDownstream6Key that = (TetherDownstream6Key) obj;
-
-        return iif == that.iif && Arrays.equals(neigh6, that.neigh6);
-    }
-
-    @Override
-    public int hashCode() {
-        return Long.hashCode(iif) ^ Arrays.hashCode(neigh6);
-    }
-
-    @Override
     public String toString() {
         try {
-            return String.format("iif: %d, neigh: %s", iif, Inet6Address.getByAddress(neigh6));
+            return String.format("iif: %d, dstMac: %s, neigh: %s", iif, dstMac,
+                    Inet6Address.getByAddress(neigh6));
         } catch (UnknownHostException e) {
             // Should not happen because construtor already verify neigh6.
             throw new IllegalStateException("Invalid TetherDownstream6Key");
diff --git a/Tethering/src/com/android/networkstack/tethering/TetherUpstream6Key.java b/Tethering/src/com/android/networkstack/tethering/TetherUpstream6Key.java
index c736f2a..5893885 100644
--- a/Tethering/src/com/android/networkstack/tethering/TetherUpstream6Key.java
+++ b/Tethering/src/com/android/networkstack/tethering/TetherUpstream6Key.java
@@ -16,14 +16,26 @@
 
 package com.android.networkstack.tethering;
 
+import android.net.MacAddress;
+
+import androidx.annotation.NonNull;
+
 import com.android.net.module.util.Struct;
 
+import java.util.Objects;
+
 /** Key type for upstream IPv6 forwarding map. */
 public class TetherUpstream6Key extends Struct {
     @Field(order = 0, type = Type.S32)
     public final int iif; // The input interface index.
 
-    public TetherUpstream6Key(int iif) {
+    @Field(order = 1, type = Type.EUI48, padding = 2)
+    public final MacAddress dstMac; // Destination ethernet mac address (zeroed iff rawip ingress).
+
+    public TetherUpstream6Key(int iif, @NonNull final MacAddress dstMac) {
+        Objects.requireNonNull(dstMac);
+
         this.iif = iif;
+        this.dstMac = dstMac;
     }
 }
diff --git a/Tethering/src/com/android/networkstack/tethering/Tethering.java b/Tethering/src/com/android/networkstack/tethering/Tethering.java
index ac5857d..f795747 100644
--- a/Tethering/src/com/android/networkstack/tethering/Tethering.java
+++ b/Tethering/src/com/android/networkstack/tethering/Tethering.java
@@ -442,7 +442,8 @@
     // NOTE: This is always invoked on the mLooper thread.
     private void updateConfiguration() {
         mConfig = mDeps.generateTetheringConfiguration(mContext, mLog, mActiveDataSubId);
-        mUpstreamNetworkMonitor.updateMobileRequiresDun(mConfig.isDunRequired);
+        mUpstreamNetworkMonitor.setUpstreamConfig(mConfig.chooseUpstreamAutomatically,
+                mConfig.isDunRequired);
         reportConfigurationChanged(mConfig.toStableParcelable());
     }
 
@@ -1559,7 +1560,7 @@
                             config.preferredUpstreamIfaceTypes);
             if (ns == null) {
                 if (tryCell) {
-                    mUpstreamNetworkMonitor.registerMobileNetworkRequest();
+                    mUpstreamNetworkMonitor.setTryCell(true);
                     // We think mobile should be coming up; don't set a retry.
                 } else {
                     sendMessageDelayed(CMD_RETRY_UPSTREAM, UPSTREAM_SETTLE_TIME_MS);
@@ -1718,6 +1719,12 @@
                     break;
             }
 
+            if (mConfig.chooseUpstreamAutomatically
+                    && arg1 == UpstreamNetworkMonitor.EVENT_DEFAULT_SWITCHED) {
+                chooseUpstreamType(true);
+                return;
+            }
+
             if (ns == null || !pertainsToCurrentUpstream(ns)) {
                 // TODO: In future, this is where upstream evaluation and selection
                 // could be handled for notifications which include sufficient data.
@@ -1852,7 +1859,7 @@
                         // longer desired, release any mobile requests.
                         final boolean previousUpstreamWanted = updateUpstreamWanted();
                         if (previousUpstreamWanted && !mUpstreamWanted) {
-                            mUpstreamNetworkMonitor.releaseMobileNetworkRequest();
+                            mUpstreamNetworkMonitor.setTryCell(false);
                         }
                         break;
                     }
diff --git a/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java b/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java
index b17065c..f9af777 100644
--- a/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java
+++ b/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java
@@ -42,11 +42,15 @@
 import android.util.Log;
 import android.util.SparseIntArray;
 
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.StateMachine;
 
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Objects;
 import java.util.Set;
 
 
@@ -60,7 +64,7 @@
  * Calling #startObserveAllNetworks() to observe all networks. Listening all
  * networks is necessary while the expression of preferred upstreams remains
  * a list of legacy connectivity types.  In future, this can be revisited.
- * Calling #registerMobileNetworkRequest() to bring up mobile DUN/HIPRI network.
+ * Calling #setTryCell() to request bringing up mobile DUN or HIPRI.
  *
  * The methods and data members of this class are only to be accessed and
  * modified from the tethering main state machine thread. Any other
@@ -82,6 +86,7 @@
     public static final int EVENT_ON_CAPABILITIES   = 1;
     public static final int EVENT_ON_LINKPROPERTIES = 2;
     public static final int EVENT_ON_LOST           = 3;
+    public static final int EVENT_DEFAULT_SWITCHED  = 4;
     public static final int NOTIFY_LOCAL_PREFIXES   = 10;
     // This value is used by deprecated preferredUpstreamIfaceTypes selection which is default
     // disabled.
@@ -114,7 +119,14 @@
     private NetworkCallback mListenAllCallback;
     private NetworkCallback mDefaultNetworkCallback;
     private NetworkCallback mMobileNetworkCallback;
+
+    /** Whether Tethering has requested a cellular upstream. */
+    private boolean mTryCell;
+    /** Whether the carrier requires DUN. */
     private boolean mDunRequired;
+    /** Whether automatic upstream selection is enabled. */
+    private boolean mAutoUpstream;
+
     // Whether the current default upstream is mobile or not.
     private boolean mIsDefaultCellularUpstream;
     // The current system default network (not really used yet).
@@ -190,23 +202,49 @@
         mNetworkMap.clear();
     }
 
-    /** Setup or teardown DUN connection according to |dunRequired|. */
-    public void updateMobileRequiresDun(boolean dunRequired) {
-        final boolean valueChanged = (mDunRequired != dunRequired);
+    private void reevaluateUpstreamRequirements(boolean tryCell, boolean autoUpstream,
+            boolean dunRequired) {
+        final boolean mobileRequestRequired = tryCell && (dunRequired || !autoUpstream);
+        final boolean dunRequiredChanged = (mDunRequired != dunRequired);
+
+        mTryCell = tryCell;
         mDunRequired = dunRequired;
-        if (valueChanged && mobileNetworkRequested()) {
-            releaseMobileNetworkRequest();
+        mAutoUpstream = autoUpstream;
+
+        if (mobileRequestRequired && !mobileNetworkRequested()) {
             registerMobileNetworkRequest();
+        } else if (mobileNetworkRequested() && !mobileRequestRequired) {
+            releaseMobileNetworkRequest();
+        } else if (mobileNetworkRequested() && dunRequiredChanged) {
+            releaseMobileNetworkRequest();
+            if (mobileRequestRequired) {
+                registerMobileNetworkRequest();
+            }
         }
     }
 
+    /**
+     * Informs UpstreamNetworkMonitor that a cellular upstream is desired.
+     *
+     * This may result in filing a NetworkRequest for DUN if it is required, or for MOBILE_HIPRI if
+     * automatic upstream selection is disabled and MOBILE_HIPRI is the preferred upstream.
+     */
+    public void setTryCell(boolean tryCell) {
+        reevaluateUpstreamRequirements(tryCell, mAutoUpstream, mDunRequired);
+    }
+
+    /** Informs UpstreamNetworkMonitor of upstream configuration parameters. */
+    public void setUpstreamConfig(boolean autoUpstream, boolean dunRequired) {
+        reevaluateUpstreamRequirements(mTryCell, autoUpstream, dunRequired);
+    }
+
     /** Whether mobile network is requested. */
     public boolean mobileNetworkRequested() {
         return (mMobileNetworkCallback != null);
     }
 
     /** Request mobile network if mobile upstream is permitted. */
-    public void registerMobileNetworkRequest() {
+    private void registerMobileNetworkRequest() {
         if (!isCellularUpstreamPermitted()) {
             mLog.i("registerMobileNetworkRequest() is not permitted");
             releaseMobileNetworkRequest();
@@ -241,14 +279,16 @@
         // TODO: Change the timeout from 0 (no onUnavailable callback) to some
         // moderate callback timeout. This might be useful for updating some UI.
         // Additionally, we log a message to aid in any subsequent debugging.
-        mLog.i("requesting mobile upstream network: " + mobileUpstreamRequest);
+        mLog.i("requesting mobile upstream network: " + mobileUpstreamRequest
+                + " mTryCell=" + mTryCell + " mAutoUpstream=" + mAutoUpstream
+                + " mDunRequired=" + mDunRequired);
 
         cm().requestNetwork(mobileUpstreamRequest, 0, legacyType, mHandler,
                 mMobileNetworkCallback);
     }
 
     /** Release mobile network request. */
-    public void releaseMobileNetworkRequest() {
+    private void releaseMobileNetworkRequest() {
         if (mMobileNetworkCallback == null) return;
 
         cm().unregisterNetworkCallback(mMobileNetworkCallback);
@@ -363,13 +403,20 @@
         notifyTarget(EVENT_ON_CAPABILITIES, network);
     }
 
-    private void handleLinkProp(Network network, LinkProperties newLp) {
+    private @Nullable UpstreamNetworkState updateLinkProperties(@NonNull Network network,
+            LinkProperties newLp) {
         final UpstreamNetworkState prev = mNetworkMap.get(network);
         if (prev == null || newLp.equals(prev.linkProperties)) {
             // Ignore notifications about networks for which we have not yet
             // received onAvailable() (should never happen) and any duplicate
             // notifications (e.g. matching more than one of our callbacks).
-            return;
+            //
+            // Also, it can happen that onLinkPropertiesChanged is called after
+            // onLost removed the state from mNetworkMap. This appears to be due
+            // to a bug in disconnectAndDestroyNetwork, which calls
+            // nai.clatd.update() after the onLost callbacks.
+            // TODO: fix the bug and make this method void.
+            return null;
         }
 
         if (VDBG) {
@@ -377,11 +424,17 @@
                     network, newLp));
         }
 
-        mNetworkMap.put(network, new UpstreamNetworkState(
-                newLp, prev.networkCapabilities, network));
-        // TODO: If sufficient information is available to select a more
-        // preferable upstream, do so now and notify the target.
-        notifyTarget(EVENT_ON_LINKPROPERTIES, network);
+        final UpstreamNetworkState ns = new UpstreamNetworkState(newLp, prev.networkCapabilities,
+                network);
+        mNetworkMap.put(network, ns);
+        return ns;
+    }
+
+    private void handleLinkProp(Network network, LinkProperties newLp) {
+        final UpstreamNetworkState ns = updateLinkProperties(network, newLp);
+        if (ns != null) {
+            notifyTarget(EVENT_ON_LINKPROPERTIES, ns);
+        }
     }
 
     private void handleLost(Network network) {
@@ -410,6 +463,24 @@
         notifyTarget(EVENT_ON_LOST, mNetworkMap.remove(network));
     }
 
+    private void maybeHandleNetworkSwitch(@NonNull Network network) {
+        if (Objects.equals(mDefaultInternetNetwork, network)) return;
+
+        final UpstreamNetworkState ns = mNetworkMap.get(network);
+        if (ns == null) {
+            // Can never happen unless there is a bug in ConnectivityService. Entries are only
+            // removed from mNetworkMap when receiving onLost, and onLost for a given network can
+            // never be followed by any other callback on that network.
+            Log.wtf(TAG, "maybeHandleNetworkSwitch: no UpstreamNetworkState for " + network);
+            return;
+        }
+
+        // Default network changed. Update local data and notify tethering.
+        Log.d(TAG, "New default Internet network: " + network);
+        mDefaultInternetNetwork = network;
+        notifyTarget(EVENT_DEFAULT_SWITCHED, ns);
+    }
+
     private void recomputeLocalPrefixes() {
         final HashSet<IpPrefix> localPrefixes = allLocalPrefixes(mNetworkMap.values());
         if (!mLocalPrefixes.equals(localPrefixes)) {
@@ -447,7 +518,22 @@
         @Override
         public void onCapabilitiesChanged(Network network, NetworkCapabilities newNc) {
             if (mCallbackType == CALLBACK_DEFAULT_INTERNET) {
-                mDefaultInternetNetwork = network;
+                // mDefaultInternetNetwork is not updated here because upstream selection must only
+                // run when the LinkProperties have been updated as well as the capabilities. If
+                // this callback is due to a default network switch, then the system will invoke
+                // onLinkPropertiesChanged right after this method and mDefaultInternetNetwork will
+                // be updated then.
+                //
+                // Technically, not updating here isn't necessary, because the notifications to
+                // Tethering sent by notifyTarget are messages sent to a state machine running on
+                // the same thread as this method, and so cannot arrive until after this method has
+                // returned. However, it is not a good idea to rely on that because fact that
+                // Tethering uses multiple state machines running on the same thread is a major
+                // source of race conditions and something that should be fixed.
+                //
+                // TODO: is it correct that this code always updates EntitlementManager?
+                // This code runs when the default network connects or changes capabilities, but the
+                // default network might not be the tethering upstream.
                 final boolean newIsCellular = isCellular(newNc);
                 if (mIsDefaultCellularUpstream != newIsCellular) {
                     mIsDefaultCellularUpstream = newIsCellular;
@@ -461,7 +547,15 @@
 
         @Override
         public void onLinkPropertiesChanged(Network network, LinkProperties newLp) {
-            if (mCallbackType == CALLBACK_DEFAULT_INTERNET) return;
+            if (mCallbackType == CALLBACK_DEFAULT_INTERNET) {
+                updateLinkProperties(network, newLp);
+                // When the default network callback calls onLinkPropertiesChanged, it means that
+                // all the network information for the default network is known (because
+                // onLinkPropertiesChanged is called after onAvailable and onCapabilitiesChanged).
+                // Inform tethering that the default network might have changed.
+                maybeHandleNetworkSwitch(network);
+                return;
+            }
 
             handleLinkProp(network, newLp);
             // Any non-LISTEN_ALL callback will necessarily concern a network that will
@@ -478,6 +572,8 @@
                 mDefaultInternetNetwork = null;
                 mIsDefaultCellularUpstream = false;
                 mEntitlementMgr.notifyUpstream(false);
+                Log.d(TAG, "Lost default Internet network: " + network);
+                notifyTarget(EVENT_DEFAULT_SWITCHED, null);
                 return;
             }
 
diff --git a/Tethering/tests/privileged/src/android/net/ip/DadProxyTest.java b/Tethering/tests/privileged/src/android/net/ip/DadProxyTest.java
index 42a91aa..a933e1b 100644
--- a/Tethering/tests/privileged/src/android/net/ip/DadProxyTest.java
+++ b/Tethering/tests/privileged/src/android/net/ip/DadProxyTest.java
@@ -21,9 +21,8 @@
 import static com.android.net.module.util.IpUtils.icmpv6Checksum;
 import static com.android.net.module.util.NetworkStackConstants.ETHER_SRC_ADDR_OFFSET;
 
-import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
 
 import android.app.Instrumentation;
 import android.content.Context;
@@ -52,13 +51,14 @@
 import org.junit.runner.RunWith;
 import org.mockito.MockitoAnnotations;
 
+import java.io.IOException;
 import java.nio.ByteBuffer;
 
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class DadProxyTest {
     private static final int DATA_BUFFER_LEN = 4096;
-    private static final int PACKET_TIMEOUT_MS = 5_000;
+    private static final int PACKET_TIMEOUT_MS = 2_000;  // Long enough for DAD to succeed.
 
     // Start the readers manually on a common handler shared with DadProxy, for simplicity
     @Rule
@@ -119,16 +119,18 @@
         }
     }
 
-    private void setupTapInterfaces() {
+    private void setupTapInterfaces() throws Exception {
         // Create upstream test iface.
         mUpstreamReader.start(mHandler);
-        mUpstreamParams = InterfaceParams.getByName(mUpstreamReader.iface.getInterfaceName());
+        final String upstreamIface = mUpstreamReader.iface.getInterfaceName();
+        mUpstreamParams = InterfaceParams.getByName(upstreamIface);
         assertNotNull(mUpstreamParams);
         mUpstreamPacketReader = mUpstreamReader.getReader();
 
         // Create tethered test iface.
         mTetheredReader.start(mHandler);
-        mTetheredParams = InterfaceParams.getByName(mTetheredReader.getIface().getInterfaceName());
+        final String tetheredIface = mTetheredReader.getIface().getInterfaceName();
+        mTetheredParams = InterfaceParams.getByName(tetheredIface);
         assertNotNull(mTetheredParams);
         mTetheredPacketReader = mTetheredReader.getReader();
     }
@@ -224,6 +226,12 @@
         return false;
     }
 
+    private ByteBuffer copy(ByteBuffer buf) {
+        // There does not seem to be a way to copy ByteBuffers. ByteBuffer does not implement
+        // clone() and duplicate() copies the metadata but shares the contents.
+        return ByteBuffer.wrap(buf.array().clone());
+    }
+
     private void updateDstMac(ByteBuffer buf, MacAddress mac) {
         buf.put(mac.toByteArray());
         buf.rewind();
@@ -234,14 +242,50 @@
         buf.rewind();
     }
 
+    private void receivePacketAndMaybeExpectForwarded(boolean expectForwarded,
+            ByteBuffer in, TapPacketReader inReader, ByteBuffer out, TapPacketReader outReader)
+            throws IOException {
+
+        inReader.sendResponse(in);
+        if (waitForPacket(out, outReader)) return;
+
+        // When the test runs, DAD may be in progress, because the interface has just been created.
+        // If so, the DAD proxy will get EADDRNOTAVAIL when trying to send packets. It is not
+        // possible to work around this using IPV6_FREEBIND or IPV6_TRANSPARENT options because the
+        // kernel rawv6 code doesn't consider those options either when binding or when sending, and
+        // doesn't get the source address from the packet even in IPPROTO_RAW/HDRINCL mode (it only
+        // gets it from the socket or from cmsg).
+        //
+        // If DAD was in progress when the above was attempted, try again and expect the packet to
+        // be forwarded. Don't disable DAD in the test because if we did, the test would not notice
+        // if, for example, the DAD proxy code just crashed if it received EADDRNOTAVAIL.
+        final String msg = expectForwarded
+                ? "Did not receive expected packet even after waiting for DAD:"
+                : "Unexpectedly received packet:";
+
+        inReader.sendResponse(in);
+        assertEquals(msg, expectForwarded, waitForPacket(out, outReader));
+    }
+
+    private void receivePacketAndExpectForwarded(ByteBuffer in, TapPacketReader inReader,
+            ByteBuffer out, TapPacketReader outReader) throws IOException {
+        receivePacketAndMaybeExpectForwarded(true, in, inReader, out, outReader);
+    }
+
+    private void receivePacketAndExpectNotForwarded(ByteBuffer in, TapPacketReader inReader,
+            ByteBuffer out, TapPacketReader outReader) throws IOException {
+        receivePacketAndMaybeExpectForwarded(false, in, inReader, out, outReader);
+    }
+
     @Test
     public void testNaForwardingFromUpstreamToTether() throws Exception {
         ByteBuffer na = createDadPacket(NeighborPacketForwarder.ICMPV6_NEIGHBOR_ADVERTISEMENT);
 
-        mUpstreamPacketReader.sendResponse(na);
-        updateDstMac(na, MacAddress.fromString("33:33:00:00:00:01"));
-        updateSrcMac(na, mTetheredParams);
-        assertTrue(waitForPacket(na, mTetheredPacketReader));
+        ByteBuffer out = copy(na);
+        updateDstMac(out, MacAddress.fromString("33:33:00:00:00:01"));
+        updateSrcMac(out, mTetheredParams);
+
+        receivePacketAndExpectForwarded(na, mUpstreamPacketReader, out, mTetheredPacketReader);
     }
 
     @Test
@@ -249,19 +293,21 @@
     public void testNaForwardingFromTetherToUpstream() throws Exception {
         ByteBuffer na = createDadPacket(NeighborPacketForwarder.ICMPV6_NEIGHBOR_ADVERTISEMENT);
 
-        mTetheredPacketReader.sendResponse(na);
-        updateDstMac(na, MacAddress.fromString("33:33:00:00:00:01"));
-        updateSrcMac(na, mTetheredParams);
-        assertFalse(waitForPacket(na, mUpstreamPacketReader));
+        ByteBuffer out = copy(na);
+        updateDstMac(out, MacAddress.fromString("33:33:00:00:00:01"));
+        updateSrcMac(out, mTetheredParams);
+
+        receivePacketAndExpectNotForwarded(na, mTetheredPacketReader, out, mUpstreamPacketReader);
     }
 
     @Test
     public void testNsForwardingFromTetherToUpstream() throws Exception {
         ByteBuffer ns = createDadPacket(NeighborPacketForwarder.ICMPV6_NEIGHBOR_SOLICITATION);
 
-        mTetheredPacketReader.sendResponse(ns);
-        updateSrcMac(ns, mUpstreamParams);
-        assertTrue(waitForPacket(ns, mUpstreamPacketReader));
+        ByteBuffer out = copy(ns);
+        updateSrcMac(out, mUpstreamParams);
+
+        receivePacketAndExpectForwarded(ns, mTetheredPacketReader, out, mUpstreamPacketReader);
     }
 
     @Test
@@ -269,8 +315,9 @@
     public void testNsForwardingFromUpstreamToTether() throws Exception {
         ByteBuffer ns = createDadPacket(NeighborPacketForwarder.ICMPV6_NEIGHBOR_SOLICITATION);
 
-        mUpstreamPacketReader.sendResponse(ns);
+        ByteBuffer out = copy(ns);
         updateSrcMac(ns, mUpstreamParams);
-        assertFalse(waitForPacket(ns, mTetheredPacketReader));
+
+        receivePacketAndExpectNotForwarded(ns, mUpstreamPacketReader, out, mTetheredPacketReader);
     }
 }
diff --git a/Tethering/tests/privileged/src/android/net/ip/RouterAdvertisementDaemonTest.java b/Tethering/tests/privileged/src/android/net/ip/RouterAdvertisementDaemonTest.java
index 14dae5c..1d94214 100644
--- a/Tethering/tests/privileged/src/android/net/ip/RouterAdvertisementDaemonTest.java
+++ b/Tethering/tests/privileged/src/android/net/ip/RouterAdvertisementDaemonTest.java
@@ -16,6 +16,8 @@
 
 package android.net.ip;
 
+import static android.net.RouteInfo.RTN_UNICAST;
+
 import static com.android.net.module.util.NetworkStackConstants.ETHER_HEADER_LEN;
 import static com.android.net.module.util.NetworkStackConstants.ETHER_TYPE_IPV6;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_MTU;
@@ -27,6 +29,8 @@
 import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_ALL_NODES_MULTICAST;
 import static com.android.net.module.util.NetworkStackConstants.IPV6_ADDR_LEN;
 import static com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN;
+import static com.android.net.module.util.NetworkStackConstants.PIO_FLAG_AUTONOMOUS;
+import static com.android.net.module.util.NetworkStackConstants.PIO_FLAG_ON_LINK;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
@@ -38,7 +42,9 @@
 import android.net.INetd;
 import android.net.IpPrefix;
 import android.net.MacAddress;
+import android.net.RouteInfo;
 import android.net.ip.RouterAdvertisementDaemon.RaParams;
+import android.net.shared.RouteUtils;
 import android.net.util.InterfaceParams;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -74,6 +80,7 @@
 import java.net.InetAddress;
 import java.nio.ByteBuffer;
 import java.util.HashSet;
+import java.util.List;
 
 @RunWith(AndroidJUnit4.class)
 @SmallTest
@@ -96,8 +103,6 @@
 
     @BeforeClass
     public static void setupOnce() {
-        System.loadLibrary("tetherutilsjni");
-
         final Instrumentation inst = InstrumentationRegistry.getInstrumentation();
         final IBinder netdIBinder =
                 (IBinder) inst.getContext().getSystemService(Context.NETD_SERVICE);
@@ -151,7 +156,7 @@
             mNewParams = newParams;
         }
 
-        public boolean isPacketMatched(final byte[] pkt) throws Exception {
+        public boolean isPacketMatched(final byte[] pkt, boolean multicast) throws Exception {
             if (pkt.length < (ETHER_HEADER_LEN + IPV6_HEADER_LEN + ICMPV6_RA_HEADER_LEN)) {
                 return false;
             }
@@ -172,6 +177,15 @@
             final Icmpv6Header icmpv6Hdr = Struct.parse(Icmpv6Header.class, buf);
             if (icmpv6Hdr.type != (short) ICMPV6_ROUTER_ADVERTISEMENT) return false;
 
+            // Check whether IPv6 destination address is multicast or unicast
+            if (multicast) {
+                assertEquals(ipv6Hdr.dstIp, IPV6_ADDR_ALL_NODES_MULTICAST);
+            } else {
+                // The unicast IPv6 destination address in RA can be either link-local or global
+                // IPv6 address. This test only expects link-local address.
+                assertTrue(ipv6Hdr.dstIp.isLinkLocalAddress());
+            }
+
             // Parse RA header
             final RaHeader raHdr = Struct.parse(RaHeader.class, buf);
             assertEquals(mNewParams.hopLimit, raHdr.hopLimit);
@@ -182,13 +196,15 @@
                 final int length = Byte.toUnsignedInt(buf.get());
                 switch (type) {
                     case ICMPV6_ND_OPTION_PIO:
+                        // length is 4 because this test only expects one PIO included in the
+                        // router advertisement packet.
                         assertEquals(4, length);
 
                         final ByteBuffer pioBuf = ByteBuffer.wrap(buf.array(), currentPos,
                                 Struct.getSize(PrefixInformationOption.class));
                         final PrefixInformationOption pio =
                                 Struct.parse(PrefixInformationOption.class, pioBuf);
-                        assertEquals((byte) 0xc0, pio.flags); // L & A set
+                        assertEquals((byte) (PIO_FLAG_ON_LINK | PIO_FLAG_AUTONOMOUS), pio.flags);
 
                         final InetAddress address = InetAddress.getByAddress(pio.prefix);
                         final IpPrefix prefix = new IpPrefix(address, pio.prefixLen);
@@ -199,7 +215,7 @@
                             assertEquals(0, pio.validLifetime);
                             assertEquals(0, pio.preferredLifetime);
                         } else {
-                            fail("Unepxected prefix: " + prefix);
+                            fail("Unexpected prefix: " + prefix);
                         }
 
                         // Move ByteBuffer position to the next option.
@@ -261,15 +277,24 @@
         return params;
     }
 
-    private boolean assertRaPacket(final TestRaPacket testRa)
-            throws Exception {
+    private boolean isRaPacket(final TestRaPacket testRa, boolean multicast) throws Exception {
         byte[] packet;
         while ((packet = mTetheredPacketReader.poll(PACKET_TIMEOUT_MS)) != null) {
-            if (testRa.isPacketMatched(packet)) return true;
+            if (testRa.isPacketMatched(packet, multicast)) {
+                return true;
+            }
         }
         return false;
     }
 
+    private void assertUnicastRaPacket(final TestRaPacket testRa) throws Exception {
+        assertTrue(isRaPacket(testRa, false /* multicast */));
+    }
+
+    private void assertMulticastRaPacket(final TestRaPacket testRa) throws Exception {
+        assertTrue(isRaPacket(testRa, true /* multicast */));
+    }
+
     private ByteBuffer createRsPacket(final String srcIp) throws Exception {
         final MacAddress dstMac = MacAddress.fromString("33:33:03:04:05:06");
         final MacAddress srcMac = mTetheredParams.macAddr;
@@ -284,22 +309,36 @@
         assertTrue(mRaDaemon.start());
         final RaParams params1 = createRaParams("2001:1122:3344::5566");
         mRaDaemon.buildNewRa(null, params1);
-        assertRaPacket(new TestRaPacket(null, params1));
+        assertMulticastRaPacket(new TestRaPacket(null, params1));
 
         final RaParams params2 = createRaParams("2006:3344:5566::7788");
         mRaDaemon.buildNewRa(params1, params2);
-        assertRaPacket(new TestRaPacket(params1, params2));
+        assertMulticastRaPacket(new TestRaPacket(params1, params2));
     }
 
     @Test
     public void testSolicitRouterAdvertisement() throws Exception {
+        // Enable IPv6 forwarding is necessary, which makes kernel process RS correctly and
+        // create the neighbor entry for peer's link-layer address and IPv6 address. Otherwise,
+        // when device receives RS with IPv6 link-local address as source address, it has to
+        // initiate the address resolution first before responding the unicast RA.
+        sNetd.setProcSysNet(INetd.IPV6, INetd.CONF, mTetheredParams.name, "forwarding", "1");
+
         assertTrue(mRaDaemon.start());
         final RaParams params1 = createRaParams("2001:1122:3344::5566");
         mRaDaemon.buildNewRa(null, params1);
-        assertRaPacket(new TestRaPacket(null, params1));
+        assertMulticastRaPacket(new TestRaPacket(null, params1));
+
+        // Add a default route "fe80::/64 -> ::" to local network, otherwise, device will fail to
+        // send the unicast RA out due to the ENETUNREACH error(No route to the peer's link-local
+        // address is present).
+        final String iface = mTetheredParams.name;
+        final RouteInfo linkLocalRoute =
+                new RouteInfo(new IpPrefix("fe80::/64"), null, iface, RTN_UNICAST);
+        RouteUtils.addRoutesToLocalNetwork(sNetd, iface, List.of(linkLocalRoute));
 
         final ByteBuffer rs = createRsPacket("fe80::1122:3344:5566:7788");
         mTetheredPacketReader.sendResponse(rs);
-        assertRaPacket(new TestRaPacket(null, params1));
+        assertUnicastRaPacket(new TestRaPacket(null, params1));
     }
 }
diff --git a/Tethering/tests/privileged/src/com/android/networkstack/tethering/BpfMapTest.java b/Tethering/tests/privileged/src/com/android/networkstack/tethering/BpfMapTest.java
index 62302c3..830729d 100644
--- a/Tethering/tests/privileged/src/com/android/networkstack/tethering/BpfMapTest.java
+++ b/Tethering/tests/privileged/src/com/android/networkstack/tethering/BpfMapTest.java
@@ -65,13 +65,13 @@
     @Before
     public void setUp() throws Exception {
         mTestData = new ArrayMap<>();
-        mTestData.put(createTetherDownstream6Key(101, "2001:db8::1"),
+        mTestData.put(createTetherDownstream6Key(101, "00:00:00:00:00:aa", "2001:db8::1"),
                 createTether6Value(11, "00:00:00:00:00:0a", "11:11:11:00:00:0b",
                 ETH_P_IPV6, 1280));
-        mTestData.put(createTetherDownstream6Key(102, "2001:db8::2"),
+        mTestData.put(createTetherDownstream6Key(102, "00:00:00:00:00:bb", "2001:db8::2"),
                 createTether6Value(22, "00:00:00:00:00:0c", "22:22:22:00:00:0d",
                 ETH_P_IPV6, 1400));
-        mTestData.put(createTetherDownstream6Key(103, "2001:db8::3"),
+        mTestData.put(createTetherDownstream6Key(103, "00:00:00:00:00:cc", "2001:db8::3"),
                 createTether6Value(33, "00:00:00:00:00:0e", "33:33:33:00:00:0f",
                 ETH_P_IPV6, 1500));
 
@@ -94,11 +94,12 @@
         assertTrue(mTestMap.isEmpty());
     }
 
-    private TetherDownstream6Key createTetherDownstream6Key(long iif, String address)
-            throws Exception {
+    private TetherDownstream6Key createTetherDownstream6Key(long iif, String mac,
+            String address) throws Exception {
+        final MacAddress dstMac = MacAddress.fromString(mac);
         final InetAddress ipv6Address = InetAddress.getByName(address);
 
-        return new TetherDownstream6Key(iif, ipv6Address.getAddress());
+        return new TetherDownstream6Key(iif, dstMac, ipv6Address.getAddress());
     }
 
     private Tether6Value createTether6Value(int oif, String src, String dst, int proto, int pmtu) {
@@ -164,7 +165,7 @@
     public void testGetNextKey() throws Exception {
         // [1] If the passed-in key is not found on empty map, return null.
         final TetherDownstream6Key nonexistentKey =
-                createTetherDownstream6Key(1234, "2001:db8::10");
+                createTetherDownstream6Key(1234, "00:00:00:00:00:01", "2001:db8::10");
         assertNull(mTestMap.getNextKey(nonexistentKey));
 
         // [2] If the passed-in key is null on empty map, throw NullPointerException.
@@ -209,7 +210,7 @@
     }
 
     @Test
-    public void testUpdateBpfMap() throws Exception {
+    public void testUpdateEntry() throws Exception {
         final TetherDownstream6Key key = mTestData.keyAt(0);
         final Tether6Value value = mTestData.valueAt(0);
         final Tether6Value value2 = mTestData.valueAt(1);
@@ -232,6 +233,29 @@
     }
 
     @Test
+    public void testInsertOrReplaceEntry() throws Exception {
+        final TetherDownstream6Key key = mTestData.keyAt(0);
+        final Tether6Value value = mTestData.valueAt(0);
+        final Tether6Value value2 = mTestData.valueAt(1);
+        assertFalse(mTestMap.deleteEntry(key));
+
+        // insertOrReplaceEntry will create an entry if it does not exist already.
+        assertTrue(mTestMap.insertOrReplaceEntry(key, value));
+        assertTrue(mTestMap.containsKey(key));
+        final Tether6Value result = mTestMap.getValue(key);
+        assertEquals(value, result);
+
+        // updateEntry will update an entry that already exists.
+        assertFalse(mTestMap.insertOrReplaceEntry(key, value2));
+        assertTrue(mTestMap.containsKey(key));
+        final Tether6Value result2 = mTestMap.getValue(key);
+        assertEquals(value2, result2);
+
+        assertTrue(mTestMap.deleteEntry(key));
+        assertFalse(mTestMap.containsKey(key));
+    }
+
+    @Test
     public void testInsertReplaceEntry() throws Exception {
         final TetherDownstream6Key key = mTestData.keyAt(0);
         final Tether6Value value = mTestData.valueAt(0);
@@ -344,7 +368,8 @@
 
         // Build test data for TEST_MAP_SIZE + 1 entries.
         for (int i = 1; i <= TEST_MAP_SIZE + 1; i++) {
-            testData.put(createTetherDownstream6Key(i, "2001:db8::1"),
+            testData.put(
+                    createTetherDownstream6Key(i, "00:00:00:00:00:01", "2001:db8::1"),
                     createTether6Value(100, "de:ad:be:ef:00:01", "de:ad:be:ef:00:02",
                     ETH_P_IPV6, 1500));
         }
diff --git a/Tethering/tests/unit/src/android/net/ip/IpServerTest.java b/Tethering/tests/unit/src/android/net/ip/IpServerTest.java
index adf1f67..435cab5 100644
--- a/Tethering/tests/unit/src/android/net/ip/IpServerTest.java
+++ b/Tethering/tests/unit/src/android/net/ip/IpServerTest.java
@@ -474,6 +474,8 @@
         InOrder inOrder = inOrder(mNetd, mBpfCoordinator);
 
         // Add the forwarding pair <IFACE_NAME, UPSTREAM_IFACE>.
+        inOrder.verify(mBpfCoordinator).addUpstreamNameToLookupTable(UPSTREAM_IFINDEX,
+                UPSTREAM_IFACE);
         inOrder.verify(mBpfCoordinator).maybeAttachProgram(IFACE_NAME, UPSTREAM_IFACE);
         inOrder.verify(mNetd).tetherAddForward(IFACE_NAME, UPSTREAM_IFACE);
         inOrder.verify(mNetd).ipfwdAddInterfaceForward(IFACE_NAME, UPSTREAM_IFACE);
@@ -494,6 +496,8 @@
         inOrder.verify(mNetd).tetherRemoveForward(IFACE_NAME, UPSTREAM_IFACE);
 
         // Add the forwarding pair <IFACE_NAME, UPSTREAM_IFACE2>.
+        inOrder.verify(mBpfCoordinator).addUpstreamNameToLookupTable(UPSTREAM_IFINDEX2,
+                UPSTREAM_IFACE2);
         inOrder.verify(mBpfCoordinator).maybeAttachProgram(IFACE_NAME, UPSTREAM_IFACE2);
         inOrder.verify(mNetd).tetherAddForward(IFACE_NAME, UPSTREAM_IFACE2);
         inOrder.verify(mNetd).ipfwdAddInterfaceForward(IFACE_NAME, UPSTREAM_IFACE2);
@@ -517,6 +521,8 @@
 
         // Add the forwarding pair <IFACE_NAME, UPSTREAM_IFACE2> and expect that failed on
         // tetherAddForward.
+        inOrder.verify(mBpfCoordinator).addUpstreamNameToLookupTable(UPSTREAM_IFINDEX2,
+                UPSTREAM_IFACE2);
         inOrder.verify(mBpfCoordinator).maybeAttachProgram(IFACE_NAME, UPSTREAM_IFACE2);
         inOrder.verify(mNetd).tetherAddForward(IFACE_NAME, UPSTREAM_IFACE2);
 
@@ -543,6 +549,8 @@
 
         // Add the forwarding pair <IFACE_NAME, UPSTREAM_IFACE2> and expect that failed on
         // ipfwdAddInterfaceForward.
+        inOrder.verify(mBpfCoordinator).addUpstreamNameToLookupTable(UPSTREAM_IFINDEX2,
+                UPSTREAM_IFACE2);
         inOrder.verify(mBpfCoordinator).maybeAttachProgram(IFACE_NAME, UPSTREAM_IFACE2);
         inOrder.verify(mNetd).tetherAddForward(IFACE_NAME, UPSTREAM_IFACE2);
         inOrder.verify(mNetd).ipfwdAddInterfaceForward(IFACE_NAME, UPSTREAM_IFACE2);
@@ -830,8 +838,8 @@
 
     @NonNull
     private static TetherDownstream6Key makeDownstream6Key(int upstreamIfindex,
-            @NonNull final InetAddress dst) {
-        return new TetherDownstream6Key(upstreamIfindex, dst.getAddress());
+            @NonNull MacAddress upstreamMac, @NonNull final InetAddress dst) {
+        return new TetherDownstream6Key(upstreamIfindex, upstreamMac, dst.getAddress());
     }
 
     @NonNull
@@ -849,10 +857,12 @@
     }
 
     private void verifyTetherOffloadRuleAdd(@Nullable InOrder inOrder, int upstreamIfindex,
-            @NonNull final InetAddress dst, @NonNull final MacAddress dstMac) throws Exception {
+            @NonNull MacAddress upstreamMac, @NonNull final InetAddress dst,
+            @NonNull final MacAddress dstMac) throws Exception {
         if (mBpfDeps.isAtLeastS()) {
             verifyWithOrder(inOrder, mBpfDownstream6Map).updateEntry(
-                    makeDownstream6Key(upstreamIfindex, dst), makeDownstream6Value(dstMac));
+                    makeDownstream6Key(upstreamIfindex, upstreamMac, dst),
+                    makeDownstream6Value(dstMac));
         } else {
             verifyWithOrder(inOrder, mNetd).tetherOffloadRuleAdd(matches(upstreamIfindex, dst,
                     dstMac));
@@ -860,10 +870,11 @@
     }
 
     private void verifyNeverTetherOffloadRuleAdd(int upstreamIfindex,
-            @NonNull final InetAddress dst, @NonNull final MacAddress dstMac) throws Exception {
+            @NonNull MacAddress upstreamMac, @NonNull final InetAddress dst,
+            @NonNull final MacAddress dstMac) throws Exception {
         if (mBpfDeps.isAtLeastS()) {
             verify(mBpfDownstream6Map, never()).updateEntry(
-                    makeDownstream6Key(upstreamIfindex, dst),
+                    makeDownstream6Key(upstreamIfindex, upstreamMac, dst),
                     makeDownstream6Value(dstMac));
         } else {
             verify(mNetd, never()).tetherOffloadRuleAdd(matches(upstreamIfindex, dst, dstMac));
@@ -879,10 +890,11 @@
     }
 
     private void verifyTetherOffloadRuleRemove(@Nullable InOrder inOrder, int upstreamIfindex,
-            @NonNull final InetAddress dst, @NonNull final MacAddress dstMac) throws Exception {
+            @NonNull MacAddress upstreamMac, @NonNull final InetAddress dst,
+            @NonNull final MacAddress dstMac) throws Exception {
         if (mBpfDeps.isAtLeastS()) {
             verifyWithOrder(inOrder, mBpfDownstream6Map).deleteEntry(makeDownstream6Key(
-                    upstreamIfindex, dst));
+                    upstreamIfindex, upstreamMac, dst));
         } else {
             // |dstMac| is not required for deleting rules. Used bacause tetherOffloadRuleRemove
             // uses a whole rule to be a argument.
@@ -903,7 +915,8 @@
     private void verifyStartUpstreamIpv6Forwarding(@Nullable InOrder inOrder, int upstreamIfindex)
             throws Exception {
         if (!mBpfDeps.isAtLeastS()) return;
-        final TetherUpstream6Key key = new TetherUpstream6Key(TEST_IFACE_PARAMS.index);
+        final TetherUpstream6Key key = new TetherUpstream6Key(TEST_IFACE_PARAMS.index,
+                TEST_IFACE_PARAMS.macAddr);
         final Tether6Value value = new Tether6Value(upstreamIfindex,
                 MacAddress.ALL_ZEROS_ADDRESS, MacAddress.ALL_ZEROS_ADDRESS,
                 ETH_P_IPV6, NetworkStackConstants.ETHER_MTU);
@@ -913,7 +926,8 @@
     private void verifyStopUpstreamIpv6Forwarding(@Nullable InOrder inOrder)
             throws Exception {
         if (!mBpfDeps.isAtLeastS()) return;
-        final TetherUpstream6Key key = new TetherUpstream6Key(TEST_IFACE_PARAMS.index);
+        final TetherUpstream6Key key = new TetherUpstream6Key(TEST_IFACE_PARAMS.index,
+                TEST_IFACE_PARAMS.macAddr);
         verifyWithOrder(inOrder, mBpfUpstream6Map).deleteEntry(key);
     }
 
@@ -983,14 +997,16 @@
         recvNewNeigh(myIfindex, neighA, NUD_REACHABLE, macA);
         verify(mBpfCoordinator).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighA, macA));
-        verifyTetherOffloadRuleAdd(null, UPSTREAM_IFINDEX, neighA, macA);
+        verifyTetherOffloadRuleAdd(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighA, macA);
         verifyStartUpstreamIpv6Forwarding(null, UPSTREAM_IFINDEX);
         resetNetdBpfMapAndCoordinator();
 
         recvNewNeigh(myIfindex, neighB, NUD_REACHABLE, macB);
         verify(mBpfCoordinator).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighB, macB));
-        verifyTetherOffloadRuleAdd(null, UPSTREAM_IFINDEX, neighB, macB);
+        verifyTetherOffloadRuleAdd(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macB);
         verifyNoUpstreamIpv6ForwardingChange(null);
         resetNetdBpfMapAndCoordinator();
 
@@ -1005,7 +1021,8 @@
         recvNewNeigh(myIfindex, neighA, NUD_FAILED, null);
         verify(mBpfCoordinator).tetherOffloadRuleRemove(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighA, macNull));
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX, neighA, macNull);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighA, macNull);
         verifyNoUpstreamIpv6ForwardingChange(null);
         resetNetdBpfMapAndCoordinator();
 
@@ -1013,7 +1030,8 @@
         recvDelNeigh(myIfindex, neighB, NUD_STALE, macB);
         verify(mBpfCoordinator).tetherOffloadRuleRemove(
                 mIpServer,  makeForwardingRule(UPSTREAM_IFINDEX, neighB, macNull));
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX, neighB, macNull);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macNull);
         verifyStopUpstreamIpv6Forwarding(null);
         resetNetdBpfMapAndCoordinator();
 
@@ -1028,12 +1046,16 @@
         lp.setInterfaceName(UPSTREAM_IFACE2);
         dispatchTetherConnectionChanged(UPSTREAM_IFACE2, lp, -1);
         verify(mBpfCoordinator).tetherOffloadRuleUpdate(mIpServer, UPSTREAM_IFINDEX2);
-        verifyTetherOffloadRuleRemove(inOrder, UPSTREAM_IFINDEX, neighA, macA);
-        verifyTetherOffloadRuleRemove(inOrder, UPSTREAM_IFINDEX, neighB, macB);
+        verifyTetherOffloadRuleRemove(inOrder,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighA, macA);
+        verifyTetherOffloadRuleRemove(inOrder,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macB);
         verifyStopUpstreamIpv6Forwarding(inOrder);
-        verifyTetherOffloadRuleAdd(inOrder, UPSTREAM_IFINDEX2, neighA, macA);
+        verifyTetherOffloadRuleAdd(inOrder,
+                UPSTREAM_IFINDEX2, UPSTREAM_IFACE_PARAMS2.macAddr, neighA, macA);
         verifyStartUpstreamIpv6Forwarding(inOrder, UPSTREAM_IFINDEX2);
-        verifyTetherOffloadRuleAdd(inOrder, UPSTREAM_IFINDEX2, neighB, macB);
+        verifyTetherOffloadRuleAdd(inOrder,
+                UPSTREAM_IFINDEX2, UPSTREAM_IFACE_PARAMS2.macAddr, neighB, macB);
         verifyNoUpstreamIpv6ForwardingChange(inOrder);
         resetNetdBpfMapAndCoordinator();
 
@@ -1044,8 +1066,10 @@
         // - processMessage CMD_IPV6_TETHER_UPDATE for the IPv6 upstream is lost.
         // See dispatchTetherConnectionChanged.
         verify(mBpfCoordinator, times(2)).tetherOffloadRuleClear(mIpServer);
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX2, neighA, macA);
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX2, neighB, macB);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX2, UPSTREAM_IFACE_PARAMS2.macAddr, neighA, macA);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX2, UPSTREAM_IFACE_PARAMS2.macAddr, neighB, macB);
         verifyStopUpstreamIpv6Forwarding(inOrder);
         resetNetdBpfMapAndCoordinator();
 
@@ -1064,17 +1088,20 @@
         recvNewNeigh(myIfindex, neighB, NUD_REACHABLE, macB);
         verify(mBpfCoordinator).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighB, macB));
-        verifyTetherOffloadRuleAdd(null, UPSTREAM_IFINDEX, neighB, macB);
+        verifyTetherOffloadRuleAdd(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macB);
         verifyStartUpstreamIpv6Forwarding(null, UPSTREAM_IFINDEX);
         verify(mBpfCoordinator, never()).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighA, macA));
-        verifyNeverTetherOffloadRuleAdd(UPSTREAM_IFINDEX, neighA, macA);
+        verifyNeverTetherOffloadRuleAdd(
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighA, macA);
 
         // If upstream IPv6 connectivity is lost, rules are removed.
         resetNetdBpfMapAndCoordinator();
         dispatchTetherConnectionChanged(UPSTREAM_IFACE, null, 0);
         verify(mBpfCoordinator).tetherOffloadRuleClear(mIpServer);
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX, neighB, macB);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macB);
         verifyStopUpstreamIpv6Forwarding(null);
 
         // When the interface goes down, rules are removed.
@@ -1084,18 +1111,22 @@
         recvNewNeigh(myIfindex, neighB, NUD_REACHABLE, macB);
         verify(mBpfCoordinator).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighA, macA));
-        verifyTetherOffloadRuleAdd(null, UPSTREAM_IFINDEX, neighA, macA);
+        verifyTetherOffloadRuleAdd(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighA, macA);
         verifyStartUpstreamIpv6Forwarding(null, UPSTREAM_IFINDEX);
         verify(mBpfCoordinator).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neighB, macB));
-        verifyTetherOffloadRuleAdd(null, UPSTREAM_IFINDEX, neighB, macB);
+        verifyTetherOffloadRuleAdd(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macB);
         resetNetdBpfMapAndCoordinator();
 
         mIpServer.stop();
         mLooper.dispatchAll();
         verify(mBpfCoordinator).tetherOffloadRuleClear(mIpServer);
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX, neighA, macA);
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX, neighB, macB);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighA, macA);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neighB, macB);
         verifyStopUpstreamIpv6Forwarding(null);
         verify(mIpNeighborMonitor).stop();
         resetNetdBpfMapAndCoordinator();
@@ -1124,14 +1155,16 @@
         recvNewNeigh(myIfindex, neigh, NUD_REACHABLE, macA);
         verify(mBpfCoordinator).tetherOffloadRuleAdd(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neigh, macA));
-        verifyTetherOffloadRuleAdd(null, UPSTREAM_IFINDEX, neigh, macA);
+        verifyTetherOffloadRuleAdd(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neigh, macA);
         verifyStartUpstreamIpv6Forwarding(null, UPSTREAM_IFINDEX);
         resetNetdBpfMapAndCoordinator();
 
         recvDelNeigh(myIfindex, neigh, NUD_STALE, macA);
         verify(mBpfCoordinator).tetherOffloadRuleRemove(
                 mIpServer, makeForwardingRule(UPSTREAM_IFINDEX, neigh, macNull));
-        verifyTetherOffloadRuleRemove(null, UPSTREAM_IFINDEX, neigh, macNull);
+        verifyTetherOffloadRuleRemove(null,
+                UPSTREAM_IFINDEX, UPSTREAM_IFACE_PARAMS.macAddr, neigh, macNull);
         verifyStopUpstreamIpv6Forwarding(null);
         resetNetdBpfMapAndCoordinator();
 
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
index 293d0df..233f6db 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
@@ -23,9 +23,21 @@
 import static android.net.NetworkStats.TAG_NONE;
 import static android.net.NetworkStats.UID_ALL;
 import static android.net.NetworkStats.UID_TETHERING;
+import static android.net.ip.ConntrackMonitor.ConntrackEvent;
+import static android.net.netlink.ConntrackMessage.DYING_MASK;
+import static android.net.netlink.ConntrackMessage.ESTABLISHED_MASK;
+import static android.net.netlink.ConntrackMessage.Tuple;
+import static android.net.netlink.ConntrackMessage.TupleIpv4;
+import static android.net.netlink.ConntrackMessage.TupleProto;
+import static android.net.netlink.NetlinkConstants.IPCTNL_MSG_CT_DELETE;
+import static android.net.netlink.NetlinkConstants.IPCTNL_MSG_CT_NEW;
 import static android.net.netstats.provider.NetworkStatsProvider.QUOTA_UNLIMITED;
+import static android.system.OsConstants.ETH_P_IP;
 import static android.system.OsConstants.ETH_P_IPV6;
+import static android.system.OsConstants.IPPROTO_TCP;
+import static android.system.OsConstants.IPPROTO_UDP;
 
+import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.staticMockMarker;
 import static com.android.networkstack.tethering.BpfCoordinator.StatsType;
 import static com.android.networkstack.tethering.BpfCoordinator.StatsType.STATS_PER_IFACE;
@@ -43,9 +55,9 @@
 import static org.mockito.Matchers.anyInt;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.argThat;
 import static org.mockito.Mockito.clearInvocations;
-import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
@@ -55,6 +67,8 @@
 import android.app.usage.NetworkStatsManager;
 import android.net.INetd;
 import android.net.InetAddresses;
+import android.net.LinkAddress;
+import android.net.LinkProperties;
 import android.net.MacAddress;
 import android.net.NetworkStats;
 import android.net.TetherOffloadRuleParcel;
@@ -62,6 +76,8 @@
 import android.net.ip.ConntrackMonitor;
 import android.net.ip.ConntrackMonitor.ConntrackEventConsumer;
 import android.net.ip.IpServer;
+import android.net.netlink.NetlinkConstants;
+import android.net.util.InterfaceParams;
 import android.net.util.SharedLog;
 import android.os.Build;
 import android.os.Handler;
@@ -76,6 +92,8 @@
 import com.android.dx.mockito.inline.extended.ExtendedMockito;
 import com.android.net.module.util.NetworkStackConstants;
 import com.android.net.module.util.Struct;
+import com.android.networkstack.tethering.BpfCoordinator.BpfConntrackEventConsumer;
+import com.android.networkstack.tethering.BpfCoordinator.ClientInfo;
 import com.android.networkstack.tethering.BpfCoordinator.Ipv6ForwardingRule;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
@@ -93,6 +111,7 @@
 import org.mockito.MockitoAnnotations;
 import org.mockito.MockitoSession;
 
+import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.util.ArrayList;
@@ -108,13 +127,22 @@
     @Rule
     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
 
-    private static final int DOWNSTREAM_IFINDEX = 10;
-    private static final MacAddress DOWNSTREAM_MAC = MacAddress.ALL_ZEROS_ADDRESS;
-    private static final InetAddress NEIGH_A = InetAddresses.parseNumericAddress("2001:db8::1");
-    private static final InetAddress NEIGH_B = InetAddresses.parseNumericAddress("2001:db8::2");
+    private static final int UPSTREAM_IFINDEX = 1001;
+    private static final int DOWNSTREAM_IFINDEX = 1002;
+
+    private static final String UPSTREAM_IFACE = "rmnet0";
+
+    private static final MacAddress DOWNSTREAM_MAC = MacAddress.fromString("12:34:56:78:90:ab");
     private static final MacAddress MAC_A = MacAddress.fromString("00:00:00:00:00:0a");
     private static final MacAddress MAC_B = MacAddress.fromString("11:22:33:00:00:0b");
 
+    private static final InetAddress NEIGH_A = InetAddresses.parseNumericAddress("2001:db8::1");
+    private static final InetAddress NEIGH_B = InetAddresses.parseNumericAddress("2001:db8::2");
+
+    private static final InterfaceParams UPSTREAM_IFACE_PARAMS = new InterfaceParams(
+            UPSTREAM_IFACE, UPSTREAM_IFINDEX, null /* macAddr, rawip */,
+            NetworkStackConstants.ETHER_MTU);
+
     // The test fake BPF map class is needed because the test has no privilege to access the BPF
     // map. All member functions which eventually call JNI to access the real native BPF map need
     // to be overridden.
@@ -183,6 +211,11 @@
     // Late init since methods must be called by the thread that created this object.
     private TestableNetworkStatsProviderCbBinder mTetherStatsProviderCb;
     private BpfCoordinator.BpfTetherStatsProvider mTetherStatsProvider;
+
+    // Late init since the object must be initialized by the BPF coordinator instance because
+    // it has to access the non-static function of BPF coordinator.
+    private BpfConntrackEventConsumer mConsumer;
+
     private final ArgumentCaptor<ArrayList> mStringArrayCaptor =
             ArgumentCaptor.forClass(ArrayList.class);
     private final TestLooper mTestLooper = new TestLooper();
@@ -262,6 +295,8 @@
         mTestLooper.dispatchAll();
     }
 
+    // TODO: Remove unnecessary calling on R because the BPF map accessing has been moved into
+    // module.
     private void setupFunctioningNetdInterface() throws Exception {
         when(mNetd.tetherOffloadGetStats()).thenReturn(new TetherStatsParcel[0]);
     }
@@ -269,6 +304,8 @@
     @NonNull
     private BpfCoordinator makeBpfCoordinator() throws Exception {
         final BpfCoordinator coordinator = new BpfCoordinator(mDeps);
+
+        mConsumer = coordinator.getBpfConntrackEventConsumerForTesting();
         final ArgumentCaptor<BpfCoordinator.BpfTetherStatsProvider>
                 tetherStatsProviderCaptor =
                 ArgumentCaptor.forClass(BpfCoordinator.BpfTetherStatsProvider.class);
@@ -278,6 +315,7 @@
         assertNotNull(mTetherStatsProvider);
         mTetherStatsProviderCb = new TestableNetworkStatsProviderCbBinder();
         mTetherStatsProvider.setProviderCallbackBinder(mTetherStatsProviderCb);
+
         return coordinator;
     }
 
@@ -383,19 +421,20 @@
     }
 
     private void verifyStartUpstreamIpv6Forwarding(@Nullable InOrder inOrder, int downstreamIfIndex,
-            int upstreamIfindex) throws Exception {
+            MacAddress downstreamMac, int upstreamIfindex) throws Exception {
         if (!mDeps.isAtLeastS()) return;
-        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfIndex);
+        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfIndex, downstreamMac);
         final Tether6Value value = new Tether6Value(upstreamIfindex,
                 MacAddress.ALL_ZEROS_ADDRESS, MacAddress.ALL_ZEROS_ADDRESS,
                 ETH_P_IPV6, NetworkStackConstants.ETHER_MTU);
         verifyWithOrder(inOrder, mBpfUpstream6Map).insertEntry(key, value);
     }
 
-    private void verifyStopUpstreamIpv6Forwarding(@Nullable InOrder inOrder, int downstreamIfIndex)
+    private void verifyStopUpstreamIpv6Forwarding(@Nullable InOrder inOrder, int downstreamIfIndex,
+            MacAddress downstreamMac)
             throws Exception {
         if (!mDeps.isAtLeastS()) return;
-        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfIndex);
+        final TetherUpstream6Key key = new TetherUpstream6Key(downstreamIfIndex, downstreamMac);
         verifyWithOrder(inOrder, mBpfUpstream6Map).deleteEntry(key);
     }
 
@@ -465,7 +504,7 @@
         }
     }
 
-    private void verifyNeverTetherOffloadSetInterfaceQuota(@Nullable InOrder inOrder)
+    private void verifyNeverTetherOffloadSetInterfaceQuota(@NonNull InOrder inOrder)
             throws Exception {
         if (mDeps.isAtLeastS()) {
             inOrder.verify(mBpfStatsMap, never()).getValue(any());
@@ -476,7 +515,7 @@
         }
     }
 
-    private void verifyTetherOffloadGetAndClearStats(@Nullable InOrder inOrder, int ifIndex)
+    private void verifyTetherOffloadGetAndClearStats(@NonNull InOrder inOrder, int ifIndex)
             throws Exception {
         if (mDeps.isAtLeastS()) {
             inOrder.verify(mBpfStatsMap).getValue(new TetherStatsKey(ifIndex));
@@ -732,9 +771,10 @@
 
         final TetherDownstream6Key key = rule.makeTetherDownstream6Key();
         assertEquals(key.iif, (long) mobileIfIndex);
+        assertEquals(key.dstMac, MacAddress.ALL_ZEROS_ADDRESS);  // rawip upstream
         assertTrue(Arrays.equals(key.neigh6, NEIGH_A.getAddress()));
-        // iif (4) + neigh6 (16) = 20.
-        assertEquals(20, key.writeToBytes().length);
+        // iif (4) + dstMac(6) + padding(2) + neigh6 (16) = 28.
+        assertEquals(28, key.writeToBytes().length);
     }
 
     @Test
@@ -797,7 +837,7 @@
 
     // TODO: Test the case in which the rules are changed from different IpServer objects.
     @Test
-    public void testSetDataLimitOnRuleChange() throws Exception {
+    public void testSetDataLimitOnRule6Change() throws Exception {
         setupFunctioningNetdInterface();
 
         final BpfCoordinator coordinator = makeBpfCoordinator();
@@ -875,7 +915,7 @@
         verifyTetherOffloadRuleAdd(inOrder, ethernetRuleA);
         verifyTetherOffloadSetInterfaceQuota(inOrder, ethIfIndex, QUOTA_UNLIMITED,
                 true /* isInit */);
-        verifyStartUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX, ethIfIndex);
+        verifyStartUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX, DOWNSTREAM_MAC, ethIfIndex);
         coordinator.tetherOffloadRuleAdd(mIpServer, ethernetRuleB);
         verifyTetherOffloadRuleAdd(inOrder, ethernetRuleB);
 
@@ -892,12 +932,13 @@
         coordinator.tetherOffloadRuleUpdate(mIpServer, mobileIfIndex);
         verifyTetherOffloadRuleRemove(inOrder, ethernetRuleA);
         verifyTetherOffloadRuleRemove(inOrder, ethernetRuleB);
-        verifyStopUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX);
+        verifyStopUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX, DOWNSTREAM_MAC);
         verifyTetherOffloadGetAndClearStats(inOrder, ethIfIndex);
         verifyTetherOffloadRuleAdd(inOrder, mobileRuleA);
         verifyTetherOffloadSetInterfaceQuota(inOrder, mobileIfIndex, QUOTA_UNLIMITED,
                 true /* isInit */);
-        verifyStartUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX, mobileIfIndex);
+        verifyStartUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX, DOWNSTREAM_MAC,
+                mobileIfIndex);
         verifyTetherOffloadRuleAdd(inOrder, mobileRuleB);
 
         // [3] Clear all rules for a given IpServer.
@@ -906,7 +947,7 @@
         coordinator.tetherOffloadRuleClear(mIpServer);
         verifyTetherOffloadRuleRemove(inOrder, mobileRuleA);
         verifyTetherOffloadRuleRemove(inOrder, mobileRuleB);
-        verifyStopUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX);
+        verifyStopUpstreamIpv6Forwarding(inOrder, DOWNSTREAM_IFINDEX, DOWNSTREAM_MAC);
         verifyTetherOffloadGetAndClearStats(inOrder, mobileIfIndex);
 
         // [4] Force pushing stats update to verify that the last diff of stats is reported on all
@@ -1216,4 +1257,192 @@
         coordinator.stopMonitoring(mIpServer);
         verify(mConntrackMonitor).stop();
     }
+
+    // Test network topology:
+    //
+    //         public network (rawip)                 private network
+    //                   |                 UE                |
+    // +------------+    V    +------------+------------+    V    +------------+
+    // |   Sever    +---------+  Upstream  | Downstream +---------+   Client   |
+    // +------------+         +------------+------------+         +------------+
+    // remote ip              public ip                           private ip
+    // 140.112.8.116:443      100.81.179.1:62449                  192.168.80.12:62449
+    //
+    private static final Inet4Address REMOTE_ADDR =
+            (Inet4Address) InetAddresses.parseNumericAddress("140.112.8.116");
+    private static final Inet4Address PUBLIC_ADDR =
+            (Inet4Address) InetAddresses.parseNumericAddress("100.81.179.1");
+    private static final Inet4Address PRIVATE_ADDR =
+            (Inet4Address) InetAddresses.parseNumericAddress("192.168.80.12");
+
+    // IPv4-mapped IPv6 addresses
+    // Remote addrress ::ffff:140.112.8.116
+    // Public addrress ::ffff:100.81.179.1
+    // Private addrress ::ffff:192.168.80.12
+    private static final byte[] REMOTE_ADDR_V4MAPPED_BYTES = new byte[] {
+            (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+            (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+            (byte) 0x8c, (byte) 0x70, (byte) 0x08, (byte) 0x74 };
+    private static final byte[] PUBLIC_ADDR_V4MAPPED_BYTES = new byte[] {
+            (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+            (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+            (byte) 0x64, (byte) 0x51, (byte) 0xb3, (byte) 0x01 };
+    private static final byte[] PRIVATE_ADDR_V4MAPPED_BYTES = new byte[] {
+            (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00,
+            (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0xff, (byte) 0xff,
+            (byte) 0xc0, (byte) 0xa8, (byte) 0x50, (byte) 0x0c };
+
+    // Generally, public port and private port are the same in the NAT conntrack message.
+    // TODO: consider using different private port and public port for testing.
+    private static final short REMOTE_PORT = (short) 443;
+    private static final short PUBLIC_PORT = (short) 62449;
+    private static final short PRIVATE_PORT = (short) 62449;
+
+    @NonNull
+    private Tether4Key makeUpstream4Key(int proto) {
+        if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
+            fail("Not support protocol " + proto);
+        }
+        return new Tether4Key(DOWNSTREAM_IFINDEX, DOWNSTREAM_MAC, (short) proto,
+            PRIVATE_ADDR.getAddress(), REMOTE_ADDR.getAddress(), PRIVATE_PORT, REMOTE_PORT);
+    }
+
+    @NonNull
+    private Tether4Key makeDownstream4Key(int proto) {
+        if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
+            fail("Not support protocol " + proto);
+        }
+        return new Tether4Key(UPSTREAM_IFINDEX,
+                MacAddress.ALL_ZEROS_ADDRESS /* dstMac (rawip) */, (short) proto,
+                REMOTE_ADDR.getAddress(), PUBLIC_ADDR.getAddress(), REMOTE_PORT, PUBLIC_PORT);
+    }
+
+    @NonNull
+    private Tether4Value makeUpstream4Value() {
+        return new Tether4Value(UPSTREAM_IFINDEX,
+                MacAddress.ALL_ZEROS_ADDRESS /* ethDstMac (rawip) */,
+                MacAddress.ALL_ZEROS_ADDRESS /* ethSrcMac (rawip) */, ETH_P_IP,
+                NetworkStackConstants.ETHER_MTU, PUBLIC_ADDR_V4MAPPED_BYTES,
+                REMOTE_ADDR_V4MAPPED_BYTES, PUBLIC_PORT, REMOTE_PORT, 0 /* lastUsed */);
+    }
+
+    @NonNull
+    private Tether4Value makeDownstream4Value() {
+        return new Tether4Value(DOWNSTREAM_IFINDEX, MAC_A /* client mac */, DOWNSTREAM_MAC,
+                ETH_P_IP, NetworkStackConstants.ETHER_MTU, REMOTE_ADDR_V4MAPPED_BYTES,
+                PRIVATE_ADDR_V4MAPPED_BYTES, REMOTE_PORT, PRIVATE_PORT, 0 /* lastUsed */);
+    }
+
+    @NonNull
+    private ConntrackEvent makeTestConntrackEvent(short msgType, int proto) {
+        if (msgType != IPCTNL_MSG_CT_NEW && msgType != IPCTNL_MSG_CT_DELETE) {
+            fail("Not support message type " + msgType);
+        }
+        if (proto != IPPROTO_TCP && proto != IPPROTO_UDP) {
+            fail("Not support protocol " + proto);
+        }
+
+        final int status = (msgType == IPCTNL_MSG_CT_NEW) ? ESTABLISHED_MASK : DYING_MASK;
+        final int timeoutSec = (msgType == IPCTNL_MSG_CT_NEW) ? 100 /* nonzero, new */
+                : 0 /* unused, delete */;
+        return new ConntrackEvent(
+                (short) (NetlinkConstants.NFNL_SUBSYS_CTNETLINK << 8 | msgType),
+                new Tuple(new TupleIpv4(PRIVATE_ADDR, REMOTE_ADDR),
+                        new TupleProto((byte) proto, PRIVATE_PORT, REMOTE_PORT)),
+                new Tuple(new TupleIpv4(REMOTE_ADDR, PUBLIC_ADDR),
+                        new TupleProto((byte) proto, REMOTE_PORT, PUBLIC_PORT)),
+                status,
+                timeoutSec);
+    }
+
+    private void setUpstreamInformationTo(final BpfCoordinator coordinator) {
+        final LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName(UPSTREAM_IFACE);
+        lp.addLinkAddress(new LinkAddress(PUBLIC_ADDR, 32 /* prefix length */));
+        coordinator.addUpstreamIfindexToMap(lp);
+    }
+
+    private void setDownstreamAndClientInformationTo(final BpfCoordinator coordinator) {
+        final ClientInfo clientInfo = new ClientInfo(DOWNSTREAM_IFINDEX, DOWNSTREAM_MAC,
+                PRIVATE_ADDR, MAC_A /* client mac */);
+        coordinator.tetherOffloadClientAdd(mIpServer, clientInfo);
+    }
+
+    // TODO: Test the IPv4 and IPv6 exist concurrently.
+    // TODO: Test the IPv4 rule delete failed.
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testSetDataLimitOnRule4Change() throws Exception {
+        final BpfCoordinator coordinator = makeBpfCoordinator();
+        coordinator.startPolling();
+
+        // Needed because tetherOffloadRuleRemove of api31.BpfCoordinatorShimImpl only decreases
+        // the count while the entry is deleted. In the other words, deleteEntry returns true.
+        doReturn(true).when(mBpfDownstream4Map).deleteEntry(any());
+
+        // Needed because BpfCoordinator#addUpstreamIfindexToMap queries interface parameter for
+        // interface index.
+        doReturn(UPSTREAM_IFACE_PARAMS).when(mDeps).getInterfaceParams(UPSTREAM_IFACE);
+
+        coordinator.addUpstreamNameToLookupTable(UPSTREAM_IFINDEX, UPSTREAM_IFACE);
+        setUpstreamInformationTo(coordinator);
+        setDownstreamAndClientInformationTo(coordinator);
+
+        // Applying a data limit to the current upstream does not take any immediate action.
+        // The data limit could be only set on an upstream which has rules.
+        final long limit = 12345;
+        final InOrder inOrder = inOrder(mNetd, mBpfUpstream4Map, mBpfDownstream4Map, mBpfLimitMap,
+                mBpfStatsMap);
+        mTetherStatsProvider.onSetLimit(UPSTREAM_IFACE, limit);
+        waitForIdle();
+        verifyNeverTetherOffloadSetInterfaceQuota(inOrder);
+
+        // Build TCP and UDP rules for testing. Note that the values of {TCP, UDP} are the same
+        // because the protocol is not an element of the value. Consider using different address
+        // or port to make them different for better testing.
+        // TODO: Make the values of {TCP, UDP} rules different.
+        final Tether4Key expectedUpstream4KeyTcp = makeUpstream4Key(IPPROTO_TCP);
+        final Tether4Key expectedDownstream4KeyTcp = makeDownstream4Key(IPPROTO_TCP);
+        final Tether4Value expectedUpstream4ValueTcp = makeUpstream4Value();
+        final Tether4Value expectedDownstream4ValueTcp = makeDownstream4Value();
+
+        final Tether4Key expectedUpstream4KeyUdp = makeUpstream4Key(IPPROTO_UDP);
+        final Tether4Key expectedDownstream4KeyUdp = makeDownstream4Key(IPPROTO_UDP);
+        final Tether4Value expectedUpstream4ValueUdp = makeUpstream4Value();
+        final Tether4Value expectedDownstream4ValueUdp = makeDownstream4Value();
+
+        // [1] Adding the first rule on current upstream immediately sends the quota.
+        mConsumer.accept(makeTestConntrackEvent(IPCTNL_MSG_CT_NEW, IPPROTO_TCP));
+        verifyTetherOffloadSetInterfaceQuota(inOrder, UPSTREAM_IFINDEX, limit, true /* isInit */);
+        inOrder.verify(mBpfUpstream4Map)
+                .insertEntry(eq(expectedUpstream4KeyTcp), eq(expectedUpstream4ValueTcp));
+        inOrder.verify(mBpfDownstream4Map)
+                .insertEntry(eq(expectedDownstream4KeyTcp), eq(expectedDownstream4ValueTcp));
+        inOrder.verifyNoMoreInteractions();
+
+        // [2] Adding the second rule on current upstream does not send the quota.
+        mConsumer.accept(makeTestConntrackEvent(IPCTNL_MSG_CT_NEW, IPPROTO_UDP));
+        verifyNeverTetherOffloadSetInterfaceQuota(inOrder);
+        inOrder.verify(mBpfUpstream4Map)
+                .insertEntry(eq(expectedUpstream4KeyUdp), eq(expectedUpstream4ValueUdp));
+        inOrder.verify(mBpfDownstream4Map)
+                .insertEntry(eq(expectedDownstream4KeyUdp), eq(expectedDownstream4ValueUdp));
+        inOrder.verifyNoMoreInteractions();
+
+        // [3] Removing the second rule on current upstream does not send the quota.
+        mConsumer.accept(makeTestConntrackEvent(IPCTNL_MSG_CT_DELETE, IPPROTO_UDP));
+        verifyNeverTetherOffloadSetInterfaceQuota(inOrder);
+        inOrder.verify(mBpfUpstream4Map).deleteEntry(eq(expectedUpstream4KeyUdp));
+        inOrder.verify(mBpfDownstream4Map).deleteEntry(eq(expectedDownstream4KeyUdp));
+        inOrder.verifyNoMoreInteractions();
+
+        // [4] Removing the last rule on current upstream immediately sends the cleanup stuff.
+        updateStatsEntryForTetherOffloadGetAndClearStats(
+                buildTestTetherStatsParcel(UPSTREAM_IFINDEX, 0, 0, 0, 0));
+        mConsumer.accept(makeTestConntrackEvent(IPCTNL_MSG_CT_DELETE, IPPROTO_TCP));
+        inOrder.verify(mBpfUpstream4Map).deleteEntry(eq(expectedUpstream4KeyTcp));
+        inOrder.verify(mBpfDownstream4Map).deleteEntry(eq(expectedDownstream4KeyTcp));
+        verifyTetherOffloadGetAndClearStats(inOrder, UPSTREAM_IFINDEX);
+        inOrder.verifyNoMoreInteractions();
+    }
 }
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java
index 3636b03..d045bf1 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TestConnectivityManager.java
@@ -30,12 +30,12 @@
 import android.net.NetworkRequest;
 import android.os.Handler;
 import android.os.UserHandle;
+import android.util.ArrayMap;
 
-import java.util.HashMap;
-import java.util.HashSet;
+import androidx.annotation.Nullable;
+
 import java.util.Map;
 import java.util.Objects;
-import java.util.Set;
 
 /**
  * Simulates upstream switching and sending NetworkCallbacks and CONNECTIVITY_ACTION broadcasts.
@@ -60,17 +60,20 @@
  *   that state changes), this may become less important or unnecessary.
  */
 public class TestConnectivityManager extends ConnectivityManager {
-    public Map<NetworkCallback, Handler> allCallbacks = new HashMap<>();
-    public Set<NetworkCallback> trackingDefault = new HashSet<>();
-    public TestNetworkAgent defaultNetwork = null;
-    public Map<NetworkCallback, NetworkRequest> listening = new HashMap<>();
-    public Map<NetworkCallback, NetworkRequest> requested = new HashMap<>();
-    public Map<NetworkCallback, Integer> legacyTypeMap = new HashMap<>();
+    public static final boolean BROADCAST_FIRST = false;
+    public static final boolean CALLBACKS_FIRST = true;
+
+    final Map<NetworkCallback, NetworkRequestInfo> mAllCallbacks = new ArrayMap<>();
+    final Map<NetworkCallback, NetworkRequestInfo> mTrackingDefault = new ArrayMap<>();
+    final Map<NetworkCallback, NetworkRequestInfo> mListening = new ArrayMap<>();
+    final Map<NetworkCallback, NetworkRequestInfo> mRequested = new ArrayMap<>();
+    final Map<NetworkCallback, Integer> mLegacyTypeMap = new ArrayMap<>();
 
     private final NetworkRequest mDefaultRequest;
     private final Context mContext;
 
     private int mNetworkId = 100;
+    private TestNetworkAgent mDefaultNetwork = null;
 
     /**
      * Constructs a TestConnectivityManager.
@@ -86,28 +89,37 @@
         mDefaultRequest = defaultRequest;
     }
 
+    class NetworkRequestInfo {
+        public final NetworkRequest request;
+        public final Handler handler;
+        NetworkRequestInfo(NetworkRequest r, Handler h) {
+            request = r;
+            handler = h;
+        }
+    }
+
     boolean hasNoCallbacks() {
-        return allCallbacks.isEmpty()
-                && trackingDefault.isEmpty()
-                && listening.isEmpty()
-                && requested.isEmpty()
-                && legacyTypeMap.isEmpty();
+        return mAllCallbacks.isEmpty()
+                && mTrackingDefault.isEmpty()
+                && mListening.isEmpty()
+                && mRequested.isEmpty()
+                && mLegacyTypeMap.isEmpty();
     }
 
     boolean onlyHasDefaultCallbacks() {
-        return (allCallbacks.size() == 1)
-                && (trackingDefault.size() == 1)
-                && listening.isEmpty()
-                && requested.isEmpty()
-                && legacyTypeMap.isEmpty();
+        return (mAllCallbacks.size() == 1)
+                && (mTrackingDefault.size() == 1)
+                && mListening.isEmpty()
+                && mRequested.isEmpty()
+                && mLegacyTypeMap.isEmpty();
     }
 
     boolean isListeningForAll() {
         final NetworkCapabilities empty = new NetworkCapabilities();
         empty.clearAll();
 
-        for (NetworkRequest req : listening.values()) {
-            if (req.networkCapabilities.equalRequestableCapabilities(empty)) {
+        for (NetworkRequestInfo nri : mListening.values()) {
+            if (nri.request.networkCapabilities.equalRequestableCapabilities(empty)) {
                 return true;
             }
         }
@@ -118,40 +130,67 @@
         return ++mNetworkId;
     }
 
-    void makeDefaultNetwork(TestNetworkAgent agent) {
-        if (Objects.equals(defaultNetwork, agent)) return;
-
-        final TestNetworkAgent formerDefault = defaultNetwork;
-        defaultNetwork = agent;
-
+    private void sendDefaultNetworkBroadcasts(TestNetworkAgent formerDefault,
+            TestNetworkAgent defaultNetwork) {
         if (formerDefault != null) {
             sendConnectivityAction(formerDefault.legacyType, false /* connected */);
         }
         if (defaultNetwork != null) {
             sendConnectivityAction(defaultNetwork.legacyType, true /* connected */);
         }
+    }
 
-        for (NetworkCallback cb : trackingDefault) {
+    private void sendDefaultNetworkCallbacks(TestNetworkAgent formerDefault,
+            TestNetworkAgent defaultNetwork) {
+        for (NetworkCallback cb : mTrackingDefault.keySet()) {
+            final NetworkRequestInfo nri = mTrackingDefault.get(cb);
             if (defaultNetwork != null) {
-                cb.onAvailable(defaultNetwork.networkId);
-                cb.onCapabilitiesChanged(
-                        defaultNetwork.networkId, defaultNetwork.networkCapabilities);
-                cb.onLinkPropertiesChanged(
-                        defaultNetwork.networkId, defaultNetwork.linkProperties);
+                nri.handler.post(() -> cb.onAvailable(defaultNetwork.networkId));
+                nri.handler.post(() -> cb.onCapabilitiesChanged(
+                        defaultNetwork.networkId, defaultNetwork.networkCapabilities));
+                nri.handler.post(() -> cb.onLinkPropertiesChanged(
+                        defaultNetwork.networkId, defaultNetwork.linkProperties));
+            } else if (formerDefault != null) {
+                nri.handler.post(() -> cb.onLost(formerDefault.networkId));
             }
         }
     }
 
+    void makeDefaultNetwork(TestNetworkAgent agent, boolean order, @Nullable Runnable inBetween) {
+        if (Objects.equals(mDefaultNetwork, agent)) return;
+
+        final TestNetworkAgent formerDefault = mDefaultNetwork;
+        mDefaultNetwork = agent;
+
+        if (order == CALLBACKS_FIRST) {
+            sendDefaultNetworkCallbacks(formerDefault, mDefaultNetwork);
+            if (inBetween != null) inBetween.run();
+            sendDefaultNetworkBroadcasts(formerDefault, mDefaultNetwork);
+        } else {
+            sendDefaultNetworkBroadcasts(formerDefault, mDefaultNetwork);
+            if (inBetween != null) inBetween.run();
+            sendDefaultNetworkCallbacks(formerDefault, mDefaultNetwork);
+        }
+    }
+
+    void makeDefaultNetwork(TestNetworkAgent agent, boolean order) {
+        makeDefaultNetwork(agent, order, null /* inBetween */);
+    }
+
+    void makeDefaultNetwork(TestNetworkAgent agent) {
+        makeDefaultNetwork(agent, BROADCAST_FIRST, null /* inBetween */);
+    }
+
     @Override
     public void requestNetwork(NetworkRequest req, NetworkCallback cb, Handler h) {
-        assertFalse(allCallbacks.containsKey(cb));
-        allCallbacks.put(cb, h);
+        assertFalse(mAllCallbacks.containsKey(cb));
+        mAllCallbacks.put(cb, new NetworkRequestInfo(req, h));
         if (mDefaultRequest.equals(req)) {
-            assertFalse(trackingDefault.contains(cb));
-            trackingDefault.add(cb);
+            assertFalse(mTrackingDefault.containsKey(cb));
+            mTrackingDefault.put(cb, new NetworkRequestInfo(req, h));
         } else {
-            assertFalse(requested.containsKey(cb));
-            requested.put(cb, req);
+            assertFalse(mRequested.containsKey(cb));
+            mRequested.put(cb, new NetworkRequestInfo(req, h));
         }
     }
 
@@ -163,22 +202,22 @@
     @Override
     public void requestNetwork(NetworkRequest req,
             int timeoutMs, int legacyType, Handler h, NetworkCallback cb) {
-        assertFalse(allCallbacks.containsKey(cb));
-        allCallbacks.put(cb, h);
-        assertFalse(requested.containsKey(cb));
-        requested.put(cb, req);
-        assertFalse(legacyTypeMap.containsKey(cb));
+        assertFalse(mAllCallbacks.containsKey(cb));
+        mAllCallbacks.put(cb, new NetworkRequestInfo(req, h));
+        assertFalse(mRequested.containsKey(cb));
+        mRequested.put(cb, new NetworkRequestInfo(req, h));
+        assertFalse(mLegacyTypeMap.containsKey(cb));
         if (legacyType != ConnectivityManager.TYPE_NONE) {
-            legacyTypeMap.put(cb, legacyType);
+            mLegacyTypeMap.put(cb, legacyType);
         }
     }
 
     @Override
     public void registerNetworkCallback(NetworkRequest req, NetworkCallback cb, Handler h) {
-        assertFalse(allCallbacks.containsKey(cb));
-        allCallbacks.put(cb, h);
-        assertFalse(listening.containsKey(cb));
-        listening.put(cb, req);
+        assertFalse(mAllCallbacks.containsKey(cb));
+        mAllCallbacks.put(cb, new NetworkRequestInfo(req, h));
+        assertFalse(mListening.containsKey(cb));
+        mListening.put(cb, new NetworkRequestInfo(req, h));
     }
 
     @Override
@@ -198,22 +237,22 @@
 
     @Override
     public void unregisterNetworkCallback(NetworkCallback cb) {
-        if (trackingDefault.contains(cb)) {
-            trackingDefault.remove(cb);
-        } else if (listening.containsKey(cb)) {
-            listening.remove(cb);
-        } else if (requested.containsKey(cb)) {
-            requested.remove(cb);
-            legacyTypeMap.remove(cb);
+        if (mTrackingDefault.containsKey(cb)) {
+            mTrackingDefault.remove(cb);
+        } else if (mListening.containsKey(cb)) {
+            mListening.remove(cb);
+        } else if (mRequested.containsKey(cb)) {
+            mRequested.remove(cb);
+            mLegacyTypeMap.remove(cb);
         } else {
             fail("Unexpected callback removed");
         }
-        allCallbacks.remove(cb);
+        mAllCallbacks.remove(cb);
 
-        assertFalse(allCallbacks.containsKey(cb));
-        assertFalse(trackingDefault.contains(cb));
-        assertFalse(listening.containsKey(cb));
-        assertFalse(requested.containsKey(cb));
+        assertFalse(mAllCallbacks.containsKey(cb));
+        assertFalse(mTrackingDefault.containsKey(cb));
+        assertFalse(mListening.containsKey(cb));
+        assertFalse(mRequested.containsKey(cb));
     }
 
     private void sendConnectivityAction(int type, boolean connected) {
@@ -275,34 +314,38 @@
         }
 
         public void fakeConnect() {
-            for (NetworkRequest request : cm.requested.values()) {
-                if (matchesLegacyType(request.legacyType)) {
+            for (NetworkRequestInfo nri : cm.mRequested.values()) {
+                if (matchesLegacyType(nri.request.legacyType)) {
                     cm.sendConnectivityAction(legacyType, true /* connected */);
                     // In practice, a given network can match only one legacy type.
                     break;
                 }
             }
-            for (NetworkCallback cb : cm.listening.keySet()) {
-                cb.onAvailable(networkId);
-                cb.onCapabilitiesChanged(networkId, copy(networkCapabilities));
-                cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
+            for (NetworkCallback cb : cm.mListening.keySet()) {
+                final NetworkRequestInfo nri = cm.mListening.get(cb);
+                nri.handler.post(() -> cb.onAvailable(networkId));
+                nri.handler.post(() -> cb.onCapabilitiesChanged(
+                        networkId, copy(networkCapabilities)));
+                nri.handler.post(() -> cb.onLinkPropertiesChanged(networkId, copy(linkProperties)));
             }
+            // mTrackingDefault will be updated if/when the caller calls makeDefaultNetwork
         }
 
         public void fakeDisconnect() {
-            for (NetworkRequest request : cm.requested.values()) {
-                if (matchesLegacyType(request.legacyType)) {
+            for (NetworkRequestInfo nri : cm.mRequested.values()) {
+                if (matchesLegacyType(nri.request.legacyType)) {
                     cm.sendConnectivityAction(legacyType, false /* connected */);
                     break;
                 }
             }
-            for (NetworkCallback cb : cm.listening.keySet()) {
+            for (NetworkCallback cb : cm.mListening.keySet()) {
                 cb.onLost(networkId);
             }
+            // mTrackingDefault will be updated if/when the caller calls makeDefaultNetwork
         }
 
         public void sendLinkProperties() {
-            for (NetworkCallback cb : cm.listening.keySet()) {
+            for (NetworkCallback cb : cm.mListening.keySet()) {
                 cb.onLinkPropertiesChanged(networkId, copy(linkProperties));
             }
         }
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index 0611086..e042df4 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -59,6 +59,8 @@
 
 import static com.android.net.module.util.Inet4AddressUtils.inet4AddressToIntHTH;
 import static com.android.net.module.util.Inet4AddressUtils.intToInet4AddressHTH;
+import static com.android.networkstack.tethering.TestConnectivityManager.BROADCAST_FIRST;
+import static com.android.networkstack.tethering.TestConnectivityManager.CALLBACKS_FIRST;
 import static com.android.networkstack.tethering.Tethering.UserRestrictionActionListener;
 import static com.android.networkstack.tethering.TetheringNotificationUpdater.DOWNSTREAM_NONE;
 import static com.android.networkstack.tethering.UpstreamNetworkMonitor.EVENT_ON_CAPABILITIES;
@@ -332,13 +334,14 @@
             assertTrue("Non-mocked interface " + ifName,
                     ifName.equals(TEST_USB_IFNAME)
                             || ifName.equals(TEST_WLAN_IFNAME)
+                            || ifName.equals(TEST_WIFI_IFNAME)
                             || ifName.equals(TEST_MOBILE_IFNAME)
                             || ifName.equals(TEST_P2P_IFNAME)
                             || ifName.equals(TEST_NCM_IFNAME)
                             || ifName.equals(TEST_ETH_IFNAME));
             final String[] ifaces = new String[] {
-                    TEST_USB_IFNAME, TEST_WLAN_IFNAME, TEST_MOBILE_IFNAME, TEST_P2P_IFNAME,
-                    TEST_NCM_IFNAME, TEST_ETH_IFNAME};
+                    TEST_USB_IFNAME, TEST_WLAN_IFNAME, TEST_WIFI_IFNAME, TEST_MOBILE_IFNAME,
+                    TEST_P2P_IFNAME, TEST_NCM_IFNAME, TEST_ETH_IFNAME};
             return new InterfaceParams(ifName, ArrayUtils.indexOf(ifaces, ifName) + IFINDEX_OFFSET,
                     MacAddress.ALL_ZEROS_ADDRESS);
         }
@@ -618,6 +621,7 @@
         when(mOffloadHardwareInterface.getForwardedStats(any())).thenReturn(mForwardedStats);
 
         mServiceContext = new TestContext(mContext);
+        mServiceContext.setUseRegisteredHandlers(true);
         mContentResolver = new MockContentResolver(mServiceContext);
         mContentResolver.addProvider(Settings.AUTHORITY, new FakeSettingsProvider());
         setTetheringSupported(true /* supported */);
@@ -716,6 +720,7 @@
         final Intent intent = new Intent(WifiManager.WIFI_AP_STATE_CHANGED_ACTION);
         intent.putExtra(EXTRA_WIFI_AP_STATE, state);
         mServiceContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
+        mLooper.dispatchAll();
     }
 
     private void sendWifiApStateChanged(int state, String ifname, int ipmode) {
@@ -724,6 +729,7 @@
         intent.putExtra(EXTRA_WIFI_AP_INTERFACE_NAME, ifname);
         intent.putExtra(EXTRA_WIFI_AP_MODE, ipmode);
         mServiceContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
+        mLooper.dispatchAll();
     }
 
     private static final String[] P2P_RECEIVER_PERMISSIONS_FOR_BROADCAST = {
@@ -750,6 +756,7 @@
 
         mServiceContext.sendBroadcastAsUserMultiplePermissions(intent, UserHandle.ALL,
                 P2P_RECEIVER_PERMISSIONS_FOR_BROADCAST);
+        mLooper.dispatchAll();
     }
 
     private void sendUsbBroadcast(boolean connected, boolean configured, boolean function,
@@ -763,11 +770,13 @@
             intent.putExtra(USB_FUNCTION_NCM, function);
         }
         mServiceContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
+        mLooper.dispatchAll();
     }
 
     private void sendConfigurationChanged() {
         final Intent intent = new Intent(Intent.ACTION_CONFIGURATION_CHANGED);
         mServiceContext.sendStickyBroadcastAsUser(intent, UserHandle.ALL);
+        mLooper.dispatchAll();
     }
 
     private void verifyDefaultNetworkRequestFiled() {
@@ -809,7 +818,6 @@
             mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
         }
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED);
-        mLooper.dispatchAll();
 
         // If, and only if, Tethering received an interface status changed then
         // it creates a IpServer and sends out a broadcast indicating that the
@@ -857,7 +865,6 @@
 
         // Pretend we then receive USB configured broadcast.
         sendUsbBroadcast(true, true, true, TETHERING_USB);
-        mLooper.dispatchAll();
         // Now we should see the start of tethering mechanics (in this case:
         // tetherMatchingInterfaces() which starts by fetching all interfaces).
         verify(mNetd, times(1)).interfaceGetList();
@@ -886,7 +893,6 @@
             mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
         }
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_LOCAL_ONLY);
-        mLooper.dispatchAll();
 
         verifyInterfaceServingModeStarted(TEST_WLAN_IFNAME);
         verifyTetheringBroadcast(TEST_WLAN_IFNAME, EXTRA_AVAILABLE_TETHER);
@@ -948,7 +954,6 @@
         initTetheringUpstream(upstreamState);
         prepareUsbTethering();
         sendUsbBroadcast(true, true, true, TETHERING_USB);
-        mLooper.dispatchAll();
     }
 
     private void assertSetIfaceToDadProxy(final int numOfCalls, final String ifaceName) {
@@ -1099,29 +1104,44 @@
         // Start USB tethering with no current upstream.
         prepareUsbTethering();
         sendUsbBroadcast(true, true, true, TETHERING_USB);
-        mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).startObserveAllNetworks();
-        inOrder.verify(mUpstreamNetworkMonitor).registerMobileNetworkRequest();
+        inOrder.verify(mUpstreamNetworkMonitor).setTryCell(true);
 
         // Pretend cellular connected and expect the upstream to be set.
         TestNetworkAgent mobile = new TestNetworkAgent(mCm, buildMobileDualStackUpstreamState());
         mobile.fakeConnect();
-        mCm.makeDefaultNetwork(mobile);
+        mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
 
         // Switch upstreams a few times.
-        // TODO: there may be a race where if the effects of the CONNECTIVITY_ACTION happen before
-        // UpstreamNetworkMonitor gets onCapabilitiesChanged on CALLBACK_DEFAULT_INTERNET, the
-        // upstream does not change. Extend TestConnectivityManager to simulate this condition and
-        // write a test for this.
         TestNetworkAgent wifi = new TestNetworkAgent(mCm, buildWifiUpstreamState());
         wifi.fakeConnect();
-        mCm.makeDefaultNetwork(wifi);
+        mCm.makeDefaultNetwork(wifi, BROADCAST_FIRST);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
 
-        mCm.makeDefaultNetwork(mobile);
+        // This code has historically been racy, so test different orderings of CONNECTIVITY_ACTION
+        // broadcasts and callbacks, and add mLooper.dispatchAll() calls between the two.
+        final Runnable doDispatchAll = () -> mLooper.dispatchAll();
+
+        mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST, doDispatchAll);
+        mLooper.dispatchAll();
+        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+
+        mCm.makeDefaultNetwork(wifi, BROADCAST_FIRST, doDispatchAll);
+        mLooper.dispatchAll();
+        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
+
+        mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST);
+        mLooper.dispatchAll();
+        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+
+        mCm.makeDefaultNetwork(wifi, CALLBACKS_FIRST);
+        mLooper.dispatchAll();
+        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
+
+        mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
 
@@ -1133,14 +1153,15 @@
         // Lose and regain upstream.
         assertTrue(mUpstreamNetworkMonitor.getCurrentPreferredUpstream().linkProperties
                 .hasIPv4Address());
+        mCm.makeDefaultNetwork(null, BROADCAST_FIRST, doDispatchAll);
+        mLooper.dispatchAll();
         mobile.fakeDisconnect();
-        mCm.makeDefaultNetwork(null);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(null);
 
         mobile = new TestNetworkAgent(mCm, buildMobile464xlatUpstreamState());
         mobile.fakeConnect();
-        mCm.makeDefaultNetwork(mobile);
+        mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST, doDispatchAll);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
 
@@ -1148,16 +1169,31 @@
         // mobile upstream, even though the netId is (unrealistically) the same.
         assertFalse(mUpstreamNetworkMonitor.getCurrentPreferredUpstream().linkProperties
                 .hasIPv4Address());
+
+        // Lose and regain upstream again.
+        mCm.makeDefaultNetwork(null, CALLBACKS_FIRST, doDispatchAll);
         mobile.fakeDisconnect();
-        mCm.makeDefaultNetwork(null);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(null);
+
+        mobile = new TestNetworkAgent(mCm, buildMobileDualStackUpstreamState());
+        mobile.fakeConnect();
+        mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST, doDispatchAll);
+        mLooper.dispatchAll();
+        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+
+        assertTrue(mUpstreamNetworkMonitor.getCurrentPreferredUpstream().linkProperties
+                .hasIPv4Address());
+
+        // Check that the code does not crash if onLinkPropertiesChanged is received after onLost.
+        mobile.fakeDisconnect();
+        mobile.sendLinkProperties();
+        mLooper.dispatchAll();
     }
 
     private void runNcmTethering() {
         prepareNcmTethering();
         sendUsbBroadcast(true, true, true, TETHERING_NCM);
-        mLooper.dispatchAll();
     }
 
     @Test
@@ -1205,7 +1241,6 @@
         // tethering mode is to be started.
         mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED);
-        mLooper.dispatchAll();
 
         // There is 1 IpServer state change event: STATE_AVAILABLE
         verify(mNotificationUpdater, times(1)).onDownstreamChanged(DOWNSTREAM_NONE);
@@ -1233,7 +1268,6 @@
         // tethering mode is to be started.
         mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
-        mLooper.dispatchAll();
 
         verifyInterfaceServingModeStarted(TEST_WLAN_IFNAME);
         verifyTetheringBroadcast(TEST_WLAN_IFNAME, EXTRA_AVAILABLE_TETHER);
@@ -1251,7 +1285,7 @@
         verify(mUpstreamNetworkMonitor, times(1)).startObserveAllNetworks();
         // In tethering mode, in the default configuration, an explicit request
         // for a mobile network is also made.
-        verify(mUpstreamNetworkMonitor, times(1)).registerMobileNetworkRequest();
+        verify(mUpstreamNetworkMonitor, times(1)).setTryCell(true);
         // There are 2 IpServer state change events: STATE_AVAILABLE -> STATE_TETHERED
         verify(mNotificationUpdater, times(1)).onDownstreamChanged(DOWNSTREAM_NONE);
         verify(mNotificationUpdater, times(1)).onDownstreamChanged(eq(1 << TETHERING_WIFI));
@@ -1310,7 +1344,6 @@
         // tethering mode is to be started.
         mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
-        mLooper.dispatchAll();
 
         // We verify get/set called three times here: twice for setup and once during
         // teardown because all events happen over the course of the single
@@ -1634,7 +1667,6 @@
 
         mTethering.startTethering(createTetheringRequestParcel(TETHERING_WIFI), null);
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
-        mLooper.dispatchAll();
         tetherState = callback.pollTetherStatesChanged();
         assertArrayEquals(tetherState.tetheredList, new String[] {TEST_WLAN_IFNAME});
         callback.expectUpstreamChanged(upstreamState.network);
@@ -1656,7 +1688,6 @@
         mLooper.dispatchAll();
         mTethering.stopTethering(TETHERING_WIFI);
         sendWifiApStateChanged(WifiManager.WIFI_AP_STATE_DISABLED);
-        mLooper.dispatchAll();
         tetherState = callback2.pollTetherStatesChanged();
         assertArrayEquals(tetherState.availableList, new String[] {TEST_WLAN_IFNAME});
         mLooper.dispatchAll();
@@ -1749,7 +1780,6 @@
             mTethering.interfaceStatusChanged(TEST_P2P_IFNAME, true);
         }
         sendWifiP2pConnectionChanged(true, true, TEST_P2P_IFNAME);
-        mLooper.dispatchAll();
 
         verifyInterfaceServingModeStarted(TEST_P2P_IFNAME);
         verifyTetheringBroadcast(TEST_P2P_IFNAME, EXTRA_AVAILABLE_TETHER);
@@ -1767,7 +1797,6 @@
         // is being removed.
         sendWifiP2pConnectionChanged(false, true, TEST_P2P_IFNAME);
         mTethering.interfaceRemoved(TEST_P2P_IFNAME);
-        mLooper.dispatchAll();
 
         verify(mNetd, times(1)).tetherApplyDnsInterfaces();
         verify(mNetd, times(1)).tetherInterfaceRemove(TEST_P2P_IFNAME);
@@ -1790,7 +1819,6 @@
             mTethering.interfaceStatusChanged(TEST_P2P_IFNAME, true);
         }
         sendWifiP2pConnectionChanged(true, false, TEST_P2P_IFNAME);
-        mLooper.dispatchAll();
 
         verify(mNetd, never()).interfaceSetCfg(any(InterfaceConfigurationParcel.class));
         verify(mNetd, never()).tetherInterfaceAdd(TEST_P2P_IFNAME);
@@ -1802,7 +1830,6 @@
         // is being removed.
         sendWifiP2pConnectionChanged(false, false, TEST_P2P_IFNAME);
         mTethering.interfaceRemoved(TEST_P2P_IFNAME);
-        mLooper.dispatchAll();
 
         verify(mNetd, never()).tetherApplyDnsInterfaces();
         verify(mNetd, never()).tetherInterfaceRemove(TEST_P2P_IFNAME);
@@ -1838,7 +1865,6 @@
             mTethering.interfaceStatusChanged(TEST_P2P_IFNAME, true);
         }
         sendWifiP2pConnectionChanged(true, true, TEST_P2P_IFNAME);
-        mLooper.dispatchAll();
 
         verify(mNetd, never()).interfaceSetCfg(any(InterfaceConfigurationParcel.class));
         verify(mNetd, never()).tetherInterfaceAdd(TEST_P2P_IFNAME);
@@ -1968,7 +1994,6 @@
         // Expect that when USB comes up, the DHCP server is configured with the requested address.
         mTethering.interfaceStatusChanged(TEST_USB_IFNAME, true);
         sendUsbBroadcast(true, true, true, TETHERING_USB);
-        mLooper.dispatchAll();
         verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS).times(1)).startWithCallbacks(
                 any(), any());
         verify(mNetd).interfaceSetCfg(argThat(cfg -> serverAddr.equals(cfg.ipv4Addr)));
@@ -1988,7 +2013,6 @@
         verify(mUsbManager, times(1)).setCurrentFunctions(UsbManager.FUNCTION_RNDIS);
         mTethering.interfaceStatusChanged(TEST_USB_IFNAME, true);
         sendUsbBroadcast(true, true, true, TETHERING_USB);
-        mLooper.dispatchAll();
         verify(mNetd).interfaceSetCfg(argThat(cfg -> serverAddr.equals(cfg.ipv4Addr)));
         verify(mIpServerDependencies, times(1)).makeDhcpServer(any(), dhcpParamsCaptor.capture(),
                 any());
@@ -2212,7 +2236,6 @@
 
         mTethering.interfaceStatusChanged(TEST_USB_IFNAME, true);
         sendUsbBroadcast(true, true, true, TETHERING_USB);
-        mLooper.dispatchAll();
         assertContains(Arrays.asList(mTethering.getTetherableIfaces()), TEST_USB_IFNAME);
         assertContains(Arrays.asList(mTethering.getTetherableIfaces()), TEST_ETH_IFNAME);
         assertEquals(TETHER_ERROR_IFACE_CFG_ERROR, mTethering.getLastTetherError(TEST_USB_IFNAME));
@@ -2251,7 +2274,6 @@
         // Run local only tethering.
         mTethering.interfaceStatusChanged(TEST_P2P_IFNAME, true);
         sendWifiP2pConnectionChanged(true, true, TEST_P2P_IFNAME);
-        mLooper.dispatchAll();
         verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS)).startWithCallbacks(
                 any(), dhcpEventCbsCaptor.capture());
         eventCallbacks = dhcpEventCbsCaptor.getValue();
@@ -2268,7 +2290,6 @@
         // Run wifi tethering.
         mTethering.interfaceStatusChanged(TEST_WLAN_IFNAME, true);
         sendWifiApStateChanged(WIFI_AP_STATE_ENABLED, TEST_WLAN_IFNAME, IFACE_IP_MODE_TETHERED);
-        mLooper.dispatchAll();
         verify(mDhcpServer, timeout(DHCPSERVER_START_TIMEOUT_MS)).startWithCallbacks(
                 any(), dhcpEventCbsCaptor.capture());
         eventCallbacks = dhcpEventCbsCaptor.getValue();
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java
index 7d735fc..bc21692 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/UpstreamNetworkMonitorTest.java
@@ -41,7 +41,6 @@
 import static org.mockito.Mockito.when;
 
 import android.content.Context;
-import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.IConnectivityManager;
 import android.net.IpPrefix;
@@ -51,13 +50,16 @@
 import android.net.NetworkRequest;
 import android.net.util.SharedLog;
 import android.os.Handler;
+import android.os.Looper;
 import android.os.Message;
+import android.os.test.TestLooper;
 
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.internal.util.State;
 import com.android.internal.util.StateMachine;
+import com.android.networkstack.tethering.TestConnectivityManager.NetworkRequestInfo;
 import com.android.networkstack.tethering.TestConnectivityManager.TestNetworkAgent;
 
 import org.junit.After;
@@ -101,6 +103,8 @@
     private TestConnectivityManager mCM;
     private UpstreamNetworkMonitor mUNM;
 
+    private final TestLooper mLooper = new TestLooper();
+
     @Before public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
         reset(mContext);
@@ -110,9 +114,8 @@
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(true);
 
         mCM = spy(new TestConnectivityManager(mContext, mCS, sDefaultRequest));
-        mSM = new TestStateMachine();
-        mUNM = new UpstreamNetworkMonitor(
-                (ConnectivityManager) mCM, mSM, mLog, EVENT_UNM_UPDATE);
+        mSM = new TestStateMachine(mLooper.getLooper());
+        mUNM = new UpstreamNetworkMonitor(mCM, mSM, mLog, EVENT_UNM_UPDATE);
     }
 
     @After public void tearDown() throws Exception {
@@ -134,9 +137,9 @@
         assertTrue(mCM.hasNoCallbacks());
         assertFalse(mUNM.mobileNetworkRequested());
 
-        mUNM.updateMobileRequiresDun(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
         assertTrue(mCM.hasNoCallbacks());
-        mUNM.updateMobileRequiresDun(false);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, false /* dunRequired */);
         assertTrue(mCM.hasNoCallbacks());
     }
 
@@ -146,7 +149,7 @@
         mUNM.startTrackDefaultNetwork(sDefaultRequest, mEntitleMgr);
 
         mUNM.startObserveAllNetworks();
-        assertEquals(1, mCM.trackingDefault.size());
+        assertEquals(1, mCM.mTrackingDefault.size());
 
         mUNM.stop();
         assertTrue(mCM.onlyHasDefaultCallbacks());
@@ -154,11 +157,11 @@
 
     @Test
     public void testListensForAllNetworks() throws Exception {
-        assertTrue(mCM.listening.isEmpty());
+        assertTrue(mCM.mListening.isEmpty());
 
         mUNM.startTrackDefaultNetwork(sDefaultRequest, mEntitleMgr);
         mUNM.startObserveAllNetworks();
-        assertFalse(mCM.listening.isEmpty());
+        assertFalse(mCM.mListening.isEmpty());
         assertTrue(mCM.isListeningForAll());
 
         mUNM.stop();
@@ -181,17 +184,17 @@
     @Test
     public void testRequestsMobileNetwork() throws Exception {
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
         mUNM.startObserveAllNetworks();
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
-        mUNM.updateMobileRequiresDun(false);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, false /* dunRequired */);
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
-        mUNM.registerMobileNetworkRequest();
+        mUNM.setTryCell(true);
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
         assertFalse(isDunRequested());
@@ -204,16 +207,16 @@
     @Test
     public void testDuplicateMobileRequestsIgnored() throws Exception {
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
         mUNM.startObserveAllNetworks();
         verify(mCM, times(1)).registerNetworkCallback(
                 any(NetworkRequest.class), any(NetworkCallback.class), any(Handler.class));
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
-        mUNM.updateMobileRequiresDun(true);
-        mUNM.registerMobileNetworkRequest();
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
+        mUNM.setTryCell(true);
         verify(mCM, times(1)).requestNetwork(
                 any(NetworkRequest.class), anyInt(), anyInt(), any(Handler.class),
                 any(NetworkCallback.class));
@@ -223,9 +226,9 @@
         assertTrue(isDunRequested());
 
         // Try a few things that must not result in any state change.
-        mUNM.registerMobileNetworkRequest();
-        mUNM.updateMobileRequiresDun(true);
-        mUNM.registerMobileNetworkRequest();
+        mUNM.setTryCell(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
+        mUNM.setTryCell(true);
 
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
@@ -240,17 +243,17 @@
     @Test
     public void testRequestsDunNetwork() throws Exception {
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
         mUNM.startObserveAllNetworks();
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
-        mUNM.updateMobileRequiresDun(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
         assertFalse(mUNM.mobileNetworkRequested());
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
-        mUNM.registerMobileNetworkRequest();
+        mUNM.setTryCell(true);
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
         assertTrue(isDunRequested());
@@ -265,18 +268,18 @@
         mUNM.startObserveAllNetworks();
 
         // Test going from no-DUN to DUN correctly re-registers callbacks.
-        mUNM.updateMobileRequiresDun(false);
-        mUNM.registerMobileNetworkRequest();
+        mUNM.setUpstreamConfig(false /* autoUpstream */, false /* dunRequired */);
+        mUNM.setTryCell(true);
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
         assertFalse(isDunRequested());
-        mUNM.updateMobileRequiresDun(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_DUN);
         assertTrue(isDunRequested());
 
         // Test going from DUN to no-DUN correctly re-registers callbacks.
-        mUNM.updateMobileRequiresDun(false);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, false /* dunRequired */);
         assertTrue(mUNM.mobileNetworkRequested());
         assertUpstreamTypeRequested(TYPE_MOBILE_HIPRI);
         assertFalse(isDunRequested());
@@ -297,72 +300,78 @@
 
         final TestNetworkAgent wifiAgent = new TestNetworkAgent(mCM, WIFI_CAPABILITIES);
         wifiAgent.fakeConnect();
+        mLooper.dispatchAll();
         // WiFi is up, we should prefer it.
         assertSatisfiesLegacyType(TYPE_WIFI, mUNM.selectPreferredUpstreamType(preferredTypes));
         wifiAgent.fakeDisconnect();
+        mLooper.dispatchAll();
         // There are no networks, so there is nothing to select.
         assertSatisfiesLegacyType(TYPE_NONE, mUNM.selectPreferredUpstreamType(preferredTypes));
 
         final TestNetworkAgent cellAgent = new TestNetworkAgent(mCM, CELL_CAPABILITIES);
         cellAgent.fakeConnect();
+        mLooper.dispatchAll();
         assertSatisfiesLegacyType(TYPE_NONE, mUNM.selectPreferredUpstreamType(preferredTypes));
 
         preferredTypes.add(TYPE_MOBILE_DUN);
         // This is coupled with preferred types in TetheringConfiguration.
-        mUNM.updateMobileRequiresDun(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
         // DUN is available, but only use regular cell: no upstream selected.
         assertSatisfiesLegacyType(TYPE_NONE, mUNM.selectPreferredUpstreamType(preferredTypes));
         preferredTypes.remove(TYPE_MOBILE_DUN);
         // No WiFi, but our preferred flavour of cell is up.
         preferredTypes.add(TYPE_MOBILE_HIPRI);
         // This is coupled with preferred types in TetheringConfiguration.
-        mUNM.updateMobileRequiresDun(false);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, false /* dunRequired */);
         assertSatisfiesLegacyType(TYPE_MOBILE_HIPRI,
                 mUNM.selectPreferredUpstreamType(preferredTypes));
         // Check to see we filed an explicit request.
-        assertEquals(1, mCM.requested.size());
-        NetworkRequest netReq = (NetworkRequest) mCM.requested.values().toArray()[0];
+        assertEquals(1, mCM.mRequested.size());
+        NetworkRequest netReq = ((NetworkRequestInfo) mCM.mRequested.values().toArray()[0]).request;
         assertTrue(netReq.networkCapabilities.hasTransport(TRANSPORT_CELLULAR));
         assertFalse(netReq.networkCapabilities.hasCapability(NET_CAPABILITY_DUN));
         // mobile is not permitted, we should not use HIPRI.
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(false);
         assertSatisfiesLegacyType(TYPE_NONE, mUNM.selectPreferredUpstreamType(preferredTypes));
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
         // mobile change back to permitted, HIRPI should come back
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(true);
         assertSatisfiesLegacyType(TYPE_MOBILE_HIPRI,
                 mUNM.selectPreferredUpstreamType(preferredTypes));
 
         wifiAgent.fakeConnect();
+        mLooper.dispatchAll();
         // WiFi is up, and we should prefer it over cell.
         assertSatisfiesLegacyType(TYPE_WIFI, mUNM.selectPreferredUpstreamType(preferredTypes));
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
 
         preferredTypes.remove(TYPE_MOBILE_HIPRI);
         preferredTypes.add(TYPE_MOBILE_DUN);
         // This is coupled with preferred types in TetheringConfiguration.
-        mUNM.updateMobileRequiresDun(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
         assertSatisfiesLegacyType(TYPE_WIFI, mUNM.selectPreferredUpstreamType(preferredTypes));
 
         final TestNetworkAgent dunAgent = new TestNetworkAgent(mCM, DUN_CAPABILITIES);
         dunAgent.fakeConnect();
+        mLooper.dispatchAll();
 
         // WiFi is still preferred.
         assertSatisfiesLegacyType(TYPE_WIFI, mUNM.selectPreferredUpstreamType(preferredTypes));
 
         // WiFi goes down, cell and DUN are still up but only DUN is preferred.
         wifiAgent.fakeDisconnect();
+        mLooper.dispatchAll();
         assertSatisfiesLegacyType(TYPE_MOBILE_DUN,
                 mUNM.selectPreferredUpstreamType(preferredTypes));
         // Check to see we filed an explicit request.
-        assertEquals(1, mCM.requested.size());
-        netReq = (NetworkRequest) mCM.requested.values().toArray()[0];
+        assertEquals(1, mCM.mRequested.size());
+        netReq = ((NetworkRequestInfo) mCM.mRequested.values().toArray()[0]).request;
         assertTrue(netReq.networkCapabilities.hasTransport(TRANSPORT_CELLULAR));
         assertTrue(netReq.networkCapabilities.hasCapability(NET_CAPABILITY_DUN));
         // mobile is not permitted, we should not use DUN.
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(false);
         assertSatisfiesLegacyType(TYPE_NONE, mUNM.selectPreferredUpstreamType(preferredTypes));
-        assertEquals(0, mCM.requested.size());
+        assertEquals(0, mCM.mRequested.size());
         // mobile change back to permitted, DUN should come back
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(true);
         assertSatisfiesLegacyType(TYPE_MOBILE_DUN,
@@ -373,49 +382,72 @@
     public void testGetCurrentPreferredUpstream() throws Exception {
         mUNM.startTrackDefaultNetwork(sDefaultRequest, mEntitleMgr);
         mUNM.startObserveAllNetworks();
-        mUNM.updateMobileRequiresDun(false);
+        mUNM.setUpstreamConfig(true /* autoUpstream */, false /* dunRequired */);
+        mUNM.setTryCell(true);
 
         // [0] Mobile connects, DUN not required -> mobile selected.
         final TestNetworkAgent cellAgent = new TestNetworkAgent(mCM, CELL_CAPABILITIES);
         cellAgent.fakeConnect();
         mCM.makeDefaultNetwork(cellAgent);
+        mLooper.dispatchAll();
         assertEquals(cellAgent.networkId, mUNM.getCurrentPreferredUpstream().network);
+        assertEquals(0, mCM.mRequested.size());
 
         // [1] Mobile connects but not permitted -> null selected
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(false);
         assertEquals(null, mUNM.getCurrentPreferredUpstream());
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(true);
+        assertEquals(0, mCM.mRequested.size());
 
         // [2] WiFi connects but not validated/promoted to default -> mobile selected.
         final TestNetworkAgent wifiAgent = new TestNetworkAgent(mCM, WIFI_CAPABILITIES);
         wifiAgent.fakeConnect();
+        mLooper.dispatchAll();
         assertEquals(cellAgent.networkId, mUNM.getCurrentPreferredUpstream().network);
+        assertEquals(0, mCM.mRequested.size());
 
         // [3] WiFi validates and is promoted to the default network -> WiFi selected.
         mCM.makeDefaultNetwork(wifiAgent);
+        mLooper.dispatchAll();
         assertEquals(wifiAgent.networkId, mUNM.getCurrentPreferredUpstream().network);
+        assertEquals(0, mCM.mRequested.size());
 
         // [4] DUN required, no other changes -> WiFi still selected
-        mUNM.updateMobileRequiresDun(true);
+        mUNM.setUpstreamConfig(false /* autoUpstream */, true /* dunRequired */);
         assertEquals(wifiAgent.networkId, mUNM.getCurrentPreferredUpstream().network);
+        assertEquals(1, mCM.mRequested.size());
+        assertTrue(isDunRequested());
 
         // [5] WiFi no longer validated, mobile becomes default, DUN required -> null selected.
         mCM.makeDefaultNetwork(cellAgent);
+        mLooper.dispatchAll();
         assertEquals(null, mUNM.getCurrentPreferredUpstream());
-        // TODO: make sure that a DUN request has been filed. This is currently
-        // triggered by code over in Tethering, but once that has been moved
-        // into UNM we should test for this here.
+        assertEquals(1, mCM.mRequested.size());
+        assertTrue(isDunRequested());
 
         // [6] DUN network arrives -> DUN selected
         final TestNetworkAgent dunAgent = new TestNetworkAgent(mCM, CELL_CAPABILITIES);
         dunAgent.networkCapabilities.addCapability(NET_CAPABILITY_DUN);
         dunAgent.networkCapabilities.removeCapability(NET_CAPABILITY_INTERNET);
         dunAgent.fakeConnect();
+        mLooper.dispatchAll();
         assertEquals(dunAgent.networkId, mUNM.getCurrentPreferredUpstream().network);
+        assertEquals(1, mCM.mRequested.size());
 
         // [7] Mobile is not permitted -> null selected
         when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(false);
         assertEquals(null, mUNM.getCurrentPreferredUpstream());
+        assertEquals(1, mCM.mRequested.size());
+
+        // [7] Mobile is permitted again -> DUN selected
+        when(mEntitleMgr.isCellularUpstreamPermitted()).thenReturn(true);
+        assertEquals(dunAgent.networkId, mUNM.getCurrentPreferredUpstream().network);
+        assertEquals(1, mCM.mRequested.size());
+
+        // [8] DUN no longer required -> request is withdrawn
+        mUNM.setUpstreamConfig(true /* autoUpstream */, false /* dunRequired */);
+        assertEquals(0, mCM.mRequested.size());
+        assertFalse(isDunRequested());
     }
 
     @Test
@@ -445,6 +477,7 @@
         }
         wifiAgent.fakeConnect();
         wifiAgent.sendLinkProperties();
+        mLooper.dispatchAll();
 
         local = mUNM.getLocalPrefixes();
         assertPrefixSet(local, INCLUDES, alreadySeen);
@@ -469,6 +502,7 @@
         }
         cellAgent.fakeConnect();
         cellAgent.sendLinkProperties();
+        mLooper.dispatchAll();
 
         local = mUNM.getLocalPrefixes();
         assertPrefixSet(local, INCLUDES, alreadySeen);
@@ -490,6 +524,7 @@
         }
         dunAgent.fakeConnect();
         dunAgent.sendLinkProperties();
+        mLooper.dispatchAll();
 
         local = mUNM.getLocalPrefixes();
         assertPrefixSet(local, INCLUDES, alreadySeen);
@@ -501,6 +536,7 @@
         // [4] Pretend Wi-Fi disconnected.  It's addresses/prefixes should no
         // longer be included (should be properly removed).
         wifiAgent.fakeDisconnect();
+        mLooper.dispatchAll();
         local = mUNM.getLocalPrefixes();
         assertPrefixSet(local, EXCLUDES, wifiLinkPrefixes);
         assertPrefixSet(local, INCLUDES, cellLinkPrefixes);
@@ -508,6 +544,7 @@
 
         // [5] Pretend mobile disconnected.
         cellAgent.fakeDisconnect();
+        mLooper.dispatchAll();
         local = mUNM.getLocalPrefixes();
         assertPrefixSet(local, EXCLUDES, wifiLinkPrefixes);
         assertPrefixSet(local, EXCLUDES, cellLinkPrefixes);
@@ -515,6 +552,7 @@
 
         // [6] Pretend DUN disconnected.
         dunAgent.fakeDisconnect();
+        mLooper.dispatchAll();
         local = mUNM.getLocalPrefixes();
         assertTrue(local.isEmpty());
     }
@@ -534,6 +572,7 @@
         // Setup mobile network.
         final TestNetworkAgent cellAgent = new TestNetworkAgent(mCM, CELL_CAPABILITIES);
         cellAgent.fakeConnect();
+        mLooper.dispatchAll();
 
         assertSatisfiesLegacyType(TYPE_MOBILE_HIPRI,
                 mUNM.selectPreferredUpstreamType(preferredTypes));
@@ -552,15 +591,15 @@
     }
 
     private void assertUpstreamTypeRequested(int upstreamType) throws Exception {
-        assertEquals(1, mCM.requested.size());
-        assertEquals(1, mCM.legacyTypeMap.size());
+        assertEquals(1, mCM.mRequested.size());
+        assertEquals(1, mCM.mLegacyTypeMap.size());
         assertEquals(Integer.valueOf(upstreamType),
-                mCM.legacyTypeMap.values().iterator().next());
+                mCM.mLegacyTypeMap.values().iterator().next());
     }
 
     private boolean isDunRequested() {
-        for (NetworkRequest req : mCM.requested.values()) {
-            if (req.networkCapabilities.hasCapability(NET_CAPABILITY_DUN)) {
+        for (NetworkRequestInfo nri : mCM.mRequested.values()) {
+            if (nri.request.networkCapabilities.hasCapability(NET_CAPABILITY_DUN)) {
                 return true;
             }
         }
@@ -586,8 +625,8 @@
             }
         }
 
-        public TestStateMachine() {
-            super("UpstreamNetworkMonitor.TestStateMachine");
+        public TestStateMachine(Looper looper) {
+            super("UpstreamNetworkMonitor.TestStateMachine", looper);
             addState(mLoggingState);
             setInitialState(mLoggingState);
             super.start();
diff --git a/tests/cts/hostside/aidl/com/android/cts/net/hostside/IMyService.aidl b/tests/cts/hostside/aidl/com/android/cts/net/hostside/IMyService.aidl
index 5aafdf0..f523745 100644
--- a/tests/cts/hostside/aidl/com/android/cts/net/hostside/IMyService.aidl
+++ b/tests/cts/hostside/aidl/com/android/cts/net/hostside/IMyService.aidl
@@ -24,6 +24,6 @@
     String checkNetworkStatus();
     String getRestrictBackgroundStatus();
     void sendNotification(int notificationId, String notificationType);
-    void registerNetworkCallback(in INetworkCallback cb);
+    void registerNetworkCallback(in NetworkRequest request, in INetworkCallback cb);
     void unregisterNetworkCallback();
 }
diff --git a/tests/cts/hostside/app/Android.bp b/tests/cts/hostside/app/Android.bp
index 50fda6d..a9686de 100644
--- a/tests/cts/hostside/app/Android.bp
+++ b/tests/cts/hostside/app/Android.bp
@@ -27,6 +27,7 @@
         "androidx.test.rules",
         "androidx.test.ext.junit",
         "compatibility-device-util-axt",
+        "cts-net-utils",
         "ctstestrunner-axt",
         "ub-uiautomator",
         "CtsHostsideNetworkTestsAidl",
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java
index f423503..1afbfb0 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java
@@ -25,7 +25,8 @@
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.getConnectivityManager;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.getContext;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.getInstrumentation;
-import static com.android.cts.net.hostside.NetworkPolicyTestUtils.getWifiManager;
+import static com.android.cts.net.hostside.NetworkPolicyTestUtils.isAppStandbySupported;
+import static com.android.cts.net.hostside.NetworkPolicyTestUtils.isBatterySaverSupported;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.isDozeModeSupported;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.restrictBackgroundValueToString;
 
@@ -46,12 +47,11 @@
 import android.net.ConnectivityManager;
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkInfo.State;
-import android.net.wifi.WifiManager;
+import android.net.NetworkRequest;
 import android.os.BatteryManager;
 import android.os.Binder;
 import android.os.Bundle;
 import android.os.SystemClock;
-import android.provider.Settings;
 import android.service.notification.NotificationListenerService;
 import android.util.Log;
 
@@ -628,6 +628,9 @@
     }
 
     protected void setBatterySaverMode(boolean enabled) throws Exception {
+        if (!isBatterySaverSupported()) {
+            return;
+        }
         Log.i(TAG, "Setting Battery Saver Mode to " + enabled);
         if (enabled) {
             turnBatteryOn();
@@ -639,8 +642,9 @@
     }
 
     protected void setDozeMode(boolean enabled) throws Exception {
-        // Check doze mode is supported.
-        assertTrue("Device does not support Doze Mode", isDozeModeSupported());
+        if (!isDozeModeSupported()) {
+            return;
+        }
 
         Log.i(TAG, "Setting Doze Mode to " + enabled);
         if (enabled) {
@@ -660,12 +664,18 @@
     }
 
     protected void setAppIdle(boolean enabled) throws Exception {
+        if (!isAppStandbySupported()) {
+            return;
+        }
         Log.i(TAG, "Setting app idle to " + enabled);
         executeSilentShellCommand("am set-inactive " + TEST_APP2_PKG + " " + enabled );
         assertAppIdle(enabled);
     }
 
     protected void setAppIdleNoAssert(boolean enabled) throws Exception {
+        if (!isAppStandbySupported()) {
+            return;
+        }
         Log.i(TAG, "Setting app idle to " + enabled);
         executeSilentShellCommand("am set-inactive " + TEST_APP2_PKG + " " + enabled );
     }
@@ -704,8 +714,10 @@
         fail("app2 receiver is not ready");
     }
 
-    protected void registerNetworkCallback(INetworkCallback cb) throws Exception {
-        mServiceClient.registerNetworkCallback(cb);
+    protected void registerNetworkCallback(final NetworkRequest request, INetworkCallback cb)
+            throws Exception {
+        Log.i(TAG, "Registering network callback for request: " + request);
+        mServiceClient.registerNetworkCallback(request, cb);
     }
 
     protected void unregisterNetworkCallback() throws Exception {
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/DumpOnFailureRule.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/DumpOnFailureRule.java
index 5ecb399..66cb935 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/DumpOnFailureRule.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/DumpOnFailureRule.java
@@ -24,6 +24,8 @@
 import android.os.ParcelFileDescriptor;
 import android.util.Log;
 
+import androidx.test.platform.app.InstrumentationRegistry;
+
 import com.android.compatibility.common.util.OnFailureRule;
 
 import org.junit.AssumptionViolatedException;
@@ -37,23 +39,20 @@
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
 
-import androidx.test.platform.app.InstrumentationRegistry;
-
 public class DumpOnFailureRule extends OnFailureRule {
     private File mDumpDir = new File(Environment.getExternalStorageDirectory(),
             "CtsHostsideNetworkTests");
 
     @Override
     public void onTestFailure(Statement base, Description description, Throwable throwable) {
-        final String testName = description.getClassName() + "_" + description.getMethodName();
-
         if (throwable instanceof AssumptionViolatedException) {
+            final String testName = description.getClassName() + "_" + description.getMethodName();
             Log.d(TAG, "Skipping test " + testName + ": " + throwable);
             return;
         }
 
         prepareDumpRootDir();
-        final File dumpFile = new File(mDumpDir, "dump-" + testName);
+        final File dumpFile = new File(mDumpDir, "dump-" + getShortenedTestName(description));
         Log.i(TAG, "Dumping debug info for " + description + ": " + dumpFile.getPath());
         try (FileOutputStream out = new FileOutputStream(dumpFile)) {
             for (String cmd : new String[] {
@@ -71,6 +70,17 @@
         }
     }
 
+    private String getShortenedTestName(Description description) {
+        final String qualifiedClassName = description.getClassName();
+        final String className = qualifiedClassName.substring(
+                qualifiedClassName.lastIndexOf(".") + 1);
+        final String shortenedClassName = className.chars()
+                .filter(Character::isUpperCase)
+                .collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append)
+                .toString();
+        return shortenedClassName + "_" + description.getMethodName();
+    }
+
     void dumpCommandOutput(FileOutputStream out, String cmd) {
         final ParcelFileDescriptor pfd = InstrumentationRegistry.getInstrumentation()
                 .getUiAutomation().executeShellCommand(cmd);
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyActivity.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyActivity.java
index 0d0bc58..55eec11 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyActivity.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyActivity.java
@@ -17,6 +17,7 @@
 package com.android.cts.net.hostside;
 
 import android.app.Activity;
+import android.app.KeyguardManager;
 import android.content.Intent;
 import android.os.Bundle;
 import android.view.WindowManager;
@@ -34,6 +35,11 @@
         getWindow().addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON
                 | WindowManager.LayoutParams.FLAG_TURN_SCREEN_ON
                 | WindowManager.LayoutParams.FLAG_DISMISS_KEYGUARD);
+
+        // Dismiss the keyguard so that the tests can click on the VPN confirmation dialog.
+        // FLAG_DISMISS_KEYGUARD is not sufficient to do this because as soon as the dialog appears,
+        // this activity goes into the background and the keyguard reappears.
+        getSystemService(KeyguardManager.class).requestDismissKeyguard(this, null /* callback */);
     }
 
     @Override
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyServiceClient.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyServiceClient.java
index 6546e26..c37e8d5 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyServiceClient.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/MyServiceClient.java
@@ -20,12 +20,11 @@
 import android.content.Context;
 import android.content.Intent;
 import android.content.ServiceConnection;
+import android.net.NetworkRequest;
 import android.os.ConditionVariable;
 import android.os.IBinder;
 import android.os.RemoteException;
 
-import com.android.cts.net.hostside.IMyService;
-
 public class MyServiceClient {
     private static final int TIMEOUT_MS = 5000;
     private static final String PACKAGE = MyServiceClient.class.getPackage().getName();
@@ -93,12 +92,14 @@
         return mService.getRestrictBackgroundStatus();
     }
 
-    public void sendNotification(int notificationId, String notificationType) throws RemoteException {
+    public void sendNotification(int notificationId, String notificationType)
+            throws RemoteException {
         mService.sendNotification(notificationId, notificationType);
     }
 
-    public void registerNetworkCallback(INetworkCallback cb) throws RemoteException {
-        mService.registerNetworkCallback(cb);
+    public void registerNetworkCallback(final NetworkRequest request, INetworkCallback cb)
+            throws RemoteException {
+        mService.registerNetworkCallback(request, cb);
     }
 
     public void unregisterNetworkCallback() throws RemoteException {
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
index 955317b..36e2ffe 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkCallbackTest.java
@@ -19,6 +19,7 @@
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
 
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.canChangeActiveNetworkMeteredness;
+import static com.android.cts.net.hostside.NetworkPolicyTestUtils.getActiveNetworkCapabilities;
 import static com.android.cts.net.hostside.NetworkPolicyTestUtils.setRestrictBackground;
 import static com.android.cts.net.hostside.Property.BATTERY_SAVER_MODE;
 import static com.android.cts.net.hostside.Property.DATA_SAVER_MODE;
@@ -29,6 +30,7 @@
 
 import android.net.Network;
 import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
 import android.util.Log;
 
 import org.junit.After;
@@ -195,11 +197,16 @@
         setBatterySaverMode(false);
         setRestrictBackground(false);
 
+        // Get transports of the active network, this has to be done before changing meteredness,
+        // since wifi will be disconnected when changing from non-metered to metered.
+        final NetworkCapabilities networkCapabilities = getActiveNetworkCapabilities();
+
         // Mark network as metered.
         mMeterednessConfiguration.configureNetworkMeteredness(true);
 
         // Register callback
-        registerNetworkCallback((INetworkCallback.Stub) mTestNetworkCallback);
+        registerNetworkCallback(new NetworkRequest.Builder()
+                        .setCapabilities(networkCapabilities).build(), mTestNetworkCallback);
         // Wait for onAvailable() callback to ensure network is available before the test
         // and store the default network.
         mNetwork = mTestNetworkCallback.expectAvailableCallbackAndGetNetwork();
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
index b61535b..e62d557 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
@@ -22,12 +22,13 @@
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_METERED;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+import static android.net.wifi.WifiConfiguration.METERED_OVERRIDE_METERED;
+import static android.net.wifi.WifiConfiguration.METERED_OVERRIDE_NONE;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommand;
 import static com.android.cts.net.hostside.AbstractRestrictBackgroundNetworkTestCase.TAG;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
@@ -35,19 +36,21 @@
 
 import android.app.ActivityManager;
 import android.app.Instrumentation;
+import android.app.UiAutomation;
 import android.content.Context;
 import android.location.LocationManager;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.Network;
 import android.net.NetworkCapabilities;
+import android.net.wifi.WifiConfiguration;
 import android.net.wifi.WifiManager;
+import android.net.wifi.WifiManager.ActionListener;
 import android.os.PersistableBundle;
 import android.os.Process;
 import android.telephony.CarrierConfigManager;
 import android.telephony.SubscriptionManager;
 import android.telephony.data.ApnSetting;
-import android.text.TextUtils;
 import android.util.Log;
 
 import androidx.test.platform.app.InstrumentationRegistry;
@@ -57,7 +60,12 @@
 import com.android.compatibility.common.util.ShellIdentityUtils;
 import com.android.compatibility.common.util.ThrowingRunnable;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 
 public class NetworkPolicyTestUtils {
@@ -98,11 +106,11 @@
         if (mDataSaverSupported == null) {
             assertMyRestrictBackgroundStatus(RESTRICT_BACKGROUND_STATUS_DISABLED);
             try {
-                setRestrictBackground(true);
+                setRestrictBackgroundInternal(true);
                 mDataSaverSupported = !isMyRestrictBackgroundStatus(
                         RESTRICT_BACKGROUND_STATUS_DISABLED);
             } finally {
-                setRestrictBackground(false);
+                setRestrictBackgroundInternal(false);
             }
         }
         return mDataSaverSupported;
@@ -182,41 +190,89 @@
     }
 
     private static String getWifiSsid() {
-        final boolean isLocationEnabled = isLocationEnabled();
+        final UiAutomation uiAutomation = getInstrumentation().getUiAutomation();
         try {
-            if (!isLocationEnabled) {
-                setLocationEnabled(true);
-            }
-            final String ssid = unquoteSSID(getWifiManager().getConnectionInfo().getSSID());
+            uiAutomation.adoptShellPermissionIdentity();
+            final String ssid = getWifiManager().getConnectionInfo().getSSID();
             assertNotEquals(WifiManager.UNKNOWN_SSID, ssid);
             return ssid;
         } finally {
-            // Reset the location enabled state
-            if (!isLocationEnabled) {
-                setLocationEnabled(false);
-            }
+            uiAutomation.dropShellPermissionIdentity();
         }
     }
 
-    private static NetworkCapabilities getActiveNetworkCapabilities() {
+    static NetworkCapabilities getActiveNetworkCapabilities() {
         final Network activeNetwork = getConnectivityManager().getActiveNetwork();
         assertNotNull("No active network available", activeNetwork);
         return getConnectivityManager().getNetworkCapabilities(activeNetwork);
     }
 
     private static void setWifiMeteredStatus(String ssid, boolean metered) throws Exception {
-        assertFalse("SSID should not be empty", TextUtils.isEmpty(ssid));
-        final String cmd = "cmd netpolicy set metered-network " + ssid + " " + metered;
-        executeShellCommand(cmd);
-        assertWifiMeteredStatus(ssid, metered);
-        assertActiveNetworkMetered(metered);
+        final UiAutomation uiAutomation = getInstrumentation().getUiAutomation();
+        try {
+            uiAutomation.adoptShellPermissionIdentity();
+            final WifiConfiguration currentConfig = getWifiConfiguration(ssid);
+            currentConfig.meteredOverride = metered
+                    ? METERED_OVERRIDE_METERED : METERED_OVERRIDE_NONE;
+            BlockingQueue<Integer> blockingQueue = new LinkedBlockingQueue<>();
+            getWifiManager().save(currentConfig, createActionListener(
+                    blockingQueue, Integer.MAX_VALUE));
+            Integer resultCode = blockingQueue.poll(TIMEOUT_CHANGE_METEREDNESS_MS,
+                    TimeUnit.MILLISECONDS);
+            if (resultCode == null) {
+                fail("Timed out waiting for meteredness to change; ssid=" + ssid
+                        + ", metered=" + metered);
+            } else if (resultCode != Integer.MAX_VALUE) {
+                fail("Error overriding the meteredness; ssid=" + ssid
+                        + ", metered=" + metered + ", error=" + resultCode);
+            }
+            final boolean success = assertActiveNetworkMetered(metered, false /* throwOnFailure */);
+            if (!success) {
+                Log.i(TAG, "Retry connecting to wifi; ssid=" + ssid);
+                blockingQueue = new LinkedBlockingQueue<>();
+                getWifiManager().connect(currentConfig, createActionListener(
+                        blockingQueue, Integer.MAX_VALUE));
+                resultCode = blockingQueue.poll(TIMEOUT_CHANGE_METEREDNESS_MS,
+                        TimeUnit.MILLISECONDS);
+                if (resultCode == null) {
+                    fail("Timed out waiting for wifi to connect; ssid=" + ssid);
+                } else if (resultCode != Integer.MAX_VALUE) {
+                    fail("Error connecting to wifi; ssid=" + ssid
+                            + ", error=" + resultCode);
+                }
+                assertActiveNetworkMetered(metered, true /* throwOnFailure */);
+            }
+        } finally {
+            uiAutomation.dropShellPermissionIdentity();
+        }
     }
 
-    private static void assertWifiMeteredStatus(String ssid, boolean expectedMeteredStatus) {
-        final String result = executeShellCommand("cmd netpolicy list wifi-networks");
-        final String expectedLine = ssid + ";" + expectedMeteredStatus;
-        assertTrue("Expected line: " + expectedLine + "; Actual result: " + result,
-                result.contains(expectedLine));
+    private static WifiConfiguration getWifiConfiguration(String ssid) {
+        final List<String> ssids = new ArrayList<>();
+        for (WifiConfiguration config : getWifiManager().getConfiguredNetworks()) {
+            if (config.SSID.equals(ssid)) {
+                return config;
+            }
+            ssids.add(config.SSID);
+        }
+        fail("Couldn't find the wifi config; ssid=" + ssid
+                + ", all=" + Arrays.toString(ssids.toArray()));
+        return null;
+    }
+
+    private static ActionListener createActionListener(BlockingQueue<Integer> blockingQueue,
+            int successCode) {
+        return new ActionListener() {
+            @Override
+            public void onSuccess() {
+                blockingQueue.offer(successCode);
+            }
+
+            @Override
+            public void onFailure(int reason) {
+                blockingQueue.offer(reason);
+            }
+        };
     }
 
     private static void setCellularMeteredStatus(int subId, boolean metered) throws Exception {
@@ -225,11 +281,11 @@
                 new String[] {ApnSetting.TYPE_MMS_STRING});
         ShellIdentityUtils.invokeMethodWithShellPermissionsNoReturn(getCarrierConfigManager(),
                 (cm) -> cm.overrideConfig(subId, metered ? null : bundle));
-        assertActiveNetworkMetered(metered);
+        assertActiveNetworkMetered(metered, true /* throwOnFailure */);
     }
 
-    // Copied from cts/tests/tests/net/src/android/net/cts/ConnectivityManagerTest.java
-    private static void assertActiveNetworkMetered(boolean expectedMeteredStatus) throws Exception {
+    private static boolean assertActiveNetworkMetered(boolean expectedMeteredStatus,
+            boolean throwOnFailure) throws Exception {
         final CountDownLatch latch = new CountDownLatch(1);
         final NetworkCallback networkCallback = new NetworkCallback() {
             @Override
@@ -246,16 +302,29 @@
         getConnectivityManager().registerDefaultNetworkCallback(networkCallback);
         try {
             if (!latch.await(TIMEOUT_CHANGE_METEREDNESS_MS, TimeUnit.MILLISECONDS)) {
-                fail("Timed out waiting for active network metered status to change to "
-                        + expectedMeteredStatus + "; network = "
-                        + getConnectivityManager().getActiveNetwork());
+                final String errorMsg = "Timed out waiting for active network metered status "
+                        + "to change to " + expectedMeteredStatus + "; network = "
+                        + getConnectivityManager().getActiveNetwork();
+                if (throwOnFailure) {
+                    fail(errorMsg);
+                }
+                Log.w(TAG, errorMsg);
+                return false;
             }
+            return true;
         } finally {
             getConnectivityManager().unregisterNetworkCallback(networkCallback);
         }
     }
 
     public static void setRestrictBackground(boolean enabled) {
+        if (!isDataSaverSupported()) {
+            return;
+        }
+        setRestrictBackgroundInternal(enabled);
+    }
+
+    private static void setRestrictBackgroundInternal(boolean enabled) {
         executeShellCommand("cmd netpolicy set restrict-background " + enabled);
         final String output = executeShellCommand("cmd netpolicy get restrict-background");
         final String expectedSuffix = enabled ? "enabled" : "disabled";
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index 9b437e6..c0600e7 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -61,6 +61,7 @@
 import android.os.ParcelFileDescriptor;
 import android.os.Process;
 import android.os.SystemProperties;
+import android.os.UserHandle;
 import android.provider.Settings;
 import android.support.test.uiautomator.UiDevice;
 import android.support.test.uiautomator.UiObject;
@@ -76,6 +77,7 @@
 
 import com.android.compatibility.common.util.BlockingBroadcastReceiver;
 import com.android.modules.utils.build.SdkLevel;
+import com.android.testutils.TestableNetworkCallback;
 
 import java.io.Closeable;
 import java.io.FileDescriptor;
@@ -92,6 +94,7 @@
 import java.net.UnknownHostException;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
 import java.util.Random;
 import java.util.concurrent.CompletableFuture;
@@ -698,34 +701,6 @@
         setAndVerifyPrivateDns(initialMode);
     }
 
-    private class NeverChangeNetworkCallback extends NetworkCallback {
-        private CountDownLatch mLatch = new CountDownLatch(1);
-        private volatile Network mFirstNetwork;
-        private volatile Network mOtherNetwork;
-
-        public void onAvailable(Network n) {
-            // Don't assert here, as it crashes the test with a hard to debug message.
-            if (mFirstNetwork == null) {
-                mFirstNetwork = n;
-                mLatch.countDown();
-            } else if (mOtherNetwork == null) {
-                mOtherNetwork = n;
-            }
-        }
-
-        public Network getFirstNetwork() throws Exception {
-            assertTrue(
-                    "System default callback got no network after " + TIMEOUT_MS + "ms. "
-                    + "Please ensure the device has a working Internet connection.",
-                    mLatch.await(TIMEOUT_MS, TimeUnit.MILLISECONDS));
-            return mFirstNetwork;
-        }
-
-        public void assertNeverChanged() {
-            assertNull(mOtherNetwork);
-        }
-    }
-
     public void testDefault() throws Exception {
         if (!supportedHardware()) return;
         // If adb TCP port opened, this test may running by adb over network.
@@ -741,13 +716,24 @@
                 getInstrumentation().getTargetContext(), MyVpnService.ACTION_ESTABLISHED);
         receiver.register();
 
-        // Expect the system default network not to change.
-        final NeverChangeNetworkCallback neverChangeCallback = new NeverChangeNetworkCallback();
+        // Test the behaviour of a variety of types of network callbacks.
         final Network defaultNetwork = mCM.getActiveNetwork();
+        final TestableNetworkCallback systemDefaultCallback = new TestableNetworkCallback();
+        final TestableNetworkCallback otherUidCallback = new TestableNetworkCallback();
+        final TestableNetworkCallback myUidCallback = new TestableNetworkCallback();
         if (SdkLevel.isAtLeastS()) {
-            runWithShellPermissionIdentity(() ->
-                    mCM.registerSystemDefaultNetworkCallback(neverChangeCallback,
-                            new Handler(Looper.getMainLooper())), NETWORK_SETTINGS);
+            final int otherUid = UserHandle.getUid(UserHandle.of(5), Process.FIRST_APPLICATION_UID);
+            final Handler h = new Handler(Looper.getMainLooper());
+            runWithShellPermissionIdentity(() -> {
+                mCM.registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
+                mCM.registerDefaultNetworkCallbackAsUid(otherUid, otherUidCallback, h);
+                mCM.registerDefaultNetworkCallbackAsUid(Process.myUid(), myUidCallback, h);
+            }, NETWORK_SETTINGS);
+            for (TestableNetworkCallback callback :
+                    List.of(systemDefaultCallback, otherUidCallback, myUidCallback)) {
+                callback.expectAvailableCallbacks(defaultNetwork, false /* suspended */,
+                        true /* validated */, false /* blocked */, TIMEOUT_MS);
+            }
         }
 
         FileDescriptor fd = openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS);
@@ -767,20 +753,24 @@
 
         checkTrafficOnVpn();
 
-        maybeExpectVpnTransportInfo(mCM.getActiveNetwork());
+        final Network vpnNetwork = mCM.getActiveNetwork();
+        myUidCallback.expectAvailableThenValidatedCallbacks(vpnNetwork, TIMEOUT_MS);
+        assertEquals(vpnNetwork, mCM.getActiveNetwork());
+        assertNotEqual(defaultNetwork, vpnNetwork);
+        maybeExpectVpnTransportInfo(vpnNetwork);
 
-        assertNotEqual(defaultNetwork, mCM.getActiveNetwork());
         if (SdkLevel.isAtLeastS()) {
             // Check that system default network callback has not seen any network changes, even
-            // though the app's default network changed. This needs to be done before testing
-            // private DNS because checkStrictModePrivateDns will set the private DNS server to
-            // a nonexistent name, which will cause validation to fail and cause the default
-            // network to switch (e.g., from wifi to cellular).
-            assertEquals(defaultNetwork, neverChangeCallback.getFirstNetwork());
-            neverChangeCallback.assertNeverChanged();
-            runWithShellPermissionIdentity(
-                    () -> mCM.unregisterNetworkCallback(neverChangeCallback),
-                    NETWORK_SETTINGS);
+            // though the app's default network changed. Also check that otherUidCallback saw no
+            // network changes, because otherUid is in a different user and not subject to the VPN.
+            // This needs to be done before testing  private DNS because checkStrictModePrivateDns
+            // will set the private DNS server to a nonexistent name, which will cause validation to
+            // fail and could cause the default network to switch (e.g., from wifi to cellular).
+            systemDefaultCallback.assertNoCallback();
+            otherUidCallback.assertNoCallback();
+            mCM.unregisterNetworkCallback(systemDefaultCallback);
+            mCM.unregisterNetworkCallback(otherUidCallback);
+            mCM.unregisterNetworkCallback(myUidCallback);
         }
 
         checkStrictModePrivateDns();
diff --git a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
index 1c9ff05..8a5e00f 100644
--- a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
+++ b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyService.java
@@ -90,7 +90,7 @@
         }
 
         @Override
-        public void registerNetworkCallback(INetworkCallback cb) {
+        public void registerNetworkCallback(final NetworkRequest request, INetworkCallback cb) {
             if (mNetworkCallback != null) {
                 Log.d(TAG, "unregister previous network callback: " + mNetworkCallback);
                 unregisterNetworkCallback();
@@ -138,7 +138,7 @@
                     }
                 }
             };
-            mCm.registerNetworkCallback(makeNetworkRequest(), mNetworkCallback);
+            mCm.registerNetworkCallback(request, mNetworkCallback);
             try {
                 cb.asBinder().linkToDeath(() -> unregisterNetworkCallback(), 0);
             } catch (RemoteException e) {
@@ -156,12 +156,6 @@
         }
       };
 
-    private NetworkRequest makeNetworkRequest() {
-        return new NetworkRequest.Builder()
-                .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)
-                .build();
-    }
-
     @Override
     public IBinder onBind(Intent intent) {
         return mBinder;
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index bfab497..18f0588 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -111,6 +111,7 @@
 import android.os.Handler;
 import android.os.Looper;
 import android.os.MessageQueue;
+import android.os.Process;
 import android.os.SystemClock;
 import android.os.SystemProperties;
 import android.os.VintfRuntimeInfo;
@@ -587,12 +588,14 @@
         final TestNetworkCallback defaultTrackingCallback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(defaultTrackingCallback);
 
-        final TestNetworkCallback systemDefaultTrackingCallback = new TestNetworkCallback();
+        final TestNetworkCallback systemDefaultCallback = new TestNetworkCallback();
+        final TestNetworkCallback perUidCallback = new TestNetworkCallback();
+        final Handler h = new Handler(Looper.getMainLooper());
         if (shouldTestSApis()) {
-            runWithShellPermissionIdentity(() ->
-                    mCmShim.registerSystemDefaultNetworkCallback(systemDefaultTrackingCallback,
-                            new Handler(Looper.getMainLooper())),
-                    NETWORK_SETTINGS);
+            runWithShellPermissionIdentity(() -> {
+                mCmShim.registerSystemDefaultNetworkCallback(systemDefaultCallback, h);
+                mCmShim.registerDefaultNetworkCallbackAsUid(Process.myUid(), perUidCallback, h);
+            }, NETWORK_SETTINGS);
         }
 
         Network wifiNetwork = null;
@@ -607,22 +610,27 @@
             assertNotNull("Did not receive onAvailable for TRANSPORT_WIFI request",
                     wifiNetwork);
 
+            final Network defaultNetwork = defaultTrackingCallback.waitForAvailable();
             assertNotNull("Did not receive onAvailable on default network callback",
-                    defaultTrackingCallback.waitForAvailable());
+                    defaultNetwork);
 
             if (shouldTestSApis()) {
                 assertNotNull("Did not receive onAvailable on system default network callback",
-                        systemDefaultTrackingCallback.waitForAvailable());
+                        systemDefaultCallback.waitForAvailable());
+                final Network perUidNetwork = perUidCallback.waitForAvailable();
+                assertNotNull("Did not receive onAvailable on per-UID default network callback",
+                        perUidNetwork);
+                assertEquals(defaultNetwork, perUidNetwork);
             }
+
         } catch (InterruptedException e) {
             fail("Broadcast receiver or NetworkCallback wait was interrupted.");
         } finally {
             mCm.unregisterNetworkCallback(callback);
             mCm.unregisterNetworkCallback(defaultTrackingCallback);
             if (shouldTestSApis()) {
-                runWithShellPermissionIdentity(
-                        () -> mCm.unregisterNetworkCallback(systemDefaultTrackingCallback),
-                        NETWORK_SETTINGS);
+                mCm.unregisterNetworkCallback(systemDefaultCallback);
+                mCm.unregisterNetworkCallback(perUidCallback);
             }
         }
     }
@@ -1636,6 +1644,62 @@
     }
 
     /**
+     * Verifies that apps are forbidden from getting ssid information from
+     * {@Code NetworkCapabilities} if they do not hold NETWORK_SETTINGS permission.
+     * See b/161370134.
+     */
+    @AppModeFull(reason = "Cannot get WifiManager in instant app mode")
+    @Test
+    public void testSsidInNetworkCapabilities() throws Exception {
+        assumeTrue("testSsidInNetworkCapabilities cannot execute unless device supports WiFi",
+                mPackageManager.hasSystemFeature(FEATURE_WIFI));
+
+        final Network network = mCtsNetUtils.ensureWifiConnected();
+        final String ssid = unquoteSSID(mWifiManager.getConnectionInfo().getSSID());
+        assertNotNull("Ssid getting from WiifManager is null", ssid);
+        // This package should have no NETWORK_SETTINGS permission. Verify that no ssid is contained
+        // in the NetworkCapabilities.
+        verifySsidFromQueriedNetworkCapabilities(network, ssid, false /* hasSsid */);
+        verifySsidFromCallbackNetworkCapabilities(ssid, false /* hasSsid */);
+        // Adopt shell permission to allow to get ssid information.
+        runWithShellPermissionIdentity(() -> {
+            verifySsidFromQueriedNetworkCapabilities(network, ssid, true /* hasSsid */);
+            verifySsidFromCallbackNetworkCapabilities(ssid, true /* hasSsid */);
+        });
+    }
+
+    private void verifySsidFromQueriedNetworkCapabilities(@NonNull Network network,
+            @NonNull String ssid, boolean hasSsid) throws Exception {
+        // Verify if ssid is contained in NetworkCapabilities queried from ConnectivityManager.
+        final NetworkCapabilities nc = mCm.getNetworkCapabilities(network);
+        assertNotNull("NetworkCapabilities of the network is null", nc);
+        assertEquals(hasSsid, Pattern.compile(ssid).matcher(nc.toString()).find());
+    }
+
+    private void verifySsidFromCallbackNetworkCapabilities(@NonNull String ssid, boolean hasSsid)
+            throws Exception {
+        final CompletableFuture<NetworkCapabilities> foundNc = new CompletableFuture();
+        final NetworkCallback callback = new NetworkCallback() {
+            @Override
+            public void onCapabilitiesChanged(Network network, NetworkCapabilities nc) {
+                foundNc.complete(nc);
+            }
+        };
+        try {
+            mCm.registerNetworkCallback(makeWifiNetworkRequest(), callback);
+            // Registering a callback here guarantees onCapabilitiesChanged is called immediately
+            // because WiFi network should be connected.
+            final NetworkCapabilities nc =
+                    foundNc.get(NETWORK_CALLBACK_TIMEOUT_MS, TimeUnit.MILLISECONDS);
+            // Verify if ssid is contained in the NetworkCapabilities received from callback.
+            assertNotNull("NetworkCapabilities of the network is null", nc);
+            assertEquals(hasSsid, Pattern.compile(ssid).matcher(nc.toString()).find());
+        } finally {
+            mCm.unregisterNetworkCallback(callback);
+        }
+    }
+
+    /**
      * Verify background request can only be requested when acquiring
      * {@link android.Manifest.permission.NETWORK_SETTINGS}.
      */
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index f17e50c..f53a2a8 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -18,6 +18,8 @@
 import android.app.Instrumentation
 import android.content.Context
 import android.net.ConnectivityManager
+import android.net.INetworkAgent
+import android.net.INetworkAgentRegistry
 import android.net.InetAddresses
 import android.net.IpPrefix
 import android.net.KeepalivePacketData
@@ -44,6 +46,7 @@
 import android.net.NetworkInfo
 import android.net.NetworkProvider
 import android.net.NetworkRequest
+import android.net.NetworkScore
 import android.net.RouteInfo
 import android.net.SocketKeepalive
 import android.net.Uri
@@ -52,6 +55,8 @@
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAddKeepalivePacketFilter
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnAutomaticReconnectDisabled
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnBandwidthUpdateRequested
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkCreated
+import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkDestroyed
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnNetworkUnwanted
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnRemoveKeepalivePacketFilter
 import android.net.cts.NetworkAgentTest.TestableNetworkAgent.CallbackEntry.OnSaveAcceptUnvalidated
@@ -65,8 +70,6 @@
 import android.os.Message
 import android.util.DebugUtils.valueToString
 import androidx.test.InstrumentationRegistry
-import com.android.connectivity.aidl.INetworkAgent
-import com.android.connectivity.aidl.INetworkAgentRegistry
 import com.android.modules.utils.build.SdkLevel
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.testutils.CompatUtil
@@ -215,6 +218,8 @@
             object OnAutomaticReconnectDisabled : CallbackEntry()
             data class OnValidationStatus(val status: Int, val uri: Uri?) : CallbackEntry()
             data class OnSignalStrengthThresholdsUpdated(val thresholds: IntArray) : CallbackEntry()
+            object OnNetworkCreated : CallbackEntry()
+            object OnNetworkDestroyed : CallbackEntry()
         }
 
         override fun onBandwidthUpdateRequested() {
@@ -268,6 +273,14 @@
             history.add(OnValidationStatus(status, uri))
         }
 
+        override fun onNetworkCreated() {
+            history.add(OnNetworkCreated)
+        }
+
+        override fun onNetworkDestroyed() {
+            history.add(OnNetworkDestroyed)
+        }
+
         // Expects the initial validation event that always occurs immediately after registering
         // a NetworkAgent whose network does not require validation (which test networks do
         // not, since they lack the INTERNET capability). It always contains the default argument
@@ -346,8 +359,10 @@
         val callback = TestableNetworkCallback(timeoutMs = DEFAULT_TIMEOUT_MS)
         requestNetwork(request, callback)
         val agent = createNetworkAgent(context, name)
+        agent.setTeardownDelayMs(0)
         agent.register()
         agent.markConnected()
+        agent.expectCallback<OnNetworkCreated>()
         return agent to callback
     }
 
@@ -367,6 +382,7 @@
         assertFailsWith<IllegalStateException>("Must not be able to register an agent twice") {
             agent.register()
         }
+        agent.expectCallback<OnNetworkDestroyed>()
     }
 
     @Test
@@ -546,6 +562,7 @@
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.R)
     fun testSetUnderlyingNetworksAndVpnSpecifier() {
+        val mySessionId = "MySession12345"
         val request = NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_TEST)
                 .addTransportType(TRANSPORT_VPN)
@@ -559,7 +576,7 @@
             addTransportType(TRANSPORT_TEST)
             addTransportType(TRANSPORT_VPN)
             removeCapability(NET_CAPABILITY_NOT_VPN)
-            setTransportInfo(VpnTransportInfo(VpnManager.TYPE_VPN_SERVICE))
+            setTransportInfo(VpnTransportInfo(VpnManager.TYPE_VPN_SERVICE, mySessionId))
             if (SdkLevel.isAtLeastS()) {
                 addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
             }
@@ -579,6 +596,8 @@
         assertNotNull(vpnNc)
         assertEquals(VpnManager.TYPE_VPN_SERVICE,
                 (vpnNc.transportInfo as VpnTransportInfo).type)
+        // TODO: b/183938194 please fix the issue and enable following check.
+        // assertEquals(mySessionId, (vpnNc.transportInfo as VpnTransportInfo).sessionId)
 
         val testAndVpn = intArrayOf(TRANSPORT_TEST, TRANSPORT_VPN)
         assertTrue(hasAllTransports(vpnNc, testAndVpn))
@@ -626,12 +645,13 @@
         val mockContext = mock(Context::class.java)
         val mockCm = mock(ConnectivityManager::class.java)
         doReturn(mockCm).`when`(mockContext).getSystemService(Context.CONNECTIVITY_SERVICE)
-        createConnectedNetworkAgent(mockContext)
+        val agent = createNetworkAgent(mockContext)
+        agent.register()
         verify(mockCm).registerNetworkAgent(any(),
                 argThat<NetworkInfo> { it.detailedState == NetworkInfo.DetailedState.CONNECTING },
                 any(LinkProperties::class.java),
                 any(NetworkCapabilities::class.java),
-                any() /* score */,
+                any(NetworkScore::class.java),
                 any(NetworkAgentConfig::class.java),
                 eq(NetworkProvider.ID_NONE))
     }
diff --git a/tests/cts/net/src/android/net/cts/NetworkRequestTest.java b/tests/cts/net/src/android/net/cts/NetworkRequestTest.java
index 30c4e72..9906c30 100644
--- a/tests/cts/net/src/android/net/cts/NetworkRequestTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkRequestTest.java
@@ -39,17 +39,20 @@
 import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
 import android.net.NetworkSpecifier;
-import android.net.UidRange;
 import android.net.wifi.WifiNetworkSpecifier;
 import android.os.Build;
 import android.os.PatternMatcher;
 import android.os.Process;
 import android.util.ArraySet;
+import android.util.Range;
 
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.modules.utils.build.SdkLevel;
 import com.android.networkstack.apishim.ConstantsShim;
+import com.android.networkstack.apishim.NetworkRequestShimImpl;
+import com.android.networkstack.apishim.common.NetworkRequestShim;
+import com.android.networkstack.apishim.common.UnsupportedApiLevelException;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 
@@ -57,6 +60,8 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.util.Set;
+
 @RunWith(AndroidJUnit4.class)
 public class NetworkRequestTest {
     @Rule
@@ -225,6 +230,14 @@
         assertTrue(requestCellularInternet.canBeSatisfiedBy(capCellularVpnMmsInternet));
     }
 
+    private void setUids(NetworkRequest.Builder builder, Set<Range<Integer>> ranges)
+            throws UnsupportedApiLevelException {
+        if (SdkLevel.isAtLeastS()) {
+            final NetworkRequestShim networkRequestShim = NetworkRequestShimImpl.newInstance();
+            networkRequestShim.setUids(builder, ranges);
+        }
+    }
+
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.Q)
     public void testInvariantInCanBeSatisfiedBy() {
@@ -232,15 +245,26 @@
         // NetworkCapabilities.satisfiedByNetworkCapabilities().
         final LocalNetworkSpecifier specifier1 = new LocalNetworkSpecifier(1234 /* id */);
         final int uid = Process.myUid();
-        final ArraySet<UidRange> ranges = new ArraySet<>();
-        ranges.add(new UidRange(uid, uid));
-        final NetworkRequest requestCombination = new NetworkRequest.Builder()
+        final NetworkRequest.Builder nrBuilder = new NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .addCapability(NET_CAPABILITY_INTERNET)
                 .setLinkUpstreamBandwidthKbps(1000)
                 .setNetworkSpecifier(specifier1)
-                .setSignalStrength(-123)
-                .setUids(ranges).build();
+                .setSignalStrength(-123);
+
+        // The uid ranges should be set into the request, but setUids() takes a set of UidRange
+        // that is hidden and inaccessible from shims. Before, S setUids will be a no-op. But
+        // because NetworkRequest.Builder sets the UID of the request to the current UID, the
+        // request contains the current UID both on S and before S.
+        final Set<Range<Integer>> ranges = new ArraySet<>();
+        ranges.add(new Range<Integer>(uid, uid));
+        try {
+            setUids(nrBuilder, ranges);
+        } catch (UnsupportedApiLevelException e) {
+            // Not supported before API31.
+        }
+        final NetworkRequest requestCombination = nrBuilder.build();
+
         final NetworkCapabilities capCell = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR).build();
         assertCorrectlySatisfies(false, requestCombination, capCell);
diff --git a/tests/cts/net/src/android/net/cts/VpnServiceTest.java b/tests/cts/net/src/android/net/cts/VpnServiceTest.java
index 15af23c..5c7b5ca 100644
--- a/tests/cts/net/src/android/net/cts/VpnServiceTest.java
+++ b/tests/cts/net/src/android/net/cts/VpnServiceTest.java
@@ -47,6 +47,7 @@
         assertEquals(1, count);
     }
 
+    @AppModeFull(reason = "establish() requires prepare(), which requires PackageManager access")
     public void testEstablish() throws Exception {
         ParcelFileDescriptor descriptor = null;
         try {
@@ -62,7 +63,7 @@
         }
     }
 
-    @AppModeFull(reason = "Socket cannot bind in instant app mode")
+    @AppModeFull(reason = "Protecting sockets requires prepare(), which requires PackageManager")
     public void testProtect_DatagramSocket() throws Exception {
         DatagramSocket socket = new DatagramSocket();
         try {
@@ -77,6 +78,7 @@
         }
     }
 
+    @AppModeFull(reason = "Protecting sockets requires prepare(), which requires PackageManager")
     public void testProtect_Socket() throws Exception {
         Socket socket = new Socket();
         try {
@@ -91,7 +93,7 @@
         }
     }
 
-    @AppModeFull(reason = "Socket cannot bind in instant app mode")
+    @AppModeFull(reason = "Protecting sockets requires prepare(), which requires PackageManager")
     public void testProtect_int() throws Exception {
         DatagramSocket socket = new DatagramSocket();
         ParcelFileDescriptor descriptor = ParcelFileDescriptor.fromDatagramSocket(socket);