Support all digests for RSA.

Also switch to using the EVP APIs where possible for RSA ops.

Cherry-picked from internal.

Change-Id: Ib51c6af9d41bb9bb1601beb93cb18313df22e412
diff --git a/android_keymaster_test.cpp b/android_keymaster_test.cpp
index 698e426..63e1d2c 100644
--- a/android_keymaster_test.cpp
+++ b/android_keymaster_test.cpp
@@ -183,7 +183,10 @@
     keymaster_digest_t* digests;
     ASSERT_EQ(KM_ERROR_OK, device()->get_supported_digests(device(), KM_ALGORITHM_RSA,
                                                            KM_PURPOSE_SIGN, &digests, &len));
-    EXPECT_TRUE(ResponseContains({KM_DIGEST_NONE, KM_DIGEST_SHA_2_256}, digests, len));
+    EXPECT_TRUE(
+        ResponseContains({KM_DIGEST_NONE, KM_DIGEST_MD5, KM_DIGEST_SHA1, KM_DIGEST_SHA_2_224,
+                          KM_DIGEST_SHA_2_256, KM_DIGEST_SHA_2_384, KM_DIGEST_SHA_2_512},
+                         digests, len));
     free(digests);
 
     ASSERT_EQ(KM_ERROR_OK, device()->get_supported_digests(device(), KM_ALGORITHM_EC,
@@ -442,19 +445,6 @@
         EXPECT_EQ(3, GetParam()->keymaster0_calls());
 }
 
-TEST_P(SigningOperationsTest, RsaSha256DigestSuccess) {
-    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
-                                           .RsaSigningKey(384, 3)
-                                           .Digest(KM_DIGEST_SHA_2_256)
-                                           .Padding(KM_PAD_RSA_PSS)));
-    string message(1024, 'a');
-    string signature;
-    SignMessage(message, &signature, KM_DIGEST_SHA_2_256, KM_PAD_RSA_PSS);
-
-    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
-        EXPECT_EQ(3, GetParam()->keymaster0_calls());
-}
-
 TEST_P(SigningOperationsTest, RsaPssSha256Success) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
                                            .RsaSigningKey(512, 3)
@@ -483,8 +473,8 @@
 }
 
 TEST_P(SigningOperationsTest, RsaPssSha256TooSmallKey) {
-    // Key must be at least 10 bytes larger than hash, to provide minimal random salt, so verify
-    // that 9 bytes larger than hash won't work.
+    // Key must be at least 10 bytes larger than hash, to provide eight bytes of random salt, so
+    // verify that nine bytes larger than hash won't work.
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
                                            .RsaSigningKey(256 + 9 * 8, 3)
                                            .Digest(KM_DIGEST_SHA_2_256)
@@ -495,16 +485,7 @@
     AuthorizationSet begin_params(client_params());
     begin_params.push_back(TAG_DIGEST, KM_DIGEST_SHA_2_256);
     begin_params.push_back(TAG_PADDING, KM_PAD_RSA_PSS);
-    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_SIGN, begin_params));
-
-    string result;
-    size_t input_consumed;
-    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
-    EXPECT_EQ(message.size(), input_consumed);
-    EXPECT_EQ(KM_ERROR_INCOMPATIBLE_DIGEST, FinishOperation(signature, &result));
-
-    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
-        EXPECT_EQ(2, GetParam()->keymaster0_calls());
+    EXPECT_EQ(KM_ERROR_INCOMPATIBLE_DIGEST, BeginOperation(KM_PURPOSE_SIGN, begin_params));
 }
 
 TEST_P(SigningOperationsTest, RsaAbort) {
@@ -615,7 +596,7 @@
 TEST_P(SigningOperationsTest, EcdsaSuccess) {
     ASSERT_EQ(KM_ERROR_OK,
               GenerateKey(AuthorizationSetBuilder().EcdsaSigningKey(224).Digest(KM_DIGEST_NONE)));
-    string message = "123456789012345678901234567890123456789012345678";
+    string message(1024, 'a');
     string signature;
     SignMessage(message, &signature, KM_DIGEST_NONE);
 
@@ -996,45 +977,6 @@
         EXPECT_EQ(4, GetParam()->keymaster0_calls());
 }
 
-TEST_P(VerificationOperationsTest, RsaSha256DigestSuccess) {
-    GenerateKey(AuthorizationSetBuilder()
-                    .RsaSigningKey(384, 3)
-                    .Digest(KM_DIGEST_SHA_2_256)
-                    .Padding(KM_PAD_RSA_PSS));
-    string message(1024, 'a');
-    string signature;
-    SignMessage(message, &signature, KM_DIGEST_SHA_2_256, KM_PAD_RSA_PSS);
-    VerifyMessage(message, signature, KM_DIGEST_SHA_2_256, KM_PAD_RSA_PSS);
-
-    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
-        EXPECT_EQ(4, GetParam()->keymaster0_calls());
-}
-
-TEST_P(VerificationOperationsTest, RsaSha256CorruptSignature) {
-    GenerateKey(AuthorizationSetBuilder()
-                    .RsaSigningKey(384, 3)
-                    .Digest(KM_DIGEST_SHA_2_256)
-                    .Padding(KM_PAD_RSA_PSS));
-    string message(1024, 'a');
-    string signature;
-    SignMessage(message, &signature, KM_DIGEST_SHA_2_256, KM_PAD_RSA_PSS);
-    ++signature[signature.size() / 2];
-
-    AuthorizationSet begin_params(client_params());
-    begin_params.push_back(TAG_DIGEST, KM_DIGEST_SHA_2_256);
-    begin_params.push_back(TAG_PADDING, KM_PAD_RSA_PSS);
-    EXPECT_EQ(KM_ERROR_OK, BeginOperation(KM_PURPOSE_VERIFY, begin_params));
-
-    string result;
-    size_t input_consumed;
-    EXPECT_EQ(KM_ERROR_OK, UpdateOperation(message, &result, &input_consumed));
-    EXPECT_EQ(message.size(), input_consumed);
-    EXPECT_EQ(KM_ERROR_VERIFICATION_FAILED, FinishOperation(signature, &result));
-
-    if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
-        EXPECT_EQ(4, GetParam()->keymaster0_calls());
-}
-
 TEST_P(VerificationOperationsTest, RsaPssSha256Success) {
     ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
                                            .RsaSigningKey(512, 3)
@@ -1185,13 +1127,19 @@
                                                     &padding_modes, &padding_modes_len));
 
     // Try them.
+    int trial_count = 0;
     for (keymaster_padding_t padding_mode : make_vector(padding_modes, padding_modes_len)) {
         for (keymaster_digest_t digest : make_vector(digests, digests_len)) {
+            if (digest != KM_DIGEST_NONE && padding_mode == KM_PAD_NONE)
+                // Digesting requires padding
+                continue;
+
             // Compute key & message size that will work.
-            size_t key_bits = 256;
+            size_t key_bits = 0;
             size_t message_len = 1000;
-            switch (digest) {
-            case KM_DIGEST_NONE:
+
+            if (digest == KM_DIGEST_NONE) {
+                key_bits = 256;
                 switch (padding_mode) {
                 case KM_PAD_NONE:
                     // Match key size.
@@ -1207,26 +1155,42 @@
                     FAIL() << "Missing padding";
                     break;
                 }
-                break;
+            } else {
+                size_t digest_bits;
+                switch (digest) {
+                case KM_DIGEST_MD5:
+                    digest_bits = 128;
+                    break;
+                case KM_DIGEST_SHA1:
+                    digest_bits = 160;
+                    break;
+                case KM_DIGEST_SHA_2_224:
+                    digest_bits = 224;
+                    break;
+                case KM_DIGEST_SHA_2_256:
+                    digest_bits = 256;
+                    break;
+                case KM_DIGEST_SHA_2_384:
+                    digest_bits = 384;
+                    break;
+                case KM_DIGEST_SHA_2_512:
+                    digest_bits = 512;
+                    break;
+                default:
+                    FAIL() << "Missing digest";
+                }
 
-            case KM_DIGEST_SHA_2_256:
                 switch (padding_mode) {
-                case KM_PAD_NONE:
-                    // Digesting requires padding
-                    continue;
                 case KM_PAD_RSA_PKCS1_1_5_SIGN:
-                    key_bits += 8 * 11;
+                    key_bits = digest_bits + 8 * (11 + 19);
                     break;
                 case KM_PAD_RSA_PSS:
-                    key_bits += 8 * 10;
+                    key_bits = digest_bits + 8 * 10;
                     break;
                 default:
                     FAIL() << "Missing padding";
                     break;
                 }
-                break;
-            default:
-                FAIL() << "Missing digest";
             }
 
             GenerateKey(AuthorizationSetBuilder()
@@ -1237,6 +1201,7 @@
             string signature;
             SignMessage(message, &signature, digest, padding_mode);
             VerifyMessage(message, signature, digest, padding_mode);
+            ++trial_count;
         }
     }
 
@@ -1244,7 +1209,7 @@
     free(digests);
 
     if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
-        EXPECT_EQ(16, GetParam()->keymaster0_calls());
+        EXPECT_EQ(trial_count * 4, GetParam()->keymaster0_calls());
 }
 
 TEST_P(VerificationOperationsTest, EcdsaSuccess) {
diff --git a/openssl_err.cpp b/openssl_err.cpp
index 04ab6f6..38edc05 100644
--- a/openssl_err.cpp
+++ b/openssl_err.cpp
@@ -37,6 +37,7 @@
 static keymaster_error_t TranslateCipherError(int reason);
 static keymaster_error_t TranslatePKCS8Error(int reason);
 static keymaster_error_t TranslateX509v3Error(int reason);
+static keymaster_error_t TranslateRsaError(int reason);
 #endif
 
 keymaster_error_t TranslateLastOpenSslError(bool log_message) {
@@ -60,6 +61,8 @@
         return TranslatePKCS8Error(reason);
     case ERR_LIB_X509V3:
         return TranslateX509v3Error(reason);
+    case ERR_LIB_RSA:
+        return TranslateRsaError(reason);
 #else
     case ERR_LIB_ASN1:
         LOG_E("ASN.1 parsing error %d", reason);
@@ -140,6 +143,16 @@
     }
 }
 
+keymaster_error_t TranslateRsaError(int reason) {
+    switch (reason) {
+    case RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE:
+    case RSA_R_DATA_TOO_SMALL_FOR_KEY_SIZE:
+        return KM_ERROR_INVALID_INPUT_LENGTH;
+    default:
+        return KM_ERROR_UNKNOWN_ERROR;
+    };
+}
+
 #endif  // OPENSSL_IS_BORINGSSL
 
 keymaster_error_t TranslateEvpError(int reason) {
diff --git a/rsa_operation.cpp b/rsa_operation.cpp
index 55ff1b3..3d50ba9 100644
--- a/rsa_operation.cpp
+++ b/rsa_operation.cpp
@@ -31,18 +31,25 @@
 static const int MIN_PSS_SALT_LEN = 8 /* salt len */ + 2 /* overhead */;
 
 /* static */
-RSA* RsaOperationFactory::GetRsaKey(const Key& key, keymaster_error_t* error) {
+EVP_PKEY* RsaOperationFactory::GetRsaKey(const Key& key, keymaster_error_t* error) {
     const RsaKey* rsa_key = static_cast<const RsaKey*>(&key);
     assert(rsa_key);
     if (!rsa_key || !rsa_key->key()) {
         *error = KM_ERROR_UNKNOWN_ERROR;
-        return NULL;
+        return nullptr;
     }
-    RSA_up_ref(rsa_key->key());
-    return rsa_key->key();
+
+    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKEY_new());
+    if (!rsa_key->InternalToEvp(pkey.get())) {
+        *error = KM_ERROR_UNKNOWN_ERROR;
+        return nullptr;
+    }
+    return pkey.release();
 }
 
-static const keymaster_digest_t supported_digests[] = {KM_DIGEST_NONE, KM_DIGEST_SHA_2_256};
+static const keymaster_digest_t supported_digests[] = {
+    KM_DIGEST_NONE,      KM_DIGEST_MD5,       KM_DIGEST_SHA1,     KM_DIGEST_SHA_2_224,
+    KM_DIGEST_SHA_2_256, KM_DIGEST_SHA_2_384, KM_DIGEST_SHA_2_512};
 static const keymaster_padding_t supported_sig_padding[] = {KM_PAD_NONE, KM_PAD_RSA_PKCS1_1_5_SIGN,
                                                             KM_PAD_RSA_PSS};
 
@@ -63,13 +70,15 @@
                                                          keymaster_error_t* error) {
     keymaster_padding_t padding;
     keymaster_digest_t digest;
-    RSA* rsa;
     if (!GetAndValidateDigest(begin_params, key, &digest, error) ||
-        !GetAndValidatePadding(begin_params, key, &padding, error) ||
-        !(rsa = GetRsaKey(key, error)))
-        return NULL;
+        !GetAndValidatePadding(begin_params, key, &padding, error))
+        return nullptr;
 
-    Operation* op = InstantiateOperation(digest, padding, rsa);
+    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> rsa(GetRsaKey(key, error));
+    if (!rsa.get())
+        return nullptr;
+
+    Operation* op = InstantiateOperation(digest, padding, rsa.release());
     if (!op)
         *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
     return op;
@@ -82,12 +91,14 @@
                                                         const AuthorizationSet& begin_params,
                                                         keymaster_error_t* error) {
     keymaster_padding_t padding;
-    RSA* rsa;
-    if (!GetAndValidatePadding(begin_params, key, &padding, error) ||
-        !(rsa = GetRsaKey(key, error)))
-        return NULL;
+    if (!GetAndValidatePadding(begin_params, key, &padding, error))
+        return nullptr;
 
-    Operation* op = InstantiateOperation(padding, rsa);
+    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> rsa(GetRsaKey(key, error));
+    if (!rsa.get())
+        return nullptr;
+
+    Operation* op = InstantiateOperation(padding, rsa.release());
     if (!op)
         *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
     return op;
@@ -107,7 +118,7 @@
 
 RsaOperation::~RsaOperation() {
     if (rsa_key_ != NULL)
-        RSA_free(rsa_key_);
+        EVP_PKEY_free(rsa_key_);
 }
 
 keymaster_error_t RsaOperation::Update(const AuthorizationSet& /* additional_params */,
@@ -134,76 +145,119 @@
     return KM_ERROR_OK;
 }
 
+keymaster_error_t RsaOperation::SetRsaPaddingInEvpContext(EVP_PKEY_CTX* pkey_ctx) {
+    keymaster_error_t error;
+    int openssl_padding = GetOpensslPadding(&error);
+    if (error != KM_ERROR_OK)
+        return error;
+
+    if (EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, openssl_padding) <= 0)
+        return TranslateLastOpenSslError();
+    return KM_ERROR_OK;
+}
+
 RsaDigestingOperation::RsaDigestingOperation(keymaster_purpose_t purpose, keymaster_digest_t digest,
-                                             keymaster_padding_t padding, RSA* key)
+                                             keymaster_padding_t padding, EVP_PKEY* key)
     : RsaOperation(purpose, padding, key), digest_(digest), digest_algorithm_(NULL) {
     EVP_MD_CTX_init(&digest_ctx_);
 }
 RsaDigestingOperation::~RsaDigestingOperation() {
     EVP_MD_CTX_cleanup(&digest_ctx_);
-    memset_s(digest_buf_, 0, sizeof(digest_buf_));
-}
-
-keymaster_error_t RsaDigestingOperation::Begin(const AuthorizationSet& /* input_params */,
-                                               AuthorizationSet* /* output_params */) {
-    if (require_digest() && digest_ == KM_DIGEST_NONE)
-        return KM_ERROR_INCOMPATIBLE_DIGEST;
-    return InitDigest();
-}
-
-keymaster_error_t RsaDigestingOperation::Update(const AuthorizationSet& additional_params,
-                                                const Buffer& input, Buffer* output,
-                                                size_t* input_consumed) {
-    if (digest_ == KM_DIGEST_NONE)
-        return RsaOperation::Update(additional_params, input, output, input_consumed);
-    else
-        return UpdateDigest(input, input_consumed);
 }
 
 keymaster_error_t RsaDigestingOperation::InitDigest() {
+    if (digest_ == KM_DIGEST_NONE) {
+        if (require_digest())
+            return KM_ERROR_INCOMPATIBLE_DIGEST;
+        return KM_ERROR_OK;
+    }
+
     switch (digest_) {
     case KM_DIGEST_NONE:
         return KM_ERROR_OK;
+    case KM_DIGEST_MD5:
+        digest_algorithm_ = EVP_md5();
+        return KM_ERROR_OK;
+    case KM_DIGEST_SHA1:
+        digest_algorithm_ = EVP_sha1();
+        return KM_ERROR_OK;
+    case KM_DIGEST_SHA_2_224:
+        digest_algorithm_ = EVP_sha224();
+        return KM_ERROR_OK;
     case KM_DIGEST_SHA_2_256:
         digest_algorithm_ = EVP_sha256();
-        break;
+        return KM_ERROR_OK;
+    case KM_DIGEST_SHA_2_384:
+        digest_algorithm_ = EVP_sha384();
+        return KM_ERROR_OK;
+    case KM_DIGEST_SHA_2_512:
+        digest_algorithm_ = EVP_sha512();
+        return KM_ERROR_OK;
     default:
         return KM_ERROR_UNSUPPORTED_DIGEST;
     }
-
-    if (!EVP_DigestInit_ex(&digest_ctx_, digest_algorithm_, NULL /* engine */)) {
-        int err = ERR_get_error();
-        LOG_E("Failed to initialize digest: %d %s", err, ERR_error_string(err, NULL));
-        return KM_ERROR_UNKNOWN_ERROR;
-    }
-    return KM_ERROR_OK;
 }
 
-keymaster_error_t RsaDigestingOperation::UpdateDigest(const Buffer& input, size_t* input_consumed) {
-    if (!EVP_DigestUpdate(&digest_ctx_, input.peek_read(), input.available_read())) {
-        int err = ERR_get_error();
-        LOG_E("Failed to update digest: %d %s", err, ERR_error_string(err, NULL));
-        return KM_ERROR_UNKNOWN_ERROR;
+const size_t PSS_OVERHEAD = 2;
+const size_t MIN_SALT_SIZE = 8;
+
+int RsaDigestingOperation::GetOpensslPadding(keymaster_error_t* error) {
+    *error = KM_ERROR_OK;
+    switch (padding_) {
+    case KM_PAD_NONE:
+        return RSA_NO_PADDING;
+    case KM_PAD_RSA_PKCS1_1_5_SIGN:
+
+        return RSA_PKCS1_PADDING;
+    case KM_PAD_RSA_PSS:
+        if (digest_ == KM_DIGEST_NONE) {
+            *error = KM_ERROR_INCOMPATIBLE_PADDING_MODE;
+            return -1;
+        }
+        if (EVP_MD_size(digest_algorithm_) + PSS_OVERHEAD + MIN_SALT_SIZE >
+            (size_t)EVP_PKEY_size(rsa_key_)) {
+            *error = KM_ERROR_INCOMPATIBLE_DIGEST;
+            return -1;
+        }
+        return RSA_PKCS1_PSS_PADDING;
+    default:
+        return -1;
     }
+}
+
+keymaster_error_t RsaSignOperation::Begin(const AuthorizationSet& /* input_params */,
+                                          AuthorizationSet* /* output_params */) {
+    keymaster_error_t error = InitDigest();
+    if (error != KM_ERROR_OK)
+        return error;
+
+    if (digest_ == KM_DIGEST_NONE)
+        return KM_ERROR_OK;
+
+    EVP_PKEY_CTX* pkey_ctx;
+    if (EVP_DigestSignInit(&digest_ctx_, &pkey_ctx, digest_algorithm_, nullptr /* engine */,
+                           rsa_key_) != 1)
+        return TranslateLastOpenSslError();
+    return SetRsaPaddingInEvpContext(pkey_ctx);
+}
+
+keymaster_error_t RsaSignOperation::Update(const AuthorizationSet& additional_params,
+                                           const Buffer& input, Buffer* output,
+                                           size_t* input_consumed) {
+    if (digest_ == KM_DIGEST_NONE)
+        // Just buffer the data.
+        return RsaOperation::Update(additional_params, input, output, input_consumed);
+
+    if (EVP_DigestSignUpdate(&digest_ctx_, input.peek_read(), input.available_read()) != 1)
+        return TranslateLastOpenSslError();
     *input_consumed = input.available_read();
     return KM_ERROR_OK;
 }
 
-keymaster_error_t RsaDigestingOperation::FinishDigest(unsigned* digest_size) {
-    assert(digest_algorithm_ != NULL);
-    if (!EVP_DigestFinal_ex(&digest_ctx_, digest_buf_, digest_size)) {
-        int err = ERR_get_error();
-        LOG_E("Failed to finalize digest: %d %s", err, ERR_error_string(err, NULL));
-        return KM_ERROR_UNKNOWN_ERROR;
-    }
-    assert(*digest_size == static_cast<unsigned>(EVP_MD_size(digest_algorithm_)));
-    return KM_ERROR_OK;
-}
-
 keymaster_error_t RsaSignOperation::Finish(const AuthorizationSet& /* additional_params */,
                                            const Buffer& /* signature */, Buffer* output) {
     assert(output);
-    output->Reinitialize(RSA_size(rsa_key_));
+
     if (digest_ == KM_DIGEST_NONE)
         return SignUndigested(output);
     else
@@ -211,15 +265,23 @@
 }
 
 keymaster_error_t RsaSignOperation::SignUndigested(Buffer* output) {
+    UniquePtr<RSA, RSA_Delete> rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(rsa_key_)));
+    if (!rsa.get())
+        return TranslateLastOpenSslError();
+
+    if (!output->Reinitialize(RSA_size(rsa.get())))
+        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+
     int bytes_encrypted;
     switch (padding_) {
     case KM_PAD_NONE:
         bytes_encrypted = RSA_private_encrypt(data_.available_read(), data_.peek_read(),
-                                              output->peek_write(), rsa_key_, RSA_NO_PADDING);
+                                              output->peek_write(), rsa.get(), RSA_NO_PADDING);
         break;
     case KM_PAD_RSA_PKCS1_1_5_SIGN:
+        // Does PKCS1 padding without digesting even make sense?  Dunno.  We'll support it.
         bytes_encrypted = RSA_private_encrypt(data_.available_read(), data_.peek_read(),
-                                              output->peek_write(), rsa_key_, RSA_PKCS1_PADDING);
+                                              output->peek_write(), rsa.get(), RSA_PKCS1_PADDING);
         break;
     default:
         return KM_ERROR_UNSUPPORTED_PADDING_MODE;
@@ -232,57 +294,45 @@
 }
 
 keymaster_error_t RsaSignOperation::SignDigested(Buffer* output) {
-    unsigned digest_size = 0;
-    keymaster_error_t error = FinishDigest(&digest_size);
-    if (error != KM_ERROR_OK)
-        return error;
+    size_t siglen;
+    if (EVP_DigestSignFinal(&digest_ctx_, nullptr /* signature */, &siglen) != 1)
+        return TranslateLastOpenSslError();
 
-    UniquePtr<uint8_t[]> padded_digest;
-    switch (padding_) {
-    case KM_PAD_NONE:
-        LOG_E("Digesting requires padding", 0);
-        return KM_ERROR_INCOMPATIBLE_PADDING_MODE;
-    case KM_PAD_RSA_PKCS1_1_5_SIGN:
-        return PrivateEncrypt(digest_buf_, digest_size, RSA_PKCS1_PADDING, output);
-    case KM_PAD_RSA_PSS:
-        // OpenSSL doesn't verify that the key is large enough for the digest size.  This can cause
-        // a segfault in some cases, and in others can result in a unsafely-small salt.
-        if ((unsigned)RSA_size(rsa_key_) < MIN_PSS_SALT_LEN + digest_size) {
-            LOG_E("%d-byte too small for PSS padding and %d-byte digest", RSA_size(rsa_key_),
-                  digest_size);
-            // TODO(swillden): Add a better return code for this.
-            return KM_ERROR_INCOMPATIBLE_DIGEST;
-        }
-
-        if ((error = PssPadDigest(&padded_digest)) != KM_ERROR_OK)
-            return error;
-        return PrivateEncrypt(padded_digest.get(), RSA_size(rsa_key_), RSA_NO_PADDING, output);
-    default:
-        return KM_ERROR_UNSUPPORTED_PADDING_MODE;
-    }
-}
-
-keymaster_error_t RsaSignOperation::PssPadDigest(UniquePtr<uint8_t[]>* padded_digest) {
-    padded_digest->reset(new uint8_t[RSA_size(rsa_key_)]);
-    if (!padded_digest->get())
+    if (!output->Reinitialize(siglen))
         return KM_ERROR_MEMORY_ALLOCATION_FAILED;
 
-    if (!RSA_padding_add_PKCS1_PSS_mgf1(rsa_key_, padded_digest->get(), digest_buf_,
-                                        digest_algorithm_, NULL,
-                                        -2 /* Indicates maximum salt length */)) {
-        LOG_E("%s", "Failed to apply PSS padding");
-        return KM_ERROR_UNKNOWN_ERROR;
-    }
+    if (EVP_DigestSignFinal(&digest_ctx_, output->peek_write(), &siglen) <= 0)
+        return TranslateLastOpenSslError();
+    output->advance_write(siglen);
+
     return KM_ERROR_OK;
 }
 
-keymaster_error_t RsaSignOperation::PrivateEncrypt(uint8_t* to_encrypt, size_t len,
-                                                   int openssl_padding, Buffer* output) {
-    int bytes_encrypted =
-        RSA_private_encrypt(len, to_encrypt, output->peek_write(), rsa_key_, openssl_padding);
-    if (bytes_encrypted <= 0)
-        return KM_ERROR_UNKNOWN_ERROR;
-    output->advance_write(bytes_encrypted);
+keymaster_error_t RsaVerifyOperation::Begin(const AuthorizationSet& /* input_params */,
+                                            AuthorizationSet* /* output_params */) {
+    keymaster_error_t error = InitDigest();
+    if (error != KM_ERROR_OK)
+        return error;
+
+    if (digest_ == KM_DIGEST_NONE)
+        return KM_ERROR_OK;
+
+    EVP_PKEY_CTX* pkey_ctx;
+    if (EVP_DigestVerifyInit(&digest_ctx_, &pkey_ctx, digest_algorithm_, NULL, rsa_key_) != 1)
+        return TranslateLastOpenSslError();
+    return SetRsaPaddingInEvpContext(pkey_ctx);
+}
+
+keymaster_error_t RsaVerifyOperation::Update(const AuthorizationSet& additional_params,
+                                             const Buffer& input, Buffer* output,
+                                             size_t* input_consumed) {
+    if (digest_ == KM_DIGEST_NONE)
+        // Just buffer the data.
+        return RsaOperation::Update(additional_params, input, output, input_consumed);
+
+    if (EVP_DigestVerifyUpdate(&digest_ctx_, input.peek_read(), input.available_read()) != 1)
+        return TranslateLastOpenSslError();
+    *input_consumed = input.available_read();
     return KM_ERROR_OK;
 }
 
@@ -295,37 +345,20 @@
 }
 
 keymaster_error_t RsaVerifyOperation::VerifyUndigested(const Buffer& signature) {
-    return DecryptAndMatch(signature, data_.peek_read(), data_.available_read());
-}
+    UniquePtr<RSA, RSA_Delete> rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(rsa_key_)));
+    if (!rsa.get())
+        return KM_ERROR_UNKNOWN_ERROR;
 
-keymaster_error_t RsaVerifyOperation::VerifyDigested(const Buffer& signature) {
-    unsigned digest_size = 0;
-    keymaster_error_t error = FinishDigest(&digest_size);
-    if (error != KM_ERROR_OK)
-        return error;
-    return DecryptAndMatch(signature, digest_buf_, digest_size);
-}
-
-keymaster_error_t RsaVerifyOperation::DecryptAndMatch(const Buffer& signature,
-                                                      const uint8_t* to_match, size_t len) {
-#ifdef OPENSSL_IS_BORINGSSL
-    size_t key_len = RSA_size(rsa_key_);
-#else
-    size_t key_len = (size_t)RSA_size(rsa_key_);
-#endif
-
+    size_t key_len = RSA_size(rsa.get());
     int openssl_padding;
     switch (padding_) {
     case KM_PAD_NONE:
-        if (len != key_len)
+        if (data_.available_read() != key_len)
             return KM_ERROR_INVALID_INPUT_LENGTH;
-        if (len != signature.available_read())
+        if (data_.available_read() != signature.available_read())
             return KM_ERROR_VERIFICATION_FAILED;
         openssl_padding = RSA_NO_PADDING;
         break;
-    case KM_PAD_RSA_PSS:  // Do a raw decrypt for PSS
-        openssl_padding = RSA_NO_PADDING;
-        break;
     case KM_PAD_RSA_PKCS1_1_5_SIGN:
         openssl_padding = RSA_PKCS1_PADDING;
         break;
@@ -335,67 +368,65 @@
 
     UniquePtr<uint8_t[]> decrypted_data(new uint8_t[key_len]);
     int bytes_decrypted = RSA_public_decrypt(signature.available_read(), signature.peek_read(),
-                                             decrypted_data.get(), rsa_key_, openssl_padding);
+                                             decrypted_data.get(), rsa.get(), openssl_padding);
     if (bytes_decrypted < 0)
         return KM_ERROR_VERIFICATION_FAILED;
 
-    if (padding_ == KM_PAD_RSA_PSS &&
-        RSA_verify_PKCS1_PSS_mgf1(rsa_key_, to_match, digest_algorithm_, NULL, decrypted_data.get(),
-                                  -2 /* salt length recovered from signature */))
-        return KM_ERROR_OK;
-    else if (padding_ != KM_PAD_RSA_PSS && memcmp_s(decrypted_data.get(), to_match, len) == 0)
-        return KM_ERROR_OK;
-
-    return KM_ERROR_VERIFICATION_FAILED;
+    if (memcmp_s(decrypted_data.get(), data_.peek_read(), data_.available_read()) != 0)
+        return KM_ERROR_VERIFICATION_FAILED;
+    return KM_ERROR_OK;
 }
 
-const int OAEP_PADDING_OVERHEAD = 42;
-const int PKCS1_PADDING_OVERHEAD = 11;
+keymaster_error_t RsaVerifyOperation::VerifyDigested(const Buffer& signature) {
+    if (!EVP_DigestVerifyFinal(&digest_ctx_, signature.peek_read(), signature.available_read()))
+        return KM_ERROR_VERIFICATION_FAILED;
+    return KM_ERROR_OK;
+}
+
+int RsaCryptOperation::GetOpensslPadding(keymaster_error_t* error) {
+    *error = KM_ERROR_OK;
+    switch (padding_) {
+    case KM_PAD_RSA_PKCS1_1_5_ENCRYPT:
+        return RSA_PKCS1_PADDING;
+    case KM_PAD_RSA_OAEP:
+        return RSA_PKCS1_OAEP_PADDING;
+    default:
+        return -1;
+    }
+}
+
+struct EVP_PKEY_CTX_Delete {
+    void operator()(EVP_PKEY_CTX* p) { EVP_PKEY_CTX_free(p); }
+};
 
 keymaster_error_t RsaEncryptOperation::Finish(const AuthorizationSet& /* additional_params */,
                                               const Buffer& /* signature */, Buffer* output) {
     assert(output);
-    int openssl_padding;
 
-#if defined(OPENSSL_IS_BORINGSSL)
-    size_t key_len = RSA_size(rsa_key_);
-#else
-    size_t key_len = (size_t)RSA_size(rsa_key_);
-#endif
+    UniquePtr<EVP_PKEY_CTX, EVP_PKEY_CTX_Delete> ctx(
+        EVP_PKEY_CTX_new(rsa_key_, nullptr /* engine */));
+    if (!ctx.get())
+        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
 
-    size_t message_size = data_.available_read();
-    switch (padding_) {
-    case KM_PAD_RSA_OAEP:
-        openssl_padding = RSA_PKCS1_OAEP_PADDING;
-        if (message_size + OAEP_PADDING_OVERHEAD > key_len) {
-            LOG_E("Cannot encrypt %d bytes with %d-byte key and OAEP padding",
-                  data_.available_read(), key_len);
-            return KM_ERROR_INVALID_INPUT_LENGTH;
-        }
-        break;
-    case KM_PAD_RSA_PKCS1_1_5_ENCRYPT:
-        openssl_padding = RSA_PKCS1_PADDING;
-        if (message_size + PKCS1_PADDING_OVERHEAD > key_len) {
-            LOG_E("Cannot encrypt %d bytes with %d-byte key and PKCS1 padding",
-                  data_.available_read(), key_len);
-            return KM_ERROR_INVALID_INPUT_LENGTH;
-        }
-        break;
-    default:
-        LOG_E("Padding mode %d not supported", padding_);
-        return KM_ERROR_UNSUPPORTED_PADDING_MODE;
-    }
+    if (EVP_PKEY_encrypt_init(ctx.get()) <= 0)
+        return TranslateLastOpenSslError();
 
-    output->Reinitialize(RSA_size(rsa_key_));
-    int bytes_encrypted = RSA_public_encrypt(data_.available_read(), data_.peek_read(),
-                                             output->peek_write(), rsa_key_, openssl_padding);
+    keymaster_error_t error = SetRsaPaddingInEvpContext(ctx.get());
+    if (error != KM_ERROR_OK)
+        return error;
 
-    if (bytes_encrypted < 0) {
-        LOG_E("Error %d encrypting data with RSA", ERR_get_error());
-        return KM_ERROR_UNKNOWN_ERROR;
-    }
-    assert(bytes_encrypted == (int)RSA_size(rsa_key_));
-    output->advance_write(bytes_encrypted);
+    size_t outlen;
+    if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen, data_.peek_read(),
+                         data_.available_read()) <= 0)
+        return TranslateLastOpenSslError();
+
+    if (!output->Reinitialize(outlen))
+        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+
+    if (EVP_PKEY_encrypt(ctx.get(), output->peek_write(), &outlen, data_.peek_read(),
+                         data_.available_read()) <= 0)
+        return TranslateLastOpenSslError();
+    output->advance_write(outlen);
 
     return KM_ERROR_OK;
 }
@@ -403,28 +434,31 @@
 keymaster_error_t RsaDecryptOperation::Finish(const AuthorizationSet& /* additional_params */,
                                               const Buffer& /* signature */, Buffer* output) {
     assert(output);
-    int openssl_padding;
-    switch (padding_) {
-    case KM_PAD_RSA_OAEP:
-        openssl_padding = RSA_PKCS1_OAEP_PADDING;
-        break;
-    case KM_PAD_RSA_PKCS1_1_5_ENCRYPT:
-        openssl_padding = RSA_PKCS1_PADDING;
-        break;
-    default:
-        LOG_E("Padding mode %d not supported", padding_);
-        return KM_ERROR_UNSUPPORTED_PADDING_MODE;
-    }
 
-    output->Reinitialize(RSA_size(rsa_key_));
-    int bytes_decrypted = RSA_private_decrypt(data_.available_read(), data_.peek_read(),
-                                              output->peek_write(), rsa_key_, openssl_padding);
+    UniquePtr<EVP_PKEY_CTX, EVP_PKEY_CTX_Delete> ctx(
+        EVP_PKEY_CTX_new(rsa_key_, nullptr /* engine */));
+    if (!ctx.get())
+        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
 
-    if (bytes_decrypted < 0) {
-        LOG_E("Error %d decrypting data with RSA", ERR_get_error());
-        return KM_ERROR_UNKNOWN_ERROR;
-    }
-    output->advance_write(bytes_decrypted);
+    if (EVP_PKEY_decrypt_init(ctx.get()) <= 0)
+        return TranslateLastOpenSslError();
+
+    keymaster_error_t error = SetRsaPaddingInEvpContext(ctx.get());
+    if (error != KM_ERROR_OK)
+        return error;
+
+    size_t outlen;
+    if (EVP_PKEY_decrypt(ctx.get(), nullptr /* out */, &outlen, data_.peek_read(),
+                         data_.available_read()) <= 0)
+        return TranslateLastOpenSslError();
+
+    if (!output->Reinitialize(outlen))
+        return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+
+    if (EVP_PKEY_decrypt(ctx.get(), output->peek_write(), &outlen, data_.peek_read(),
+                         data_.available_read()) <= 0)
+        return TranslateLastOpenSslError();
+    output->advance_write(outlen);
 
     return KM_ERROR_OK;
 }
diff --git a/rsa_operation.h b/rsa_operation.h
index 9c0af17..30d62b8 100644
--- a/rsa_operation.h
+++ b/rsa_operation.h
@@ -33,7 +33,7 @@
  */
 class RsaOperation : public Operation {
   public:
-    RsaOperation(keymaster_purpose_t purpose, keymaster_padding_t padding, RSA* key)
+    RsaOperation(keymaster_purpose_t purpose, keymaster_padding_t padding, EVP_PKEY* key)
         : Operation(purpose), rsa_key_(key), padding_(padding) {}
     ~RsaOperation();
 
@@ -46,15 +46,18 @@
     keymaster_error_t Abort() override { return KM_ERROR_OK; }
 
   protected:
-    keymaster_error_t StoreData(const Buffer& input, size_t* input_consumed);
+    virtual int GetOpensslPadding(keymaster_error_t* error) = 0;
 
-    RSA* rsa_key_;
+    keymaster_error_t StoreData(const Buffer& input, size_t* input_consumed);
+    keymaster_error_t SetRsaPaddingInEvpContext(EVP_PKEY_CTX* pkey_ctx);
+
+    EVP_PKEY* rsa_key_;
     keymaster_padding_t padding_;
     Buffer data_;
 };
 
 /**
- * Base class for all RSA operations.
+ * Base class for all digesting RSA operations.
  *
  * This class adds digesting support, for digesting modes.  For non-digesting modes, it falls back
  * on the RsaOperation input buffering.
@@ -62,24 +65,18 @@
 class RsaDigestingOperation : public RsaOperation {
   public:
     RsaDigestingOperation(keymaster_purpose_t purpose, keymaster_digest_t digest,
-                          keymaster_padding_t padding, RSA* key);
+                          keymaster_padding_t padding, EVP_PKEY* key);
     ~RsaDigestingOperation();
 
-    keymaster_error_t Begin(const AuthorizationSet& input_params,
-                            AuthorizationSet* output_params) override;
-    keymaster_error_t Update(const AuthorizationSet& additional_params, const Buffer& input,
-                             Buffer* output, size_t* input_consumed) override;
-
   protected:
-    bool require_digest() const { return padding_ == KM_PAD_RSA_PSS; }
     keymaster_error_t InitDigest();
-    keymaster_error_t UpdateDigest(const Buffer& input, size_t* input_consumed);
-    keymaster_error_t FinishDigest(unsigned* digest_size);
+    int GetOpensslPadding(keymaster_error_t* error) override;
+
+    bool require_digest() const { return padding_ == KM_PAD_RSA_PSS; }
 
     const keymaster_digest_t digest_;
     const EVP_MD* digest_algorithm_;
     EVP_MD_CTX digest_ctx_;
-    uint8_t digest_buf_[EVP_MAX_MD_SIZE];
 };
 
 /**
@@ -87,17 +84,19 @@
  */
 class RsaSignOperation : public RsaDigestingOperation {
   public:
-    RsaSignOperation(keymaster_digest_t digest, keymaster_padding_t padding, RSA* key)
+    RsaSignOperation(keymaster_digest_t digest, keymaster_padding_t padding, EVP_PKEY* key)
         : RsaDigestingOperation(KM_PURPOSE_SIGN, digest, padding, key) {}
+
+    keymaster_error_t Begin(const AuthorizationSet& input_params,
+                            AuthorizationSet* output_params) override;
+    keymaster_error_t Update(const AuthorizationSet& additional_params, const Buffer& input,
+                             Buffer* output, size_t* input_consumed) override;
     keymaster_error_t Finish(const AuthorizationSet& additional_params, const Buffer& signature,
                              Buffer* output) override;
 
   private:
     keymaster_error_t SignUndigested(Buffer* output);
     keymaster_error_t SignDigested(Buffer* output);
-    keymaster_error_t PrivateEncrypt(uint8_t* to_encrypt, size_t len, int openssl_padding,
-                                     Buffer* output);
-    keymaster_error_t PssPadDigest(UniquePtr<uint8_t[]>* padded_digest);
 };
 
 /**
@@ -105,24 +104,40 @@
  */
 class RsaVerifyOperation : public RsaDigestingOperation {
   public:
-    RsaVerifyOperation(keymaster_digest_t digest, keymaster_padding_t padding, RSA* key)
+    RsaVerifyOperation(keymaster_digest_t digest, keymaster_padding_t padding, EVP_PKEY* key)
         : RsaDigestingOperation(KM_PURPOSE_VERIFY, digest, padding, key) {}
+
+    keymaster_error_t Begin(const AuthorizationSet& input_params,
+                            AuthorizationSet* output_params) override;
+    keymaster_error_t Update(const AuthorizationSet& additional_params, const Buffer& input,
+                             Buffer* output, size_t* input_consumed) override;
     keymaster_error_t Finish(const AuthorizationSet& additional_params, const Buffer& signature,
                              Buffer* output) override;
 
   private:
     keymaster_error_t VerifyUndigested(const Buffer& signature);
     keymaster_error_t VerifyDigested(const Buffer& signature);
-    keymaster_error_t DecryptAndMatch(const Buffer& signature, const uint8_t* to_match, size_t len);
+};
+
+/**
+ * Base class for RSA crypting operations.
+ */
+class RsaCryptOperation : public RsaOperation {
+  public:
+    RsaCryptOperation(keymaster_purpose_t, keymaster_padding_t padding, EVP_PKEY* key)
+        : RsaOperation(KM_PURPOSE_ENCRYPT, padding, key) {}
+
+  private:
+    int GetOpensslPadding(keymaster_error_t* error) override;
 };
 
 /**
  * RSA public key encryption operation.
  */
-class RsaEncryptOperation : public RsaOperation {
+class RsaEncryptOperation : public RsaCryptOperation {
   public:
-    RsaEncryptOperation(keymaster_padding_t padding, RSA* key)
-        : RsaOperation(KM_PURPOSE_ENCRYPT, padding, key) {}
+    RsaEncryptOperation(keymaster_padding_t padding, EVP_PKEY* key)
+        : RsaCryptOperation(KM_PURPOSE_ENCRYPT, padding, key) {}
     keymaster_error_t Finish(const AuthorizationSet& additional_params, const Buffer& signature,
                              Buffer* output) override;
 };
@@ -130,10 +145,10 @@
 /**
  * RSA private key decryption operation.
  */
-class RsaDecryptOperation : public RsaOperation {
+class RsaDecryptOperation : public RsaCryptOperation {
   public:
-    RsaDecryptOperation(keymaster_padding_t padding, RSA* key)
-        : RsaOperation(KM_PURPOSE_DECRYPT, padding, key) {}
+    RsaDecryptOperation(keymaster_padding_t padding, EVP_PKEY* key)
+        : RsaCryptOperation(KM_PURPOSE_DECRYPT, padding, key) {}
     keymaster_error_t Finish(const AuthorizationSet& additional_params, const Buffer& signature,
                              Buffer* output) override;
 };
@@ -148,7 +163,7 @@
     virtual keymaster_purpose_t purpose() const = 0;
 
   protected:
-    static RSA* GetRsaKey(const Key& key, keymaster_error_t* error);
+    static EVP_PKEY* GetRsaKey(const Key& key, keymaster_error_t* error);
 };
 
 /**
@@ -165,7 +180,7 @@
 
   private:
     virtual Operation* InstantiateOperation(keymaster_digest_t digest, keymaster_padding_t padding,
-                                            RSA* key) = 0;
+                                            EVP_PKEY* key) = 0;
 };
 
 /**
@@ -180,7 +195,7 @@
     const keymaster_digest_t* SupportedDigests(size_t* digest_count) const override;
 
   private:
-    virtual Operation* InstantiateOperation(keymaster_padding_t padding, RSA* key) = 0;
+    virtual Operation* InstantiateOperation(keymaster_padding_t padding, EVP_PKEY* key) = 0;
 };
 
 /**
@@ -190,7 +205,7 @@
   public:
     keymaster_purpose_t purpose() const override { return KM_PURPOSE_SIGN; }
     Operation* InstantiateOperation(keymaster_digest_t digest, keymaster_padding_t padding,
-                                    RSA* key) override {
+                                    EVP_PKEY* key) override {
         return new RsaSignOperation(digest, padding, key);
     }
 };
@@ -201,7 +216,7 @@
 class RsaVerificationOperationFactory : public RsaDigestingOperationFactory {
     keymaster_purpose_t purpose() const override { return KM_PURPOSE_VERIFY; }
     Operation* InstantiateOperation(keymaster_digest_t digest, keymaster_padding_t padding,
-                                    RSA* key) override {
+                                    EVP_PKEY* key) override {
         return new RsaVerifyOperation(digest, padding, key);
     }
 };
@@ -211,7 +226,7 @@
  */
 class RsaEncryptionOperationFactory : public RsaCryptingOperationFactory {
     keymaster_purpose_t purpose() const override { return KM_PURPOSE_ENCRYPT; }
-    Operation* InstantiateOperation(keymaster_padding_t padding, RSA* key) override {
+    Operation* InstantiateOperation(keymaster_padding_t padding, EVP_PKEY* key) override {
         return new RsaEncryptOperation(padding, key);
     }
 };
@@ -221,7 +236,7 @@
  */
 class RsaDecryptionOperationFactory : public RsaCryptingOperationFactory {
     keymaster_purpose_t purpose() const override { return KM_PURPOSE_DECRYPT; }
-    Operation* InstantiateOperation(keymaster_padding_t padding, RSA* key) override {
+    Operation* InstantiateOperation(keymaster_padding_t padding, EVP_PKEY* key) override {
         return new RsaDecryptOperation(padding, key);
     }
 };