| # -*- coding: utf-8 -*- |
| # |
| # Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu> |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from rsa._compat import zip |
| |
| """Common functionality shared by several modules.""" |
| |
| |
| class NotRelativePrimeError(ValueError): |
| def __init__(self, a, b, d, msg=None): |
| super(NotRelativePrimeError, self).__init__( |
| msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d)) |
| self.a = a |
| self.b = b |
| self.d = d |
| |
| |
| def bit_size(num): |
| """ |
| Number of bits needed to represent a integer excluding any prefix |
| 0 bits. |
| |
| Usage:: |
| |
| >>> bit_size(1023) |
| 10 |
| >>> bit_size(1024) |
| 11 |
| >>> bit_size(1025) |
| 11 |
| |
| :param num: |
| Integer value. If num is 0, returns 0. Only the absolute value of the |
| number is considered. Therefore, signed integers will be abs(num) |
| before the number's bit length is determined. |
| :returns: |
| Returns the number of bits in the integer. |
| """ |
| |
| try: |
| return num.bit_length() |
| except AttributeError: |
| raise TypeError('bit_size(num) only supports integers, not %r' % type(num)) |
| |
| |
| def byte_size(number): |
| """ |
| Returns the number of bytes required to hold a specific long number. |
| |
| The number of bytes is rounded up. |
| |
| Usage:: |
| |
| >>> byte_size(1 << 1023) |
| 128 |
| >>> byte_size((1 << 1024) - 1) |
| 128 |
| >>> byte_size(1 << 1024) |
| 129 |
| |
| :param number: |
| An unsigned integer |
| :returns: |
| The number of bytes required to hold a specific long number. |
| """ |
| if number == 0: |
| return 1 |
| return ceil_div(bit_size(number), 8) |
| |
| |
| def ceil_div(num, div): |
| """ |
| Returns the ceiling function of a division between `num` and `div`. |
| |
| Usage:: |
| |
| >>> ceil_div(100, 7) |
| 15 |
| >>> ceil_div(100, 10) |
| 10 |
| >>> ceil_div(1, 4) |
| 1 |
| |
| :param num: Division's numerator, a number |
| :param div: Division's divisor, a number |
| |
| :return: Rounded up result of the division between the parameters. |
| """ |
| quanta, mod = divmod(num, div) |
| if mod: |
| quanta += 1 |
| return quanta |
| |
| |
| def extended_gcd(a, b): |
| """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb |
| """ |
| # r = gcd(a,b) i = multiplicitive inverse of a mod b |
| # or j = multiplicitive inverse of b mod a |
| # Neg return values for i or j are made positive mod b or a respectively |
| # Iterateive Version is faster and uses much less stack space |
| x = 0 |
| y = 1 |
| lx = 1 |
| ly = 0 |
| oa = a # Remember original a/b to remove |
| ob = b # negative values from return results |
| while b != 0: |
| q = a // b |
| (a, b) = (b, a % b) |
| (x, lx) = ((lx - (q * x)), x) |
| (y, ly) = ((ly - (q * y)), y) |
| if lx < 0: |
| lx += ob # If neg wrap modulo orignal b |
| if ly < 0: |
| ly += oa # If neg wrap modulo orignal a |
| return a, lx, ly # Return only positive values |
| |
| |
| def inverse(x, n): |
| """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n) |
| |
| >>> inverse(7, 4) |
| 3 |
| >>> (inverse(143, 4) * 143) % 4 |
| 1 |
| """ |
| |
| (divider, inv, _) = extended_gcd(x, n) |
| |
| if divider != 1: |
| raise NotRelativePrimeError(x, n, divider) |
| |
| return inv |
| |
| |
| def crt(a_values, modulo_values): |
| """Chinese Remainder Theorem. |
| |
| Calculates x such that x = a[i] (mod m[i]) for each i. |
| |
| :param a_values: the a-values of the above equation |
| :param modulo_values: the m-values of the above equation |
| :returns: x such that x = a[i] (mod m[i]) for each i |
| |
| |
| >>> crt([2, 3], [3, 5]) |
| 8 |
| |
| >>> crt([2, 3, 2], [3, 5, 7]) |
| 23 |
| |
| >>> crt([2, 3, 0], [7, 11, 15]) |
| 135 |
| """ |
| |
| m = 1 |
| x = 0 |
| |
| for modulo in modulo_values: |
| m *= modulo |
| |
| for (m_i, a_i) in zip(modulo_values, a_values): |
| M_i = m // m_i |
| inv = inverse(M_i, m_i) |
| |
| x = (x + a_i * M_i * inv) % m |
| |
| return x |
| |
| |
| if __name__ == '__main__': |
| import doctest |
| |
| doctest.testmod() |