Test for a cross-family bytecode comparison bug.

Change-Id: I251088dc09d803a7448930cd155fc3a1c6c5bddf
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index b13befe..1baafbd 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -198,6 +198,41 @@
         # TODO: why doesn't comparing the cstructs work?
         self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
 
+  def testCrossFamilyBytecode(self):
+    """Checks for a cross-family bug in inet_diag_hostcond matching.
+
+    Relevant kernel commits:
+      android-3.4:
+        f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
+    """
+    pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
+    pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
+
+    bytecode4 = self.sock_diag.PackBytecode([
+        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
+    bytecode6 = self.sock_diag.PackBytecode([
+        (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
+
+    # IPv4/v6 filters must never match IPv6/IPv4 sockets...
+    v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
+    self.assertTrue(v4sockets)
+    self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
+
+    v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
+    self.assertTrue(v6sockets)
+    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
+
+    # Except for mapped addresses, which match both IPv4 and IPv6.
+    pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
+                                      "::ffff:127.0.0.1")
+    diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
+    v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
+                                                                 bytecode4)]
+    v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
+                                                                 bytecode6)]
+    self.assertTrue(all(d in v4sockets for d in diag_msgs))
+    self.assertTrue(all(d in v6sockets for d in diag_msgs))
+
   @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testClosesSockets(self):
     self.socketpairs = self._CreateLotsOfSockets()