blob: 97bbb6f69b1830c25b946fb18b1954dd23eb03f1 [file] [log] [blame]
#!/usr/bin/env python
"""Base class for api services."""
import base64
import contextlib
import datetime
import logging
import pprint
from protorpc import message_types
from protorpc import messages
import six
from six.moves import http_client
from six.moves import urllib
from apitools.base.py import credentials_lib
from apitools.base.py import encoding
from apitools.base.py import exceptions
from apitools.base.py import http_wrapper
from apitools.base.py import util
__all__ = [
'ApiMethodInfo',
'ApiUploadInfo',
'BaseApiClient',
'BaseApiService',
'NormalizeApiEndpoint',
]
# TODO(craigcitro): Remove this once we quiet the spurious logging in
# oauth2client (or drop oauth2client).
logging.getLogger('oauth2client.util').setLevel(logging.ERROR)
_MAX_URL_LENGTH = 2048
class ApiUploadInfo(messages.Message):
"""Media upload information for a method.
Fields:
accept: (repeated) MIME Media Ranges for acceptable media uploads
to this method.
max_size: (integer) Maximum size of a media upload, such as 3MB
or 1TB (converted to an integer).
resumable_path: Path to use for resumable uploads.
resumable_multipart: (boolean) Whether or not the resumable endpoint
supports multipart uploads.
simple_path: Path to use for simple uploads.
simple_multipart: (boolean) Whether or not the simple endpoint
supports multipart uploads.
"""
accept = messages.StringField(1, repeated=True)
max_size = messages.IntegerField(2)
resumable_path = messages.StringField(3)
resumable_multipart = messages.BooleanField(4)
simple_path = messages.StringField(5)
simple_multipart = messages.BooleanField(6)
class ApiMethodInfo(messages.Message):
"""Configuration info for an API method.
All fields are strings unless noted otherwise.
Fields:
relative_path: Relative path for this method.
method_id: ID for this method.
http_method: HTTP verb to use for this method.
path_params: (repeated) path parameters for this method.
query_params: (repeated) query parameters for this method.
ordered_params: (repeated) ordered list of parameters for
this method.
description: description of this method.
request_type_name: name of the request type.
response_type_name: name of the response type.
request_field: if not null, the field to pass as the body
of this POST request. may also be the REQUEST_IS_BODY
value below to indicate the whole message is the body.
upload_config: (ApiUploadInfo) Information about the upload
configuration supported by this method.
supports_download: (boolean) If True, this method supports
downloading the request via the `alt=media` query
parameter.
"""
relative_path = messages.StringField(1)
method_id = messages.StringField(2)
http_method = messages.StringField(3)
path_params = messages.StringField(4, repeated=True)
query_params = messages.StringField(5, repeated=True)
ordered_params = messages.StringField(6, repeated=True)
description = messages.StringField(7)
request_type_name = messages.StringField(8)
response_type_name = messages.StringField(9)
request_field = messages.StringField(10, default='')
upload_config = messages.MessageField(ApiUploadInfo, 11)
supports_download = messages.BooleanField(12, default=False)
REQUEST_IS_BODY = '<request>'
def _LoadClass(name, messages_module):
if name.startswith('message_types.'):
_, _, classname = name.partition('.')
return getattr(message_types, classname)
elif '.' not in name:
return getattr(messages_module, name)
else:
raise exceptions.GeneratedClientError('Unknown class %s' % name)
def _RequireClassAttrs(obj, attrs):
for attr in attrs:
attr_name = attr.upper()
if not hasattr(obj, '%s' % attr_name) or not getattr(obj, attr_name):
msg = 'No %s specified for object of class %s.' % (
attr_name, type(obj).__name__)
raise exceptions.GeneratedClientError(msg)
def NormalizeApiEndpoint(api_endpoint):
if not api_endpoint.endswith('/'):
api_endpoint += '/'
return api_endpoint
def _urljoin(base, url): # pylint: disable=invalid-name
"""Custom urljoin replacement supporting : before / in url."""
# In general, it's unsafe to simply join base and url. However, for
# the case of discovery documents, we know:
# * base will never contain params, query, or fragment
# * url will never contain a scheme or net_loc.
# In general, this means we can safely join on /; we just need to
# ensure we end up with precisely one / joining base and url. The
# exception here is the case of media uploads, where url will be an
# absolute url.
if url.startswith('http://') or url.startswith('https://'):
return urllib.parse.urljoin(base, url)
new_base = base if base.endswith('/') else base + '/'
new_url = url[1:] if url.startswith('/') else url
return new_base + new_url
class _UrlBuilder(object):
"""Convenient container for url data."""
def __init__(self, base_url, relative_path=None, query_params=None):
components = urllib.parse.urlsplit(_urljoin(
base_url, relative_path or ''))
if components.fragment:
raise exceptions.ConfigurationValueError(
'Unexpected url fragment: %s' % components.fragment)
self.query_params = urllib.parse.parse_qs(components.query or '')
if query_params is not None:
self.query_params.update(query_params)
self.__scheme = components.scheme
self.__netloc = components.netloc
self.relative_path = components.path or ''
@classmethod
def FromUrl(cls, url):
urlparts = urllib.parse.urlsplit(url)
query_params = urllib.parse.parse_qs(urlparts.query)
base_url = urllib.parse.urlunsplit((
urlparts.scheme, urlparts.netloc, '', None, None))
relative_path = urlparts.path or ''
return cls(
base_url, relative_path=relative_path, query_params=query_params)
@property
def base_url(self):
return urllib.parse.urlunsplit(
(self.__scheme, self.__netloc, '', '', ''))
@base_url.setter
def base_url(self, value):
components = urllib.parse.urlsplit(value)
if components.path or components.query or components.fragment:
raise exceptions.ConfigurationValueError(
'Invalid base url: %s' % value)
self.__scheme = components.scheme
self.__netloc = components.netloc
@property
def query(self):
# TODO(craigcitro): In the case that some of the query params are
# non-ASCII, we may silently fail to encode correctly. We should
# figure out who is responsible for owning the object -> str
# conversion.
return urllib.parse.urlencode(self.query_params, doseq=True)
@property
def url(self):
if '{' in self.relative_path or '}' in self.relative_path:
raise exceptions.ConfigurationValueError(
'Cannot create url with relative path %s' % self.relative_path)
return urllib.parse.urlunsplit((
self.__scheme, self.__netloc, self.relative_path, self.query, ''))
def _SkipGetCredentials():
"""Hook for skipping credentials. For internal use."""
return False
class BaseApiClient(object):
"""Base class for client libraries."""
MESSAGES_MODULE = None
_API_KEY = ''
_CLIENT_ID = ''
_CLIENT_SECRET = ''
_PACKAGE = ''
_SCOPES = []
_USER_AGENT = ''
def __init__(self, url, credentials=None, get_credentials=True, http=None,
model=None, log_request=False, log_response=False,
num_retries=5, max_retry_wait=60, credentials_args=None,
default_global_params=None, additional_http_headers=None):
_RequireClassAttrs(self, ('_package', '_scopes', 'messages_module'))
if default_global_params is not None:
util.Typecheck(default_global_params, self.params_type)
self.__default_global_params = default_global_params
self.log_request = log_request
self.log_response = log_response
self.__num_retries = 5
self.__max_retry_wait = 60
# We let the @property machinery below do our validation.
self.num_retries = num_retries
self.max_retry_wait = max_retry_wait
self._credentials = credentials
get_credentials = get_credentials and not _SkipGetCredentials()
if get_credentials and not credentials:
credentials_args = credentials_args or {}
self._SetCredentials(**credentials_args)
self._url = NormalizeApiEndpoint(url)
self._http = http or http_wrapper.GetHttp()
# Note that "no credentials" is totally possible.
if self._credentials is not None:
self._http = self._credentials.authorize(self._http)
# TODO(craigcitro): Remove this field when we switch to proto2.
self.__include_fields = None
self.additional_http_headers = additional_http_headers or {}
# TODO(craigcitro): Finish deprecating these fields.
_ = model
self.__response_type_model = 'proto'
def _SetCredentials(self, **kwds):
"""Fetch credentials, and set them for this client.
Note that we can't simply return credentials, since creating them
may involve side-effecting self.
Args:
**kwds: Additional keyword arguments are passed on to GetCredentials.
Returns:
None. Sets self._credentials.
"""
args = {
'api_key': self._API_KEY,
'client': self,
'client_id': self._CLIENT_ID,
'client_secret': self._CLIENT_SECRET,
'package_name': self._PACKAGE,
'scopes': self._SCOPES,
'user_agent': self._USER_AGENT,
}
args.update(kwds)
# TODO(craigcitro): It's a bit dangerous to pass this
# still-half-initialized self into this method, but we might need
# to set attributes on it associated with our credentials.
# Consider another way around this (maybe a callback?) and whether
# or not it's worth it.
self._credentials = credentials_lib.GetCredentials(**args)
@classmethod
def ClientInfo(cls):
return {
'client_id': cls._CLIENT_ID,
'client_secret': cls._CLIENT_SECRET,
'scope': ' '.join(sorted(util.NormalizeScopes(cls._SCOPES))),
'user_agent': cls._USER_AGENT,
}
@property
def base_model_class(self):
return None
@property
def http(self):
return self._http
@property
def url(self):
return self._url
@classmethod
def GetScopes(cls):
return cls._SCOPES
@property
def params_type(self):
return _LoadClass('StandardQueryParameters', self.MESSAGES_MODULE)
@property
def user_agent(self):
return self._USER_AGENT
@property
def _default_global_params(self):
if self.__default_global_params is None:
self.__default_global_params = self.params_type()
return self.__default_global_params
def AddGlobalParam(self, name, value):
params = self._default_global_params
setattr(params, name, value)
@property
def global_params(self):
return encoding.CopyProtoMessage(self._default_global_params)
@contextlib.contextmanager
def IncludeFields(self, include_fields):
self.__include_fields = include_fields
yield
self.__include_fields = None
@property
def response_type_model(self):
return self.__response_type_model
@contextlib.contextmanager
def JsonResponseModel(self):
"""In this context, return raw JSON instead of proto."""
old_model = self.response_type_model
self.__response_type_model = 'json'
yield
self.__response_type_model = old_model
@property
def num_retries(self):
return self.__num_retries
@num_retries.setter
def num_retries(self, value):
util.Typecheck(value, six.integer_types)
if value < 0:
raise exceptions.InvalidDataError(
'Cannot have negative value for num_retries')
self.__num_retries = value
@property
def max_retry_wait(self):
return self.__max_retry_wait
@max_retry_wait.setter
def max_retry_wait(self, value):
util.Typecheck(value, six.integer_types)
if value <= 0:
raise exceptions.InvalidDataError(
'max_retry_wait must be a postiive integer')
self.__max_retry_wait = value
@contextlib.contextmanager
def WithRetries(self, num_retries):
old_num_retries = self.num_retries
self.num_retries = num_retries
yield
self.num_retries = old_num_retries
def ProcessRequest(self, method_config, request):
"""Hook for pre-processing of requests."""
if self.log_request:
logging.info(
'Calling method %s with %s: %s', method_config.method_id,
method_config.request_type_name, request)
return request
def ProcessHttpRequest(self, http_request):
"""Hook for pre-processing of http requests."""
http_request.headers.update(self.additional_http_headers)
if self.log_request:
logging.info('Making http %s to %s',
http_request.http_method, http_request.url)
logging.info('Headers: %s', pprint.pformat(http_request.headers))
if http_request.body:
# TODO(craigcitro): Make this safe to print in the case of
# non-printable body characters.
logging.info('Body:\n%s',
http_request.loggable_body or http_request.body)
else:
logging.info('Body: (none)')
return http_request
def ProcessResponse(self, method_config, response):
if self.log_response:
logging.info('Response of type %s: %s',
method_config.response_type_name, response)
return response
# TODO(craigcitro): Decide where these two functions should live.
def SerializeMessage(self, message):
return encoding.MessageToJson(
message, include_fields=self.__include_fields)
def DeserializeMessage(self, response_type, data):
"""Deserialize the given data as method_config.response_type."""
try:
message = encoding.JsonToMessage(response_type, data)
except (exceptions.InvalidDataFromServerError,
messages.ValidationError) as e:
raise exceptions.InvalidDataFromServerError(
'Error decoding response "%s" as type %s: %s' % (
data, response_type.__name__, e))
return message
def FinalizeTransferUrl(self, url):
"""Modify the url for a given transfer, based on auth and version."""
url_builder = _UrlBuilder.FromUrl(url)
if self.global_params.key:
url_builder.query_params['key'] = self.global_params.key
return url_builder.url
class BaseApiService(object):
"""Base class for generated API services."""
def __init__(self, client):
self.__client = client
self._method_configs = {}
self._upload_configs = {}
@property
def _client(self):
return self.__client
@property
def client(self):
return self.__client
def GetMethodConfig(self, method):
return self._method_configs[method]
def GetUploadConfig(self, method):
return self._upload_configs.get(method)
def GetRequestType(self, method):
method_config = self.GetMethodConfig(method)
return getattr(self.client.MESSAGES_MODULE,
method_config.request_type_name)
def GetResponseType(self, method):
method_config = self.GetMethodConfig(method)
return getattr(self.client.MESSAGES_MODULE,
method_config.response_type_name)
def __CombineGlobalParams(self, global_params, default_params):
"""Combine the given params with the defaults."""
util.Typecheck(global_params, (type(None), self.__client.params_type))
result = self.__client.params_type()
global_params = global_params or self.__client.params_type()
for field in result.all_fields():
value = global_params.get_assigned_value(field.name)
if value is None:
value = default_params.get_assigned_value(field.name)
if value not in (None, [], ()):
setattr(result, field.name, value)
return result
def __EncodePrettyPrint(self, query_info):
# The prettyPrint flag needs custom encoding: it should be encoded
# as 0 if False, and ignored otherwise (True is the default).
if not query_info.pop('prettyPrint', True):
query_info['prettyPrint'] = 0
# The One Platform equivalent of prettyPrint is pp, which also needs
# custom encoding.
if not query_info.pop('pp', True):
query_info['pp'] = 0
return query_info
def __FinalUrlValue(self, value, field):
"""Encode value for the URL, using field to skip encoding for bytes."""
if isinstance(field, messages.BytesField) and value is not None:
return base64.urlsafe_b64encode(value)
elif isinstance(value, six.text_type):
return value.encode('utf8')
elif isinstance(value, six.binary_type):
return value.decode('utf8')
elif isinstance(value, datetime.datetime):
return value.isoformat()
return value
def __ConstructQueryParams(self, query_params, request, global_params):
"""Construct a dictionary of query parameters for this request."""
# First, handle the global params.
global_params = self.__CombineGlobalParams(
global_params, self.__client.global_params)
global_param_names = util.MapParamNames(
[x.name for x in self.__client.params_type.all_fields()],
self.__client.params_type)
global_params_type = type(global_params)
query_info = dict(
(param,
self.__FinalUrlValue(getattr(global_params, param),
getattr(global_params_type, param)))
for param in global_param_names)
# Next, add the query params.
query_param_names = util.MapParamNames(query_params, type(request))
request_type = type(request)
query_info.update(
(param,
self.__FinalUrlValue(getattr(request, param, None),
getattr(request_type, param)))
for param in query_param_names)
query_info = dict((k, v) for k, v in query_info.items()
if v is not None)
query_info = self.__EncodePrettyPrint(query_info)
query_info = util.MapRequestParams(query_info, type(request))
return query_info
def __ConstructRelativePath(self, method_config, request,
relative_path=None):
"""Determine the relative path for request."""
python_param_names = util.MapParamNames(
method_config.path_params, type(request))
params = dict([(param, getattr(request, param, None))
for param in python_param_names])
params = util.MapRequestParams(params, type(request))
return util.ExpandRelativePath(method_config, params,
relative_path=relative_path)
def __FinalizeRequest(self, http_request, url_builder):
"""Make any final general adjustments to the request."""
if (http_request.http_method == 'GET' and
len(http_request.url) > _MAX_URL_LENGTH):
http_request.http_method = 'POST'
http_request.headers['x-http-method-override'] = 'GET'
http_request.headers[
'content-type'] = 'application/x-www-form-urlencoded'
http_request.body = url_builder.query
url_builder.query_params = {}
http_request.url = url_builder.url
def __ProcessHttpResponse(self, method_config, http_response):
"""Process the given http response."""
if http_response.status_code not in (http_client.OK,
http_client.NO_CONTENT):
raise exceptions.HttpError.FromResponse(http_response)
if http_response.status_code == http_client.NO_CONTENT:
# TODO(craigcitro): Find out why _replace doesn't seem to work
# here.
http_response = http_wrapper.Response(
info=http_response.info, content='{}',
request_url=http_response.request_url)
if self.__client.response_type_model == 'json':
return http_response.content
else:
response_type = _LoadClass(method_config.response_type_name,
self.__client.MESSAGES_MODULE)
return self.__client.DeserializeMessage(
response_type, http_response.content)
def __SetBaseHeaders(self, http_request, client):
"""Fill in the basic headers on http_request."""
# TODO(craigcitro): Make the default a little better here, and
# include the apitools version.
user_agent = client.user_agent or 'apitools-client/1.0'
http_request.headers['user-agent'] = user_agent
http_request.headers['accept'] = 'application/json'
http_request.headers['accept-encoding'] = 'gzip, deflate'
def __SetBody(self, http_request, method_config, request, upload):
"""Fill in the body on http_request."""
if not method_config.request_field:
return
request_type = _LoadClass(
method_config.request_type_name, self.__client.MESSAGES_MODULE)
if method_config.request_field == REQUEST_IS_BODY:
body_value = request
body_type = request_type
else:
body_value = getattr(request, method_config.request_field)
body_field = request_type.field_by_name(
method_config.request_field)
util.Typecheck(body_field, messages.MessageField)
body_type = body_field.type
# If there was no body provided, we use an empty message of the
# appropriate type.
body_value = body_value or body_type()
if upload and not body_value:
# We're going to fill in the body later.
return
util.Typecheck(body_value, body_type)
http_request.headers['content-type'] = 'application/json'
http_request.body = self.__client.SerializeMessage(body_value)
def PrepareHttpRequest(self, method_config, request, global_params=None,
upload=None, upload_config=None, download=None):
"""Prepares an HTTP request to be sent."""
request_type = _LoadClass(
method_config.request_type_name, self.__client.MESSAGES_MODULE)
util.Typecheck(request, request_type)
request = self.__client.ProcessRequest(method_config, request)
http_request = http_wrapper.Request(
http_method=method_config.http_method)
self.__SetBaseHeaders(http_request, self.__client)
self.__SetBody(http_request, method_config, request, upload)
url_builder = _UrlBuilder(
self.__client.url, relative_path=method_config.relative_path)
url_builder.query_params = self.__ConstructQueryParams(
method_config.query_params, request, global_params)
# It's important that upload and download go before we fill in the
# relative path, so that they can replace it.
if upload is not None:
upload.ConfigureRequest(upload_config, http_request, url_builder)
if download is not None:
download.ConfigureRequest(http_request, url_builder)
url_builder.relative_path = self.__ConstructRelativePath(
method_config, request, relative_path=url_builder.relative_path)
self.__FinalizeRequest(http_request, url_builder)
return self.__client.ProcessHttpRequest(http_request)
def _RunMethod(self, method_config, request, global_params=None,
upload=None, upload_config=None, download=None):
"""Call this method with request."""
if upload is not None and download is not None:
# TODO(craigcitro): This just involves refactoring the logic
# below into callbacks that we can pass around; in particular,
# the order should be that the upload gets the initial request,
# and then passes its reply to a download if one exists, and
# then that goes to ProcessResponse and is returned.
raise exceptions.NotYetImplementedError(
'Cannot yet use both upload and download at once')
http_request = self.PrepareHttpRequest(
method_config, request, global_params, upload, upload_config,
download)
# TODO(craigcitro): Make num_retries customizable on Transfer
# objects, and pass in self.__client.num_retries when initializing
# an upload or download.
if download is not None:
download.InitializeDownload(http_request, client=self.client)
return
http_response = None
if upload is not None:
http_response = upload.InitializeUpload(
http_request, client=self.client)
if http_response is None:
http = self.__client.http
if upload and upload.bytes_http:
http = upload.bytes_http
http_response = http_wrapper.MakeRequest(
http, http_request, retries=self.__client.num_retries,
max_retry_wait=self.__client.max_retry_wait)
return self.ProcessHttpResponse(method_config, http_response)
def ProcessHttpResponse(self, method_config, http_response):
"""Convert an HTTP response to the expected message type."""
return self.__client.ProcessResponse(
method_config,
self.__ProcessHttpResponse(method_config, http_response))