#!/usr/bin/env python
# Copyright 2015 Google Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Assorted utilities shared between parts of apitools."""
from __future__ import print_function
import collections
import contextlib
import json
import keyword
import logging
import os
import re
import six
from six.moves import urllib_parse
import six.moves.urllib.error as urllib_error
import six.moves.urllib.request as urllib_request
class Error(Exception):
"""Base error for apitools generation."""
class CommunicationError(Error):
"""Error in network communication."""
def _SortLengthFirstKey(a):
return -len(a), a
class Names(object):
"""Utility class for cleaning and normalizing names in a fixed style."""
def __init__(self, strip_prefixes,
self.__strip_prefixes = sorted(strip_prefixes, key=_SortLengthFirstKey)
self.__name_convention = (
name_convention or self.DEFAULT_NAME_CONVENTION)
self.__capitalize_enums = capitalize_enums
def __FromCamel(name, separator='_'):
name = re.sub(r'([a-z0-9])([A-Z])', r'\1%s\2' % separator, name)
return name.lower()
def __ToCamel(name, separator='_'):
# TODO(craigcitro): Consider what to do about leading or trailing
# underscores (such as `_refValue` in discovery).
return ''.join(s[0:1].upper() + s[1:] for s in name.split(separator))
def __ToLowerCamel(name, separator='_'):
name = Names.__ToCamel(name, separator=separator)
return name[0].lower() + name[1:]
def __StripName(self, name):
"""Strip strip_prefix entries from name."""
if not name:
return name
for prefix in self.__strip_prefixes:
if name.startswith(prefix):
return name[len(prefix):]
return name
def CleanName(name):
"""Perform generic name cleaning."""
name = re.sub('[^_A-Za-z0-9]', '_', name)
if name[0].isdigit():
name = '_%s' % name
while keyword.iskeyword(name):
name = '%s_' % name
# If we end up with __ as a prefix, we'll run afoul of python
# field renaming, so we manually correct for it.
if name.startswith('__'):
name = 'f%s' % name
return name
def NormalizeRelativePath(path):
"""Normalize camelCase entries in path."""
path_components = path.split('/')
normalized_components = []
for component in path_components:
if re.match(r'{[A-Za-z0-9_]+}$', component):
'{%s}' % Names.CleanName(component[1:-1]))
return '/'.join(normalized_components)
def NormalizeEnumName(self, enum_name):
if self.__capitalize_enums:
enum_name = enum_name.upper()
return self.CleanName(enum_name)
def ClassName(self, name, separator='_'):
"""Generate a valid class name from name."""
# TODO(craigcitro): Get rid of this case here and in MethodName.
if name is None:
return name
# TODO(craigcitro): This is a hack to handle the case of specific
# protorpc class names; clean this up.
if name.startswith(('protorpc.', 'message_types.',
return name
name = self.__StripName(name)
name = self.__ToCamel(name, separator=separator)
return self.CleanName(name)
def MethodName(self, name, separator='_'):
"""Generate a valid method name from name."""
if name is None:
return None
name = Names.__ToCamel(name, separator=separator)
return Names.CleanName(name)
def FieldName(self, name):
"""Generate a valid field name from name."""
# TODO(craigcitro): We shouldn't need to strip this name, but some
# of the service names here are excessive. Fix the API and then
# remove this.
name = self.__StripName(name)
if self.__name_convention == 'LOWER_CAMEL':
name = Names.__ToLowerCamel(name)
elif self.__name_convention == 'LOWER_WITH_UNDER':
name = Names.__FromCamel(name)
return Names.CleanName(name)
def Chdir(dirname, create=True):
if not os.path.exists(dirname):
if not create:
raise OSError('Cannot find directory %s' % dirname)
previous_directory = os.getcwd()
def NormalizeVersion(version):
# Currently, '.' is the only character that might cause us trouble.
return version.replace('.', '_')
def _ComputePaths(package, version, discovery_doc):
full_path = urllib_parse.urljoin(
discovery_doc['rootUrl'], discovery_doc['servicePath'])
api_path_component = '/'.join((package, version, ''))
if api_path_component not in full_path:
return full_path, ''
prefix, _, suffix = full_path.rpartition(api_path_component)
return prefix + api_path_component, suffix
class ClientInfo(collections.namedtuple('ClientInfo', (
'package', 'scopes', 'version', 'client_id', 'client_secret',
'user_agent', 'client_class_name', 'url_version', 'api_key',
'base_url', 'base_path'))):
"""Container for client-related info and names."""
def Create(cls, discovery_doc,
scope_ls, client_id, client_secret, user_agent, names, api_key):
"""Create a new ClientInfo object from a discovery document."""
scopes = set(
discovery_doc.get('auth', {}).get('oauth2', {}).get('scopes', {}))
package = discovery_doc['name']
url_version = discovery_doc['version']
base_url, base_path = _ComputePaths(package, url_version,
client_info = {
'package': package,
'version': NormalizeVersion(discovery_doc['version']),
'url_version': url_version,
'scopes': sorted(list(scopes)),
'client_id': client_id,
'client_secret': client_secret,
'user_agent': user_agent,
'api_key': api_key,
'base_url': base_url,
'base_path': base_path,
client_class_name = '%s%s' % (
client_info['client_class_name'] = client_class_name
return cls(**client_info)
def default_directory(self):
return self.package
def cli_rule_name(self):
return '%s_%s' % (self.package, self.version)
def cli_file_name(self):
return '' % self.cli_rule_name
def client_rule_name(self):
return '%s_%s_client' % (self.package, self.version)
def client_file_name(self):
return '' % self.client_rule_name
def messages_rule_name(self):
return '%s_%s_messages' % (self.package, self.version)
def services_rule_name(self):
return '%s_%s_services' % (self.package, self.version)
def messages_file_name(self):
return '' % self.messages_rule_name
def messages_proto_file_name(self):
return '%s.proto' % self.messages_rule_name
def services_proto_file_name(self):
return '%s.proto' % self.services_rule_name
def CleanDescription(description):
"""Return a version of description safe for printing in a docstring."""
if not isinstance(description, six.string_types):
return description
return description.replace('"""', '" " "')
class SimplePrettyPrinter(object):
"""Simple pretty-printer that supports an indent contextmanager."""
def __init__(self, out):
self.__out = out
self.__indent = ''
self.__skip = False
self.__comment_context = False
def indent(self):
return self.__indent
def CalculateWidth(self, max_width=78):
return max_width - len(self.indent)
def Indent(self, indent=' '):
previous_indent = self.__indent
self.__indent = '%s%s' % (previous_indent, indent)
self.__indent = previous_indent
def CommentContext(self):
"""Print without any argument formatting."""
old_context = self.__comment_context
self.__comment_context = True
self.__comment_context = old_context
def __call__(self, *args):
if self.__comment_context and args[1:]:
raise Error('Cannot do string interpolation in comment context')
if args and args[0]:
if not self.__comment_context:
line = (args[0] % args[1:]).rstrip()
line = args[0].rstrip()
line = line.encode('ascii', 'backslashreplace')
print('%s%s' % (self.__indent, line), file=self.__out)
print('', file=self.__out)
def _NormalizeDiscoveryUrls(discovery_url):
"""Expands a few abbreviations into full discovery urls."""
if discovery_url.startswith('http'):
return [discovery_url]
elif '.' not in discovery_url:
raise ValueError('Unrecognized value "%s" for discovery url')
api_name, _, api_version = discovery_url.partition('.')
return [
'' % (
api_name, api_version),
'$discovery/rest?version=%s' % (
api_name, api_version),
def FetchDiscoveryDoc(discovery_url, retries=5):
"""Fetch the discovery document at the given url."""
discovery_urls = _NormalizeDiscoveryUrls(discovery_url)
discovery_doc = None
last_exception = None
for url in discovery_urls:
for _ in range(retries):
discovery_doc = json.loads(urllib_request.urlopen(url).read())
except (urllib_error.HTTPError, urllib_error.URLError) as e:
'Attempting to fetch discovery doc again after "%s"', e)
last_exception = e
if discovery_doc is None:
raise CommunicationError(
'Could not find discovery doc at any of %s: %s' % (
discovery_urls, last_exception))
return discovery_doc