Snap for 4742838 from 0146dc63d10d2f5d904d293eb0942f04d1ea314b to pi-release

Change-Id: I5bc6ea8ad0c8d92ecf2481bf65a50ffb1a96a2d5
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 9419587..e25035b 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -457,7 +457,7 @@
   def run(self):
     try:
       self.operation(self.sock)
-    except IOError, e:
+    except (IOError, AssertionError), e:
       self.exception = e
 
 
@@ -712,14 +712,26 @@
     super(PollOnCloseTest, self).setUp()
     self.netid = random.choice(self.tuns.keys())
 
-  def BlockingPoll(self, sock, mask, expected_event):
+  POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
+                (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
+
+  def PollResultToString(self, poll_events, ignoremask):
+    out = []
+    for fd, event in poll_events:
+      flags = [name for (flag, name) in self.POLL_FLAGS
+               if event & flag & ~ignoremask != 0]
+      out.append((fd, "|".join(flags)))
+    return out
+
+  def BlockingPoll(self, sock, mask, expected, ignoremask):
     p = select.poll()
     p.register(sock, mask)
-    expected_fds = [(sock.fileno(), expected_event)]
+    expected_fds = [(sock.fileno(), expected)]
     # Don't block forever or we'll hang continuous test runs on failure.
     # A 5-second timeout should be long enough not to be flaky.
     actual_fds = p.poll(5000)
-    self.assertEqual(expected_fds, actual_fds)
+    self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
+                     self.PollResultToString(actual_fds, ignoremask))
 
   def RstDuringBlockingCall(self, sock, call, expected_errno):
     self._EventDuringBlockingCall(
@@ -735,43 +747,53 @@
     self.assertEquals("", self.accepted.recv(4096))
     self.assertEquals("", self.accepted.recv(4096))
 
-  def CheckPollDestroy(self, mask, expected_event):
+  def CheckPollDestroy(self, mask, expected, ignoremask):
     """Interrupts a poll() with SOCK_DESTROY."""
     for version in [4, 5, 6]:
       self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
       self.CloseDuringBlockingCall(
           self.accepted,
-          lambda sock: self.BlockingPoll(sock, mask, expected_event),
+          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
           None)
       self.assertSocketErrors(ECONNABORTED)
 
-  def CheckPollRst(self, mask, expected_event):
+  def CheckPollRst(self, mask, expected, ignoremask):
     """Interrupts a poll() by receiving a TCP RST."""
     for version in [4, 5, 6]:
       self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
       self.RstDuringBlockingCall(
           self.accepted,
-          lambda sock: self.BlockingPoll(sock, mask, expected_event),
+          lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
           None)
       self.assertSocketErrors(ECONNRESET)
 
   def testReadPollRst(self):
-    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP)
+    # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll()
+    # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This
+    # is due to a race inside the kernel and thus is not visible on the VM, only
+    # on physical hardware.
+    if net_test.LINUX_VERSION < (4, 14, 0):
+      ignoremask = select.POLLIN | select.POLLHUP
+    else:
+      ignoremask = 0
+    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
 
   def testWritePollRst(self):
-    self.CheckPollRst(select.POLLOUT, select.POLLOUT)
+    self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
 
   def testReadWritePollRst(self):
-    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT)
+    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
 
   def testReadPollDestroy(self):
-    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP)
+    # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
+    ignoremask = select.POLLIN | select.POLLHUP
+    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
 
   def testWritePollDestroy(self):
-    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT)
+    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
 
   def testReadWritePollDestroy(self):
-    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT)
+    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
 
 
 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")