Add tests for cgroup v2 bpf and new helper functions

Add several bpf program test to verify the new get_socket_cookie and
get_socket_uid helper function cherry-picked to 4.9 kernel. Also, added
a new test class to test the cgroup v2 bpf functionality backported from
upstream. Refactored the bpf assembly program into reusable code blocks.
Added some helper function to avoid duplication of code. Added IPv6 support
for all exisiting tests.

Test: All test added in this patch should pass in all 4.9 kernel

Signed-off-by: Chenbo Feng <fengc@google.com>
Bug: 30950746
Change-Id: I64ead3d1a04985e7499a1fe1cdee58ef580e9313
diff --git a/net/test/bpf.py b/net/test/bpf.py
index 14ee613..9e8cf1e 100755
--- a/net/test/bpf.py
+++ b/net/test/bpf.py
@@ -36,6 +36,8 @@
 BPF_PROG_LOAD = 5
 BPF_OBJ_PIN = 6
 BPF_OBJ_GET = 7
+BPF_PROG_ATTACH = 8
+BPF_PROG_DETACH = 9
 SO_ATTACH_BPF = 50
 
 # BPF map type constant.
@@ -51,6 +53,16 @@
 BPF_PROG_TYPE_KPROBE = 2
 BPF_PROG_TYPE_SCHED_CLS = 3
 BPF_PROG_TYPE_SCHED_ACT = 4
+BPF_PROG_TYPE_TRACEPOINT = 5
+BPF_PROG_TYPE_XDP = 6
+BPF_PROG_TYPE_PERF_EVENT = 7
+BPF_PROG_TYPE_CGROUP_SKB = 8
+BPF_PROG_TYPE_CGROUP_SOCK = 9
+
+# BPF program attach type.
+BPF_CGROUP_INET_INGRESS = 0
+BPF_CGROUP_INET_EGRESS = 1
+BPF_CGROUP_INET_SOCK_CREATE = 2
 
 # BPF register constant
 BPF_REG_0 = 0
@@ -124,8 +136,12 @@
 BPF_FUNC_map_lookup_elem = 1
 BPF_FUNC_map_update_elem = 2
 BPF_FUNC_map_delete_elem = 3
+BPF_FUNC_get_socket_cookie = 46
+BPF_FUNC_get_socket_uid = 47
 
-# BPF attr struct
+#  These object below belongs to the same kernel union and the types below
+#  (e.g., bpf_attr_create) aren't kernel struct names but just different
+#  variants of the union.
 BpfAttrCreate = cstruct.Struct("bpf_attr_create", "=IIII",
                                "map_type key_size value_size max_entries")
 BpfAttrOps = cstruct.Struct("bpf_attr_ops", "=QQQQ",
@@ -133,6 +149,8 @@
 BpfAttrProgLoad = cstruct.Struct(
     "bpf_attr_prog_load", "=IIQQIIQI", "prog_type insn_cnt insns"
     " license log_level log_size log_buf kern_version")
+BpfAttrProgAttach = cstruct.Struct(
+    "bpf_attr_prog_attach", "=III", "target_fd attach_bpf_fd attach_type")
 BpfInsn = cstruct.Struct("bpf_insn", "=BBhi", "code dst_src_reg off imm")
 
 libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
@@ -140,12 +158,15 @@
 
 
 # BPF program syscalls
-def CreateMap(map_type, key_size, value_size, max_entries):
-  attr = BpfAttrCreate((map_type, key_size, value_size, max_entries))
-  ret = libc.syscall(__NR_bpf, BPF_MAP_CREATE, attr.CPointer(), len(attr))
+def BpfSyscall(op, attr):
+  ret = libc.syscall(__NR_bpf, op, attr.CPointer(), len(attr))
   csocket.MaybeRaiseSocketError(ret)
   return ret
 
+def CreateMap(map_type, key_size, value_size, max_entries):
+  attr = BpfAttrCreate((map_type, key_size, value_size, max_entries))
+  return BpfSyscall(BPF_MAP_CREATE, attr)
+
 
 def UpdateMap(map_fd, key, value, flags=0):
   c_value = ctypes.c_uint32(value)
@@ -153,9 +174,7 @@
   value_ptr = ctypes.addressof(c_value)
   key_ptr = ctypes.addressof(c_key)
   attr = BpfAttrOps((map_fd, key_ptr, value_ptr, flags))
-  ret = libc.syscall(__NR_bpf, BPF_MAP_UPDATE_ELEM,
-                     attr.CPointer(), len(attr))
-  csocket.MaybeRaiseSocketError(ret)
+  BpfSyscall(BPF_MAP_UPDATE_ELEM, attr)
 
 
 def LookupMap(map_fd, key):
@@ -163,9 +182,7 @@
   c_key = ctypes.c_uint32(key)
   attr = BpfAttrOps(
       (map_fd, ctypes.addressof(c_key), ctypes.addressof(c_value), 0))
-  ret = libc.syscall(__NR_bpf, BPF_MAP_LOOKUP_ELEM,
-                     attr.CPointer(), len(attr))
-  csocket.MaybeRaiseSocketError(ret)
+  BpfSyscall(BPF_MAP_LOOKUP_ELEM, attr)
   return c_value
 
 
@@ -174,37 +191,44 @@
   c_next_key = ctypes.c_uint32(0)
   attr = BpfAttrOps(
       (map_fd, ctypes.addressof(c_key), ctypes.addressof(c_next_key), 0))
-  ret = libc.syscall(__NR_bpf, BPF_MAP_GET_NEXT_KEY,
-                     attr.CPointer(), len(attr))
-  csocket.MaybeRaiseSocketError(ret)
+  BpfSyscall(BPF_MAP_GET_NEXT_KEY, attr)
   return c_next_key
 
 
 def DeleteMap(map_fd, key):
   c_key = ctypes.c_uint32(key)
   attr = BpfAttrOps((map_fd, ctypes.addressof(c_key), 0, 0))
-  ret = libc.syscall(__NR_bpf, BPF_MAP_DELETE_ELEM,
-                     attr.CPointer(), len(attr))
-  csocket.MaybeRaiseSocketError(ret)
+  BpfSyscall(BPF_MAP_DELETE_ELEM, attr)
 
 
-def BpfProgLoad(prog_type, insn_ptr, prog_len, insn_len):
+def BpfProgLoad(prog_type, instructions):
+  bpf_prog = "".join(instructions)
+  insn_buff = ctypes.create_string_buffer(bpf_prog)
   gpl_license = ctypes.create_string_buffer(b"GPL")
   log_buf = ctypes.create_string_buffer(b"", LOG_SIZE)
-  attr = BpfAttrProgLoad(
-      (prog_type, prog_len / insn_len, insn_ptr, ctypes.addressof(gpl_license),
-       LOG_LEVEL, LOG_SIZE, ctypes.addressof(log_buf), 0))
-  ret = libc.syscall(__NR_bpf, BPF_PROG_LOAD, attr.CPointer(), len(attr))
-  csocket.MaybeRaiseSocketError(ret)
-  return ret
+  attr = BpfAttrProgLoad((prog_type, len(insn_buff) / len(BpfInsn),
+                          ctypes.addressof(insn_buff),
+                          ctypes.addressof(gpl_license), LOG_LEVEL,
+                          LOG_SIZE, ctypes.addressof(log_buf), 0))
+  return BpfSyscall(BPF_PROG_LOAD, attr)
 
-
+# Attach a socket eBPF filter to a target socket
 def BpfProgAttachSocket(sock_fd, prog_fd):
   prog_ptr = ctypes.c_uint32(prog_fd)
   ret = libc.setsockopt(sock_fd, socket.SOL_SOCKET, SO_ATTACH_BPF,
                         ctypes.addressof(prog_ptr), ctypes.sizeof(prog_ptr))
   csocket.MaybeRaiseSocketError(ret)
 
+# Attach a eBPF filter to a cgroup
+def BpfProgAttach(prog_fd, target_fd, prog_type):
+  attr = BpfAttrProgAttach((target_fd, prog_fd, prog_type))
+  return BpfSyscall(BPF_PROG_ATTACH, attr)
+
+# Detach a eBPF filter from a cgroup
+def BpfProgDetach(target_fd, prog_type):
+  attr = BpfAttrProgAttach((target_fd, 0, prog_type))
+  return BpfSyscall(BPF_PROG_DETACH, attr)
+
 
 # BPF program command constructors
 def BpfMov64Reg(dst, src):
@@ -275,22 +299,8 @@
   return insn1.Pack() + insn2.Pack()
 
 
-def BpfFuncLookupMap():
+def BpfFuncCall(func):
   code = BPF_JMP | BPF_CALL
   dst_src = 0
-  ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_lookup_elem))
-  return ret.Pack()
-
-
-def BpfFuncUpdateMap():
-  code = BPF_JMP | BPF_CALL
-  dst_src = 0
-  ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_update_elem))
-  return ret.Pack()
-
-
-def BpfFuncDeleteMap():
-  code = BPF_JMP | BPF_CALL
-  dst_src = 0
-  ret = BpfInsn((code, dst_src, 0, BPF_FUNC_map_delete_elem))
+  ret = BpfInsn((code, dst_src, 0, func))
   return ret.Pack()
diff --git a/net/test/bpf_test.py b/net/test/bpf_test.py
index 6d17423..33ef22d 100755
--- a/net/test/bpf_test.py
+++ b/net/test/bpf_test.py
@@ -18,14 +18,131 @@
 import errno
 import os
 import socket
+import struct
 import unittest
 
 from bpf import *  # pylint: disable=wildcard-import
 import csocket
 import net_test
+import sock_diag
 
 libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
 HAVE_EBPF_SUPPORT = net_test.LINUX_VERSION >= (4, 4, 0)
+HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0)
+KEY_SIZE = 8
+VALUE_SIZE = 4
+TOTAL_ENTRIES = 20
+# Offset to store the map key in stack register REG10
+key_offset = -8
+# Offset to store the map value in stack register REG10
+value_offset = -16
+
+# Debug usage only.
+def PrintMapInfo(map_fd):
+  # A random key that the map does not contain.
+  key = 10086
+  while 1:
+    try:
+      nextKey = GetNextKey(map_fd, key).value
+      value = LookupMap(map_fd, nextKey)
+      print repr(nextKey) + " : " + repr(value.value)
+      key = nextKey
+    except:
+      print "no value"
+      break
+
+
+# A dummy loopback function that causes a socket to send traffic to itself.
+def SocketUDPLoopBack(packet_count, version, prog_fd):
+  family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
+  sock = socket.socket(family, socket.SOCK_DGRAM, 0)
+  if prog_fd is not None:
+    BpfProgAttachSocket(sock.fileno(), prog_fd)
+  net_test.SetNonBlocking(sock)
+  addr = {4: "127.0.0.1", 6: "::1"}[version]
+  sock.bind((addr, 0))
+  addr = sock.getsockname()
+  sockaddr = csocket.Sockaddr(addr)
+  for i in xrange(packet_count):
+    sock.sendto("foo", addr)
+    data, retaddr = csocket.Recvfrom(sock, 4096, 0)
+    assert "foo" == data
+    assert sockaddr == retaddr
+  return sock
+
+
+# The main code block for eBPF packet counting program. It takes a preloaded
+# key from BPF_REG_0 and use it to look up the bpf map, if the element does not
+# exist in the map yet, the program will update the map with a new <key, 1>
+# pair. Otherwise it will jump to next code block to handle it.
+# REG0: regiter storing return value from helper function and the final return
+# value of eBPF program.
+# REG1 - REG5: temporary register used for storing values and load parameters
+# into eBPF helper function. After calling helper function, the value for these
+# registers will be reset.
+# REG6 - REG9: registers store values that will not be cleared when calling
+# eBPF helper function.
+# REG10: A stack stores values need to be accessed by the address. Program can
+# retrieve the address of a value by specifying the position of the value in
+# the stack.
+def BpfFuncCountPacketInit(map_fd):
+  key_pos = BPF_REG_7
+  insPackCountStart = [
+      # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7
+      BpfMov64Reg(key_pos, BPF_REG_10),
+      BpfAlu64Imm(BPF_ADD, key_pos, key_offset),
+      # Load map fd and look up the key in the map
+      BpfLoadMapFd(map_fd, BPF_REG_1),
+      BpfMov64Reg(BPF_REG_2, key_pos),
+      BpfFuncCall(BPF_FUNC_map_lookup_elem),
+      # if the map element already exist, jump out of this
+      # code block and let next part to handle it
+      BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10),
+      BpfLoadMapFd(map_fd, BPF_REG_1),
+      BpfMov64Reg(BPF_REG_2, key_pos),
+      # Initial a new <key, value> pair with value equal to 1 and update to map
+      BpfStMem(BPF_W, BPF_REG_10, value_offset, 1),
+      BpfMov64Reg(BPF_REG_3, BPF_REG_10),
+      BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset),
+      BpfMov64Imm(BPF_REG_4, 0),
+      BpfFuncCall(BPF_FUNC_map_update_elem)
+  ]
+  return insPackCountStart
+
+
+INS_BPF_EXIT_BLOCK = [
+    BpfMov64Imm(BPF_REG_0, 0),
+    BpfExitInsn()
+]
+
+# Bpf instruction for cgroup bpf filter to accept a packet and exit.
+INS_CGROUP_ACCEPT = [
+    # Set return value to 1 and exit.
+    BpfMov64Imm(BPF_REG_0, 1),
+    BpfExitInsn()
+]
+
+# Bpf instruction for socket bpf filter to accept a packet and exit.
+INS_SK_FILTER_ACCEPT = [
+    # Precondition: BPF_REG_6 = sk_buff context
+    # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the
+    # return value.
+    BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0),
+    BpfExitInsn()
+]
+
+# Update a existing map element with +1.
+INS_PACK_COUNT_UPDATE = [
+    # Precondition: BPF_REG_0 = Value retrieved from BPF maps
+    # Add one to the corresponding eBPF value field for a specific eBPF key.
+    BpfMov64Reg(BPF_REG_2, BPF_REG_0),
+    BpfMov64Imm(BPF_REG_1, 1),
+    BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0),
+]
+
+INS_BPF_PARAM_STORE = [
+    BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset),
+]
 
 @unittest.skipUnless(HAVE_EBPF_SUPPORT,
                      "eBPF function not fully supported")
@@ -34,130 +151,207 @@
   def setUp(self):
     self.map_fd = -1
     self.prog_fd = -1
+    self.sock = None
 
   def tearDown(self):
     if self.prog_fd >= 0:
       os.close(self.prog_fd)
     if self.map_fd >= 0:
       os.close(self.map_fd)
+    if self.sock:
+      self.sock.close()
 
   def testCreateMap(self):
     key, value = 1, 1
-    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
+    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+                            TOTAL_ENTRIES)
     UpdateMap(self.map_fd, key, value)
-    self.assertEquals(LookupMap(self.map_fd, key).value, value)
+    self.assertEquals(value, LookupMap(self.map_fd, key).value)
     DeleteMap(self.map_fd, key)
     self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
 
   def testIterateMap(self):
-    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 100)
+    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+                            TOTAL_ENTRIES)
     value = 1024
-    for key in xrange(1, 100):
+    for key in xrange(1, TOTAL_ENTRIES):
       UpdateMap(self.map_fd, key, value)
-    for key in xrange(1, 100):
-      self.assertEquals(LookupMap(self.map_fd, key).value, value)
+    for key in xrange(1, TOTAL_ENTRIES):
+      self.assertEquals(value, LookupMap(self.map_fd, key).value)
     self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101)
     key = 0
     count = 0
     while 1:
-      if count == 99:
+      if count == TOTAL_ENTRIES - 1:
         self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key)
         break
       else:
         result = GetNextKey(self.map_fd, key)
         key = result.value
         self.assertGreater(key, 0)
-        self.assertEquals(LookupMap(self.map_fd, key).value, value)
+        self.assertEquals(value, LookupMap(self.map_fd, key).value)
         count += 1
 
   def testProgLoad(self):
-    bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
-    bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
-    bpf_prog += BpfExitInsn()
-    insn_buff = ctypes.create_string_buffer(bpf_prog)
-    # Load a program that does nothing except pass every packet it receives
-    # It should not block the packet transmission otherwise the test fails.
-    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
-                          ctypes.addressof(insn_buff),
-                          len(insn_buff), BpfInsn._length)
-    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-    sock.settimeout(1)
-    BpfProgAttachSocket(sock.fileno(), self.prog_fd)
-    addr = "127.0.0.1"
-    sock.bind((addr, 0))
-    addr = sock.getsockname()
-    sockaddr = csocket.Sockaddr(addr)
-    sock.sendto("foo", addr)
-    data, addr = csocket.Recvfrom(sock, 4096, 0)
-    self.assertEqual("foo", data)
-    self.assertEqual(sockaddr, addr)
+    # Move skb to BPF_REG_6 for further usage
+    instructions = [
+        BpfMov64Reg(BPF_REG_6, BPF_REG_1)
+    ]
+    instructions += INS_SK_FILTER_ACCEPT
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+    SocketUDPLoopBack(1, 4, self.prog_fd)
+    SocketUDPLoopBack(1, 6, self.prog_fd)
 
   def testPacketBlock(self):
-    bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
-    bpf_prog += BpfMov64Imm(BPF_REG_0, 0)
-    bpf_prog += BpfExitInsn()
-    insn_buff = ctypes.create_string_buffer(bpf_prog)
-    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
-                          ctypes.addressof(insn_buff),
-                          len(insn_buff), BpfInsn._length)
-    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-    sock.settimeout(1)
-    BpfProgAttachSocket(sock.fileno(), self.prog_fd)
-    addr = "127.0.0.1"
-    sock.bind((addr, 0))
-    addr = sock.getsockname()
-    sock.sendto("foo", addr)
-    self.assertRaisesErrno(errno.EAGAIN, csocket.Recvfrom, sock, 4096, 0)
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK)
+    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd)
+    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd)
 
   def testPacketCount(self):
-    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, 4, 4, 10)
+    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+                            TOTAL_ENTRIES)
     key = 0xf0f0
-    bpf_prog = BpfMov64Reg(BPF_REG_6, BPF_REG_1)
-    bpf_prog += BpfLoadMapFd(self.map_fd, BPF_REG_1)
-    bpf_prog += BpfMov64Imm(BPF_REG_7, key)
-    bpf_prog += BpfStxMem(BPF_W, BPF_REG_10, BPF_REG_7, -4)
-    bpf_prog += BpfMov64Reg(BPF_REG_8, BPF_REG_10)
-    bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_8, -4)
-    bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8)
-    bpf_prog += BpfFuncLookupMap()
-    bpf_prog += BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10)
-    bpf_prog += BpfLoadMapFd(self.map_fd, BPF_REG_1)
-    bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_8)
-    bpf_prog += BpfStMem(BPF_W, BPF_REG_10, -8, 1)
-    bpf_prog += BpfMov64Reg(BPF_REG_3, BPF_REG_10)
-    bpf_prog += BpfAlu64Imm(BPF_ADD, BPF_REG_3, -8)
-    bpf_prog += BpfMov64Imm(BPF_REG_4, 0)
-    bpf_prog += BpfFuncUpdateMap()
-    bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
-    bpf_prog += BpfExitInsn()
-    bpf_prog += BpfMov64Reg(BPF_REG_2, BPF_REG_0)
-    bpf_prog += BpfMov64Imm(BPF_REG_1, 1)
-    bpf_prog += BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1,
-                           0, 0)
-    bpf_prog += BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0)
-    bpf_prog += BpfExitInsn()
-    insn_buff = ctypes.create_string_buffer(bpf_prog)
-    # this program loaded is used to counting the packet transmitted through
-    # a target socket. It will store the packet count into the eBPF map and we
-    # will verify if the counting result is correct.
-    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER,
-                          ctypes.addressof(insn_buff),
-                          len(insn_buff), BpfInsn._length)
-    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
-    sock.settimeout(1)
-    BpfProgAttachSocket(sock.fileno(), self.prog_fd)
-    addr = "127.0.0.1"
-    sock.bind((addr, 0))
-    addr = sock.getsockname()
-    sockaddr = csocket.Sockaddr(addr)
-    packet_count = 100
-    for i in xrange(packet_count):
-      sock.sendto("foo", addr)
-      data, retaddr = csocket.Recvfrom(sock, 4096, 0)
-      self.assertEqual("foo", data)
-      self.assertEqual(sockaddr, retaddr)
-    self.assertEquals(LookupMap(self.map_fd, key).value, packet_count)
+    # Set up instruction block with key loaded at BPF_REG_0.
+    instructions = [
+        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+        BpfMov64Imm(BPF_REG_0, key)
+    ]
+    # Concatenate the generic packet count bpf program to it.
+    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
+                     + INS_SK_FILTER_ACCEPT)
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+    packet_count = 10
+    SocketUDPLoopBack(packet_count, 4, self.prog_fd)
+    SocketUDPLoopBack(packet_count, 6, self.prog_fd)
+    self.assertEquals(packet_count * 2, LookupMap(self.map_fd, key).value)
 
+  @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+                       "BPF helper function is not fully supported")
+  def testGetSocketCookie(self):
+    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+                            TOTAL_ENTRIES)
+    # Move skb to REG6 for further usage, call helper function to get socket
+    # cookie of current skb and return the cookie at REG0 for next code block
+    instructions = [
+        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+        BpfFuncCall(BPF_FUNC_get_socket_cookie)
+    ]
+    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
+                     + INS_SK_FILTER_ACCEPT)
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+    packet_count = 10
+    def PacketCountByCookie(version):
+      self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd)
+      cookie = sock_diag.SockDiag.GetSocketCookie(self.sock)
+      self.assertEquals(packet_count, LookupMap(self.map_fd, cookie).value)
+      self.sock.close()
+    PacketCountByCookie(4)
+    PacketCountByCookie(6)
+
+  @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+                       "BPF helper function is not fully supported")
+  def testGetSocketUid(self):
+    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+                            TOTAL_ENTRIES)
+    # Set up the instruction with uid at BPF_REG_0.
+    instructions = [
+        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+        BpfFuncCall(BPF_FUNC_get_socket_uid)
+    ]
+    # Concatenate the generic packet count bpf program to it.
+    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
+                     + INS_SK_FILTER_ACCEPT)
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
+    packet_count = 10
+    uid = 12345
+    with net_test.RunAsUid(uid):
+      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
+      SocketUDPLoopBack(packet_count, 4, self.prog_fd)
+      self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+      DeleteMap(self.map_fd, uid);
+      SocketUDPLoopBack(packet_count, 6, self.prog_fd)
+      self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+
+@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
+                     "Cgroup BPF is not fully supported")
+class BpfCgroupTest(net_test.NetworkTest):
+
+  @classmethod
+  def setUpClass(cls):
+    if not os.path.isdir("/tmp"):
+      os.mkdir('/tmp')
+    os.system('mount -t cgroup2 cg_bpf /tmp')
+    cls._cg_fd = os.open('/tmp', os.O_DIRECTORY | os.O_RDONLY)
+
+  @classmethod
+  def tearDownClass(cls):
+    os.close(cls._cg_fd)
+    os.system('umount cg_bpf')
+
+  def setUp(self):
+    self.prog_fd = -1
+    self.map_fd = -1
+
+  def tearDown(self):
+    if self.prog_fd >= 0:
+      os.close(self.prog_fd)
+    if self.map_fd >= 0:
+      os.close(self.map_fd)
+    try:
+      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
+    except socket.error:
+      pass
+    try:
+      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+    except socket.error:
+      pass
+
+  def testCgroupBpfAttach(self):
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
+    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
+    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+
+  def testCgroupIngress(self):
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
+    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
+    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None)
+    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None)
+    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
+    SocketUDPLoopBack(1, 4, None)
+    SocketUDPLoopBack(1, 6, None)
+
+  def testCgroupEgress(self):
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
+    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS)
+    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None)
+    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None)
+    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
+    SocketUDPLoopBack( 1, 4, None)
+    SocketUDPLoopBack( 1, 6, None)
+
+  def testCgroupBpfUid(self):
+    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
+                            TOTAL_ENTRIES)
+    # Similar to the program used in testGetSocketUid.
+    instructions = [
+        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
+        BpfFuncCall(BPF_FUNC_get_socket_uid)
+    ]
+    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
+                     + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + INS_CGROUP_ACCEPT)
+    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions)
+    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
+    uid = os.getuid()
+    packet_count = 20
+    SocketUDPLoopBack(packet_count, 4, None)
+    self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+    DeleteMap(self.map_fd, uid);
+    SocketUDPLoopBack(packet_count, 6, None)
+    self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
+    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
 
 if __name__ == "__main__":
   unittest.main()
diff --git a/net/test/run_net_test.sh b/net/test/run_net_test.sh
index bfba4db..268c6d7 100755
--- a/net/test/run_net_test.sh
+++ b/net/test/run_net_test.sh
@@ -28,6 +28,7 @@
 OPTIONS="$OPTIONS INET6_XFRM_MODE_TRANSPORT INET6_XFRM_MODE_TUNNEL"
 OPTIONS="$OPTIONS CRYPTO_SHA256 CRYPTO_SHA512 CRYPTO_AES_X86_64"
 OPTIONS="$OPTIONS CRYPTO_ECHAINIV"
+OPTIONS="$OPTIONS SOCK_CGROUP_DATA CGROUP_BPF"
 
 # For 3.1 kernels, where devtmpfs is not on by default.
 OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT"
diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py
index 1865891..c6278e1 100755
--- a/net/test/sock_diag.py
+++ b/net/test/sock_diag.py
@@ -375,6 +375,11 @@
     sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
     return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
 
+  @staticmethod
+  def GetSocketCookie(s):
+    cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
+    return struct.unpack("=Q", cookie)[0]
+
   def FindSockInfoFromFd(self, s):
     """Gets a diag_msg and attrs from the kernel for the specified socket."""
     req = self.DiagReqFromSocket(s)