| # Copyright 2014 Google Inc. All rights reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import base64 |
| import os |
| |
| import mock |
| import unittest2 |
| |
| from oauth2client import _helpers |
| from oauth2client import client |
| from oauth2client import crypt |
| from oauth2client import service_account |
| |
| |
| def data_filename(filename): |
| return os.path.join(os.path.dirname(__file__), 'data', filename) |
| |
| |
| def datafile(filename): |
| with open(data_filename(filename), 'rb') as file_obj: |
| return file_obj.read() |
| |
| |
| class Test__bad_pkcs12_key_as_pem(unittest2.TestCase): |
| |
| def test_fails(self): |
| with self.assertRaises(NotImplementedError): |
| crypt._bad_pkcs12_key_as_pem() |
| |
| |
| class Test_pkcs12_key_as_pem(unittest2.TestCase): |
| |
| def _make_svc_account_creds(self, private_key_file='privatekey.p12'): |
| filename = data_filename(private_key_file) |
| credentials = ( |
| service_account.ServiceAccountCredentials.from_p12_keyfile( |
| 'some_account@example.com', filename, |
| scopes='read+write')) |
| credentials._kwargs['sub'] = 'joe@example.org' |
| return credentials |
| |
| def _succeeds_helper(self, password=None): |
| self.assertEqual(True, client.HAS_OPENSSL) |
| |
| credentials = self._make_svc_account_creds() |
| if password is None: |
| password = credentials._private_key_password |
| pem_contents = crypt.pkcs12_key_as_pem( |
| credentials._private_key_pkcs12, password) |
| pkcs12_key_as_pem = datafile('pem_from_pkcs12.pem') |
| pkcs12_key_as_pem = _helpers._parse_pem_key(pkcs12_key_as_pem) |
| alternate_pem = datafile('pem_from_pkcs12_alternate.pem') |
| self.assertTrue(pem_contents in [pkcs12_key_as_pem, alternate_pem]) |
| |
| def test_succeeds(self): |
| self._succeeds_helper() |
| |
| def test_succeeds_with_unicode_password(self): |
| password = u'notasecret' |
| self._succeeds_helper(password) |
| |
| |
| class Test__verify_signature(unittest2.TestCase): |
| |
| def test_success_single_cert(self): |
| cert_value = 'cert-value' |
| certs = [cert_value] |
| message = object() |
| signature = object() |
| |
| verifier = mock.MagicMock() |
| verifier.verify = mock.MagicMock(name='verify', return_value=True) |
| with mock.patch('oauth2client.crypt.Verifier') as Verifier: |
| Verifier.from_string = mock.MagicMock(name='from_string', |
| return_value=verifier) |
| result = crypt._verify_signature(message, signature, certs) |
| self.assertEqual(result, None) |
| |
| # Make sure our mocks were called as expected. |
| Verifier.from_string.assert_called_once_with(cert_value, |
| is_x509_cert=True) |
| verifier.verify.assert_called_once_with(message, signature) |
| |
| def test_success_multiple_certs(self): |
| cert_value1 = 'cert-value1' |
| cert_value2 = 'cert-value2' |
| cert_value3 = 'cert-value3' |
| certs = [cert_value1, cert_value2, cert_value3] |
| message = object() |
| signature = object() |
| |
| verifier = mock.MagicMock() |
| # Use side_effect to force all 3 cert values to be used by failing |
| # to verify on the first two. |
| verifier.verify = mock.MagicMock(name='verify', |
| side_effect=[False, False, True]) |
| with mock.patch('oauth2client.crypt.Verifier') as Verifier: |
| Verifier.from_string = mock.MagicMock(name='from_string', |
| return_value=verifier) |
| result = crypt._verify_signature(message, signature, certs) |
| self.assertEqual(result, None) |
| |
| # Make sure our mocks were called three times. |
| expected_from_string_calls = [ |
| mock.call(cert_value1, is_x509_cert=True), |
| mock.call(cert_value2, is_x509_cert=True), |
| mock.call(cert_value3, is_x509_cert=True), |
| ] |
| self.assertEqual(Verifier.from_string.mock_calls, |
| expected_from_string_calls) |
| expected_verify_calls = [mock.call(message, signature)] * 3 |
| self.assertEqual(verifier.verify.mock_calls, |
| expected_verify_calls) |
| |
| def test_failure(self): |
| cert_value = 'cert-value' |
| certs = [cert_value] |
| message = object() |
| signature = object() |
| |
| verifier = mock.MagicMock() |
| verifier.verify = mock.MagicMock(name='verify', return_value=False) |
| with mock.patch('oauth2client.crypt.Verifier') as Verifier: |
| Verifier.from_string = mock.MagicMock(name='from_string', |
| return_value=verifier) |
| with self.assertRaises(crypt.AppIdentityError): |
| crypt._verify_signature(message, signature, certs) |
| |
| # Make sure our mocks were called as expected. |
| Verifier.from_string.assert_called_once_with(cert_value, |
| is_x509_cert=True) |
| verifier.verify.assert_called_once_with(message, signature) |
| |
| |
| class Test__check_audience(unittest2.TestCase): |
| |
| def test_null_audience(self): |
| result = crypt._check_audience(None, None) |
| self.assertEqual(result, None) |
| |
| def test_success(self): |
| audience = 'audience' |
| payload_dict = {'aud': audience} |
| result = crypt._check_audience(payload_dict, audience) |
| # No exception and no result. |
| self.assertEqual(result, None) |
| |
| def test_missing_aud(self): |
| audience = 'audience' |
| payload_dict = {} |
| with self.assertRaises(crypt.AppIdentityError): |
| crypt._check_audience(payload_dict, audience) |
| |
| def test_wrong_aud(self): |
| audience1 = 'audience1' |
| audience2 = 'audience2' |
| self.assertNotEqual(audience1, audience2) |
| payload_dict = {'aud': audience1} |
| with self.assertRaises(crypt.AppIdentityError): |
| crypt._check_audience(payload_dict, audience2) |
| |
| |
| class Test__verify_time_range(unittest2.TestCase): |
| |
| def _exception_helper(self, payload_dict): |
| exception_caught = None |
| try: |
| crypt._verify_time_range(payload_dict) |
| except crypt.AppIdentityError as exc: |
| exception_caught = exc |
| |
| return exception_caught |
| |
| def test_without_issued_at(self): |
| payload_dict = {} |
| exception_caught = self._exception_helper(payload_dict) |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'No iat field in token')) |
| |
| def test_without_expiration(self): |
| payload_dict = {'iat': 'iat'} |
| exception_caught = self._exception_helper(payload_dict) |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'No exp field in token')) |
| |
| def test_with_bad_token_lifetime(self): |
| current_time = 123456 |
| payload_dict = { |
| 'iat': 'iat', |
| 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS + 1, |
| } |
| with mock.patch('oauth2client.crypt.time') as time: |
| time.time = mock.MagicMock(name='time', |
| return_value=current_time) |
| |
| exception_caught = self._exception_helper(payload_dict) |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'exp field too far in future')) |
| |
| def test_with_issued_at_in_future(self): |
| current_time = 123456 |
| payload_dict = { |
| 'iat': current_time + crypt.CLOCK_SKEW_SECS + 1, |
| 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1, |
| } |
| with mock.patch('oauth2client.crypt.time') as time: |
| time.time = mock.MagicMock(name='time', |
| return_value=current_time) |
| |
| exception_caught = self._exception_helper(payload_dict) |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'Token used too early')) |
| |
| def test_with_expiration_in_the_past(self): |
| current_time = 123456 |
| payload_dict = { |
| 'iat': current_time, |
| 'exp': current_time - crypt.CLOCK_SKEW_SECS - 1, |
| } |
| with mock.patch('oauth2client.crypt.time') as time: |
| time.time = mock.MagicMock(name='time', |
| return_value=current_time) |
| |
| exception_caught = self._exception_helper(payload_dict) |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'Token used too late')) |
| |
| def test_success(self): |
| current_time = 123456 |
| payload_dict = { |
| 'iat': current_time, |
| 'exp': current_time + crypt.MAX_TOKEN_LIFETIME_SECS - 1, |
| } |
| with mock.patch('oauth2client.crypt.time') as time: |
| time.time = mock.MagicMock(name='time', |
| return_value=current_time) |
| |
| exception_caught = self._exception_helper(payload_dict) |
| self.assertEqual(exception_caught, None) |
| |
| |
| class Test_verify_signed_jwt_with_certs(unittest2.TestCase): |
| |
| def test_jwt_no_segments(self): |
| exception_caught = None |
| try: |
| crypt.verify_signed_jwt_with_certs(b'', None) |
| except crypt.AppIdentityError as exc: |
| exception_caught = exc |
| |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'Wrong number of segments in token')) |
| |
| def test_jwt_payload_bad_json(self): |
| header = signature = b'' |
| payload = base64.b64encode(b'{BADJSON') |
| jwt = b'.'.join([header, payload, signature]) |
| |
| exception_caught = None |
| try: |
| crypt.verify_signed_jwt_with_certs(jwt, None) |
| except crypt.AppIdentityError as exc: |
| exception_caught = exc |
| |
| self.assertNotEqual(exception_caught, None) |
| self.assertTrue(str(exception_caught).startswith( |
| 'Can\'t parse token')) |
| |
| @mock.patch('oauth2client.crypt._check_audience') |
| @mock.patch('oauth2client.crypt._verify_time_range') |
| @mock.patch('oauth2client.crypt._verify_signature') |
| def test_success(self, verify_sig, verify_time, check_aud): |
| certs = mock.MagicMock() |
| cert_values = object() |
| certs.values = mock.MagicMock(name='values', |
| return_value=cert_values) |
| audience = object() |
| |
| header = b'header' |
| signature_bytes = b'signature' |
| signature = base64.b64encode(signature_bytes) |
| payload_dict = {'a': 'b'} |
| payload = base64.b64encode(b'{"a": "b"}') |
| jwt = b'.'.join([header, payload, signature]) |
| |
| result = crypt.verify_signed_jwt_with_certs( |
| jwt, certs, audience=audience) |
| self.assertEqual(result, payload_dict) |
| |
| message_to_sign = header + b'.' + payload |
| verify_sig.assert_called_once_with( |
| message_to_sign, signature_bytes, cert_values) |
| verify_time.assert_called_once_with(payload_dict) |
| check_aud.assert_called_once_with(payload_dict, audience) |
| certs.values.assert_called_once_with() |