blob: 972a95f5594e14c9c3ea6cce68aa3d913c9794f5 [file] [log] [blame]
#!/usr/bin/env python
"""Common code for converting proto to other formats, such as JSON."""
import base64
import collections
import datetime
import json
import logging
import os
import sys
from protorpc import message_types
from protorpc import messages
from protorpc import protojson
import six
from apitools.base.py import exceptions
__all__ = [
'CopyProtoMessage',
'JsonToMessage',
'MessageToJson',
'DictToMessage',
'MessageToDict',
'PyValueToMessage',
'MessageToPyValue',
'MessageToRepr',
'GetCustomJsonFieldMapping',
'AddCustomJsonFieldMapping',
'GetCustomJsonEnumMapping',
'AddCustomJsonEnumMapping',
]
_Codec = collections.namedtuple('_Codec', ['encoder', 'decoder'])
CodecResult = collections.namedtuple('CodecResult', ['value', 'complete'])
# TODO(craigcitro): Make these non-global.
_UNRECOGNIZED_FIELD_MAPPINGS = {}
_CUSTOM_MESSAGE_CODECS = {}
_CUSTOM_FIELD_CODECS = {}
_FIELD_TYPE_CODECS = {}
def MapUnrecognizedFields(field_name):
"""Register field_name as a container for unrecognized fields."""
def Register(cls):
_UNRECOGNIZED_FIELD_MAPPINGS[cls] = field_name
return cls
return Register
def RegisterCustomMessageCodec(encoder, decoder):
"""Register a custom encoder/decoder for this message class."""
def Register(cls):
_CUSTOM_MESSAGE_CODECS[cls] = _Codec(encoder=encoder, decoder=decoder)
return cls
return Register
def RegisterCustomFieldCodec(encoder, decoder):
"""Register a custom encoder/decoder for this field."""
def Register(field):
_CUSTOM_FIELD_CODECS[field] = _Codec(encoder=encoder, decoder=decoder)
return field
return Register
def RegisterFieldTypeCodec(encoder, decoder):
"""Register a custom encoder/decoder for all fields of this type."""
def Register(field_type):
_FIELD_TYPE_CODECS[field_type] = _Codec(
encoder=encoder, decoder=decoder)
return field_type
return Register
# TODO(craigcitro): Delete this function with the switch to proto2.
def CopyProtoMessage(message):
codec = protojson.ProtoJson()
return codec.decode_message(type(message), codec.encode_message(message))
def MessageToJson(message, include_fields=None):
"""Convert the given message to JSON."""
result = _ProtoJsonApiTools.Get().encode_message(message)
return _IncludeFields(result, message, include_fields)
def JsonToMessage(message_type, message):
"""Convert the given JSON to a message of type message_type."""
return _ProtoJsonApiTools.Get().decode_message(message_type, message)
# TODO(craigcitro): Do this directly, instead of via JSON.
def DictToMessage(d, message_type):
"""Convert the given dictionary to a message of type message_type."""
return JsonToMessage(message_type, json.dumps(d))
def MessageToDict(message):
"""Convert the given message to a dictionary."""
return json.loads(MessageToJson(message))
def PyValueToMessage(message_type, value):
"""Convert the given python value to a message of type message_type."""
return JsonToMessage(message_type, json.dumps(value))
def MessageToPyValue(message):
"""Convert the given message to a python value."""
return json.loads(MessageToJson(message))
def MessageToRepr(msg, multiline=False, **kwargs):
"""Return a repr-style string for a protorpc message.
protorpc.Message.__repr__ does not return anything that could be considered
python code. Adding this function lets us print a protorpc message in such
a way that it could be pasted into code later, and used to compare against
other things.
Args:
msg: protorpc.Message, the message to be repr'd.
multiline: bool, True if the returned string should have each field
assignment on its own line.
**kwargs: {str:str}, Additional flags for how to format the string.
Known **kwargs:
shortstrings: bool, True if all string values should be
truncated at 100 characters, since when mocking the contents
typically don't matter except for IDs, and IDs are usually
less than 100 characters.
no_modules: bool, True if the long module name should not be printed with
each type.
Returns:
str, A string of valid python (assuming the right imports have been made)
that recreates the message passed into this function.
"""
# TODO(user): craigcitro suggests a pretty-printer from apitools/gen.
indent = kwargs.get('indent', 0)
def IndentKwargs(kwargs):
kwargs = dict(kwargs)
kwargs['indent'] = kwargs.get('indent', 0) + 4
return kwargs
if isinstance(msg, list):
s = '['
for item in msg:
if multiline:
s += '\n' + ' ' * (indent + 4)
s += MessageToRepr(
item, multiline=multiline, **IndentKwargs(kwargs)) + ','
if multiline:
s += '\n' + ' ' * indent
s += ']'
return s
if isinstance(msg, messages.Message):
s = type(msg).__name__ + '('
if not kwargs.get('no_modules'):
s = msg.__module__ + '.' + s
names = sorted([field.name for field in msg.all_fields()])
for name in names:
field = msg.field_by_name(name)
if multiline:
s += '\n' + ' ' * (indent + 4)
value = getattr(msg, field.name)
s += field.name + '=' + MessageToRepr(
value, multiline=multiline, **IndentKwargs(kwargs)) + ','
if multiline:
s += '\n' + ' ' * indent
s += ')'
return s
if isinstance(msg, six.string_types):
if kwargs.get('shortstrings') and len(msg) > 100:
msg = msg[:100]
if isinstance(msg, datetime.datetime):
class SpecialTZInfo(datetime.tzinfo):
def __init__(self, offset):
super(SpecialTZInfo, self).__init__()
self.offset = offset
def __repr__(self):
s = 'TimeZoneOffset(' + repr(self.offset) + ')'
if not kwargs.get('no_modules'):
s = 'protorpc.util.' + s
return s
msg = datetime.datetime(
msg.year, msg.month, msg.day, msg.hour, msg.minute, msg.second,
msg.microsecond, SpecialTZInfo(msg.tzinfo.utcoffset(0)))
return repr(msg)
def _GetField(message, field_path):
for field in field_path:
if field not in dir(message):
raise KeyError('no field "%s"' % field)
message = getattr(message, field)
return message
def _SetField(dictblob, field_path, value):
for field in field_path[:-1]:
dictblob[field] = {}
dictblob = dictblob[field]
dictblob[field_path[-1]] = value
def _IncludeFields(encoded_message, message, include_fields):
"""Add the requested fields to the encoded message."""
if include_fields is None:
return encoded_message
result = json.loads(encoded_message)
for field_name in include_fields:
try:
value = _GetField(message, field_name.split('.'))
nullvalue = None
if isinstance(value, list):
nullvalue = []
except KeyError:
raise exceptions.InvalidDataError(
'No field named %s in message of type %s' % (
field_name, type(message)))
_SetField(result, field_name.split('.'), nullvalue)
return json.dumps(result)
def _GetFieldCodecs(field, attr):
result = [
getattr(_CUSTOM_FIELD_CODECS.get(field), attr, None),
getattr(_FIELD_TYPE_CODECS.get(type(field)), attr, None),
]
return [x for x in result if x is not None]
class _ProtoJsonApiTools(protojson.ProtoJson):
"""JSON encoder used by apitools clients."""
_INSTANCE = None
@classmethod
def Get(cls):
if cls._INSTANCE is None:
cls._INSTANCE = cls()
return cls._INSTANCE
def decode_message(self, message_type, encoded_message):
if message_type in _CUSTOM_MESSAGE_CODECS:
return _CUSTOM_MESSAGE_CODECS[
message_type].decoder(encoded_message)
# We turn off the default logging in protorpc. We may want to
# remove this later.
old_level = logging.getLogger().level
logging.getLogger().setLevel(logging.ERROR)
try:
result = _DecodeCustomFieldNames(message_type, encoded_message)
result = super(_ProtoJsonApiTools, self).decode_message(
message_type, result)
finally:
logging.getLogger().setLevel(old_level)
result = _ProcessUnknownEnums(result, encoded_message)
result = _ProcessUnknownMessages(result, encoded_message)
return _DecodeUnknownFields(result, encoded_message)
def decode_field(self, field, value):
"""Decode the given JSON value.
Args:
field: a messages.Field for the field we're decoding.
value: a python value we'd like to decode.
Returns:
A value suitable for assignment to field.
"""
for decoder in _GetFieldCodecs(field, 'decoder'):
result = decoder(field, value)
value = result.value
if result.complete:
return value
if isinstance(field, messages.MessageField):
field_value = self.decode_message(
field.message_type, json.dumps(value))
elif isinstance(field, messages.EnumField):
value = GetCustomJsonEnumMapping(
field.type, json_name=value) or value
try:
field_value = super(
_ProtoJsonApiTools, self).decode_field(field, value)
except messages.DecodeError:
if not isinstance(value, six.string_types):
raise
field_value = None
else:
field_value = super(
_ProtoJsonApiTools, self).decode_field(field, value)
return field_value
def encode_message(self, message):
if isinstance(message, messages.FieldList):
return '[%s]' % (', '.join(self.encode_message(x)
for x in message))
message_type = type(message)
if message_type in _CUSTOM_MESSAGE_CODECS:
return _CUSTOM_MESSAGE_CODECS[type(message)].encoder(message)
message = _EncodeUnknownFields(message)
result = super(_ProtoJsonApiTools, self).encode_message(message)
return _EncodeCustomFieldNames(message, result)
def encode_field(self, field, value):
"""Encode the given value as JSON.
Args:
field: a messages.Field for the field we're encoding.
value: a value for field.
Returns:
A python value suitable for json.dumps.
"""
for encoder in _GetFieldCodecs(field, 'encoder'):
result = encoder(field, value)
value = result.value
if result.complete:
return value
if isinstance(field, messages.EnumField):
if field.repeated:
remapped_value = [GetCustomJsonEnumMapping(
field.type, python_name=e.name) or e.name for e in value]
else:
remapped_value = GetCustomJsonEnumMapping(
field.type, python_name=value.name)
if remapped_value:
return remapped_value
if (isinstance(field, messages.MessageField) and
not isinstance(field, message_types.DateTimeField)):
value = json.loads(self.encode_message(value))
return super(_ProtoJsonApiTools, self).encode_field(field, value)
# TODO(craigcitro): Fold this and _IncludeFields in as codecs.
def _DecodeUnknownFields(message, encoded_message):
"""Rewrite unknown fields in message into message.destination."""
destination = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if destination is None:
return message
pair_field = message.field_by_name(destination)
if not isinstance(pair_field, messages.MessageField):
raise exceptions.InvalidDataFromServerError(
'Unrecognized fields must be mapped to a compound '
'message type.')
pair_type = pair_field.message_type
# TODO(craigcitro): Add more error checking around the pair
# type being exactly what we suspect (field names, etc).
if isinstance(pair_type.value, messages.MessageField):
new_values = _DecodeUnknownMessages(
message, json.loads(encoded_message), pair_type)
else:
new_values = _DecodeUnrecognizedFields(message, pair_type)
setattr(message, destination, new_values)
# We could probably get away with not setting this, but
# why not clear it?
setattr(message, '_Message__unrecognized_fields', {})
return message
def _DecodeUnknownMessages(message, encoded_message, pair_type):
"""Process unknown fields in encoded_message of a message type."""
field_type = pair_type.value.type
new_values = []
all_field_names = [x.name for x in message.all_fields()]
for name, value_dict in encoded_message.items():
if name in all_field_names:
continue
value = PyValueToMessage(field_type, value_dict)
if pair_type.value.repeated:
value = _AsMessageList(value)
new_pair = pair_type(key=name, value=value)
new_values.append(new_pair)
return new_values
def _DecodeUnrecognizedFields(message, pair_type):
"""Process unrecognized fields in message."""
new_values = []
for unknown_field in message.all_unrecognized_fields():
# TODO(craigcitro): Consider validating the variant if
# the assignment below doesn't take care of it. It may
# also be necessary to check it in the case that the
# type has multiple encodings.
value, _ = message.get_unrecognized_field_info(unknown_field)
value_type = pair_type.field_by_name('value')
if isinstance(value_type, messages.MessageField):
decoded_value = DictToMessage(value, pair_type.value.message_type)
elif isinstance(value_type, messages.EnumField):
decoded_value = pair_type.value.type(value)
else:
decoded_value = value
new_pair = pair_type(key=str(unknown_field), value=decoded_value)
new_values.append(new_pair)
return new_values
def _EncodeUnknownFields(message):
"""Remap unknown fields in message out of message.source."""
source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if source is None:
return message
result = CopyProtoMessage(message)
pairs_field = message.field_by_name(source)
if not isinstance(pairs_field, messages.MessageField):
raise exceptions.InvalidUserInputError(
'Invalid pairs field %s' % pairs_field)
pairs_type = pairs_field.message_type
value_variant = pairs_type.field_by_name('value').variant
pairs = getattr(message, source)
for pair in pairs:
if value_variant == messages.Variant.MESSAGE:
encoded_value = MessageToDict(pair.value)
else:
encoded_value = pair.value
result.set_unrecognized_field(pair.key, encoded_value, value_variant)
setattr(result, source, [])
return result
def _SafeEncodeBytes(field, value):
"""Encode the bytes in value as urlsafe base64."""
try:
if field.repeated:
result = [base64.urlsafe_b64encode(byte) for byte in value]
else:
result = base64.urlsafe_b64encode(value)
complete = True
except TypeError:
result = value
complete = False
return CodecResult(value=result, complete=complete)
def _SafeDecodeBytes(unused_field, value):
"""Decode the urlsafe base64 value into bytes."""
try:
result = base64.urlsafe_b64decode(str(value))
complete = True
except TypeError:
result = value
complete = False
return CodecResult(value=result, complete=complete)
def _ProcessUnknownEnums(message, encoded_message):
"""Add unknown enum values from encoded_message as unknown fields.
ProtoRPC diverges from the usual protocol buffer behavior here and
doesn't allow unknown fields. Throwing on unknown fields makes it
impossible to let servers add new enum values and stay compatible
with older clients, which isn't reasonable for us. We simply store
unrecognized enum values as unknown fields, and all is well.
Args:
message: Proto message we've decoded thus far.
encoded_message: JSON string we're decoding.
Returns:
message, with any unknown enums stored as unrecognized fields.
"""
if not encoded_message:
return message
decoded_message = json.loads(encoded_message)
for field in message.all_fields():
if (isinstance(field, messages.EnumField) and
field.name in decoded_message and
message.get_assigned_value(field.name) is None):
message.set_unrecognized_field(
field.name, decoded_message[field.name], messages.Variant.ENUM)
return message
def _ProcessUnknownMessages(message, encoded_message):
"""Store any remaining unknown fields as strings.
ProtoRPC currently ignores unknown values for which no type can be
determined (and logs a "No variant found" message). For the purposes
of reserializing, this is quite harmful (since it throws away
information). Here we simply add those as unknown fields of type
string (so that they can easily be reserialized).
Args:
message: Proto message we've decoded thus far.
encoded_message: JSON string we're decoding.
Returns:
message, with any remaining unrecognized fields saved.
"""
if not encoded_message:
return message
decoded_message = json.loads(encoded_message)
message_fields = [x.name for x in message.all_fields()] + list(
message.all_unrecognized_fields())
missing_fields = [x for x in decoded_message.keys()
if x not in message_fields]
for field_name in missing_fields:
message.set_unrecognized_field(field_name, decoded_message[field_name],
messages.Variant.STRING)
return message
RegisterFieldTypeCodec(_SafeEncodeBytes, _SafeDecodeBytes)(messages.BytesField)
# Note that these could share a dictionary, since they're keyed by
# distinct types, but it's not really worth it.
_JSON_ENUM_MAPPINGS = {}
_JSON_FIELD_MAPPINGS = {}
def _GetTypeKey(message_type, package):
"""Get the prefix for this message type in mapping dicts."""
key = message_type.definition_name()
if package and key.startswith(package + '.'):
module_name = message_type.__module__
# We normalize '__main__' to something unique, if possible.
if module_name == '__main__':
try:
file_name = sys.modules[module_name].__file__
except (AttributeError, KeyError):
pass
else:
base_name = os.path.basename(file_name)
split_name = os.path.splitext(base_name)
if len(split_name) == 1:
module_name = unicode(base_name)
else:
module_name = u'.'.join(split_name[:-1])
key = module_name + '.' + key.partition('.')[2]
return key
def AddCustomJsonEnumMapping(enum_type, python_name, json_name,
package=''):
"""Add a custom wire encoding for a given enum value.
This is primarily used in generated code, to handle enum values
which happen to be Python keywords.
Args:
enum_type: (messages.Enum) An enum type
python_name: (string) Python name for this value.
json_name: (string) JSON name to be used on the wire.
package: (basestring, optional) Package prefix for this enum, if
present. We strip this off the enum name in order to generate
unique keys.
"""
if not issubclass(enum_type, messages.Enum):
raise exceptions.TypecheckError(
'Cannot set JSON enum mapping for non-enum "%s"' % enum_type)
enum_name = _GetTypeKey(enum_type, package)
if python_name not in enum_type.names():
raise exceptions.InvalidDataError(
'Enum value %s not a value for type %s' % (python_name, enum_type))
field_mappings = _JSON_ENUM_MAPPINGS.setdefault(enum_name, {})
_CheckForExistingMappings('enum', enum_type, python_name, json_name)
field_mappings[python_name] = json_name
def AddCustomJsonFieldMapping(message_type, python_name, json_name,
package=''):
"""Add a custom wire encoding for a given message field.
This is primarily used in generated code, to handle enum values
which happen to be Python keywords.
Args:
message_type: (messages.Message) A message type
python_name: (string) Python name for this value.
json_name: (string) JSON name to be used on the wire.
package: (basestring, optional) Package prefix for this message, if
present. We strip this off the message name in order to generate
unique keys.
"""
if not issubclass(message_type, messages.Message):
raise exceptions.TypecheckError(
'Cannot set JSON field mapping for '
'non-message "%s"' % message_type)
message_name = _GetTypeKey(message_type, package)
try:
_ = message_type.field_by_name(python_name)
except KeyError:
raise exceptions.InvalidDataError(
'Field %s not recognized for type %s' % (
python_name, message_type))
field_mappings = _JSON_FIELD_MAPPINGS.setdefault(message_name, {})
_CheckForExistingMappings('field', message_type, python_name, json_name)
field_mappings[python_name] = json_name
def GetCustomJsonEnumMapping(enum_type, python_name=None, json_name=None):
"""Return the appropriate remapping for the given enum, or None."""
return _FetchRemapping(enum_type.definition_name(), 'enum',
python_name=python_name, json_name=json_name,
mappings=_JSON_ENUM_MAPPINGS)
def GetCustomJsonFieldMapping(message_type, python_name=None, json_name=None):
"""Return the appropriate remapping for the given field, or None."""
return _FetchRemapping(message_type.definition_name(), 'field',
python_name=python_name, json_name=json_name,
mappings=_JSON_FIELD_MAPPINGS)
def _FetchRemapping(type_name, mapping_type, python_name=None, json_name=None,
mappings=None):
"""Common code for fetching a key or value from a remapping dict."""
if python_name and json_name:
raise exceptions.InvalidDataError(
'Cannot specify both python_name and json_name '
'for %s remapping' % mapping_type)
if not (python_name or json_name):
raise exceptions.InvalidDataError(
'Must specify either python_name or json_name for %s remapping' % (
mapping_type,))
field_remappings = mappings.get(type_name, {})
if field_remappings:
if python_name:
return field_remappings.get(python_name)
elif json_name:
if json_name in list(field_remappings.values()):
return [k for k in field_remappings
if field_remappings[k] == json_name][0]
return None
def _CheckForExistingMappings(mapping_type, message_type,
python_name, json_name):
"""Validate that no mappings exist for the given values."""
if mapping_type == 'field':
getter = GetCustomJsonFieldMapping
elif mapping_type == 'enum':
getter = GetCustomJsonEnumMapping
remapping = getter(message_type, python_name=python_name)
if remapping is not None:
raise exceptions.InvalidDataError(
'Cannot add mapping for %s "%s", already mapped to "%s"' % (
mapping_type, python_name, remapping))
remapping = getter(message_type, json_name=json_name)
if remapping is not None:
raise exceptions.InvalidDataError(
'Cannot add mapping for %s "%s", already mapped to "%s"' % (
mapping_type, json_name, remapping))
def _EncodeCustomFieldNames(message, encoded_value):
message_name = type(message).definition_name()
field_remappings = list(_JSON_FIELD_MAPPINGS.get(message_name, {}).items())
if field_remappings:
decoded_value = json.loads(encoded_value)
for python_name, json_name in field_remappings:
if python_name in encoded_value:
decoded_value[json_name] = decoded_value.pop(python_name)
encoded_value = json.dumps(decoded_value)
return encoded_value
def _DecodeCustomFieldNames(message_type, encoded_message):
message_name = message_type.definition_name()
field_remappings = _JSON_FIELD_MAPPINGS.get(message_name, {})
if field_remappings:
decoded_message = json.loads(encoded_message)
for python_name, json_name in list(field_remappings.items()):
if json_name in decoded_message:
decoded_message[python_name] = decoded_message.pop(json_name)
encoded_message = json.dumps(decoded_message)
return encoded_message
def _AsMessageList(msg):
"""Convert the provided list-as-JsonValue to a list."""
# This really needs to live in extra_types, but extra_types needs
# to import this file to be able to register codecs.
# TODO(craigcitro): Split out a codecs module and fix this ugly
# import.
from apitools.base.py import extra_types
def _IsRepeatedJsonValue(msg):
"""Return True if msg is a repeated value as a JsonValue."""
if isinstance(msg, extra_types.JsonArray):
return True
if isinstance(msg, extra_types.JsonValue) and msg.array_value:
return True
return False
if not _IsRepeatedJsonValue(msg):
raise ValueError('invalid argument to _AsMessageList')
if isinstance(msg, extra_types.JsonValue):
msg = msg.array_value
if isinstance(msg, extra_types.JsonArray):
msg = msg.entries
return msg