Merge "Make fewer assumptions in AndroidKeyStoreTest." into mnc-dev
diff --git a/tests/tests/keystore/src/android/keystore/cts/AndroidKeyStoreTest.java b/tests/tests/keystore/src/android/keystore/cts/AndroidKeyStoreTest.java
index 577e830..d3f4ccd 100644
--- a/tests/tests/keystore/src/android/keystore/cts/AndroidKeyStoreTest.java
+++ b/tests/tests/keystore/src/android/keystore/cts/AndroidKeyStoreTest.java
@@ -36,9 +36,8 @@
 import java.security.PublicKey;
 import java.security.cert.Certificate;
 import java.security.cert.CertificateFactory;
-import java.security.interfaces.ECPrivateKey;
-import java.security.interfaces.ECPublicKey;
-import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.ECKey;
+import java.security.interfaces.RSAKey;
 import java.security.spec.PKCS8EncodedKeySpec;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -1090,14 +1089,18 @@
         final PrivateKey privKey = keyEntry.getPrivateKey();
         final PublicKey pubKey = keyEntry.getCertificate().getPublicKey();
 
-        if (expectedKey instanceof ECPrivateKey) {
+        if (expectedKey instanceof ECKey) {
+            assertTrue("Returned PrivateKey " + privKey.getClass() + " should be instanceof ECKey",
+                    privKey instanceof ECKey);
             assertEquals("Returned PrivateKey should be what we inserted",
-                    ((ECPrivateKey) expectedKey).getParams().getCurve(),
-                    ((ECPublicKey) pubKey).getParams().getCurve());
-        } else if (expectedKey instanceof RSAPrivateKey) {
+                    ((ECKey) expectedKey).getParams().getCurve(),
+                    ((ECKey) privKey).getParams().getCurve());
+        } else if (expectedKey instanceof RSAKey) {
+            assertTrue("Returned PrivateKey " + privKey.getClass() + " should be instanceof RSAKey",
+                    privKey instanceof RSAKey);
             assertEquals("Returned PrivateKey should be what we inserted",
-                    ((RSAPrivateKey) expectedKey).getModulus(),
-                    ((RSAPrivateKey) privKey).getModulus());
+                    ((RSAKey) expectedKey).getModulus(),
+                    ((RSAKey) privKey).getModulus());
         }
 
         assertNull("getFormat() should return null", privKey.getFormat());
@@ -1143,15 +1146,16 @@
         Key key = mKeyStore.getKey(TEST_ALIAS_1, null);
         assertNotNull("Key should exist", key);
 
-        assertTrue("Should be a RSAPrivateKey", key instanceof RSAPrivateKey);
+        assertTrue("Should be a PrivateKey", key instanceof PrivateKey);
+        assertTrue("Should be a RSAKey", key instanceof RSAKey);
 
-        RSAPrivateKey actualKey = (RSAPrivateKey) key;
+        RSAKey actualKey = (RSAKey) key;
 
         KeyFactory keyFact = KeyFactory.getInstance("RSA");
         PrivateKey expectedKey = keyFact.generatePrivate(new PKCS8EncodedKeySpec(FAKE_RSA_KEY_1));
 
         assertEquals("Inserted key should be same as retrieved key",
-                ((RSAPrivateKey) expectedKey).getModulus(), actualKey.getModulus());
+                ((RSAKey) expectedKey).getModulus(), actualKey.getModulus());
     }
 
     public void testKeyStore_GetKey_Certificate_Unencrypted_Failure() throws Exception {
@@ -1893,14 +1897,17 @@
 
         kpg.generateKeyPair();
 
-        RSAPrivateKey key = (RSAPrivateKey) ks.getKey(alias, null);
-        assertNotNull(key);
-        String cipher = key.getAlgorithm() + "/NONE/NOPADDING";
+        PrivateKey privateKey = (PrivateKey) ks.getKey(alias, null);
+        assertNotNull(privateKey);
+        PublicKey publicKey = ks.getCertificate(alias).getPublicKey();
+        assertNotNull(publicKey);
+        String cipher = privateKey.getAlgorithm() + "/NONE/NOPADDING";
         Cipher encrypt = Cipher.getInstance(cipher);
         assertNotNull(encrypt);
-        encrypt.init(Cipher.ENCRYPT_MODE, key);
+        encrypt.init(Cipher.ENCRYPT_MODE, privateKey);
 
-        byte[] plainText = new byte[encrypt.getBlockSize()];
+        int modulusSizeBytes = (((RSAKey) publicKey).getModulus().bitLength() + 7) / 8;
+        byte[] plainText = new byte[modulusSizeBytes];
         Arrays.fill(plainText, (byte) 0xFF);
 
         // We expect a BadPaddingException here as the message size (plaintext)