Fix RSA upcalls from TLS/SSL into JCA.

When BoringSSL/OpenSSL TLS/SSL stack operates on opaque private keys
(those that don't expose their key material) it upcalls (via
Conscrypt's NativeCrypto) into corresponding JCA Signature and Cipher
primitives.

This CL fixes two issues with RSA-related upcalls, which prevented
the use of opaque RSA private keys for TLS/SSL with Conscrypt backed
by BoringSSL:
* RSA sign was upcalled into RSA Cipher decrypt using private key.
  In JCA, the correct upcall is RSA Signature sign. This is now
  invoked instead of RSA Cipher decrypt.
* RSA decrypt was not implemented. It's now implemented.

As part of implementing RSA decrypt upcall from BoringSSL, it
transpired that BoringSSL requests no padding as opposed to OpenSSL
which requests PKCS#1 padding. As a result, this CL modifies the
decrypt upcall to take a padding parameter. The implementation of
the upcall (see CryptoUpcalls.java) now supports PKCS#1 padding
scheme, OAEP padding scheme, and no padding.

This CL also drops the encrypt/decrypt flag from the RSA
encrypt/decrypt upcall and simplies it into an RSA decrypt upcall. RSA
encrypt upcall is not needed at all.

Bug: 21738458
Change-Id: I2a4610890ea1ed1a2e99eb1d5c34348fbf406e54
diff --git a/src/gen/native/generate_constants.cc b/src/gen/native/generate_constants.cc
index ba3d5d4..4d1b3ab 100644
--- a/src/gen/native/generate_constants.cc
+++ b/src/gen/native/generate_constants.cc
@@ -73,6 +73,7 @@
 
   CONST(RSA_PKCS1_PADDING);
   CONST(RSA_NO_PADDING);
+  CONST(RSA_PKCS1_OAEP_PADDING);
 
   CONST(SSL_MODE_SEND_FALLBACK_SCSV);
   CONST(SSL_MODE_CBC_RECORD_SPLITTING);
diff --git a/src/main/java/org/conscrypt/CryptoUpcalls.java b/src/main/java/org/conscrypt/CryptoUpcalls.java
index 2b71fea..abd9063 100644
--- a/src/main/java/org/conscrypt/CryptoUpcalls.java
+++ b/src/main/java/org/conscrypt/CryptoUpcalls.java
@@ -29,7 +29,6 @@
  * calls to work on delegated key types from native code.
  */
 public final class CryptoUpcalls {
-    private static final String RSA_CRYPTO_ALGORITHM = "RSA/ECB/PKCS1Padding";
 
     private CryptoUpcalls() {
     }
@@ -97,7 +96,7 @@
         }
     }
 
-    public static byte[] rawCipherWithPrivateKey(PrivateKey javaKey, boolean encrypt,
+    public static byte[] rsaDecryptWithPrivateKey(PrivateKey javaKey, int openSSLPadding,
             byte[] input) {
         String keyAlgorithm = javaKey.getAlgorithm();
         if (!"RSA".equals(keyAlgorithm)) {
@@ -105,14 +104,31 @@
             return null;
         }
 
-        Provider p = getExternalProvider("Cipher." + RSA_CRYPTO_ALGORITHM);
+        String jcaPadding;
+        switch (openSSLPadding) {
+            case NativeConstants.RSA_PKCS1_PADDING:
+                jcaPadding = "PKCS1Padding";
+                break;
+            case NativeConstants.RSA_NO_PADDING:
+                jcaPadding = "NoPadding";
+                break;
+            case NativeConstants.RSA_PKCS1_OAEP_PADDING:
+                jcaPadding = "OAEPPadding";
+                break;
+            default:
+                System.err.println("Unsupported OpenSSL/BoringSSL padding: " + openSSLPadding);
+                return null;
+        }
+
+        String transformation = "RSA/ECB/" + jcaPadding;
+        Provider p = getExternalProvider("Cipher." + transformation);
         if (p == null) {
             return null;
         }
 
         Cipher c = null;
         try {
-            c = Cipher.getInstance(RSA_CRYPTO_ALGORITHM, p);
+            c = Cipher.getInstance(transformation, p);
         } catch (NoSuchAlgorithmException e) {
             ;
         } catch (NoSuchPaddingException e) {
@@ -120,16 +136,16 @@
         }
 
         if (c == null) {
-            System.err.println("Unsupported transformation: " + RSA_CRYPTO_ALGORITHM);
+            System.err.println("Unsupported transformation: " + transformation);
             return null;
         }
 
         try {
-            c.init(encrypt ? Cipher.ENCRYPT_MODE : Cipher.DECRYPT_MODE, javaKey);
+            c.init(Cipher.DECRYPT_MODE, javaKey);
             return c.doFinal(input);
         } catch (Exception e) {
-            System.err.println("Exception while ciphering message with " + javaKey.getAlgorithm()
-                    + " private key:");
+            System.err.println("Exception while decrypting message with " + javaKey.getAlgorithm()
+                    + " private key using " + transformation + ":");
             e.printStackTrace();
             return null;
         }
diff --git a/src/main/native/org_conscrypt_NativeCrypto.cpp b/src/main/native/org_conscrypt_NativeCrypto.cpp
index 051301d..29ad4a4 100644
--- a/src/main/native/org_conscrypt_NativeCrypto.cpp
+++ b/src/main/native/org_conscrypt_NativeCrypto.cpp
@@ -1487,33 +1487,41 @@
             cryptoUpcallsClass, rawSignMethod, privateKey, messageArray.get()));
 }
 
-static jbyteArray rawCipherWithPrivateKey(JNIEnv* env, jobject privateKey, jboolean encrypt,
+// rsaDecryptWithPrivateKey uses privateKey to decrypt |ciphertext_len| bytes
+// from |ciphertext|. The ciphertext is expected to be padded using the scheme
+// given in |padding|, which must be one of |RSA_*_PADDING| constants from
+// OpenSSL.
+static jbyteArray rsaDecryptWithPrivateKey(JNIEnv* env, jobject privateKey, jint padding,
         const char* ciphertext, size_t ciphertext_len) {
     ScopedLocalRef<jbyteArray> ciphertextArray(env, env->NewByteArray(ciphertext_len));
     if (env->ExceptionCheck()) {
-        JNI_TRACE("rawCipherWithPrivateKey(%p) => threw exception", privateKey);
+        JNI_TRACE("rsaDecryptWithPrivateKey(%p) => threw exception", privateKey);
         return NULL;
     }
 
     {
         ScopedByteArrayRW ciphertextBytes(env, ciphertextArray.get());
         if (ciphertextBytes.get() == NULL) {
-            JNI_TRACE("rawCipherWithPrivateKey(%p) => using byte array failed", privateKey);
+            JNI_TRACE("rsaDecryptWithPrivateKey(%p) => using byte array failed", privateKey);
             return NULL;
         }
 
         memcpy(ciphertextBytes.get(), ciphertext, ciphertext_len);
     }
 
-    jmethodID rawCipherMethod = env->GetStaticMethodID(cryptoUpcallsClass,
-            "rawCipherWithPrivateKey", "(Ljava/security/PrivateKey;Z[B)[B");
-    if (rawCipherMethod == NULL) {
-        ALOGE("Could not find rawCipherWithPrivateKey");
+    jmethodID rsaDecryptMethod = env->GetStaticMethodID(cryptoUpcallsClass,
+            "rsaDecryptWithPrivateKey", "(Ljava/security/PrivateKey;I[B)[B");
+    if (rsaDecryptMethod == NULL) {
+        ALOGE("Could not find rsaDecryptWithPrivateKey");
         return NULL;
     }
 
     return reinterpret_cast<jbyteArray>(env->CallStaticObjectMethod(
-            cryptoUpcallsClass, rawCipherMethod, privateKey, encrypt, ciphertextArray.get()));
+            cryptoUpcallsClass,
+            rsaDecryptMethod,
+            privateKey,
+            padding,
+            ciphertextArray.get()));
 }
 
 // *********************************************
@@ -1611,11 +1619,6 @@
                      unsigned char* to,
                      RSA* rsa,
                      int padding) {
-    if (padding != RSA_PKCS1_PADDING) {
-        RSAerr(RSA_F_RSA_PRIVATE_DECRYPT, RSA_R_UNKNOWN_PADDING_TYPE);
-        return -1;
-    }
-
     // Retrieve private key JNI reference.
     jobject private_key = reinterpret_cast<jobject>(RSA_get_app_data(rsa));
     if (!private_key) {
@@ -1629,10 +1632,9 @@
         return -1;
     }
 
-    // For RSA keys, this function behaves as RSA_private_decrypt with
-    // PKCS#1 padding.
-    ScopedLocalRef<jbyteArray> cleartext(env, rawCipherWithPrivateKey(env, private_key, false,
-                                         reinterpret_cast<const char*>(from), flen));
+    // This function behaves as RSA_private_decrypt.
+    ScopedLocalRef<jbyteArray> cleartext(env, rsaDecryptWithPrivateKey(env, private_key,
+                                         padding, reinterpret_cast<const char*>(from), flen));
     if (cleartext.get() == NULL) {
         ALOGE("Could not decrypt message in RsaMethodPrivDec!");
         RSAerr(RSA_F_RSA_PRIVATE_DECRYPT, ERR_R_INTERNAL_ERROR);
@@ -1929,18 +1931,19 @@
     return 0;
   }
 
-  // For RSA keys, this function behaves as RSA_private_decrypt with
-  // PKCS#1 v1.5 padding.
-  ScopedLocalRef<jbyteArray> cleartext(
-      env, rawCipherWithPrivateKey(env, ex_data->private_key, false,
-                                   reinterpret_cast<const char*>(in), in_len));
+  // For RSA keys, this function behaves as RSA_private_encrypt with
+  // PKCS#1 padding.
+  ScopedLocalRef<jbyteArray> signature(
+      env, rawSignDigestWithPrivateKey(
+          env, ex_data->private_key,
+          reinterpret_cast<const char*>(in), in_len));
 
-  if (cleartext.get() == NULL) {
+  if (signature.get() == NULL) {
     OPENSSL_PUT_ERROR(RSA, sign_raw, ERR_R_INTERNAL_ERROR);
     return 0;
   }
 
-  ScopedByteArrayRO result(env, cleartext.get());
+  ScopedByteArrayRO result(env, signature.get());
 
   size_t expected_size = static_cast<size_t>(RSA_size(rsa));
   if (result.size() > expected_size) {
@@ -1963,15 +1966,48 @@
   return 1;
 }
 
-int RsaMethodDecrypt(RSA* /* rsa */,
-                     size_t* /* out_len */,
-                     uint8_t* /* out */,
-                     size_t /* max_out */,
-                     const uint8_t* /* in */,
-                     size_t /* in_len */,
-                     int /* padding */) {
-  OPENSSL_PUT_ERROR(RSA, decrypt, RSA_R_UNKNOWN_ALGORITHM_TYPE);
-  return 0;
+int RsaMethodDecrypt(RSA* rsa,
+                     size_t* out_len,
+                     uint8_t* out,
+                     size_t max_out,
+                     const uint8_t* in,
+                     size_t in_len,
+                     int padding) {
+  // Retrieve private key JNI reference.
+  const KeyExData *ex_data = RsaGetExData(rsa);
+  if (!ex_data || !ex_data->private_key) {
+    OPENSSL_PUT_ERROR(RSA, decrypt, ERR_R_INTERNAL_ERROR);
+    return 0;
+  }
+
+  JNIEnv* env = getJNIEnv();
+  if (env == NULL) {
+    OPENSSL_PUT_ERROR(RSA, decrypt, ERR_R_INTERNAL_ERROR);
+    return 0;
+  }
+
+  // This function behaves as RSA_private_decrypt.
+  ScopedLocalRef<jbyteArray> cleartext(
+      env, rsaDecryptWithPrivateKey(
+          env, ex_data->private_key, padding,
+          reinterpret_cast<const char*>(in), in_len));
+  if (cleartext.get() == NULL) {
+    OPENSSL_PUT_ERROR(RSA, decrypt, ERR_R_INTERNAL_ERROR);
+    return 0;
+  }
+
+  ScopedByteArrayRO cleartextBytes(env, cleartext.get());
+
+  if (max_out < cleartextBytes.size()) {
+    OPENSSL_PUT_ERROR(RSA, decrypt, RSA_R_DATA_TOO_LARGE);
+    return 0;
+  }
+
+  // Copy result to OpenSSL-provided buffer.
+  memcpy(out, cleartextBytes.get(), cleartextBytes.size());
+  *out_len = cleartextBytes.size();
+
+  return 1;
 }
 
 int RsaMethodVerifyRaw(RSA* /* rsa */,