Test to verify cgroup socket filter

Add a base test for cgroup socket filter to check the cgroup socket eBPF
program can actually block socket creation for INET socket. This feature
will be used in 4.14 and above kernel to replace paranoid network.

Test: ./bpf_test.py
Bug: 111560739
Change-Id: I10f5c0c6847ec033cf757b8ce9dfa1a6b80c50fb
diff --git a/net/test/bpf.py b/net/test/bpf.py
index aa50f3e..43502bd 100755
--- a/net/test/bpf.py
+++ b/net/test/bpf.py
@@ -152,6 +152,7 @@
 BPF_FUNC_map_lookup_elem = 1
 BPF_FUNC_map_update_elem = 2
 BPF_FUNC_map_delete_elem = 3
+BPF_FUNC_get_current_uid_gid = 15
 BPF_FUNC_get_socket_cookie = 46
 BPF_FUNC_get_socket_uid = 47
 
diff --git a/net/test/bpf_test.py b/net/test/bpf_test.py
index 270243e..823dcbe 100755
--- a/net/test/bpf_test.py
+++ b/net/test/bpf_test.py
@@ -30,10 +30,12 @@
 
 libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
 HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0)
+HAVE_EBPF_SOCKET = net_test.LINUX_VERSION >= (4, 14, 0)
 KEY_SIZE = 8
 VALUE_SIZE = 4
 TOTAL_ENTRIES = 20
 TEST_UID = 54321
+TEST_GID = 12345
 # Offset to store the map key in stack register REG10
 key_offset = -8
 # Offset to store the map value in stack register REG10
@@ -350,6 +352,10 @@
       BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
     except socket.error:
       pass
+    try:
+      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
+    except socket.error:
+      pass
 
   def testCgroupBpfAttach(self):
     self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
@@ -392,10 +398,48 @@
       self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
       SocketUDPLoopBack(packet_count, 4, None)
       self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
-      DeleteMap(self.map_fd, uid);
+      DeleteMap(self.map_fd, uid)
       SocketUDPLoopBack(packet_count, 6, None)
       self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
     BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
 
+  def checkSocketCreate(self, family, socktype, success):
+    try:
+      sock = socket.socket(family, socktype, 0)
+    except socket.error, e:
+      if success:
+        self.fail("Failed to create socket family=%d type=%d err=%s" %
+                  (family, socktype, os.strerror(e.errno)))
+      return;
+    if not success:
+      self.fail("unexpected socket family=%d type=%d created, should be blocked" %
+                (family, socktype))
+
+
+  def trySocketCreate(self, success):
+      for family in [socket.AF_INET, socket.AF_INET6]:
+        for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
+          self.checkSocketCreate(family, socktype, success)
+
+  @unittest.skipUnless(HAVE_EBPF_SOCKET,
+                     "Cgroup BPF socket is not supported")
+  def testCgroupSocketCreateBlock(self):
+    instructions = [
+        BpfFuncCall(BPF_FUNC_get_current_uid_gid),
+        BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff),
+        BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2),
+    ]
+    instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT;
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions)
+    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
+    with net_test.RunAsUid(TEST_UID):
+      # Socket creation with target uid should fail
+      self.trySocketCreate(False);
+    # Socket create with different uid should success
+    self.trySocketCreate(True)
+    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
+    with net_test.RunAsUid(TEST_UID):
+      self.trySocketCreate(True)
+
 if __name__ == "__main__":
   unittest.main()
diff --git a/net/test/net_test.py b/net/test/net_test.py
index 035ba60..1c7f32f 100755
--- a/net/test/net_test.py
+++ b/net/test/net_test.py
@@ -369,17 +369,17 @@
 
   def __enter__(self):
     if self.uid:
-      self.saved_uid = os.geteuid()
+      self.saved_uids = os.getresuid()
       self.saved_groups = os.getgroups()
       os.setgroups(self.saved_groups + [AID_INET])
-      os.seteuid(self.uid)
+      os.setresuid(self.uid, self.uid, self.saved_uids[0])
     if self.gid:
       self.saved_gid = os.getgid()
       os.setgid(self.gid)
 
   def __exit__(self, unused_type, unused_value, unused_traceback):
     if self.uid:
-      os.seteuid(self.saved_uid)
+      os.setresuid(*self.saved_uids)
       os.setgroups(self.saved_groups)
     if self.gid:
       os.setgid(self.saved_gid)
@@ -390,7 +390,6 @@
   def __init__(self, uid):
     RunAsUidGid.__init__(self, uid, 0)
 
-
 class NetworkTest(unittest.TestCase):
 
   def assertRaisesErrno(self, err_num, f=None, *args):