Merge "do not drop ingress dns replies with non bypassable vpn"
diff --git a/bpf_progs/netd.c b/bpf_progs/netd.c
index 70e0ae7..f347028 100644
--- a/bpf_progs/netd.c
+++ b/bpf_progs/netd.c
@@ -30,6 +30,7 @@
 #include "netdbpf/bpf_shared.h"
 
 // This is defined for cgroup bpf filter only.
+#define BPF_DROP_UNLESS_DNS 2
 #define BPF_PASS 1
 #define BPF_DROP 0
 
@@ -206,7 +207,7 @@
     if (direction == BPF_INGRESS && (uidRules & IIF_MATCH)) {
         // Drops packets not coming from lo nor the whitelisted interface
         if (allowed_iif && skb->ifindex != 1 && skb->ifindex != allowed_iif) {
-            return BPF_DROP;
+            return BPF_DROP_UNLESS_DNS;
         }
     }
     return BPF_PASS;
@@ -247,6 +248,17 @@
         tag = 0;
     }
 
+// Workaround for secureVPN with VpnIsolation enabled, refer to b/159994981 for details.
+// Keep TAG_SYSTEM_DNS in sync with DnsResolver/include/netd_resolv/resolv.h
+// and TrafficStatsConstants.java
+#define TAG_SYSTEM_DNS 0xFFFFFF82
+    if (tag == TAG_SYSTEM_DNS && uid == AID_DNS) {
+        uid = sock_uid;
+        if (match == BPF_DROP_UNLESS_DNS) match = BPF_PASS;
+    } else {
+        if (match == BPF_DROP_UNLESS_DNS) match = BPF_DROP;
+    }
+
     StatsKey key = {.uid = uid, .tag = tag, .counterSet = 0, .ifaceIndex = skb->ifindex};
 
     uint8_t* counterSet = bpf_uid_counterset_map_lookup_elem(&uid);
diff --git a/server/main.cpp b/server/main.cpp
index 0a86b0a..4949ff6 100644
--- a/server/main.cpp
+++ b/server/main.cpp
@@ -83,6 +83,8 @@
 }
 
 int tagSocketCallback(int sockFd, uint32_t tag, uid_t uid, pid_t) {
+    // Workaround for secureVPN with VpnIsolation enabled, refer to b/159994981 for details.
+    if (tag == TAG_SYSTEM_DNS) uid = AID_DNS;
     return gCtls->trafficCtrl.tagSocket(sockFd, tag, uid, geteuid());
 }