#!/usr/bin/env python
"""Message registry for apitools."""

import collections
import contextlib
import json

from protorpc import descriptor
from protorpc import messages
import six

from apitools.gen import extended_descriptor
from apitools.gen import util

TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant'))


class MessageRegistry(object):

    """Registry for message types.

    This closely mirrors a messages.FileDescriptor, but adds additional
    attributes (such as message and field descriptions) and some extra
    code for validation and cycle detection.
    """

    # Type information from these two maps comes from here:
    #  https://developers.google.com/discovery/v1/type-format
    PRIMITIVE_TYPE_INFO_MAP = {
        'string': TypeInfo(type_name='string',
                           variant=messages.StringField.DEFAULT_VARIANT),
        'integer': TypeInfo(type_name='integer',
                            variant=messages.IntegerField.DEFAULT_VARIANT),
        'boolean': TypeInfo(type_name='boolean',
                            variant=messages.BooleanField.DEFAULT_VARIANT),
        'number': TypeInfo(type_name='number',
                           variant=messages.FloatField.DEFAULT_VARIANT),
        'any': TypeInfo(type_name='extra_types.JsonValue',
                        variant=messages.Variant.MESSAGE),
    }

    PRIMITIVE_FORMAT_MAP = {
        'int32': TypeInfo(type_name='integer',
                          variant=messages.Variant.INT32),
        'uint32': TypeInfo(type_name='integer',
                           variant=messages.Variant.UINT32),
        'int64': TypeInfo(type_name='string',
                          variant=messages.Variant.INT64),
        'uint64': TypeInfo(type_name='string',
                           variant=messages.Variant.UINT64),
        'double': TypeInfo(type_name='number',
                           variant=messages.Variant.DOUBLE),
        'float': TypeInfo(type_name='number',
                          variant=messages.Variant.FLOAT),
        'byte': TypeInfo(type_name='byte',
                         variant=messages.BytesField.DEFAULT_VARIANT),
        'date': TypeInfo(type_name='extra_types.DateField',
                         variant=messages.Variant.STRING),
        'date-time': TypeInfo(
            type_name='protorpc.message_types.DateTimeMessage',
            variant=messages.Variant.MESSAGE),
    }

    def __init__(self, client_info, names, description,
                 root_package_dir, base_files_package):
        self.__names = names
        self.__client_info = client_info
        self.__package = client_info.package
        self.__description = util.CleanDescription(description)
        self.__root_package_dir = root_package_dir
        self.__base_files_package = base_files_package
        self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor(
            package=self.__package, description=self.__description)
        # Add required imports
        self.__file_descriptor.additional_imports = [
            'from protorpc import messages as _messages',
        ]
        # Map from scoped names (i.e. Foo.Bar) to MessageDescriptors.
        self.__message_registry = collections.OrderedDict()
        # A set of types that we're currently adding (for cycle detection).
        self.__nascent_types = set()
        # A set of types for which we've seen a reference but no
        # definition; if this set is nonempty, validation fails.
        self.__unknown_types = set()
        # Used for tracking paths during message creation
        self.__current_path = []
        # Where to register created messages
        self.__current_env = self.__file_descriptor
        # TODO(craigcitro): Add a `Finalize` method.

    @property
    def file_descriptor(self):
        self.Validate()
        return self.__file_descriptor

    def WriteProtoFile(self, printer):
        """Write the messages file to out as proto."""
        self.Validate()
        extended_descriptor.WriteMessagesFile(
            self.__file_descriptor, self.__package, self.__client_info.version,
            printer)

    def WriteFile(self, printer):
        """Write the messages file to out."""
        self.Validate()
        extended_descriptor.WritePythonFile(
            self.__file_descriptor, self.__package, self.__client_info.version,
            printer)

    def Validate(self):
        mysteries = self.__nascent_types or self.__unknown_types
        if mysteries:
            raise ValueError('Malformed MessageRegistry: %s' % mysteries)

    def __ComputeFullName(self, name):
        return '.'.join(map(six.text_type, self.__current_path[:] + [name]))

    def __AddImport(self, new_import):
        if new_import not in self.__file_descriptor.additional_imports:
            self.__file_descriptor.additional_imports.append(new_import)

    def __DeclareDescriptor(self, name):
        self.__nascent_types.add(self.__ComputeFullName(name))

    def __RegisterDescriptor(self, new_descriptor):
        """Register the given descriptor in this registry."""
        if not isinstance(new_descriptor, (
                extended_descriptor.ExtendedMessageDescriptor,
                extended_descriptor.ExtendedEnumDescriptor)):
            raise ValueError('Cannot add descriptor of type %s' % (
                type(new_descriptor),))
        full_name = self.__ComputeFullName(new_descriptor.name)
        if full_name in self.__message_registry:
            raise ValueError(
                'Attempt to re-register descriptor %s' % full_name)
        if full_name not in self.__nascent_types:
            raise ValueError('Directly adding types is not supported')
        new_descriptor.full_name = full_name
        self.__message_registry[full_name] = new_descriptor
        if isinstance(new_descriptor,
                      extended_descriptor.ExtendedMessageDescriptor):
            self.__current_env.message_types.append(new_descriptor)
        elif isinstance(new_descriptor,
                        extended_descriptor.ExtendedEnumDescriptor):
            self.__current_env.enum_types.append(new_descriptor)
        self.__unknown_types.discard(full_name)
        self.__nascent_types.remove(full_name)

    def LookupDescriptor(self, name):
        return self.__GetDescriptorByName(name)

    def LookupDescriptorOrDie(self, name):
        message_descriptor = self.LookupDescriptor(name)
        if message_descriptor is None:
            raise ValueError('No message descriptor named "%s"', name)
        return message_descriptor

    def __GetDescriptor(self, name):
        return self.__GetDescriptorByName(self.__ComputeFullName(name))

    def __GetDescriptorByName(self, name):
        if name in self.__message_registry:
            return self.__message_registry[name]
        if name in self.__nascent_types:
            raise ValueError(
                'Cannot retrieve type currently being created: %s' % name)
        return None

    @contextlib.contextmanager
    def __DescriptorEnv(self, message_descriptor):
        # TODO(craigcitro): Typecheck?
        previous_env = self.__current_env
        self.__current_path.append(message_descriptor.name)
        self.__current_env = message_descriptor
        yield
        self.__current_path.pop()
        self.__current_env = previous_env

    def AddEnumDescriptor(self, name, description,
                          enum_values, enum_descriptions):
        """Add a new EnumDescriptor named name with the given enum values."""
        message = extended_descriptor.ExtendedEnumDescriptor()
        message.name = self.__names.ClassName(name)
        message.description = util.CleanDescription(description)
        self.__DeclareDescriptor(message.name)
        for index, (enum_name, enum_description) in enumerate(
                zip(enum_values, enum_descriptions)):
            enum_value = extended_descriptor.ExtendedEnumValueDescriptor()
            enum_value.name = self.__names.NormalizeEnumName(enum_name)
            if enum_value.name != enum_name:
                message.enum_mappings.append(
                    extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping(
                        python_name=enum_value.name, json_name=enum_name))
                self.__AddImport('from %s import encoding' %
                                 self.__base_files_package)
            enum_value.number = index
            enum_value.description = util.CleanDescription(
                enum_description or '<no description>')
            message.values.append(enum_value)
        self.__RegisterDescriptor(message)

    def __DeclareMessageAlias(self, schema, alias_for):
        """Declare schema as an alias for alias_for."""
        # TODO(craigcitro): This is a hack. Remove it.
        message = extended_descriptor.ExtendedMessageDescriptor()
        message.name = self.__names.ClassName(schema['id'])
        message.alias_for = alias_for
        self.__DeclareDescriptor(message.name)
        self.__AddImport('from %s import extra_types' %
                         self.__base_files_package)
        self.__RegisterDescriptor(message)

    def __AddAdditionalProperties(self, message, schema, properties):
        """Add an additionalProperties field to message."""
        additional_properties_info = schema['additionalProperties']
        entries_type_name = self.__AddAdditionalPropertyType(
            message.name, additional_properties_info)
        description = util.CleanDescription(
            additional_properties_info.get('description'))
        if description is None:
            description = 'Additional properties of type %s' % message.name
        attrs = {
            'items': {
                '$ref': entries_type_name,
            },
            'description': description,
            'type': 'array',
        }
        field_name = 'additionalProperties'
        message.fields.append(self.__FieldDescriptorFromProperties(
            field_name, len(properties) + 1, attrs))
        self.__AddImport('from %s import encoding' % self.__base_files_package)
        message.decorators.append(
            'encoding.MapUnrecognizedFields(%r)' % field_name)

    def AddDescriptorFromSchema(self, schema_name, schema):
        """Add a new MessageDescriptor named schema_name based on schema."""
        # TODO(craigcitro): Is schema_name redundant?
        if self.__GetDescriptor(schema_name):
            return
        if schema.get('enum'):
            self.__DeclareEnum(schema_name, schema)
            return
        if schema.get('type') == 'any':
            self.__DeclareMessageAlias(schema, 'extra_types.JsonValue')
            return
        if schema.get('type') != 'object':
            raise ValueError('Cannot create message descriptors for type %s',
                             schema.get('type'))
        message = extended_descriptor.ExtendedMessageDescriptor()
        message.name = self.__names.ClassName(schema['id'])
        message.description = util.CleanDescription(schema.get(
            'description', 'A %s object.' % message.name))
        self.__DeclareDescriptor(message.name)
        with self.__DescriptorEnv(message):
            properties = schema.get('properties', {})
            for index, (name, attrs) in enumerate(sorted(properties.items())):
                field = self.__FieldDescriptorFromProperties(
                    name, index + 1, attrs)
                message.fields.append(field)
                if field.name != name:
                    message.field_mappings.append(
                        type(message).JsonFieldMapping(
                            python_name=field.name, json_name=name))
                    self.__AddImport(
                        'from %s import encoding' % self.__base_files_package)
            if 'additionalProperties' in schema:
                self.__AddAdditionalProperties(message, schema, properties)
        self.__RegisterDescriptor(message)

    def __AddAdditionalPropertyType(self, name, property_schema):
        """Add a new nested AdditionalProperty message."""
        new_type_name = 'AdditionalProperty'
        property_schema = dict(property_schema)
        # We drop the description here on purpose, so the resulting
        # messages are less repetitive.
        property_schema.pop('description', None)
        description = 'An additional property for a %s object.' % name
        schema = {
            'id': new_type_name,
            'type': 'object',
            'description': description,
            'properties': {
                'key': {
                    'type': 'string',
                    'description': 'Name of the additional property.',
                },
                'value': property_schema,
            },
        }
        self.AddDescriptorFromSchema(new_type_name, schema)
        return new_type_name

    def __AddEntryType(self, entry_type_name, entry_schema, parent_name):
        """Add a type for a list entry."""
        entry_schema.pop('description', None)
        description = 'Single entry in a %s.' % parent_name
        schema = {
            'id': entry_type_name,
            'type': 'object',
            'description': description,
            'properties': {
                'entry': {
                    'type': 'array',
                    'items': entry_schema,
                },
            },
        }
        self.AddDescriptorFromSchema(entry_type_name, schema)
        return entry_type_name

    def __FieldDescriptorFromProperties(self, name, index, attrs):
        """Create a field descriptor for these attrs."""
        field = descriptor.FieldDescriptor()
        field.name = self.__names.CleanName(name)
        field.number = index
        field.label = self.__ComputeLabel(attrs)
        new_type_name_hint = self.__names.ClassName(
            '%sValue' % self.__names.ClassName(name))
        type_info = self.__GetTypeInfo(attrs, new_type_name_hint)
        field.type_name = type_info.type_name
        field.variant = type_info.variant
        if 'default' in attrs:
            # TODO(craigcitro): Correctly handle non-primitive default values.
            default = attrs['default']
            if not (field.type_name == 'string' or
                    field.variant == messages.Variant.ENUM):
                default = str(json.loads(default))
            if field.variant == messages.Variant.ENUM:
                default = self.__names.NormalizeEnumName(default)
            field.default_value = default
        extended_field = extended_descriptor.ExtendedFieldDescriptor()
        extended_field.name = field.name
        extended_field.description = util.CleanDescription(
            attrs.get('description', 'A %s attribute.' % field.type_name))
        extended_field.field_descriptor = field
        return extended_field

    @staticmethod
    def __ComputeLabel(attrs):
        if attrs.get('required', False):
            return descriptor.FieldDescriptor.Label.REQUIRED
        elif attrs.get('type') == 'array':
            return descriptor.FieldDescriptor.Label.REPEATED
        elif attrs.get('repeated'):
            return descriptor.FieldDescriptor.Label.REPEATED
        return descriptor.FieldDescriptor.Label.OPTIONAL

    def __DeclareEnum(self, enum_name, attrs):
        description = util.CleanDescription(attrs.get('description', ''))
        enum_values = attrs['enum']
        enum_descriptions = attrs.get(
            'enumDescriptions', [''] * len(enum_values))
        self.AddEnumDescriptor(enum_name, description,
                               enum_values, enum_descriptions)
        self.__AddIfUnknown(enum_name)
        return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM)

    def __AddIfUnknown(self, type_name):
        type_name = self.__names.ClassName(type_name)
        full_type_name = self.__ComputeFullName(type_name)
        if (full_type_name not in self.__message_registry.keys() and
                type_name not in self.__message_registry.keys()):
            self.__unknown_types.add(type_name)

    def __GetTypeInfo(self, attrs, name_hint):
        """Return a TypeInfo object for attrs, creating one if needed."""

        type_ref = self.__names.ClassName(attrs.get('$ref'))
        type_name = attrs.get('type')
        if not (type_ref or type_name):
            raise ValueError('No type found for %s' % attrs)

        if type_ref:
            self.__AddIfUnknown(type_ref)
            # We don't actually know this is a message -- it might be an
            # enum. However, we can't check that until we've created all the
            # types, so we come back and fix this up later.
            return TypeInfo(
                type_name=type_ref, variant=messages.Variant.MESSAGE)

        if 'enum' in attrs:
            enum_name = '%sValuesEnum' % name_hint
            return self.__DeclareEnum(enum_name, attrs)

        if 'format' in attrs:
            type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format'])
            if type_info is None:
                # If we don't recognize the format, the spec says we fall back
                # to just using the type name.
                if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
                    return self.PRIMITIVE_TYPE_INFO_MAP[type_name]
                raise ValueError('Unknown type/format "%s"/"%s"' % (
                    attrs['format'], type_name))
            if (type_info.type_name.startswith('protorpc.message_types.') or
                    type_info.type_name.startswith('message_types.')):
                self.__AddImport(
                    'from protorpc import message_types as _message_types')
            if type_info.type_name.startswith('extra_types.'):
                self.__AddImport(
                    'from %s import extra_types' % self.__base_files_package)
            return type_info

        if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
            type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name]
            return type_info

        if type_name == 'array':
            items = attrs.get('items')
            if not items:
                raise ValueError('Array type with no item type: %s' % attrs)
            entry_name_hint = self.__names.ClassName(
                items.get('title') or '%sListEntry' % name_hint)
            entry_label = self.__ComputeLabel(items)
            if entry_label == descriptor.FieldDescriptor.Label.REPEATED:
                parent_name = self.__names.ClassName(
                    items.get('title') or name_hint)
                entry_type_name = self.__AddEntryType(
                    entry_name_hint, items.get('items'), parent_name)
                return TypeInfo(type_name=entry_type_name,
                                variant=messages.Variant.MESSAGE)
            else:
                return self.__GetTypeInfo(items, entry_name_hint)
        elif type_name == 'any':
            self.__AddImport('from %s import extra_types' %
                             self.__base_files_package)
            return self.PRIMITIVE_TYPE_INFO_MAP['any']
        elif type_name == 'object':
            # TODO(craigcitro): Think of a better way to come up with names.
            if not name_hint:
                raise ValueError(
                    'Cannot create subtype without some name hint')
            schema = dict(attrs)
            schema['id'] = name_hint
            self.AddDescriptorFromSchema(name_hint, schema)
            self.__AddIfUnknown(name_hint)
            return TypeInfo(
                type_name=name_hint, variant=messages.Variant.MESSAGE)

        raise ValueError('Unknown type: %s' % type_name)

    def FixupMessageFields(self):
        for message_type in self.file_descriptor.message_types:
            self._FixupMessage(message_type)

    def _FixupMessage(self, message_type):
        with self.__DescriptorEnv(message_type):
            for field in message_type.fields:
                if field.field_descriptor.variant == messages.Variant.MESSAGE:
                    field_type_name = field.field_descriptor.type_name
                    field_type = self.LookupDescriptor(field_type_name)
                    if isinstance(field_type,
                                  extended_descriptor.ExtendedEnumDescriptor):
                        field.field_descriptor.variant = messages.Variant.ENUM
            for submessage_type in message_type.message_types:
                self._FixupMessage(submessage_type)
