blob: b00085c4b2367a8b17a012705790218d8a516063 [file] [log] [blame]
#
# Copyright 2015 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import datetime
import sys
import contextlib
import six
from six.moves import http_client
from six.moves import urllib_parse
import unittest2
from apitools.base.protorpclite import message_types
from apitools.base.protorpclite import messages
from apitools.base.py import base_api
from apitools.base.py import encoding
from apitools.base.py import exceptions
from apitools.base.py import http_wrapper
@contextlib.contextmanager
def mock(module, fn_name, patch):
unpatch = getattr(module, fn_name)
setattr(module, fn_name, patch)
try:
yield
finally:
setattr(module, fn_name, unpatch)
class SimpleMessage(messages.Message):
field = messages.StringField(1)
bytes_field = messages.BytesField(2)
class MessageWithTime(messages.Message):
timestamp = message_types.DateTimeField(1)
class MessageWithRemappings(messages.Message):
class AnEnum(messages.Enum):
value_one = 1
value_two = 2
str_field = messages.StringField(1)
enum_field = messages.EnumField('AnEnum', 2)
encoding.AddCustomJsonFieldMapping(
MessageWithRemappings, 'str_field', 'remapped_field')
encoding.AddCustomJsonEnumMapping(
MessageWithRemappings.AnEnum, 'value_one', 'ONE/TWO')
class StandardQueryParameters(messages.Message):
field = messages.StringField(1)
prettyPrint = messages.BooleanField(
5, default=True) # pylint: disable=invalid-name
pp = messages.BooleanField(6, default=True)
nextPageToken = messages.BytesField(7) # pylint:disable=invalid-name
class FakeCredentials(object):
def authorize(self, _): # pylint: disable=invalid-name
return None
class FakeClient(base_api.BaseApiClient):
MESSAGES_MODULE = sys.modules[__name__]
_PACKAGE = 'package'
_SCOPES = ['scope1']
_CLIENT_ID = 'client_id'
_CLIENT_SECRET = 'client_secret'
class FakeService(base_api.BaseApiService):
def __init__(self, client=None):
client = client or FakeClient(
'http://www.example.com/', credentials=FakeCredentials())
super(FakeService, self).__init__(client)
class BaseApiTest(unittest2.TestCase):
def __GetFakeClient(self):
return FakeClient('', credentials=FakeCredentials())
def testUrlNormalization(self):
client = FakeClient('http://www.googleapis.com', get_credentials=False)
self.assertTrue(client.url.endswith('/'))
def testNoCredentials(self):
client = FakeClient('', get_credentials=False)
self.assertIsNotNone(client)
self.assertIsNone(client._credentials)
def testIncludeEmptyFieldsClient(self):
msg = SimpleMessage()
client = self.__GetFakeClient()
self.assertEqual('{}', client.SerializeMessage(msg))
with client.IncludeFields(('field',)):
self.assertEqual('{"field": null}', client.SerializeMessage(msg))
def testJsonResponse(self):
method_config = base_api.ApiMethodInfo(
response_type_name='SimpleMessage')
service = FakeService()
http_response = http_wrapper.Response(
info={'status': '200'}, content='{"field": "abc"}',
request_url='http://www.google.com')
response_message = SimpleMessage(field='abc')
self.assertEqual(response_message, service.ProcessHttpResponse(
method_config, http_response))
with service.client.JsonResponseModel():
self.assertEqual(
http_response.content,
service.ProcessHttpResponse(method_config, http_response))
def testJsonResponseEncoding(self):
# On Python 3, httplib2 always returns bytes, so we need to check that
# we can correctly decode the message content using the given encoding.
method_config = base_api.ApiMethodInfo(
response_type_name='SimpleMessage')
service = FakeService(FakeClient(
'http://www.example.com/', credentials=FakeCredentials(),
response_encoding='utf8'))
http_response = http_wrapper.Response(
info={'status': '200'}, content=b'{"field": "abc"}',
request_url='http://www.google.com')
response_message = SimpleMessage(field=u'abc')
self.assertEqual(response_message, service.ProcessHttpResponse(
method_config, http_response))
with service.client.JsonResponseModel():
self.assertEqual(
http_response.content.decode('utf8'),
service.ProcessHttpResponse(method_config, http_response))
def testAdditionalHeaders(self):
additional_headers = {'Request-Is-Awesome': '1'}
client = self.__GetFakeClient()
# No headers to start
http_request = http_wrapper.Request('http://www.example.com')
new_request = client.ProcessHttpRequest(http_request)
self.assertFalse('Request-Is-Awesome' in new_request.headers)
# Add a new header and ensure it's added to the request.
client.additional_http_headers = additional_headers
http_request = http_wrapper.Request('http://www.example.com')
new_request = client.ProcessHttpRequest(http_request)
self.assertTrue('Request-Is-Awesome' in new_request.headers)
def testCustomCheckResponse(self):
def check_response():
pass
def fakeMakeRequest(*_, **kwargs):
self.assertEqual(check_response, kwargs['check_response_func'])
return http_wrapper.Response(
info={'status': '200'}, content='{"field": "abc"}',
request_url='http://www.google.com')
method_config = base_api.ApiMethodInfo(
request_type_name='SimpleMessage',
response_type_name='SimpleMessage')
client = self.__GetFakeClient()
client.check_response_func = check_response
service = FakeService(client=client)
request = SimpleMessage()
with mock(base_api.http_wrapper, 'MakeRequest', fakeMakeRequest):
service._RunMethod(method_config, request)
def testCustomRetryFunc(self):
def retry_func():
pass
def fakeMakeRequest(*_, **kwargs):
self.assertEqual(retry_func, kwargs['retry_func'])
return http_wrapper.Response(
info={'status': '200'}, content='{"field": "abc"}',
request_url='http://www.google.com')
method_config = base_api.ApiMethodInfo(
request_type_name='SimpleMessage',
response_type_name='SimpleMessage')
client = self.__GetFakeClient()
client.retry_func = retry_func
service = FakeService(client=client)
request = SimpleMessage()
with mock(base_api.http_wrapper, 'MakeRequest', fakeMakeRequest):
service._RunMethod(method_config, request)
def testHttpError(self):
def fakeMakeRequest(*unused_args, **unused_kwargs):
return http_wrapper.Response(
info={'status': http_client.BAD_REQUEST},
content='{"field": "abc"}',
request_url='http://www.google.com')
method_config = base_api.ApiMethodInfo(
request_type_name='SimpleMessage',
response_type_name='SimpleMessage')
client = self.__GetFakeClient()
service = FakeService(client=client)
request = SimpleMessage()
with mock(base_api.http_wrapper, 'MakeRequest', fakeMakeRequest):
with self.assertRaises(exceptions.HttpBadRequestError) as err:
service._RunMethod(method_config, request)
http_error = err.exception
self.assertEquals('http://www.google.com', http_error.url)
self.assertEquals('{"field": "abc"}', http_error.content)
self.assertEquals(method_config, http_error.method_config)
self.assertEquals(request, http_error.request)
def testQueryEncoding(self):
method_config = base_api.ApiMethodInfo(
request_type_name='MessageWithTime', query_params=['timestamp'])
service = FakeService()
request = MessageWithTime(
timestamp=datetime.datetime(2014, 10, 0o7, 12, 53, 13))
http_request = service.PrepareHttpRequest(method_config, request)
url_timestamp = urllib_parse.quote(request.timestamp.isoformat())
self.assertTrue(http_request.url.endswith(url_timestamp))
def testPrettyPrintEncoding(self):
method_config = base_api.ApiMethodInfo(
request_type_name='MessageWithTime', query_params=['timestamp'])
service = FakeService()
request = MessageWithTime(
timestamp=datetime.datetime(2014, 10, 0o7, 12, 53, 13))
global_params = StandardQueryParameters()
http_request = service.PrepareHttpRequest(method_config, request,
global_params=global_params)
self.assertFalse('prettyPrint' in http_request.url)
self.assertFalse('pp' in http_request.url)
global_params.prettyPrint = False # pylint: disable=invalid-name
global_params.pp = False
http_request = service.PrepareHttpRequest(method_config, request,
global_params=global_params)
self.assertTrue('prettyPrint=0' in http_request.url)
self.assertTrue('pp=0' in http_request.url)
def testQueryBytesRequest(self):
method_config = base_api.ApiMethodInfo(
request_type_name='SimpleMessage', query_params=['bytes_field'])
service = FakeService()
non_unicode_message = b''.join((six.int2byte(100),
six.int2byte(200)))
request = SimpleMessage(bytes_field=non_unicode_message)
global_params = StandardQueryParameters()
http_request = service.PrepareHttpRequest(method_config, request,
global_params=global_params)
want = urllib_parse.urlencode({
'bytes_field': base64.urlsafe_b64encode(non_unicode_message),
})
self.assertIn(want, http_request.url)
def testQueryBytesGlobalParams(self):
method_config = base_api.ApiMethodInfo(
request_type_name='SimpleMessage', query_params=['bytes_field'])
service = FakeService()
non_unicode_message = b''.join((six.int2byte(100),
six.int2byte(200)))
request = SimpleMessage()
global_params = StandardQueryParameters(
nextPageToken=non_unicode_message)
http_request = service.PrepareHttpRequest(method_config, request,
global_params=global_params)
want = urllib_parse.urlencode({
'nextPageToken': base64.urlsafe_b64encode(non_unicode_message),
})
self.assertIn(want, http_request.url)
def testQueryRemapping(self):
method_config = base_api.ApiMethodInfo(
request_type_name='MessageWithRemappings',
query_params=['remapped_field', 'enum_field'])
request = MessageWithRemappings(
str_field='foo', enum_field=MessageWithRemappings.AnEnum.value_one)
http_request = FakeService().PrepareHttpRequest(method_config, request)
result_params = urllib_parse.parse_qs(
urllib_parse.urlparse(http_request.url).query)
expected_params = {'enum_field': 'ONE%2FTWO', 'remapped_field': 'foo'}
self.assertTrue(expected_params, result_params)
def testPathRemapping(self):
method_config = base_api.ApiMethodInfo(
relative_path='parameters/{remapped_field}/remap/{enum_field}',
request_type_name='MessageWithRemappings',
path_params=['remapped_field', 'enum_field'])
request = MessageWithRemappings(
str_field='gonna',
enum_field=MessageWithRemappings.AnEnum.value_one)
service = FakeService()
expected_url = service.client.url + 'parameters/gonna/remap/ONE%2FTWO'
http_request = service.PrepareHttpRequest(method_config, request)
self.assertEqual(expected_url, http_request.url)
method_config.relative_path = (
'parameters/{+remapped_field}/remap/{+enum_field}')
expected_url = service.client.url + 'parameters/gonna/remap/ONE/TWO'
http_request = service.PrepareHttpRequest(method_config, request)
self.assertEqual(expected_url, http_request.url)
def testColonInRelativePath(self):
method_config = base_api.ApiMethodInfo(
relative_path='path:withJustColon',
request_type_name='SimpleMessage')
service = FakeService()
request = SimpleMessage()
http_request = service.PrepareHttpRequest(method_config, request)
self.assertEqual('http://www.example.com/path:withJustColon',
http_request.url)
if __name__ == '__main__':
unittest2.main()