Make the bytecode tests a bit more robust.

They used to require that there were no TCP sockets on the system
at all. Now they only require that there be no established
sockets. Not a huge improvement, but it does make it possible to
write tests that leave-non established sockets around after they
terminate.

(cherry picked from commit 502ceb5ec1e40e65aee8d6acdabd6ae6ec660aef)

Change-Id: Ied6f5aae3b6cf4a5bd25aa4fbeac637010e1f0e8
diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py
index 8a84ca2..9e4a22d 100755
--- a/net/test/sock_diag.py
+++ b/net/test/sock_diag.py
@@ -227,7 +227,7 @@
 
     return packed
 
-  def Dump(self, diag_req, bytecode=""):
+  def Dump(self, diag_req, bytecode):
     out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
     return out
 
@@ -305,7 +305,7 @@
     return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
 
   def FindSockDiagFromReq(self, req):
-    for diag_msg, attrs in self.Dump(req):
+    for diag_msg, attrs in self.Dump(req, ""):
       return diag_msg
     raise ValueError("Dump of %s returned no sockets" % req)
 
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 3c5d0a9..7871f34 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -174,11 +174,14 @@
         "0508100000006566"
         "00040400"
     )
+    states = 1 << tcp_test.TCP_ESTABLISHED
     self.assertMultiLineEqual(expected, bytecode.encode("hex"))
     self.assertEquals(76, len(bytecode))
     self.socketpairs = self._CreateLotsOfSockets()
-    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
-    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
+    filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
+                                                        states=states)
+    allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
+                                                   states=states)
     self.assertItemsEqual(allsockets, filteredsockets)
 
     # Pick a few sockets in hash table order, and check that the bytecode we
@@ -210,7 +213,9 @@
     # TODO: this is only here because the test fails if there are any open
     # sockets other than the ones it creates itself. Make the bytecode more
     # specific and remove it.
-    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, ""))
+    states = 1 << tcp_test.TCP_ESTABLISHED
+    self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "",
+                                                       states=states))
 
     unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
     unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
@@ -221,24 +226,28 @@
         (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
 
     # IPv4/v6 filters must never match IPv6/IPv4 sockets...
-    v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
-    self.assertTrue(v4sockets)
-    self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
+    v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
+                                                  states=states)
+    self.assertTrue(v4socks)
+    self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
 
-    v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
-    self.assertTrue(v6sockets)
-    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
+    v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
+                                                  states=states)
+    self.assertTrue(v6socks)
+    self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
 
     # Except for mapped addresses, which match both IPv4 and IPv6.
     pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
                                       "::ffff:127.0.0.1")
     diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
-    v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
-                                                                 bytecode4)]
-    v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
-                                                                 bytecode6)]
-    self.assertTrue(all(d in v4sockets for d in diag_msgs))
-    self.assertTrue(all(d in v6sockets for d in diag_msgs))
+    v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
+                                                               bytecode4,
+                                                               states=states)]
+    v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
+                                                               bytecode6,
+                                                               states=states)]
+    self.assertTrue(all(d in v4socks for d in diag_msgs))
+    self.assertTrue(all(d in v6socks for d in diag_msgs))
 
   def testPortComparisonValidation(self):
     """Checks for a bug in validating port comparison bytecode.