Merge "Remove test mapping file for acloud_test"
diff --git a/.travis.yml b/.travis.yml
index 0fed68b..ff7329c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -9,6 +9,7 @@
   - "3.6"
   - "3.7"
   - "3.8"
+  - "3.9"
 
 matrix:
   include:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0fa3054..3552260 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,16 @@
 # Python-RSA changelog
 
+## Version 4.7 - released 2021-01-10
+
+- Fix [#165](https://github.com/sybrenstuvel/python-rsa/issues/165]:
+  CVE-2020-25658 - Bleichenbacher-style timing oracle in PKCS#1 v1.5 decryption
+  code
+- Add padding length check as described by PKCS#1 v1.5 (Fixes
+  [#164](https://github.com/sybrenstuvel/python-rsa/issues/164))
+- Reuse of blinding factors to speed up blinding operations.
+  Fixes [#162](https://github.com/sybrenstuvel/python-rsa/issues/162).
+- Declare & test support for Python 3.9
+
 
 ## Version 4.4 & 4.6 - released 2020-06-12
 
@@ -12,7 +23,7 @@
 No functional changes compared to version 4.2.
 
 
-## Version 4.3 - released 2020-06-12
+## Version 4.3 & 4.5 - released 2020-06-12
 
 Version 4.3 and 4.5 are almost a re-tagged release of version 4.0. It is the
 last to support Python 2.7. This is now made explicit in the `python_requires`
diff --git a/METADATA b/METADATA
index 5c228ef..69d3a33 100644
--- a/METADATA
+++ b/METADATA
@@ -9,11 +9,11 @@
     type: GIT
     value: "https://github.com/sybrenstuvel/python-rsa/"
   }
-  version: "version-4.6"
+  version: "version-4.7"
   license_type: NOTICE
   last_upgrade_date {
-    year: 2020
-    month: 7
-    day: 10
+    year: 2021
+    month: 1
+    day: 11
   }
 }
diff --git a/README.md b/README.md
index ea24210..2684060 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,12 @@
 The source code is maintained at [GitHub](https://github.com/sybrenstuvel/python-rsa/) and is
 licensed under the [Apache License, version 2.0](https://www.apache.org/licenses/LICENSE-2.0)
 
+Security
+--------
+
+Because of how Python internally stores numbers, it is very hard (if not impossible) to make a pure-Python program secure against timing attacks. This library is no exception, so use it with care. See https://securitypitfalls.wordpress.com/2018/08/03/constant-time-compare-in-python/ for more info.
+
+
 Major changes in 4.1
 --------------------
 
diff --git a/doc/installation.rst b/doc/installation.rst
index 3ab3ab1..73f56e5 100644
--- a/doc/installation.rst
+++ b/doc/installation.rst
@@ -42,10 +42,10 @@
 
     git clone https://github.com/sybrenstuvel/python-rsa.git
 
-Use Poetry_ to install the development requirements in a virtual environment::
+Use Pipenv_ to install the development requirements in a virtual environment::
 
     cd python-rsa
-    poetry install
+    pipenv install --dev
 
 .. _Git: https://git-scm.com/
-.. _Poetry: https://poetry.eustace.io/
+.. _Pipenv: https://pipenv.pypa.io/en/latest/
diff --git a/doc/usage.rst b/doc/usage.rst
index b1244d4..f76765e 100644
--- a/doc/usage.rst
+++ b/doc/usage.rst
@@ -170,7 +170,7 @@
 :py:func:`rsa.sign` function:
 
     >>> (pubkey, privkey) = rsa.newkeys(512)
-    >>> message = 'Go left at the blue tree'
+    >>> message = 'Go left at the blue tree'.encode()
     >>> signature = rsa.sign(message, privkey, 'SHA-1')
 
 This hashes the message using SHA-1. Other hash methods are also
@@ -182,21 +182,21 @@
 private key on remote server). To hash a message use the :py:func:`rsa.compute_hash`
 function and then use the :py:func:`rsa.sign_hash` function to sign the hash:
 
-    >>> message = 'Go left at the blue tree'
+    >>> message = 'Go left at the blue tree'.encode()
     >>> hash = rsa.compute_hash(message, 'SHA-1')
     >>> signature = rsa.sign_hash(hash, privkey, 'SHA-1')
 
 In order to verify the signature, use the :py:func:`rsa.verify`
 function. This function returns True if the verification is successful:
 
-    >>> message = 'Go left at the blue tree'
+    >>> message = 'Go left at the blue tree'.encode()
     >>> rsa.verify(message, signature, pubkey)
     True
 
 Modify the message, and the signature is no longer valid and a
 :py:class:`rsa.pkcs1.VerificationError` is thrown:
 
-    >>> message = 'Go right at the blue tree'
+    >>> message = 'Go right at the blue tree'.encode()
     >>> rsa.verify(message, signature, pubkey)
     Traceback (most recent call last):
       File "<stdin>", line 1, in <module>
diff --git a/rsa/__init__.py b/rsa/__init__.py
index 1567dc1..26b28ca 100644
--- a/rsa/__init__.py
+++ b/rsa/__init__.py
@@ -26,8 +26,8 @@
     VerificationError, find_signature_hash,  sign_hash, compute_hash
 
 __author__ = "Sybren Stuvel, Barry Mead and Yesudeep Mangalapilly"
-__date__ = '2020-06-12'
-__version__ = '4.6'
+__date__ = '2021-01-10'
+__version__ = '4.7'
 
 # Do doctest if we're run directly
 if __name__ == "__main__":
diff --git a/rsa/common.py b/rsa/common.py
index e7df21d..b5a966a 100644
--- a/rsa/common.py
+++ b/rsa/common.py
@@ -49,8 +49,8 @@
 
     try:
         return num.bit_length()
-    except AttributeError:
-        raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
+    except AttributeError as ex:
+        raise TypeError('bit_size(num) only supports integers, not %r' % type(num)) from ex
 
 
 def byte_size(number: int) -> int:
diff --git a/rsa/key.py b/rsa/key.py
index b1e2030..e0e7b11 100644
--- a/rsa/key.py
+++ b/rsa/key.py
@@ -49,12 +49,15 @@
 class AbstractKey:
     """Abstract superclass for private and public keys."""
 
-    __slots__ = ('n', 'e')
+    __slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse')
 
     def __init__(self, n: int, e: int) -> None:
         self.n = n
         self.e = e
 
+        # These will be computed properly on the first call to blind().
+        self.blindfac = self.blindfac_inverse = -1
+
     @classmethod
     def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey':
         """Loads a key in PKCS#1 PEM format, implement in a subclass.
@@ -145,7 +148,7 @@
         method = self._assert_format_exists(format, methods)
         return method()
 
-    def blind(self, message: int, r: int) -> int:
+    def blind(self, message: int) -> int:
         """Performs blinding on the message using random number 'r'.
 
         :param message: the message, as integer, to blind.
@@ -159,10 +162,10 @@
 
         See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
         """
+        self._update_blinding_factor()
+        return (message * pow(self.blindfac, self.e, self.n)) % self.n
 
-        return (message * pow(r, self.e, self.n)) % self.n
-
-    def unblind(self, blinded: int, r: int) -> int:
+    def unblind(self, blinded: int) -> int:
         """Performs blinding on the message using random number 'r'.
 
         :param blinded: the blinded message, as integer, to unblind.
@@ -174,8 +177,27 @@
         See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29
         """
 
-        return (rsa.common.inverse(r, self.n) * blinded) % self.n
+        return (self.blindfac_inverse * blinded) % self.n
 
+    def _initial_blinding_factor(self) -> int:
+        for _ in range(1000):
+            blind_r = rsa.randnum.randint(self.n - 1)
+            if rsa.prime.are_relatively_prime(self.n, blind_r):
+                return blind_r
+        raise RuntimeError('unable to find blinding factor')
+
+    def _update_blinding_factor(self):
+        if self.blindfac < 0:
+            # Compute initial blinding factor, which is rather slow to do.
+            self.blindfac = self._initial_blinding_factor()
+            self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n)
+        else:
+            # Reuse previous blinding factor as per section 9 of 'A Timing
+            # Attack against RSA with the Chinese Remainder Theorem' by Werner
+            # Schindler.
+            # See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf
+            self.blindfac = pow(self.blindfac, 2, self.n)
+            self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n)
 
 class PublicKey(AbstractKey):
     """Represents a public RSA key.
@@ -414,13 +436,6 @@
     def __hash__(self) -> int:
         return hash((self.n, self.e, self.d, self.p, self.q, self.exp1, self.exp2, self.coef))
 
-    def _get_blinding_factor(self) -> int:
-        for _ in range(1000):
-            blind_r = rsa.randnum.randint(self.n - 1)
-            if rsa.prime.are_relatively_prime(self.n, blind_r):
-                return blind_r
-        raise RuntimeError('unable to find blinding factor')
-
     def blinded_decrypt(self, encrypted: int) -> int:
         """Decrypts the message using blinding to prevent side-channel attacks.
 
@@ -431,11 +446,9 @@
         :rtype: int
         """
 
-        blind_r = self._get_blinding_factor()
-        blinded = self.blind(encrypted, blind_r)  # blind before decrypting
+        blinded = self.blind(encrypted)  # blind before decrypting
         decrypted = rsa.core.decrypt_int(blinded, self.d, self.n)
-
-        return self.unblind(decrypted, blind_r)
+        return self.unblind(decrypted)
 
     def blinded_encrypt(self, message: int) -> int:
         """Encrypts the message using blinding to prevent side-channel attacks.
@@ -447,10 +460,9 @@
         :rtype: int
         """
 
-        blind_r = self._get_blinding_factor()
-        blinded = self.blind(message, blind_r)  # blind before encrypting
+        blinded = self.blind(message)  # blind before encrypting
         encrypted = rsa.core.encrypt_int(blinded, self.d, self.n)
-        return self.unblind(encrypted, blind_r)
+        return self.unblind(encrypted)
 
     @classmethod
     def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey':
diff --git a/rsa/pkcs1.py b/rsa/pkcs1.py
index 57b0276..07cf85b 100644
--- a/rsa/pkcs1.py
+++ b/rsa/pkcs1.py
@@ -30,6 +30,7 @@
 import os
 import sys
 import typing
+from hmac import compare_digest
 
 from . import common, transform, core, key
 
@@ -252,16 +253,24 @@
     # encrypted value (as leading zeroes do not influence the value of an
     # integer). This fixes CVE-2020-13757.
     if len(crypto) > blocksize:
+        # This is operating on public information, so doesn't need to be constant-time.
         raise DecryptionError('Decryption failed')
 
     # If we can't find the cleartext marker, decryption failed.
-    if cleartext[0:2] != b'\x00\x02':
-        raise DecryptionError('Decryption failed')
+    cleartext_marker_bad = not compare_digest(cleartext[:2], b'\x00\x02')
 
     # Find the 00 separator between the padding and the message
-    try:
-        sep_idx = cleartext.index(b'\x00', 2)
-    except ValueError:
+    sep_idx = cleartext.find(b'\x00', 2)
+
+    # sep_idx indicates the position of the `\x00` separator that separates the
+    # padding from the actual message. The padding should be at least 8 bytes
+    # long (see https://tools.ietf.org/html/rfc8017#section-7.2.2 step 3), which
+    # means the separator should be at least at index 10 (because of the
+    # `\x00\x02` marker that preceeds it).
+    sep_idx_bad = sep_idx < 10
+
+    anything_bad = cleartext_marker_bad | sep_idx_bad
+    if anything_bad:
         raise DecryptionError('Decryption failed')
 
     return cleartext[sep_idx + 1:]
diff --git a/setup.cfg b/setup.cfg
index e377bdb..4c5e567 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,3 @@
-[bdist_wheel]
-universal = 1
-
 [metadata]
 license_file = LICENSE
 
diff --git a/setup.py b/setup.py
index 2d22865..b983b1f 100755
--- a/setup.py
+++ b/setup.py
@@ -25,7 +25,7 @@
 
 if __name__ == '__main__':
     setup(name='rsa',
-          version='4.6',
+          version='4.7',
           description='Pure-Python RSA implementation',
           long_description=long_description,
           long_description_content_type='text/markdown',
@@ -49,6 +49,7 @@
               'Programming Language :: Python :: 3.6',
               'Programming Language :: Python :: 3.7',
               'Programming Language :: Python :: 3.8',
+              'Programming Language :: Python :: 3.9',
               'Programming Language :: Python :: Implementation :: CPython',
               'Programming Language :: Python :: Implementation :: PyPy',
               'Topic :: Security :: Cryptography',
diff --git a/tests/test_key.py b/tests/test_key.py
index 9db30ce..b00e26d 100644
--- a/tests/test_key.py
+++ b/tests/test_key.py
@@ -21,11 +21,20 @@
         message = 12345
         encrypted = rsa.core.encrypt_int(message, pk.e, pk.n)
 
-        blinded = pk.blind(encrypted, 4134431)  # blind before decrypting
-        decrypted = rsa.core.decrypt_int(blinded, pk.d, pk.n)
-        unblinded = pk.unblind(decrypted, 4134431)
+        blinded_1 = pk.blind(encrypted)  # blind before decrypting
+        decrypted = rsa.core.decrypt_int(blinded_1, pk.d, pk.n)
+        unblinded_1 = pk.unblind(decrypted)
 
-        self.assertEqual(unblinded, message)
+        self.assertEqual(unblinded_1, message)
+
+        # Re-blinding should use a different blinding factor.
+        blinded_2 = pk.blind(encrypted)  # blind before decrypting
+        self.assertNotEqual(blinded_1, blinded_2)
+
+        # The unblinding should still work, though.
+        decrypted = rsa.core.decrypt_int(blinded_2, pk.d, pk.n)
+        unblinded_2 = pk.unblind(decrypted)
+        self.assertEqual(unblinded_2, message)
 
 
 class KeyGenTest(unittest.TestCase):
diff --git a/tests/test_pkcs1.py b/tests/test_pkcs1.py
index f7baf7f..64fb0c5 100644
--- a/tests/test_pkcs1.py
+++ b/tests/test_pkcs1.py
@@ -183,3 +183,36 @@
         signature = signature + bytes.fromhex('0000')
         with self.assertRaises(rsa.VerificationError):
             pkcs1.verify(message, signature, self.pub)
+
+
+class PaddingSizeTest(unittest.TestCase):
+    def test_too_little_padding(self):
+        """Padding less than 8 bytes should be rejected."""
+
+        # Construct key that will be small enough to need only 7 bytes of padding.
+        # This key is 168 bit long, and was generated with rsa.newkeys(nbits=168).
+        self.private_key = rsa.PrivateKey.load_pkcs1(b'''
+-----BEGIN RSA PRIVATE KEY-----
+MHkCAQACFgCIGbbNSkIRLtprxka9NgOf5UxgxCMCAwEAAQIVQqymO0gHubdEVS68
+CdCiWmOJxVfRAgwBQM+e1JJwMKmxSF0CCmya6CFxO8Evdn8CDACMM3AlVC4FhlN8
+3QIKC9cjoam/swMirwIMAR7Br9tdouoH7jAE
+-----END RSA PRIVATE KEY-----
+        ''')
+        self.public_key = rsa.PublicKey(n=self.private_key.n, e=self.private_key.e)
+
+        cyphertext = self.encrypt_with_short_padding(b'op je hoofd')
+        with self.assertRaises(rsa.DecryptionError):
+            rsa.decrypt(cyphertext, self.private_key)
+
+    def encrypt_with_short_padding(self, message: bytes) -> bytes:
+        # This is a copy of rsa.pkcs1.encrypt() adjusted to use the wrong padding length.
+        keylength = rsa.common.byte_size(self.public_key.n)
+
+        # The word 'padding' has 7 letters, so is one byte short of a valid padding length.
+        padded = b'\x00\x02padding\x00' + message
+
+        payload = rsa.transform.bytes2int(padded)
+        encrypted_value = rsa.core.encrypt_int(payload, self.public_key.e, self.public_key.n)
+        cyphertext = rsa.transform.int2bytes(encrypted_value, keylength)
+
+        return cyphertext
diff --git a/tox.ini b/tox.ini
index 0552a17..1a033ae 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
 [tox]
 # Environment changes have to be manually synced with '.travis.yml'.
-envlist = py35,py36,p37,p38
+envlist = py35,py36,p37,p38,p39
 
 [pytest]
 addopts = -v --cov rsa --cov-report term-missing