blob: d467e9bd10656fe1e86fbe8948fdccc369790c80 [file] [log] [blame]
#!/usr/bin/env python
"""Message registry for apitools."""
import collections
import contextlib
import json
from protorpc import descriptor
from protorpc import messages
import six
from apitools.gen import extended_descriptor
from apitools.gen import util
TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant'))
class MessageRegistry(object):
"""Registry for message types.
This closely mirrors a messages.FileDescriptor, but adds additional
attributes (such as message and field descriptions) and some extra
code for validation and cycle detection.
"""
# Type information from these two maps comes from here:
# https://developers.google.com/discovery/v1/type-format
PRIMITIVE_TYPE_INFO_MAP = {
'string': TypeInfo(type_name='string',
variant=messages.StringField.DEFAULT_VARIANT),
'integer': TypeInfo(type_name='integer',
variant=messages.IntegerField.DEFAULT_VARIANT),
'boolean': TypeInfo(type_name='boolean',
variant=messages.BooleanField.DEFAULT_VARIANT),
'number': TypeInfo(type_name='number',
variant=messages.FloatField.DEFAULT_VARIANT),
'any': TypeInfo(type_name='extra_types.JsonValue',
variant=messages.Variant.MESSAGE),
}
PRIMITIVE_FORMAT_MAP = {
'int32': TypeInfo(type_name='integer',
variant=messages.Variant.INT32),
'uint32': TypeInfo(type_name='integer',
variant=messages.Variant.UINT32),
'int64': TypeInfo(type_name='string',
variant=messages.Variant.INT64),
'uint64': TypeInfo(type_name='string',
variant=messages.Variant.UINT64),
'double': TypeInfo(type_name='number',
variant=messages.Variant.DOUBLE),
'float': TypeInfo(type_name='number',
variant=messages.Variant.FLOAT),
'byte': TypeInfo(type_name='byte',
variant=messages.BytesField.DEFAULT_VARIANT),
'date': TypeInfo(type_name='extra_types.DateField',
variant=messages.Variant.STRING),
'date-time': TypeInfo(
type_name='protorpc.message_types.DateTimeMessage',
variant=messages.Variant.MESSAGE),
}
def __init__(self, client_info, names, description,
root_package_dir, base_files_package):
self.__names = names
self.__client_info = client_info
self.__package = client_info.package
self.__description = util.CleanDescription(description)
self.__root_package_dir = root_package_dir
self.__base_files_package = base_files_package
self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor(
package=self.__package, description=self.__description)
# Add required imports
self.__file_descriptor.additional_imports = [
'from protorpc import messages as _messages',
]
# Map from scoped names (i.e. Foo.Bar) to MessageDescriptors.
self.__message_registry = collections.OrderedDict()
# A set of types that we're currently adding (for cycle detection).
self.__nascent_types = set()
# A set of types for which we've seen a reference but no
# definition; if this set is nonempty, validation fails.
self.__unknown_types = set()
# Used for tracking paths during message creation
self.__current_path = []
# Where to register created messages
self.__current_env = self.__file_descriptor
# TODO(craigcitro): Add a `Finalize` method.
@property
def file_descriptor(self):
self.Validate()
return self.__file_descriptor
def WriteProtoFile(self, printer):
"""Write the messages file to out as proto."""
self.Validate()
extended_descriptor.WriteMessagesFile(
self.__file_descriptor, self.__package, self.__client_info.version,
printer)
def WriteFile(self, printer):
"""Write the messages file to out."""
self.Validate()
extended_descriptor.WritePythonFile(
self.__file_descriptor, self.__package, self.__client_info.version,
printer)
def Validate(self):
mysteries = self.__nascent_types or self.__unknown_types
if mysteries:
raise ValueError('Malformed MessageRegistry: %s' % mysteries)
def __ComputeFullName(self, name):
return '.'.join(map(six.text_type, self.__current_path[:] + [name]))
def __AddImport(self, new_import):
if new_import not in self.__file_descriptor.additional_imports:
self.__file_descriptor.additional_imports.append(new_import)
def __DeclareDescriptor(self, name):
self.__nascent_types.add(self.__ComputeFullName(name))
def __RegisterDescriptor(self, new_descriptor):
"""Register the given descriptor in this registry."""
if not isinstance(new_descriptor, (
extended_descriptor.ExtendedMessageDescriptor,
extended_descriptor.ExtendedEnumDescriptor)):
raise ValueError('Cannot add descriptor of type %s' % (
type(new_descriptor),))
full_name = self.__ComputeFullName(new_descriptor.name)
if full_name in self.__message_registry:
raise ValueError(
'Attempt to re-register descriptor %s' % full_name)
if full_name not in self.__nascent_types:
raise ValueError('Directly adding types is not supported')
new_descriptor.full_name = full_name
self.__message_registry[full_name] = new_descriptor
if isinstance(new_descriptor,
extended_descriptor.ExtendedMessageDescriptor):
self.__current_env.message_types.append(new_descriptor)
elif isinstance(new_descriptor,
extended_descriptor.ExtendedEnumDescriptor):
self.__current_env.enum_types.append(new_descriptor)
self.__unknown_types.discard(full_name)
self.__nascent_types.remove(full_name)
def LookupDescriptor(self, name):
return self.__GetDescriptorByName(name)
def LookupDescriptorOrDie(self, name):
message_descriptor = self.LookupDescriptor(name)
if message_descriptor is None:
raise ValueError('No message descriptor named "%s"', name)
return message_descriptor
def __GetDescriptor(self, name):
return self.__GetDescriptorByName(self.__ComputeFullName(name))
def __GetDescriptorByName(self, name):
if name in self.__message_registry:
return self.__message_registry[name]
if name in self.__nascent_types:
raise ValueError(
'Cannot retrieve type currently being created: %s' % name)
return None
@contextlib.contextmanager
def __DescriptorEnv(self, message_descriptor):
# TODO(craigcitro): Typecheck?
previous_env = self.__current_env
self.__current_path.append(message_descriptor.name)
self.__current_env = message_descriptor
yield
self.__current_path.pop()
self.__current_env = previous_env
def AddEnumDescriptor(self, name, description,
enum_values, enum_descriptions):
"""Add a new EnumDescriptor named name with the given enum values."""
message = extended_descriptor.ExtendedEnumDescriptor()
message.name = self.__names.ClassName(name)
message.description = util.CleanDescription(description)
self.__DeclareDescriptor(message.name)
for index, (enum_name, enum_description) in enumerate(
zip(enum_values, enum_descriptions)):
enum_value = extended_descriptor.ExtendedEnumValueDescriptor()
enum_value.name = self.__names.NormalizeEnumName(enum_name)
if enum_value.name != enum_name:
message.enum_mappings.append(
extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping(
python_name=enum_value.name, json_name=enum_name))
self.__AddImport('from %s import encoding' %
self.__base_files_package)
enum_value.number = index
enum_value.description = util.CleanDescription(
enum_description or '<no description>')
message.values.append(enum_value)
self.__RegisterDescriptor(message)
def __DeclareMessageAlias(self, schema, alias_for):
"""Declare schema as an alias for alias_for."""
# TODO(craigcitro): This is a hack. Remove it.
message = extended_descriptor.ExtendedMessageDescriptor()
message.name = self.__names.ClassName(schema['id'])
message.alias_for = alias_for
self.__DeclareDescriptor(message.name)
self.__AddImport('from %s import extra_types' %
self.__base_files_package)
self.__RegisterDescriptor(message)
def __AddAdditionalProperties(self, message, schema, properties):
"""Add an additionalProperties field to message."""
additional_properties_info = schema['additionalProperties']
entries_type_name = self.__AddAdditionalPropertyType(
message.name, additional_properties_info)
description = util.CleanDescription(
additional_properties_info.get('description'))
if description is None:
description = 'Additional properties of type %s' % message.name
attrs = {
'items': {
'$ref': entries_type_name,
},
'description': description,
'type': 'array',
}
field_name = 'additionalProperties'
message.fields.append(self.__FieldDescriptorFromProperties(
field_name, len(properties) + 1, attrs))
self.__AddImport('from %s import encoding' % self.__base_files_package)
message.decorators.append(
'encoding.MapUnrecognizedFields(%r)' % field_name)
def AddDescriptorFromSchema(self, schema_name, schema):
"""Add a new MessageDescriptor named schema_name based on schema."""
# TODO(craigcitro): Is schema_name redundant?
if self.__GetDescriptor(schema_name):
return
if schema.get('enum'):
self.__DeclareEnum(schema_name, schema)
return
if schema.get('type') == 'any':
self.__DeclareMessageAlias(schema, 'extra_types.JsonValue')
return
if schema.get('type') != 'object':
raise ValueError('Cannot create message descriptors for type %s',
schema.get('type'))
message = extended_descriptor.ExtendedMessageDescriptor()
message.name = self.__names.ClassName(schema['id'])
message.description = util.CleanDescription(schema.get(
'description', 'A %s object.' % message.name))
self.__DeclareDescriptor(message.name)
with self.__DescriptorEnv(message):
properties = schema.get('properties', {})
for index, (name, attrs) in enumerate(sorted(properties.items())):
field = self.__FieldDescriptorFromProperties(
name, index + 1, attrs)
message.fields.append(field)
if field.name != name:
message.field_mappings.append(
type(message).JsonFieldMapping(
python_name=field.name, json_name=name))
self.__AddImport(
'from %s import encoding' % self.__base_files_package)
if 'additionalProperties' in schema:
self.__AddAdditionalProperties(message, schema, properties)
self.__RegisterDescriptor(message)
def __AddAdditionalPropertyType(self, name, property_schema):
"""Add a new nested AdditionalProperty message."""
new_type_name = 'AdditionalProperty'
property_schema = dict(property_schema)
# We drop the description here on purpose, so the resulting
# messages are less repetitive.
property_schema.pop('description', None)
description = 'An additional property for a %s object.' % name
schema = {
'id': new_type_name,
'type': 'object',
'description': description,
'properties': {
'key': {
'type': 'string',
'description': 'Name of the additional property.',
},
'value': property_schema,
},
}
self.AddDescriptorFromSchema(new_type_name, schema)
return new_type_name
def __AddEntryType(self, entry_type_name, entry_schema, parent_name):
"""Add a type for a list entry."""
entry_schema.pop('description', None)
description = 'Single entry in a %s.' % parent_name
schema = {
'id': entry_type_name,
'type': 'object',
'description': description,
'properties': {
'entry': {
'type': 'array',
'items': entry_schema,
},
},
}
self.AddDescriptorFromSchema(entry_type_name, schema)
return entry_type_name
def __FieldDescriptorFromProperties(self, name, index, attrs):
"""Create a field descriptor for these attrs."""
field = descriptor.FieldDescriptor()
field.name = self.__names.CleanName(name)
field.number = index
field.label = self.__ComputeLabel(attrs)
new_type_name_hint = self.__names.ClassName(
'%sValue' % self.__names.ClassName(name))
type_info = self.__GetTypeInfo(attrs, new_type_name_hint)
field.type_name = type_info.type_name
field.variant = type_info.variant
if 'default' in attrs:
# TODO(craigcitro): Correctly handle non-primitive default values.
default = attrs['default']
if not (field.type_name == 'string' or
field.variant == messages.Variant.ENUM):
default = str(json.loads(default))
if field.variant == messages.Variant.ENUM:
default = self.__names.NormalizeEnumName(default)
field.default_value = default
extended_field = extended_descriptor.ExtendedFieldDescriptor()
extended_field.name = field.name
extended_field.description = util.CleanDescription(
attrs.get('description', 'A %s attribute.' % field.type_name))
extended_field.field_descriptor = field
return extended_field
@staticmethod
def __ComputeLabel(attrs):
if attrs.get('required', False):
return descriptor.FieldDescriptor.Label.REQUIRED
elif attrs.get('type') == 'array':
return descriptor.FieldDescriptor.Label.REPEATED
elif attrs.get('repeated'):
return descriptor.FieldDescriptor.Label.REPEATED
return descriptor.FieldDescriptor.Label.OPTIONAL
def __DeclareEnum(self, enum_name, attrs):
description = util.CleanDescription(attrs.get('description', ''))
enum_values = attrs['enum']
enum_descriptions = attrs.get(
'enumDescriptions', [''] * len(enum_values))
self.AddEnumDescriptor(enum_name, description,
enum_values, enum_descriptions)
self.__AddIfUnknown(enum_name)
return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM)
def __AddIfUnknown(self, type_name):
type_name = self.__names.ClassName(type_name)
full_type_name = self.__ComputeFullName(type_name)
if (full_type_name not in self.__message_registry.keys() and
type_name not in self.__message_registry.keys()):
self.__unknown_types.add(type_name)
def __GetTypeInfo(self, attrs, name_hint):
"""Return a TypeInfo object for attrs, creating one if needed."""
type_ref = self.__names.ClassName(attrs.get('$ref'))
type_name = attrs.get('type')
if not (type_ref or type_name):
raise ValueError('No type found for %s' % attrs)
if type_ref:
self.__AddIfUnknown(type_ref)
# We don't actually know this is a message -- it might be an
# enum. However, we can't check that until we've created all the
# types, so we come back and fix this up later.
return TypeInfo(
type_name=type_ref, variant=messages.Variant.MESSAGE)
if 'enum' in attrs:
enum_name = '%sValuesEnum' % name_hint
return self.__DeclareEnum(enum_name, attrs)
if 'format' in attrs:
type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format'])
if type_info is None:
# If we don't recognize the format, the spec says we fall back
# to just using the type name.
if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
return self.PRIMITIVE_TYPE_INFO_MAP[type_name]
raise ValueError('Unknown type/format "%s"/"%s"' % (
attrs['format'], type_name))
if (type_info.type_name.startswith('protorpc.message_types.') or
type_info.type_name.startswith('message_types.')):
self.__AddImport(
'from protorpc import message_types as _message_types')
if type_info.type_name.startswith('extra_types.'):
self.__AddImport(
'from %s import extra_types' % self.__base_files_package)
return type_info
if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name]
return type_info
if type_name == 'array':
items = attrs.get('items')
if not items:
raise ValueError('Array type with no item type: %s' % attrs)
entry_name_hint = self.__names.ClassName(
items.get('title') or '%sListEntry' % name_hint)
entry_label = self.__ComputeLabel(items)
if entry_label == descriptor.FieldDescriptor.Label.REPEATED:
parent_name = self.__names.ClassName(
items.get('title') or name_hint)
entry_type_name = self.__AddEntryType(
entry_name_hint, items.get('items'), parent_name)
return TypeInfo(type_name=entry_type_name,
variant=messages.Variant.MESSAGE)
else:
return self.__GetTypeInfo(items, entry_name_hint)
elif type_name == 'any':
self.__AddImport('from %s import extra_types' %
self.__base_files_package)
return self.PRIMITIVE_TYPE_INFO_MAP['any']
elif type_name == 'object':
# TODO(craigcitro): Think of a better way to come up with names.
if not name_hint:
raise ValueError(
'Cannot create subtype without some name hint')
schema = dict(attrs)
schema['id'] = name_hint
self.AddDescriptorFromSchema(name_hint, schema)
self.__AddIfUnknown(name_hint)
return TypeInfo(
type_name=name_hint, variant=messages.Variant.MESSAGE)
raise ValueError('Unknown type: %s' % type_name)
def FixupMessageFields(self):
for message_type in self.file_descriptor.message_types:
self._FixupMessage(message_type)
def _FixupMessage(self, message_type):
with self.__DescriptorEnv(message_type):
for field in message_type.fields:
if field.field_descriptor.variant == messages.Variant.MESSAGE:
field_type_name = field.field_descriptor.type_name
field_type = self.LookupDescriptor(field_type_name)
if isinstance(field_type,
extended_descriptor.ExtendedEnumDescriptor):
field.field_descriptor.variant = messages.Variant.ENUM
for submessage_type in message_type.message_types:
self._FixupMessage(submessage_type)