Merge "Test more transforms and more socket types."
diff --git a/tests/tests/net/src/android/net/cts/IpSecManagerTest.java b/tests/tests/net/src/android/net/cts/IpSecManagerTest.java
index 6348d04..75bc528 100644
--- a/tests/tests/net/src/android/net/cts/IpSecManagerTest.java
+++ b/tests/tests/net/src/android/net/cts/IpSecManagerTest.java
@@ -24,16 +24,24 @@
 import android.os.ParcelFileDescriptor;
 import android.system.ErrnoException;
 import android.system.Os;
-import android.system.OsConstants;
 import android.test.AndroidTestCase;
+import java.net.ServerSocket;
+import android.util.Log;
+
+import java.io.ByteArrayOutputStream;
 import java.io.FileDescriptor;
 import java.net.DatagramSocket;
 import java.net.InetAddress;
+import java.net.Inet6Address;
 import java.net.InetSocketAddress;
-import java.net.ServerSocket;
+import java.net.Socket;
 import java.net.UnknownHostException;
 import java.util.Arrays;
 
+import android.system.OsConstants;
+import static android.system.OsConstants.IPPROTO_TCP;
+import static android.system.OsConstants.IPPROTO_UDP;
+
 public class IpSecManagerTest extends AndroidTestCase {
 
     private static final String TAG = IpSecManagerTest.class.getSimpleName();
@@ -73,6 +81,9 @@
         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7F
     };
 
+    private static final String IPV4_LOOPBACK = "127.0.0.1";
+    private static final String IPV6_LOOPBACK = "::1";
+
     protected void setUp() throws Exception {
         super.setUp();
         mCM = (ConnectivityManager) getContext().getSystemService(Context.CONNECTIVITY_SERVICE);
@@ -112,6 +123,109 @@
         }
     }
 
+    private byte[] getAuthKey(int bitLength) {
+        return Arrays.copyOf(AUTH_KEY, bitLength / 8);
+    }
+
+    private static int getDomain(InetAddress address) {
+        int domain;
+        if (address instanceof Inet6Address)
+            domain = OsConstants.AF_INET6;
+        else
+            domain = OsConstants.AF_INET;
+        return domain;
+    }
+
+    /** This function finds an available port */
+    private static int findUnusedPort() throws Exception {
+        // Get an available port.
+        DatagramSocket s = new DatagramSocket();
+        int port = s.getLocalPort();
+        s.close();
+        return port;
+    }
+
+    private static FileDescriptor getBoundUdpSocket(InetAddress address) throws Exception {
+        FileDescriptor sock =
+                Os.socket(getDomain(address), OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP);
+
+        for (int i = 0; i < MAX_PORT_BIND_ATTEMPTS; i++) {
+            try {
+                int port = findUnusedPort();
+                Os.bind(sock, address, port);
+                break;
+            } catch (ErrnoException e) {
+                // Someone claimed the port since we called findUnusedPort.
+                if (e.errno == OsConstants.EADDRINUSE) {
+                    if (i == MAX_PORT_BIND_ATTEMPTS - 1) {
+
+                        fail("Failed " + MAX_PORT_BIND_ATTEMPTS + " attempts to bind to a port");
+                    }
+                    continue;
+                }
+                throw e.rethrowAsIOException();
+            }
+        }
+        return sock;
+    }
+
+    private void checkUnconnectedUdp(IpSecTransform transform, InetAddress local) throws Exception {
+        FileDescriptor udpSocket = getBoundUdpSocket(local);
+        int localPort = getPort(udpSocket);
+
+        mISM.applyTransportModeTransform(udpSocket, transform);
+        byte[] data = new String("Best test data ever! Port: " + localPort).getBytes("UTF-8");
+
+        byte[] in = new byte[data.length];
+        Os.sendto(udpSocket, data, 0, data.length, 0, local, localPort);
+        Os.read(udpSocket, in, 0, in.length);
+        assertTrue("Encapsulated data did not match.", Arrays.equals(data, in));
+        mISM.removeTransportModeTransform(udpSocket, transform);
+        Os.close(udpSocket);
+    }
+
+    private void checkTcp(IpSecTransform transform, InetAddress local) throws Exception {
+        FileDescriptor server =
+            Os.socket(getDomain(local), OsConstants.SOCK_STREAM, IPPROTO_TCP);
+
+        FileDescriptor client =
+            Os.socket(getDomain(local), OsConstants.SOCK_STREAM, IPPROTO_TCP);
+
+        Os.bind(server, local, 0);
+        int port = ((InetSocketAddress) Os.getsockname(server)).getPort();
+
+        mISM.applyTransportModeTransform(client, transform);
+        mISM.applyTransportModeTransform(server, transform);
+
+        Os.listen(server, 10);
+        Os.connect(client, local, port);
+        FileDescriptor accepted = Os.accept(server, null);
+
+        mISM.applyTransportModeTransform(accepted, transform);
+
+        byte[] data = new String("Best test data ever! Port: " + port).getBytes("UTF-8");
+        byte[] in = new byte[data.length];
+
+        Os.write(client, data, 0, data.length);
+        Os.read(accepted, in, 0, in.length);
+        assertTrue("Client-to-server encrypted data did not match.", Arrays.equals(data, in));
+
+        data = new String("Best test data received !!!").getBytes("UTF-8");
+        in = new byte[data.length];
+
+        Os.write(accepted, data, 0, data.length);
+        Os.read(client, in, 0, in.length);
+        assertTrue("Server-to-client encrypted data did not match.", Arrays.equals(data, in));
+
+        mISM.removeTransportModeTransform(server, transform);
+        mISM.removeTransportModeTransform(client, transform);
+        mISM.removeTransportModeTransform(accepted, transform);
+
+        Os.close(server);
+        Os.close(client);
+        Os.close(accepted);
+    }
+
     /*
      * Alloc outbound SPI
      * Alloc inbound SPI
@@ -174,6 +288,181 @@
         transform.close();
     }
 
+    public void checkTransform(int protocol, String localAddress,
+            IpSecAlgorithm crypt, IpSecAlgorithm auth) throws Exception {
+        InetAddress local = InetAddress.getByName(localAddress);
+
+        IpSecManager.SecurityParameterIndex outSpi =
+                mISM.reserveSecurityParameterIndex(IpSecTransform.DIRECTION_OUT, local);
+
+        IpSecManager.SecurityParameterIndex inSpi =
+                mISM.reserveSecurityParameterIndex(
+                        IpSecTransform.DIRECTION_IN, local, outSpi.getSpi());
+
+        IpSecTransform transform =
+                new IpSecTransform.Builder(mContext)
+                        .setSpi(IpSecTransform.DIRECTION_OUT, outSpi)
+                        .setEncryption(IpSecTransform.DIRECTION_OUT,crypt)
+                        .setAuthentication(IpSecTransform.DIRECTION_OUT, auth)
+                        .setSpi(IpSecTransform.DIRECTION_IN, inSpi)
+                        .setEncryption(IpSecTransform.DIRECTION_IN, crypt)
+                        .setAuthentication(IpSecTransform.DIRECTION_IN, auth)
+                        .buildTransportModeTransform(local);
+
+        if (protocol == IPPROTO_TCP) {
+            checkTcp(transform, local);
+        } else if (protocol == IPPROTO_UDP) {
+            // TODO: Also check connected udp.
+            checkUnconnectedUdp(transform, local);
+        } else {
+            throw new IllegalArgumentException("Invalid protocol");
+        }
+
+        transform.close();
+        outSpi.close();
+        inSpi.close();
+    }
+
+    public void testAesCbcHmacMd5Tcp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_MD5, getAuthKey(256), 128);
+        checkTransform(IPPROTO_TCP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacMd5Tcp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_MD5, getAuthKey(256), 128);
+        checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacMd5Udp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_MD5, getAuthKey(256), 128);
+        checkTransform(IPPROTO_UDP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacMd5Udp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_MD5, getAuthKey(256), 128);
+        checkTransform(IPPROTO_UDP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha1Tcp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA1, getAuthKey(256), 128);
+        checkTransform(IPPROTO_TCP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha1Tcp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA1, getAuthKey(256), 128);
+        checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha1Udp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA1, getAuthKey(256), 128);
+        checkTransform(IPPROTO_UDP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha1Udp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA1, getAuthKey(256), 128);
+        checkTransform(IPPROTO_UDP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha256Tcp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA256, getAuthKey(256), 128);
+        checkTransform(IPPROTO_TCP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha256Tcp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA256, getAuthKey(256), 128);
+        checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha256Udp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA256, getAuthKey(256), 128);
+        checkTransform(IPPROTO_UDP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha256Udp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA256, getAuthKey(256), 128);
+        checkTransform(IPPROTO_UDP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha384Tcp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA384, getAuthKey(384), 192);
+        checkTransform(IPPROTO_TCP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha384Tcp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA384, getAuthKey(384), 192);
+        checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha384Udp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA384, getAuthKey(384), 192);
+        checkTransform(IPPROTO_UDP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha384Udp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA384, getAuthKey(384), 192);
+        checkTransform(IPPROTO_UDP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha512Tcp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA512, getAuthKey(512), 256);
+        checkTransform(IPPROTO_TCP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha512Tcp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA512, getAuthKey(512), 256);
+        checkTransform(IPPROTO_TCP, IPV6_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha512Udp4() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA512, getAuthKey(512), 256);
+        checkTransform(IPPROTO_UDP, IPV4_LOOPBACK, crypt, auth);
+    }
+
+    public void testAesCbcHmacSha512Udp6() throws Exception {
+        IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
+        IpSecAlgorithm auth = new IpSecAlgorithm(
+                IpSecAlgorithm.AUTH_HMAC_SHA512, getAuthKey(512), 256);
+        checkTransform(IPPROTO_UDP, IPV6_LOOPBACK, crypt, auth);
+    }
+
     public void testOpenUdpEncapSocketSpecificPort() throws Exception {
         IpSecManager.UdpEncapsulationSocket encapSocket = null;
         int port = -1;
@@ -226,7 +515,7 @@
             // Create user socket, apply transform to it
             FileDescriptor udpSocket = null;
             try {
-                udpSocket = getTestV4UdpSocket(local);
+                udpSocket = getBoundUdpSocket(local);
                 int port = getPort(udpSocket);
 
                 mISM.applyTransportModeTransform(udpSocket, transform);
@@ -285,7 +574,7 @@
             FileDescriptor sock = null;
 
             try {
-                sock = getTestV4UdpSocket(local);
+                sock = getBoundUdpSocket(local);
                 int port = getPort(sock);
 
                 mISM.applyTransportModeTransform(sock, transform);
@@ -336,15 +625,6 @@
         }
     }
 
-    /** This function finds an available port */
-    private static int findUnusedPort() throws Exception {
-        // Get an available port.
-        ServerSocket s = new ServerSocket(0);
-        int port = s.getLocalPort();
-        s.close();
-        return port;
-    }
-
     private static IpSecTransform buildIpSecTransform(
             Context mContext,
             IpSecManager.SecurityParameterIndex inSpi,
@@ -376,28 +656,4 @@
     private static int getPort(FileDescriptor sock) throws Exception {
         return ((InetSocketAddress) Os.getsockname(sock)).getPort();
     }
-
-    private static FileDescriptor getTestV4UdpSocket(InetAddress v4Addr) throws Exception {
-        FileDescriptor sock =
-                Os.socket(OsConstants.AF_INET, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP);
-
-        for (int i = 0; i < MAX_PORT_BIND_ATTEMPTS; i++) {
-            try {
-                int port = findUnusedPort();
-                Os.bind(sock, v4Addr, port);
-                break;
-            } catch (ErrnoException e) {
-                // Someone claimed the port since we called findUnusedPort.
-                if (e.errno == OsConstants.EADDRINUSE) {
-                    if (i == MAX_PORT_BIND_ATTEMPTS - 1) {
-
-                        fail("Failed " + MAX_PORT_BIND_ATTEMPTS + " attempts to bind to a port");
-                    }
-                    continue;
-                }
-                throw e.rethrowAsIOException();
-            }
-        }
-        return sock;
-    }
 }