Add code and tests for inet_diag bytecode.

Change-Id: I02af43151cf14905cc762455f282cb7fa5a1b003
diff --git a/net/test/netlink.py b/net/test/netlink.py
index 479d3a4..2b8f744 100644
--- a/net/test/netlink.py
+++ b/net/test/netlink.py
@@ -79,11 +79,12 @@
   def _GetConstantName(self, module, value, prefix):
     thismodule = sys.modules[module]
     for name in dir(thismodule):
+      if name.startswith("INET_DIAG_BC"):
+        break
       if (name.startswith(prefix) and
           not name.startswith(prefix + "F_") and
-          name.isupper() and
-          getattr(thismodule, name) == value):
-        return name
+          name.isupper() and getattr(thismodule, name) == value):
+          return name
     return value
 
   def _Decode(self, command, msg, nla_type, nla_data):
diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py
index cfce751..5cb83cf 100755
--- a/net/test/sock_diag.py
+++ b/net/test/sock_diag.py
@@ -20,6 +20,7 @@
 
 import errno
 from socket import *  # pylint: disable=wildcard-import
+import struct
 
 import cstruct
 import net_test
@@ -37,6 +38,9 @@
 # Message types.
 TCPDIAG_GETSOCK = 18
 
+# Request attributes.
+INET_DIAG_REQ_BYTECODE = 1
+
 # Extensions.
 INET_DIAG_NONE = 0
 INET_DIAG_MEMINFO = 1
@@ -49,6 +53,17 @@
 INET_DIAG_SHUTDOWN = 8
 INET_DIAG_DCTCPINFO = 9
 
+# Bytecode operations.
+INET_DIAG_BC_NOP = 0
+INET_DIAG_BC_JMP = 1
+INET_DIAG_BC_S_GE = 2
+INET_DIAG_BC_S_LE = 3
+INET_DIAG_BC_D_GE = 4
+INET_DIAG_BC_D_LE = 5
+INET_DIAG_BC_AUTO = 6
+INET_DIAG_BC_S_COND = 7
+INET_DIAG_BC_D_COND = 8
+
 # Data structure formats.
 # These aren't constants, they're classes. So, pylint: disable=invalid-name
 InetDiagSockId = cstruct.Struct(
@@ -62,6 +77,9 @@
     [InetDiagSockId])
 InetDiagMeminfo = cstruct.Struct(
     "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
+InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
+InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
+                                  "family prefix_len port")
 
 SkMeminfo = cstruct.Struct(
     "SkMeminfo", "=IIIIIIII",
@@ -133,11 +151,94 @@
   def _EmptyInetDiagSockId():
     return InetDiagSockId(("\x00" * len(InetDiagSockId)))
 
-  def Dump(self, diag_req):
-    out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, "")
+  def PackBytecode(self, instructions):
+    """Compiles instructions to inet_diag bytecode.
+
+    The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
+    and no are relative jump offsets measured in instructions. The yes branch
+    is taken if the instruction matches.
+
+    To accept, jump 1 past the last instruction. To reject, jump 2 past the
+    last instruction.
+
+    The target of a no jump is only valid if it is reachable by following
+    only yes jumps from the first instruction - see inet_diag_bc_audit and
+    valid_cc. This means that if cond1 and cond2 are two mutually exclusive
+    filter terms, it is not possible to implement cond1 OR cond2 using:
+
+      ...
+      cond1 2 1 arg
+      cond2 1 2 arg
+      accept
+      reject
+
+    but only using:
+
+      ...
+      cond1 1 2 arg
+      jmp   1 2
+      cond2 1 2 arg
+      accept
+      reject
+
+    The jmp instruction ignores yes and always jumps to no, but yes must be 1
+    or the bytecode won't validate. It doesn't have to be jmp - any instruction
+    that is guaranteed not to match on real data will do.
+
+    Args:
+      instructions: list of instruction tuples
+
+    Returns:
+      A string, the raw bytecode.
+    """
+    args = []
+    positions = [0]
+
+    for op, yes, no, arg in instructions:
+
+      if yes <= 0 or no <= 0:
+        raise ValueError("Jumps must be > 0")
+
+      if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
+        arg = ""
+      elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
+                  INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
+        arg = "\x00\x00" + struct.pack("=H", arg)
+      elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
+        addr, prefixlen, port = arg
+        family = AF_INET6 if ":" in addr else AF_INET
+        addr = inet_pton(family, addr)
+        arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
+      else:
+        raise ValueError("Unsupported opcode %d" % op)
+
+      args.append(arg)
+      length = len(InetDiagBcOp) + len(arg)
+      positions.append(positions[-1] + length)
+
+    # Reject label.
+    positions.append(positions[-1] + 4)  # Why 4? Because the kernel uses 4.
+    assert len(args) == len(instructions) == len(positions) - 2
+
+    # print positions
+
+    packed = ""
+    for i, (op, yes, no, arg) in enumerate(instructions):
+      yes = positions[i + yes] - positions[i]
+      no = positions[i + no] - positions[i]
+      instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
+      #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
+      #                                 arg, instruction.encode("hex"))
+      packed += instruction
+    #print
+
+    return packed
+
+  def Dump(self, diag_req, bytecode=""):
+    out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
     return out
 
-  def DumpAllInetSockets(self, protocol, sock_id=None, ext=0,
+  def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
                          states=ALL_NON_TIME_WAIT):
     """Dumps IPv4 or IPv6 sockets matching the specified parameters."""
     # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
@@ -145,10 +246,13 @@
     if sock_id is None:
       sock_id = self._EmptyInetDiagSockId()
 
+    if bytecode:
+      bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
+
     sockets = []
     for family in [AF_INET, AF_INET6]:
       diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
-      sockets += self.Dump(diag_req)
+      sockets += self.Dump(diag_req, bytecode)
 
     return sockets
 
@@ -255,6 +359,6 @@
   sock_id.dport = 443
   ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
   states = 0xffffffff
-  diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP,
+  diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
                                    sock_id=sock_id, ext=ext, states=states)
   print diag_msgs
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 97df3bf..b13befe 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -31,6 +31,7 @@
 
 
 NUM_SOCKETS = 100
+NO_BYTECODE = ""
 
 # TODO: Backport SOCK_DESTROY and delete this.
 HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
@@ -115,7 +116,7 @@
 
   def testFindsAllMySockets(self):
     self.socketpairs = self._CreateLotsOfSockets()
-    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP)
+    sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
     self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
 
     # Find the cookies for all of our sockets.
@@ -149,6 +150,54 @@
         diag_msg = self.sock_diag.GetSockDiag(req)
         self.assertSockDiagMatchesSocket(sock, diag_msg)
 
+  def testBytecodeCompilation(self):
+    instructions = [
+        (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
+        (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
+        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
+        (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
+        (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
+        (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
+        (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
+                                                                       # 76 acc
+                                                                       # 80 rej
+    ]
+    bytecode = self.sock_diag.PackBytecode(instructions)
+    expected = (
+        "0208500000000000"
+        "050848000000ffff"
+        "071c20000a800000ffffffff00000000000000000000000000000001"
+        "01041c00"
+        "0718200002200000ffffffff7f000001"
+        "0508100000006566"
+        "00040400"
+    )
+    self.assertMultiLineEqual(expected, bytecode.encode("hex"))
+    self.assertEquals(76, len(bytecode))
+    self.socketpairs = self._CreateLotsOfSockets()
+    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
+    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
+    self.assertEquals(len(allsockets), len(filteredsockets))
+
+    # Pick a few sockets in hash table order, and check that the bytecode we
+    # compiled selects them properly.
+    for socketpair in self.socketpairs.values()[:20]:
+      for s in socketpair:
+        diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+        instructions = [
+            (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
+            (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
+            (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
+            (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
+        ]
+        bytecode = self.sock_diag.PackBytecode(instructions)
+        self.assertEquals(32, len(bytecode))
+        sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
+        self.assertEquals(1, len(sockets))
+
+        # TODO: why doesn't comparing the cstructs work?
+        self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
+
   @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testClosesSockets(self):
     self.socketpairs = self._CreateLotsOfSockets()
@@ -356,7 +405,7 @@
     req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
     req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
     req.id.cookie = "\x00" * 8
-    children = self.sock_diag.Dump(req)
+    children = self.sock_diag.Dump(req, NO_BYTECODE)
     return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
             for d, _ in children]
 
@@ -486,7 +535,7 @@
     sock_id.sport = self.port
     states = 1 << sock_diag.TCP_SYN_RECV
     req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
-    children = self.sock_diag.Dump(req)
+    children = self.sock_diag.Dump(req, NO_BYTECODE)
 
     self.assertTrue(children)
     for child, unused_args in children: