Test for a SOCK_DIAG oops on IPv4-mapped SYN_RECV connections.

Change-Id: Ib091831cefd140161b020d9801bc7b1fa1e1ea76
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 5975931..2ca1bb0 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -252,12 +252,25 @@
                        self.last_packet)
 
   def IncomingConnection(self, version, end_state, netid):
+    if version == 5:
+      mapped = True
+      socket_version = 6
+      version = 4
+    else:
+      socket_version = version
+      mapped = False
+
     self.version = version
-    self.s = self.OpenListenSocket(version)
+    self.s = self.OpenListenSocket(socket_version)
     self.end_state = end_state
 
-    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
-    myaddr = self.myaddr = self.MyAddress(version, netid)
+    def MaybeMappedAddress(addr):
+      return "::ffff:%s" % addr if mapped else addr
+
+    remoteaddr = self.remoteaddr = MaybeMappedAddress(
+        self.GetRemoteAddress(version))
+    myaddr = self.myaddr = MaybeMappedAddress(
+        self.MyAddress(version, netid))
 
     if end_state == sock_diag.TCP_LISTEN:
       return
@@ -427,7 +440,7 @@
   @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testAcceptInterrupted(self):
     """Tests that accept() is interrupted by SOCK_DESTROY."""
-    for version in [4, 6]:
+    for version in [4, 5, 6]:
       self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
       self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
       self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
@@ -436,7 +449,7 @@
   @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testReadInterrupted(self):
     """Tests that read() is interrupted by SOCK_DESTROY."""
-    for version in [4, 6]:
+    for version in [4, 5, 6]:
       self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
       self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
                                    ECONNABORTED)
@@ -445,7 +458,7 @@
   @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testConnectInterrupted(self):
     """Tests that connect() is interrupted by SOCK_DESTROY."""
-    for version in [4, 6]:
+    for version in [4, 5, 6]:
       family = {4: AF_INET, 6: AF_INET6}[version]
       s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
       self.SelectInterface(s, self.netid, "mark")
@@ -460,6 +473,28 @@
       msg = "SOCK_DESTROY of socket in connect, expected no RST"
       self.ExpectNoPacketsOn(self.netid, msg)
 
+  def testIpv4MappedSynRecvSocket(self):
+    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
+
+    Relevant kernel commits:
+         android-3.4:
+           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
+    """
+    self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
+    sock_id = self.sock_diag._EmptyInetDiagSockId()
+    sock_id.sport = self.port
+    states = 1 << sock_diag.TCP_SYN_RECV
+    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
+    children = self.sock_diag.Dump(req)
+
+    self.assertTrue(children)
+    for child, unused_args in children:
+      self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
+      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
+                       child.id.dst)
+      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
+                       child.id.src)
+
 
 if __name__ == "__main__":
   unittest.main()