Improve constant-time padding check in RSA key exchange.

Although the PKCS#1 padding check is internally constant-time, it is not
constant time at the crypto/ ssl/ API boundary. Expose a constant-time
RSA_message_index_PKCS1_type_2 function and integrate it into the
timing-sensitive portion of the RSA key exchange logic.

Change-Id: I6fa64ddc9d65564d05529d9b2985da7650d058c3
Reviewed-on: https://boringssl-review.googlesource.com/1301
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/rsa/padding.c b/crypto/rsa/padding.c
index 6b6b0e3..c7b088f 100644
--- a/crypto/rsa/padding.c
+++ b/crypto/rsa/padding.c
@@ -220,12 +220,48 @@
   return ((x - y - 1) >> (sizeof(int) * 8 - 1)) & 1;
 }
 
+int RSA_message_index_PKCS1_type_2(const uint8_t *from, size_t from_len,
+                                   size_t *out_index) {
+  size_t i;
+  int first_byte_is_zero, second_byte_is_two, looking_for_index;
+  int valid_index, zero_index = 0;
+
+  /* PKCS#1 v1.5 decryption. See "PKCS #1 v2.2: RSA Cryptography
+   * Standard", section 7.2.2. */
+  if (from_len < RSA_PKCS1_PADDING_SIZE) {
+    return 0;
+  }
+
+  first_byte_is_zero = constant_time_byte_eq(from[0], 0);
+  second_byte_is_two = constant_time_byte_eq(from[1], 2);
+
+  looking_for_index = 1;
+  for (i = 2; i < from_len; i++) {
+    int equals0 = constant_time_byte_eq(from[i], 0);
+    zero_index =
+        constant_time_select(looking_for_index & equals0, i, zero_index);
+    looking_for_index = constant_time_select(equals0, 0, looking_for_index);
+  }
+
+  /* The input must begin with 00 02. */
+  valid_index = first_byte_is_zero;
+  valid_index &= second_byte_is_two;
+
+  /* We must have found the end of PS. */
+  valid_index &= ~looking_for_index;
+
+  /* PS must be at least 8 bytes long, and it starts two bytes into |from|. */
+  valid_index &= constant_time_le(2 + 8, zero_index);
+
+  /* Skip the zero byte. */
+  *out_index = zero_index + 1;
+
+  return valid_index;
+}
+
 int RSA_padding_check_PKCS1_type_2(uint8_t *to, unsigned tlen,
                                    const uint8_t *from, unsigned flen) {
-  size_t i;
-  int ret = -1;
-  int first_byte_is_zero, second_byte_is_two, looking_for_index;
-  int valid_index, zero_index = 0, msg_index;
+  size_t msg_index, msg_len;
 
   if (flen == 0) {
     OPENSSL_PUT_ERROR(RSA, RSA_padding_check_PKCS1_type_2,
@@ -233,44 +269,26 @@
     return -1;
   }
 
-  /* PKCS#1 v1.5 decryption. See "PKCS #1 v2.2: RSA Cryptography
-   * Standard", section 7.2.2. */
-
-  if (flen < RSA_PKCS1_PADDING_SIZE) {
-    goto err;
-  }
-
-  first_byte_is_zero = constant_time_byte_eq(from[0], 0);
-  second_byte_is_two = constant_time_byte_eq(from[1], 2);
-
-  looking_for_index = 1;
-  for (i = 2; i < flen; i++) {
-    int equals0 = constant_time_byte_eq(from[i], 0);
-    zero_index =
-        constant_time_select(looking_for_index & equals0, i, zero_index);
-    looking_for_index = constant_time_select(equals0, 0, looking_for_index);
-  }
-
-  /* PS must be at least 8 bytes long, and it starts two bytes into |from|. */
-  valid_index = constant_time_le(2 + 8, zero_index);
-  /* Skip the zero byte. */
-  msg_index = zero_index + 1;
-  valid_index &= constant_time_le(flen - msg_index, tlen);
-
-  if (!(first_byte_is_zero & second_byte_is_two & ~looking_for_index &
-        valid_index)) {
-    goto err;
-  }
-
-  ret = flen - msg_index;
-  memcpy(to, &from[msg_index], ret);
-
-err:
-  if (ret == -1) {
+  /* NOTE: Although |RSA_message_index_PKCS1_type_2| itself is constant time,
+   * the API contracts of this function and |RSA_decrypt| with
+   * |RSA_PKCS1_PADDING| make it impossible to completely avoid Bleichenbacker's
+   * attack. */
+  if (!RSA_message_index_PKCS1_type_2(from, flen, &msg_index)) {
     OPENSSL_PUT_ERROR(RSA, RSA_padding_check_PKCS1_type_2,
                       RSA_R_PKCS_DECODING_ERROR);
+    return -1;
   }
-  return ret;
+
+  msg_len = flen - msg_index;
+  if (msg_len > tlen) {
+    /* This shouldn't happen because this function is always called with |tlen|
+     * the key size and |flen| is bounded by the key size. */
+    OPENSSL_PUT_ERROR(RSA, RSA_padding_check_PKCS1_type_2,
+                      RSA_R_PKCS_DECODING_ERROR);
+    return -1;
+  }
+  memcpy(to, &from[msg_index], msg_len);
+  return msg_len;
 }
 
 int RSA_padding_add_none(uint8_t *to, unsigned tlen, const uint8_t *from, unsigned flen) {
diff --git a/include/openssl/rsa.h b/include/openssl/rsa.h
index 89c56ed..6f1fce1 100644
--- a/include/openssl/rsa.h
+++ b/include/openssl/rsa.h
@@ -154,6 +154,20 @@
 int RSA_private_decrypt(int flen, const uint8_t *from, uint8_t *to, RSA *rsa,
                         int padding);
 
+/* RSA_message_index_PKCS1_type_2 performs the first step of a PKCS #1 padding
+ * check for decryption. If the |from_len| bytes pointed to at |from| are a
+ * valid PKCS #1 message, it returns one and sets |*out_index| to the start of
+ * the unpadded message. The unpadded message is a suffix of the input and has
+ * length |from_len - *out_index|. Otherwise, it returns zero and sets
+ * |*out_index| to some undefined value. This function runs in time independent
+ * of the input data and is intended to be used directly to avoid
+ * Bleichenbacker's attack.
+ *
+ * WARNING: This function behaves differently from the usual OpenSSL convention
+ * in that it does NOT put an error on the queue in the error case. */
+int RSA_message_index_PKCS1_type_2(const uint8_t *from, size_t from_len,
+                                   size_t *out_index);
+
 
 /* Signing / Verification */
 
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index c9c01fd..6fce8da 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -1897,6 +1897,7 @@
 	size_t premaster_secret_len = 0;
 	int skip_certificate_verify = 0;
 	RSA *rsa=NULL;
+	uint8_t *decrypt_buf = NULL;
 	EVP_PKEY *pkey=NULL;
 #ifndef OPENSSL_NO_DH
 	BIGNUM *pub=NULL;
@@ -1986,10 +1987,10 @@
 	if (alg_k & SSL_kRSA)
 		{
 		CBS encrypted_premaster_secret;
-		unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
-		int decrypt_len, decrypt_good_mask;
-		unsigned char version_good;
-		size_t j;
+		uint8_t rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
+		int decrypt_good_mask;
+		uint8_t version_good;
+		size_t rsa_size, decrypt_len, premaster_index, j;
 
 		pkey=s->cert->pkeys[SSL_PKEY_RSA_ENC].privatekey;
 		if (	(pkey == NULL) ||
@@ -2028,13 +2029,13 @@
 		else
 			encrypted_premaster_secret = client_key_exchange;
 
-		/* Reject overly short RSA ciphertext because we want to be
-		 * sure that the buffer size makes it safe to iterate over the
-		 * entire size of a premaster secret
-		 * (SSL_MAX_MASTER_KEY_LENGTH). The actual expected size is
-		 * larger due to RSA padding, but the bound is sufficient to be
-		 * safe. */
-		if (CBS_len(&encrypted_premaster_secret) < SSL_MAX_MASTER_KEY_LENGTH)
+		/* Reject overly short RSA keys because we want to be sure that
+		 * the buffer size makes it safe to iterate over the entire size
+		 * of a premaster secret (SSL_MAX_MASTER_KEY_LENGTH). The actual
+		 * expected size is larger due to RSA padding, but the bound is
+		 * sufficient to be safe. */
+		rsa_size = RSA_size(rsa);
+		if (rsa_size < SSL_MAX_MASTER_KEY_LENGTH)
 			{
 			al = SSL_AD_DECRYPT_ERROR;
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, SSL_R_DECRYPTION_FAILED);
@@ -2052,25 +2053,55 @@
 			goto err;
 
 		/* Allocate a buffer large enough for an RSA decryption. */
-		premaster_secret = OPENSSL_malloc(RSA_size(rsa));
-		if (premaster_secret == NULL)
+		decrypt_buf = OPENSSL_malloc(rsa_size);
+		if (decrypt_buf == NULL)
 			{
 			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
 			goto err;
 			}
 
-		decrypt_len = RSA_private_decrypt(
-			CBS_len(&encrypted_premaster_secret),
-			CBS_data(&encrypted_premaster_secret),
-			premaster_secret,
-			rsa,
-			RSA_PKCS1_PADDING);
+		/* Decrypt with no padding. PKCS#1 padding will be removed as
+		 * part of the timing-sensitive code below. */
+		if (!RSA_decrypt(rsa, &decrypt_len, decrypt_buf, rsa_size,
+				CBS_data(&encrypted_premaster_secret),
+				CBS_len(&encrypted_premaster_secret),
+				RSA_NO_PADDING))
+			{
+			goto err;
+			}
+		if (decrypt_len != rsa_size)
+			{
+			/* This should never happen, but do a check so we do not
+			 * read uninitialized memory. */
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_INTERNAL_ERROR);
+			goto err;
+			}
 
-		ERR_clear_error();
+		/* Remove the PKCS#1 padding and adjust decrypt_len as
+		 * appropriate. decrypt_good_mask will be zero if the premaster
+		 * if good and non-zero otherwise. */
+		decrypt_good_mask = RSA_message_index_PKCS1_type_2(
+			decrypt_buf, decrypt_len, &premaster_index);
+		decrypt_good_mask--;
+		decrypt_len = decrypt_len - premaster_index;
 
-		/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH.
-		 * decrypt_good_mask will be zero if so and non-zero otherwise. */
-		decrypt_good_mask = decrypt_len ^ SSL_MAX_MASTER_KEY_LENGTH;
+		/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH. */
+		decrypt_good_mask |= decrypt_len ^ SSL_MAX_MASTER_KEY_LENGTH;
+
+		/* Copy over the unpadded premaster. Whatever the value of
+		 * |decrypt_good_mask|, copy as if the premaster were the right
+		 * length. It is important the memory access pattern be
+		 * constant. */
+		premaster_secret = BUF_memdup(
+			decrypt_buf + (rsa_size - SSL_MAX_MASTER_KEY_LENGTH),
+			SSL_MAX_MASTER_KEY_LENGTH);
+		if (premaster_secret == NULL)
+			{
+			OPENSSL_PUT_ERROR(SSL, ssl3_get_client_key_exchange, ERR_R_MALLOC_FAILURE);
+			goto err;
+			}
+		OPENSSL_free(decrypt_buf);
+		decrypt_buf = NULL;
 
 		/* If the version in the decrypted pre-master secret is correct
 		 * then version_good will be zero. The Klima-Pokorny-Rosa
@@ -2465,6 +2496,8 @@
 			OPENSSL_cleanse(premaster_secret, premaster_secret_len);
 		OPENSSL_free(premaster_secret);
 		}
+	if (decrypt_buf)
+		OPENSSL_free(decrypt_buf);
 #ifndef OPENSSL_NO_ECDH
 	EVP_PKEY_free(clnt_pub_pkey);
 	EC_POINT_free(clnt_ecpoint);