| # Copyright 2015 Google Inc. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Tests for oauth2client.contrib.devshell.""" |
| |
| import datetime |
| import json |
| import os |
| import socket |
| import threading |
| |
| import mock |
| import unittest2 |
| |
| from oauth2client import _helpers |
| from oauth2client import client |
| from oauth2client.contrib import devshell |
| |
| # A dummy value to use for the expires_in field |
| # in CredentialInfoResponse. |
| EXPIRES_IN = 1000 |
| DEFAULT_CREDENTIAL_JSON = json.dumps([ |
| 'joe@example.com', |
| 'fooproj', |
| 'sometoken', |
| EXPIRES_IN |
| ]) |
| |
| |
| class TestCredentialInfoResponse(unittest2.TestCase): |
| |
| def test_constructor_with_non_list(self): |
| json_non_list = '{}' |
| with self.assertRaises(ValueError): |
| devshell.CredentialInfoResponse(json_non_list) |
| |
| def test_constructor_with_bad_json(self): |
| json_non_list = '{BADJSON' |
| with self.assertRaises(ValueError): |
| devshell.CredentialInfoResponse(json_non_list) |
| |
| def test_constructor_empty_list(self): |
| info_response = devshell.CredentialInfoResponse('[]') |
| self.assertEqual(info_response.user_email, None) |
| self.assertEqual(info_response.project_id, None) |
| self.assertEqual(info_response.access_token, None) |
| self.assertEqual(info_response.expires_in, None) |
| |
| def test_constructor_full_list(self): |
| user_email = 'user_email' |
| project_id = 'project_id' |
| access_token = 'access_token' |
| expires_in = 1 |
| json_string = json.dumps( |
| [user_email, project_id, access_token, expires_in]) |
| info_response = devshell.CredentialInfoResponse(json_string) |
| self.assertEqual(info_response.user_email, user_email) |
| self.assertEqual(info_response.project_id, project_id) |
| self.assertEqual(info_response.access_token, access_token) |
| self.assertEqual(info_response.expires_in, expires_in) |
| |
| |
| class Test_SendRecv(unittest2.TestCase): |
| |
| def test_port_zero(self): |
| with mock.patch('oauth2client.contrib.devshell.os') as os_mod: |
| os_mod.getenv = mock.MagicMock(name='getenv', return_value=0) |
| with self.assertRaises(devshell.NoDevshellServer): |
| devshell._SendRecv() |
| os_mod.getenv.assert_called_once_with(devshell.DEVSHELL_ENV, 0) |
| |
| def test_no_newline_in_received_header(self): |
| non_zero_port = 1 |
| sock = mock.MagicMock() |
| |
| header_without_newline = '' |
| sock.recv(6).decode = mock.MagicMock( |
| name='decode', return_value=header_without_newline) |
| |
| with mock.patch('oauth2client.contrib.devshell.os') as os_mod: |
| os_mod.getenv = mock.MagicMock(name='getenv', |
| return_value=non_zero_port) |
| with mock.patch('oauth2client.contrib.devshell.socket') as socket: |
| socket.socket = mock.MagicMock(name='socket', |
| return_value=sock) |
| with self.assertRaises(devshell.CommunicationError): |
| devshell._SendRecv() |
| os_mod.getenv.assert_called_once_with(devshell.DEVSHELL_ENV, 0) |
| socket.socket.assert_called_once_with() |
| sock.recv(6).decode.assert_called_once_with() |
| |
| data = devshell.CREDENTIAL_INFO_REQUEST_JSON |
| msg = _helpers._to_bytes( |
| '{0}\n{1}'.format(len(data), data), encoding='utf-8') |
| expected_sock_calls = [ |
| mock.call.recv(6), # From the set-up above |
| mock.call.connect(('localhost', non_zero_port)), |
| mock.call.sendall(msg), |
| mock.call.recv(6), |
| mock.call.recv(6), # From the check above |
| ] |
| self.assertEqual(sock.method_calls, expected_sock_calls) |
| |
| |
| class _AuthReferenceServer(threading.Thread): |
| |
| def __init__(self, response=None): |
| super(_AuthReferenceServer, self).__init__(None) |
| self.response = response or DEFAULT_CREDENTIAL_JSON |
| self.bad_request = False |
| |
| def __enter__(self): |
| return self.start_server() |
| |
| def start_server(self): |
| self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| self._socket.bind(('localhost', 0)) |
| port = self._socket.getsockname()[1] |
| os.environ[devshell.DEVSHELL_ENV] = str(port) |
| self._socket.listen(0) |
| self.daemon = True |
| self.start() |
| return self |
| |
| def __exit__(self, e_type, value, traceback): |
| self.stop_server() |
| |
| def stop_server(self): |
| del os.environ[devshell.DEVSHELL_ENV] |
| self._socket.close() |
| |
| def run(self): |
| s = None |
| try: |
| # Do not set the timeout on the socket, leave it in the blocking |
| # mode as setting the timeout seems to cause spurious EAGAIN |
| # errors on OSX. |
| self._socket.settimeout(None) |
| |
| s, unused_addr = self._socket.accept() |
| resp_buffer = '' |
| resp_1 = s.recv(6).decode() |
| nstr, extra = resp_1.split('\n', 1) |
| resp_buffer = extra |
| n = int(nstr) |
| to_read = n - len(extra) |
| if to_read > 0: |
| resp_buffer += _helpers._from_bytes( |
| s.recv(to_read, socket.MSG_WAITALL)) |
| if resp_buffer != devshell.CREDENTIAL_INFO_REQUEST_JSON: |
| self.bad_request = True |
| l = len(self.response) |
| s.sendall('{0}\n{1}'.format(l, self.response).encode()) |
| finally: |
| # Will fail if s is None, but these tests never encounter |
| # that scenario. |
| s.close() |
| |
| |
| class DevshellCredentialsTests(unittest2.TestCase): |
| |
| def test_signals_no_server(self): |
| with self.assertRaises(devshell.NoDevshellServer): |
| devshell.DevshellCredentials() |
| |
| def test_bad_message_to_mock_server(self): |
| request_content = devshell.CREDENTIAL_INFO_REQUEST_JSON + 'extrastuff' |
| request_message = _helpers._to_bytes( |
| '{0}\n{1}'.format(len(request_content), request_content)) |
| response_message = 'foobar' |
| with _AuthReferenceServer(response_message) as auth_server: |
| self.assertFalse(auth_server.bad_request) |
| sock = socket.socket() |
| port = int(os.getenv(devshell.DEVSHELL_ENV, 0)) |
| sock.connect(('localhost', port)) |
| sock.sendall(request_message) |
| |
| # Mimic the receive part of _SendRecv |
| header = sock.recv(6).decode() |
| len_str, result = header.split('\n', 1) |
| to_read = int(len_str) - len(result) |
| result += sock.recv(to_read, socket.MSG_WAITALL).decode() |
| |
| self.assertTrue(auth_server.bad_request) |
| self.assertEqual(result, response_message) |
| |
| def test_request_response(self): |
| with _AuthReferenceServer(): |
| response = devshell._SendRecv() |
| self.assertEqual(response.user_email, 'joe@example.com') |
| self.assertEqual(response.project_id, 'fooproj') |
| self.assertEqual(response.access_token, 'sometoken') |
| |
| def test_no_refresh_token(self): |
| with _AuthReferenceServer(): |
| creds = devshell.DevshellCredentials() |
| self.assertEquals(None, creds.refresh_token) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| def test_reads_credentials(self, utcnow): |
| NOW = datetime.datetime(1992, 12, 31) |
| utcnow.return_value = NOW |
| with _AuthReferenceServer(): |
| creds = devshell.DevshellCredentials() |
| self.assertEqual('joe@example.com', creds.user_email) |
| self.assertEqual('fooproj', creds.project_id) |
| self.assertEqual('sometoken', creds.access_token) |
| self.assertEqual( |
| NOW + datetime.timedelta(seconds=EXPIRES_IN), |
| creds.token_expiry) |
| utcnow.assert_called_once_with() |
| |
| def test_handles_skipped_fields(self): |
| with _AuthReferenceServer('["joe@example.com"]'): |
| creds = devshell.DevshellCredentials() |
| self.assertEqual('joe@example.com', creds.user_email) |
| self.assertEqual(None, creds.project_id) |
| self.assertEqual(None, creds.access_token) |
| self.assertEqual(None, creds.token_expiry) |
| |
| def test_handles_tiny_response(self): |
| with _AuthReferenceServer('[]'): |
| creds = devshell.DevshellCredentials() |
| self.assertEqual(None, creds.user_email) |
| self.assertEqual(None, creds.project_id) |
| self.assertEqual(None, creds.access_token) |
| |
| def test_handles_ignores_extra_fields(self): |
| with _AuthReferenceServer( |
| '["joe@example.com", "fooproj", "sometoken", 1, "extra"]'): |
| creds = devshell.DevshellCredentials() |
| self.assertEqual('joe@example.com', creds.user_email) |
| self.assertEqual('fooproj', creds.project_id) |
| self.assertEqual('sometoken', creds.access_token) |
| |
| def test_refuses_to_save_to_well_known_file(self): |
| ORIGINAL_ISDIR = os.path.isdir |
| try: |
| os.path.isdir = lambda path: True |
| with _AuthReferenceServer(): |
| creds = devshell.DevshellCredentials() |
| with self.assertRaises(NotImplementedError): |
| client.save_to_well_known_file(creds) |
| finally: |
| os.path.isdir = ORIGINAL_ISDIR |
| |
| def test_from_json(self): |
| with self.assertRaises(NotImplementedError): |
| devshell.DevshellCredentials.from_json(None) |
| |
| def test_serialization_data(self): |
| with _AuthReferenceServer('[]'): |
| credentials = devshell.DevshellCredentials() |
| with self.assertRaises(NotImplementedError): |
| getattr(credentials, 'serialization_data') |