blob: 1e72b131b2918a01de7e045d0f5f59d2cb51b746 [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for protorpc.messages."""
import six
__author__ = 'rafek@google.com (Rafe Kaplan)'
import pickle
import re
import sys
import types
import unittest
from protorpc import descriptor
from protorpc import message_types
from protorpc import messages
from protorpc import test_util
class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
test_util.TestCase):
MODULE = messages
class ValidationErrorTest(test_util.TestCase):
def testStr_NoFieldName(self):
"""Test string version of ValidationError when no name provided."""
self.assertEquals('Validation error',
str(messages.ValidationError('Validation error')))
def testStr_FieldName(self):
"""Test string version of ValidationError when no name provided."""
validation_error = messages.ValidationError('Validation error')
validation_error.field_name = 'a_field'
self.assertEquals('Validation error', str(validation_error))
class EnumTest(test_util.TestCase):
def setUp(self):
"""Set up tests."""
# Redefine Color class in case so that changes to it (an error) in one test
# does not affect other tests.
global Color
class Color(messages.Enum):
RED = 20
ORANGE = 2
YELLOW = 40
GREEN = 4
BLUE = 50
INDIGO = 5
VIOLET = 80
def testNames(self):
"""Test that names iterates over enum names."""
self.assertEquals(
set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']),
set(Color.names()))
def testNumbers(self):
"""Tests that numbers iterates of enum numbers."""
self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers()))
def testIterate(self):
"""Test that __iter__ iterates over all enum values."""
self.assertEquals(set(Color),
set([Color.RED,
Color.ORANGE,
Color.YELLOW,
Color.GREEN,
Color.BLUE,
Color.INDIGO,
Color.VIOLET]))
def testNaturalOrder(self):
"""Test that natural order enumeration is in numeric order."""
self.assertEquals([Color.ORANGE,
Color.GREEN,
Color.INDIGO,
Color.RED,
Color.YELLOW,
Color.BLUE,
Color.VIOLET],
sorted(Color))
def testByName(self):
"""Test look-up by name."""
self.assertEquals(Color.RED, Color.lookup_by_name('RED'))
self.assertRaises(KeyError, Color.lookup_by_name, 20)
self.assertRaises(KeyError, Color.lookup_by_name, Color.RED)
def testByNumber(self):
"""Test look-up by number."""
self.assertRaises(KeyError, Color.lookup_by_number, 'RED')
self.assertEquals(Color.RED, Color.lookup_by_number(20))
self.assertRaises(KeyError, Color.lookup_by_number, Color.RED)
def testConstructor(self):
"""Test that constructor look-up by name or number."""
self.assertEquals(Color.RED, Color('RED'))
self.assertEquals(Color.RED, Color(u'RED'))
self.assertEquals(Color.RED, Color(20))
if six.PY2:
self.assertEquals(Color.RED, Color(long(20)))
self.assertEquals(Color.RED, Color(Color.RED))
self.assertRaises(TypeError, Color, 'Not exists')
self.assertRaises(TypeError, Color, 'Red')
self.assertRaises(TypeError, Color, 100)
self.assertRaises(TypeError, Color, 10.0)
def testLen(self):
"""Test that len function works to count enums."""
self.assertEquals(7, len(Color))
def testNoSubclasses(self):
"""Test that it is not possible to sub-class enum classes."""
def declare_subclass():
class MoreColor(Color):
pass
self.assertRaises(messages.EnumDefinitionError,
declare_subclass)
def testClassNotMutable(self):
"""Test that enum classes themselves are not mutable."""
self.assertRaises(AttributeError,
setattr,
Color,
'something_new',
10)
def testInstancesMutable(self):
"""Test that enum instances are not mutable."""
self.assertRaises(TypeError,
setattr,
Color.RED,
'something_new',
10)
def testDefEnum(self):
"""Test def_enum works by building enum class from dict."""
WeekDay = messages.Enum.def_enum({'Monday': 1,
'Tuesday': 2,
'Wednesday': 3,
'Thursday': 4,
'Friday': 6,
'Saturday': 7,
'Sunday': 8},
'WeekDay')
self.assertEquals('Wednesday', WeekDay(3).name)
self.assertEquals(6, WeekDay('Friday').number)
self.assertEquals(WeekDay.Sunday, WeekDay('Sunday'))
def testNonInt(self):
"""Test that non-integer values rejection by enum def."""
self.assertRaises(messages.EnumDefinitionError,
messages.Enum.def_enum,
{'Bad': '1'},
'BadEnum')
def testNegativeInt(self):
"""Test that negative numbers rejection by enum def."""
self.assertRaises(messages.EnumDefinitionError,
messages.Enum.def_enum,
{'Bad': -1},
'BadEnum')
def testLowerBound(self):
"""Test that zero is accepted by enum def."""
class NotImportant(messages.Enum):
"""Testing for value zero"""
VALUE = 0
self.assertEquals(0, int(NotImportant.VALUE))
def testTooLargeInt(self):
"""Test that numbers too large are rejected."""
self.assertRaises(messages.EnumDefinitionError,
messages.Enum.def_enum,
{'Bad': (2 ** 29)},
'BadEnum')
def testRepeatedInt(self):
"""Test duplicated numbers are forbidden."""
self.assertRaises(messages.EnumDefinitionError,
messages.Enum.def_enum,
{'Ok': 1, 'Repeated': 1},
'BadEnum')
def testStr(self):
"""Test converting to string."""
self.assertEquals('RED', str(Color.RED))
self.assertEquals('ORANGE', str(Color.ORANGE))
def testInt(self):
"""Test converting to int."""
self.assertEquals(20, int(Color.RED))
self.assertEquals(2, int(Color.ORANGE))
def testRepr(self):
"""Test enum representation."""
self.assertEquals('Color(RED, 20)', repr(Color.RED))
self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW))
def testDocstring(self):
"""Test that docstring is supported ok."""
class NotImportant(messages.Enum):
"""I have a docstring."""
VALUE1 = 1
self.assertEquals('I have a docstring.', NotImportant.__doc__)
def testDeleteEnumValue(self):
"""Test that enum values cannot be deleted."""
self.assertRaises(TypeError, delattr, Color, 'RED')
def testEnumName(self):
"""Test enum name."""
module_name = test_util.get_module_name(EnumTest)
self.assertEquals('%s.Color' % module_name, Color.definition_name())
self.assertEquals(module_name, Color.outer_definition_name())
self.assertEquals(module_name, Color.definition_package())
def testDefinitionName_OverrideModule(self):
"""Test enum module is overriden by module package name."""
global package
try:
package = 'my.package'
self.assertEquals('my.package.Color', Color.definition_name())
self.assertEquals('my.package', Color.outer_definition_name())
self.assertEquals('my.package', Color.definition_package())
finally:
del package
def testDefinitionName_NoModule(self):
"""Test what happens when there is no module for enum."""
class Enum1(messages.Enum):
pass
original_modules = sys.modules
sys.modules = dict(sys.modules)
try:
del sys.modules[__name__]
self.assertEquals('Enum1', Enum1.definition_name())
self.assertEquals(None, Enum1.outer_definition_name())
self.assertEquals(None, Enum1.definition_package())
self.assertEquals(six.text_type, type(Enum1.definition_name()))
finally:
sys.modules = original_modules
def testDefinitionName_Nested(self):
"""Test nested Enum names."""
class MyMessage(messages.Message):
class NestedEnum(messages.Enum):
pass
class NestedMessage(messages.Message):
class NestedEnum(messages.Enum):
pass
module_name = test_util.get_module_name(EnumTest)
self.assertEquals('%s.MyMessage.NestedEnum' % module_name,
MyMessage.NestedEnum.definition_name())
self.assertEquals('%s.MyMessage' % module_name,
MyMessage.NestedEnum.outer_definition_name())
self.assertEquals(module_name,
MyMessage.NestedEnum.definition_package())
self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name,
MyMessage.NestedMessage.NestedEnum.definition_name())
self.assertEquals(
'%s.MyMessage.NestedMessage' % module_name,
MyMessage.NestedMessage.NestedEnum.outer_definition_name())
self.assertEquals(module_name,
MyMessage.NestedMessage.NestedEnum.definition_package())
def testMessageDefinition(self):
"""Test that enumeration knows its enclosing message definition."""
class OuterEnum(messages.Enum):
pass
self.assertEquals(None, OuterEnum.message_definition())
class OuterMessage(messages.Message):
class InnerEnum(messages.Enum):
pass
self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition())
def testComparison(self):
"""Test comparing various enums to different types."""
class Enum1(messages.Enum):
VAL1 = 1
VAL2 = 2
class Enum2(messages.Enum):
VAL1 = 1
self.assertEquals(Enum1.VAL1, Enum1.VAL1)
self.assertNotEquals(Enum1.VAL1, Enum1.VAL2)
self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
self.assertNotEquals(Enum1.VAL1, 'VAL1')
self.assertNotEquals(Enum1.VAL1, 1)
self.assertNotEquals(Enum1.VAL1, 2)
self.assertNotEquals(Enum1.VAL1, None)
self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
self.assertTrue(Enum1.VAL1 < Enum1.VAL2)
self.assertTrue(Enum1.VAL2 > Enum1.VAL1)
self.assertNotEquals(1, Enum2.VAL1)
def testPickle(self):
"""Testing pickling and unpickling of Enum instances."""
colors = list(Color)
unpickled = pickle.loads(pickle.dumps(colors))
self.assertEquals(colors, unpickled)
# Unpickling shouldn't create new enum instances.
for i, color in enumerate(colors):
self.assertTrue(color is unpickled[i])
class FieldListTest(test_util.TestCase):
def setUp(self):
self.integer_field = messages.IntegerField(1, repeated=True)
def testConstructor(self):
self.assertEquals([1, 2, 3],
messages.FieldList(self.integer_field, [1, 2, 3]))
self.assertEquals([1, 2, 3],
messages.FieldList(self.integer_field, (1, 2, 3)))
self.assertEquals([], messages.FieldList(self.integer_field, []))
def testNone(self):
self.assertRaises(TypeError, messages.FieldList, self.integer_field, None)
def testDoNotAutoConvertString(self):
string_field = messages.StringField(1, repeated=True)
self.assertRaises(messages.ValidationError,
messages.FieldList, string_field, 'abc')
def testConstructorCopies(self):
a_list = [1, 3, 6]
field_list = messages.FieldList(self.integer_field, a_list)
self.assertFalse(a_list is field_list)
self.assertFalse(field_list is
messages.FieldList(self.integer_field, field_list))
def testNonRepeatedField(self):
self.assertRaisesWithRegexpMatch(
messages.FieldDefinitionError,
'FieldList may only accept repeated fields',
messages.FieldList,
messages.IntegerField(1),
[])
def testConstructor_InvalidValues(self):
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
re.escape("Expected type %r "
"for IntegerField, found 1 (type %r)"
% (six.integer_types, str)),
messages.FieldList, self.integer_field, ["1", "2", "3"])
def testConstructor_Scalars(self):
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
"IntegerField is repeated. Found: 3",
messages.FieldList, self.integer_field, 3)
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
"IntegerField is repeated. Found: <(list[_]?|sequence)iterator object",
messages.FieldList, self.integer_field, iter([1, 2, 3]))
def testSetSlice(self):
field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
field_list[1:3] = [10, 20]
self.assertEquals([1, 10, 20, 4, 5], field_list)
def testSetSlice_InvalidValues(self):
field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
def setslice():
field_list[1:3] = ['10', '20']
msg_re = re.escape("Expected type %r "
"for IntegerField, found 10 (type %r)"
% (six.integer_types, str))
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
msg_re,
setslice)
def testSetItem(self):
field_list = messages.FieldList(self.integer_field, [2])
field_list[0] = 10
self.assertEquals([10], field_list)
def testSetItem_InvalidValues(self):
field_list = messages.FieldList(self.integer_field, [2])
def setitem():
field_list[0] = '10'
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
re.escape("Expected type %r "
"for IntegerField, found 10 (type %r)"
% (six.integer_types, str)),
setitem)
def testAppend(self):
field_list = messages.FieldList(self.integer_field, [2])
field_list.append(10)
self.assertEquals([2, 10], field_list)
def testAppend_InvalidValues(self):
field_list = messages.FieldList(self.integer_field, [2])
field_list.name = 'a_field'
def append():
field_list.append('10')
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
re.escape("Expected type %r "
"for IntegerField, found 10 (type %r)"
% (six.integer_types, str)),
append)
def testExtend(self):
field_list = messages.FieldList(self.integer_field, [2])
field_list.extend([10])
self.assertEquals([2, 10], field_list)
def testExtend_InvalidValues(self):
field_list = messages.FieldList(self.integer_field, [2])
def extend():
field_list.extend(['10'])
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
re.escape("Expected type %r "
"for IntegerField, found 10 (type %r)"
% (six.integer_types, str)),
extend)
def testInsert(self):
field_list = messages.FieldList(self.integer_field, [2, 3])
field_list.insert(1, 10)
self.assertEquals([2, 10, 3], field_list)
def testInsert_InvalidValues(self):
field_list = messages.FieldList(self.integer_field, [2, 3])
def insert():
field_list.insert(1, '10')
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
re.escape("Expected type %r "
"for IntegerField, found 10 (type %r)"
% (six.integer_types, str)),
insert)
def testPickle(self):
"""Testing pickling and unpickling of disconnected FieldList instances."""
field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
unpickled = pickle.loads(pickle.dumps(field_list))
self.assertEquals(field_list, unpickled)
self.assertIsInstance(unpickled.field, messages.IntegerField)
self.assertEquals(1, unpickled.field.number)
self.assertTrue(unpickled.field.repeated)
class FieldTest(test_util.TestCase):
def ActionOnAllFieldClasses(self, action):
"""Test all field classes except Message and Enum.
Message and Enum require separate tests.
Args:
action: Callable that takes the field class as a parameter.
"""
for field_class in (messages.IntegerField,
messages.FloatField,
messages.BooleanField,
messages.BytesField,
messages.StringField,
):
action(field_class)
def testNumberAttribute(self):
"""Test setting the number attribute."""
def action(field_class):
# Check range.
self.assertRaises(messages.InvalidNumberError,
field_class,
0)
self.assertRaises(messages.InvalidNumberError,
field_class,
-1)
self.assertRaises(messages.InvalidNumberError,
field_class,
messages.MAX_FIELD_NUMBER + 1)
# Check reserved.
self.assertRaises(messages.InvalidNumberError,
field_class,
messages.FIRST_RESERVED_FIELD_NUMBER)
self.assertRaises(messages.InvalidNumberError,
field_class,
messages.LAST_RESERVED_FIELD_NUMBER)
self.assertRaises(messages.InvalidNumberError,
field_class,
'1')
# This one should work.
field_class(number=1)
self.ActionOnAllFieldClasses(action)
def testRequiredAndRepeated(self):
"""Test setting the required and repeated fields."""
def action(field_class):
field_class(1, required=True)
field_class(1, repeated=True)
self.assertRaises(messages.FieldDefinitionError,
field_class,
1,
required=True,
repeated=True)
self.ActionOnAllFieldClasses(action)
def testInvalidVariant(self):
"""Test field with invalid variants."""
def action(field_class):
if field_class is not message_types.DateTimeField:
self.assertRaises(messages.InvalidVariantError,
field_class,
1,
variant=messages.Variant.ENUM)
self.ActionOnAllFieldClasses(action)
def testDefaultVariant(self):
"""Test that default variant is used when not set."""
def action(field_class):
field = field_class(1)
self.assertEquals(field_class.DEFAULT_VARIANT, field.variant)
self.ActionOnAllFieldClasses(action)
def testAlternateVariant(self):
"""Test that default variant is used when not set."""
field = messages.IntegerField(1, variant=messages.Variant.UINT32)
self.assertEquals(messages.Variant.UINT32, field.variant)
def testDefaultFields_Single(self):
"""Test default field is correct type (single)."""
defaults = {messages.IntegerField: 10,
messages.FloatField: 1.5,
messages.BooleanField: False,
messages.BytesField: b'abc',
messages.StringField: u'abc',
}
def action(field_class):
field_class(1, default=defaults[field_class])
self.ActionOnAllFieldClasses(action)
# Run defaults test again checking for str/unicode compatiblity.
defaults[messages.StringField] = 'abc'
self.ActionOnAllFieldClasses(action)
def testStringField_BadUnicodeInDefault(self):
"""Test binary values in string field."""
self.assertRaisesWithRegexpMatch(
messages.InvalidDefaultError,
r"Invalid default value for StringField:.*: "
r"Field encountered non-ASCII string .*: "
r"'ascii' codec can't decode byte 0x89 in position 0: "
r"ordinal not in range",
messages.StringField, 1, default=b'\x89')
def testDefaultFields_InvalidSingle(self):
"""Test default field is correct type (invalid single)."""
def action(field_class):
self.assertRaises(messages.InvalidDefaultError,
field_class,
1,
default=object())
self.ActionOnAllFieldClasses(action)
def testDefaultFields_InvalidRepeated(self):
"""Test default field does not accept defaults."""
self.assertRaisesWithRegexpMatch(
messages.FieldDefinitionError,
'Repeated fields may not have defaults',
messages.StringField, 1, repeated=True, default=[1, 2, 3])
def testDefaultFields_None(self):
"""Test none is always acceptable."""
def action(field_class):
field_class(1, default=None)
field_class(1, required=True, default=None)
field_class(1, repeated=True, default=None)
self.ActionOnAllFieldClasses(action)
def testDefaultFields_Enum(self):
"""Test the default for enum fields."""
class Symbol(messages.Enum):
ALPHA = 1
BETA = 2
GAMMA = 3
field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA)
self.assertEquals(Symbol.ALPHA, field.default)
def testDefaultFields_EnumStringDelayedResolution(self):
"""Test that enum fields resolve default strings."""
field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
1,
default='OPTIONAL')
self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default)
def testDefaultFields_EnumIntDelayedResolution(self):
"""Test that enum fields resolve default integers."""
field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
1,
default=2)
self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default)
def testDefaultFields_EnumOkIfTypeKnown(self):
"""Test that enum fields accept valid default values when type is known."""
field = messages.EnumField(descriptor.FieldDescriptor.Label,
1,
default='REPEATED')
self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default)
def testDefaultFields_EnumForceCheckIfTypeKnown(self):
"""Test that enum fields validate default values if type is known."""
self.assertRaisesWithRegexpMatch(TypeError,
'No such value for NOT_A_LABEL in '
'Enum Label',
messages.EnumField,
descriptor.FieldDescriptor.Label,
1,
default='NOT_A_LABEL')
def testDefaultFields_EnumInvalidDelayedResolution(self):
"""Test that enum fields raise errors upon delayed resolution error."""
field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
1,
default=200)
self.assertRaisesWithRegexpMatch(TypeError,
'No such value for 200 in Enum Label',
getattr,
field,
'default')
def testValidate_Valid(self):
"""Test validation of valid values."""
values = {messages.IntegerField: 10,
messages.FloatField: 1.5,
messages.BooleanField: False,
messages.BytesField: b'abc',
messages.StringField: u'abc',
}
def action(field_class):
# Optional.
field = field_class(1)
field.validate(values[field_class])
# Required.
field = field_class(1, required=True)
field.validate(values[field_class])
# Repeated.
field = field_class(1, repeated=True)
field.validate([])
field.validate(())
field.validate([values[field_class]])
field.validate((values[field_class],))
# Right value, but not repeated.
self.assertRaises(messages.ValidationError,
field.validate,
values[field_class])
self.assertRaises(messages.ValidationError,
field.validate,
values[field_class])
self.ActionOnAllFieldClasses(action)
def testValidate_Invalid(self):
"""Test validation of valid values."""
values = {messages.IntegerField: "10",
messages.FloatField: 1,
messages.BooleanField: 0,
messages.BytesField: 10.20,
messages.StringField: 42,
}
def action(field_class):
# Optional.
field = field_class(1)
self.assertRaises(messages.ValidationError,
field.validate,
values[field_class])
# Required.
field = field_class(1, required=True)
self.assertRaises(messages.ValidationError,
field.validate,
values[field_class])
# Repeated.
field = field_class(1, repeated=True)
self.assertRaises(messages.ValidationError,
field.validate,
[values[field_class]])
self.assertRaises(messages.ValidationError,
field.validate,
(values[field_class],))
self.ActionOnAllFieldClasses(action)
def testValidate_None(self):
"""Test that None is valid for non-required fields."""
def action(field_class):
# Optional.
field = field_class(1)
field.validate(None)
# Required.
field = field_class(1, required=True)
self.assertRaisesWithRegexpMatch(messages.ValidationError,
'Required field is missing',
field.validate,
None)
# Repeated.
field = field_class(1, repeated=True)
field.validate(None)
self.assertRaisesWithRegexpMatch(messages.ValidationError,
'Repeated values for %s may '
'not be None' % field_class.__name__,
field.validate,
[None])
self.assertRaises(messages.ValidationError,
field.validate,
(None,))
self.ActionOnAllFieldClasses(action)
def testValidateElement(self):
"""Test validation of valid values."""
values = {messages.IntegerField: 10,
messages.FloatField: 1.5,
messages.BooleanField: False,
messages.BytesField: 'abc',
messages.StringField: u'abc',
}
def action(field_class):
# Optional.
field = field_class(1)
field.validate_element(values[field_class])
# Required.
field = field_class(1, required=True)
field.validate_element(values[field_class])
# Repeated.
field = field_class(1, repeated=True)
self.assertRaises(message.VAlidationError,
field.validate_element,
[])
self.assertRaises(message.VAlidationError,
field.validate_element,
())
field.validate_element(values[field_class])
field.validate_element(values[field_class])
# Right value, but repeated.
self.assertRaises(messages.ValidationError,
field.validate_element,
[values[field_class]])
self.assertRaises(messages.ValidationError,
field.validate_element,
(values[field_class],))
def testReadOnly(self):
"""Test that objects are all read-only."""
def action(field_class):
field = field_class(10)
self.assertRaises(AttributeError,
setattr,
field,
'number',
20)
self.assertRaises(AttributeError,
setattr,
field,
'anything_else',
'whatever')
self.ActionOnAllFieldClasses(action)
def testMessageField(self):
"""Test the construction of message fields."""
self.assertRaises(messages.FieldDefinitionError,
messages.MessageField,
str,
10)
self.assertRaises(messages.FieldDefinitionError,
messages.MessageField,
messages.Message,
10)
class MyMessage(messages.Message):
pass
field = messages.MessageField(MyMessage, 10)
self.assertEquals(MyMessage, field.type)
def testMessageField_ForwardReference(self):
"""Test the construction of forward reference message fields."""
global MyMessage
global ForwardMessage
try:
class MyMessage(messages.Message):
self_reference = messages.MessageField('MyMessage', 1)
forward = messages.MessageField('ForwardMessage', 2)
nested = messages.MessageField('ForwardMessage.NestedMessage', 3)
inner = messages.MessageField('Inner', 4)
class Inner(messages.Message):
sibling = messages.MessageField('Sibling', 1)
class Sibling(messages.Message):
pass
class ForwardMessage(messages.Message):
class NestedMessage(messages.Message):
pass
self.assertEquals(MyMessage,
MyMessage.field_by_name('self_reference').type)
self.assertEquals(ForwardMessage,
MyMessage.field_by_name('forward').type)
self.assertEquals(ForwardMessage.NestedMessage,
MyMessage.field_by_name('nested').type)
self.assertEquals(MyMessage.Inner,
MyMessage.field_by_name('inner').type)
self.assertEquals(MyMessage.Sibling,
MyMessage.Inner.field_by_name('sibling').type)
finally:
try:
del MyMessage
del ForwardMessage
except:
pass
def testMessageField_WrongType(self):
"""Test that forward referencing the wrong type raises an error."""
global AnEnum
try:
class AnEnum(messages.Enum):
pass
class AnotherMessage(messages.Message):
a_field = messages.MessageField('AnEnum', 1)
self.assertRaises(messages.FieldDefinitionError,
getattr,
AnotherMessage.field_by_name('a_field'),
'type')
finally:
del AnEnum
def testMessageFieldValidate(self):
"""Test validation on message field."""
class MyMessage(messages.Message):
pass
class AnotherMessage(messages.Message):
pass
field = messages.MessageField(MyMessage, 10)
field.validate(MyMessage())
self.assertRaises(messages.ValidationError,
field.validate,
AnotherMessage())
def testMessageFieldMessageType(self):
"""Test message_type property."""
class MyMessage(messages.Message):
pass
class HasMessage(messages.Message):
field = messages.MessageField(MyMessage, 1)
self.assertEqual(HasMessage.field.type, HasMessage.field.message_type)
def testMessageFieldValueFromMessage(self):
class MyMessage(messages.Message):
pass
class HasMessage(messages.Message):
field = messages.MessageField(MyMessage, 1)
instance = MyMessage()
self.assertTrue(instance is HasMessage.field.value_from_message(instance))
def testMessageFieldValueFromMessageWrongType(self):
class MyMessage(messages.Message):
pass
class HasMessage(messages.Message):
field = messages.MessageField(MyMessage, 1)
self.assertRaisesWithRegexpMatch(
messages.DecodeError,
'Expected type MyMessage, got int: 10',
HasMessage.field.value_from_message, 10)
def testMessageFieldValueToMessage(self):
class MyMessage(messages.Message):
pass
class HasMessage(messages.Message):
field = messages.MessageField(MyMessage, 1)
instance = MyMessage()
self.assertTrue(instance is HasMessage.field.value_to_message(instance))
def testMessageFieldValueToMessageWrongType(self):
class MyMessage(messages.Message):
pass
class MyOtherMessage(messages.Message):
pass
class HasMessage(messages.Message):
field = messages.MessageField(MyMessage, 1)
instance = MyOtherMessage()
self.assertRaisesWithRegexpMatch(
messages.EncodeError,
'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>',
HasMessage.field.value_to_message, instance)
def testIntegerField_AllowLong(self):
"""Test that the integer field allows for longs."""
if six.PY2:
messages.IntegerField(10, default=long(10))
def testMessageFieldValidate_Initialized(self):
"""Test validation on message field."""
class MyMessage(messages.Message):
field1 = messages.IntegerField(1, required=True)
field = messages.MessageField(MyMessage, 10)
# Will validate messages where is_initialized() is False.
message = MyMessage()
field.validate(message)
message.field1 = 20
field.validate(message)
def testEnumField(self):
"""Test the construction of enum fields."""
self.assertRaises(messages.FieldDefinitionError,
messages.EnumField,
str,
10)
self.assertRaises(messages.FieldDefinitionError,
messages.EnumField,
messages.Enum,
10)
class Color(messages.Enum):
RED = 1
GREEN = 2
BLUE = 3
field = messages.EnumField(Color, 10)
self.assertEquals(Color, field.type)
class Another(messages.Enum):
VALUE = 1
self.assertRaises(messages.InvalidDefaultError,
messages.EnumField,
Color,
10,
default=Another.VALUE)
def testEnumField_ForwardReference(self):
"""Test the construction of forward reference enum fields."""
global MyMessage
global ForwardEnum
global ForwardMessage
try:
class MyMessage(messages.Message):
forward = messages.EnumField('ForwardEnum', 1)
nested = messages.EnumField('ForwardMessage.NestedEnum', 2)
inner = messages.EnumField('Inner', 3)
class Inner(messages.Enum):
pass
class ForwardEnum(messages.Enum):
pass
class ForwardMessage(messages.Message):
class NestedEnum(messages.Enum):
pass
self.assertEquals(ForwardEnum,
MyMessage.field_by_name('forward').type)
self.assertEquals(ForwardMessage.NestedEnum,
MyMessage.field_by_name('nested').type)
self.assertEquals(MyMessage.Inner,
MyMessage.field_by_name('inner').type)
finally:
try:
del MyMessage
del ForwardEnum
del ForwardMessage
except:
pass
def testEnumField_WrongType(self):
"""Test that forward referencing the wrong type raises an error."""
global AMessage
try:
class AMessage(messages.Message):
pass
class AnotherMessage(messages.Message):
a_field = messages.EnumField('AMessage', 1)
self.assertRaises(messages.FieldDefinitionError,
getattr,
AnotherMessage.field_by_name('a_field'),
'type')
finally:
del AMessage
def testMessageDefinition(self):
"""Test that message definition is set on fields."""
class MyMessage(messages.Message):
my_field = messages.StringField(1)
self.assertEquals(MyMessage,
MyMessage.field_by_name('my_field').message_definition())
def testNoneAssignment(self):
"""Test that assigning None does not change comparison."""
class MyMessage(messages.Message):
my_field = messages.StringField(1)
m1 = MyMessage()
m2 = MyMessage()
m2.my_field = None
self.assertEquals(m1, m2)
def testNonAsciiStr(self):
"""Test validation fails for non-ascii StringField values."""
class Thing(messages.Message):
string_field = messages.StringField(2)
thing = Thing()
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
'Field string_field encountered non-ASCII string',
setattr, thing, 'string_field', test_util.BINARY)
class MessageTest(test_util.TestCase):
"""Tests for message class."""
def CreateMessageClass(self):
"""Creates a simple message class with 3 fields.
Fields are defined in alphabetical order but with conflicting numeric
order.
"""
class ComplexMessage(messages.Message):
a3 = messages.IntegerField(3)
b1 = messages.StringField(1)
c2 = messages.StringField(2)
return ComplexMessage
def testSameNumbers(self):
"""Test that cannot assign two fields with same numbers."""
def action():
class BadMessage(messages.Message):
f1 = messages.IntegerField(1)
f2 = messages.IntegerField(1)
self.assertRaises(messages.DuplicateNumberError,
action)
def testStrictAssignment(self):
"""Tests that cannot assign to unknown or non-reserved attributes."""
class SimpleMessage(messages.Message):
field = messages.IntegerField(1)
simple_message = SimpleMessage()
self.assertRaises(AttributeError,
setattr,
simple_message,
'does_not_exist',
10)
def testListAssignmentDoesNotCopy(self):
class SimpleMessage(messages.Message):
repeated = messages.IntegerField(1, repeated=True)
message = SimpleMessage()
original = message.repeated
message.repeated = []
self.assertFalse(original is message.repeated)
def testValidate_Optional(self):
"""Tests validation of optional fields."""
class SimpleMessage(messages.Message):
non_required = messages.IntegerField(1)
simple_message = SimpleMessage()
simple_message.check_initialized()
simple_message.non_required = 10
simple_message.check_initialized()
def testValidate_Required(self):
"""Tests validation of required fields."""
class SimpleMessage(messages.Message):
required = messages.IntegerField(1, required=True)
simple_message = SimpleMessage()
self.assertRaises(messages.ValidationError,
simple_message.check_initialized)
simple_message.required = 10
simple_message.check_initialized()
def testValidate_Repeated(self):
"""Tests validation of repeated fields."""
class SimpleMessage(messages.Message):
repeated = messages.IntegerField(1, repeated=True)
simple_message = SimpleMessage()
# Check valid values.
for valid_value in [], [10], [10, 20], (), (10,), (10, 20):
simple_message.repeated = valid_value
simple_message.check_initialized()
# Check cleared.
simple_message.repeated = []
simple_message.check_initialized()
# Check invalid values.
for invalid_value in 10, ['10', '20'], [None], (None,):
self.assertRaises(messages.ValidationError,
setattr, simple_message, 'repeated', invalid_value)
def testIsInitialized(self):
"""Tests is_initialized."""
class SimpleMessage(messages.Message):
required = messages.IntegerField(1, required=True)
simple_message = SimpleMessage()
self.assertFalse(simple_message.is_initialized())
simple_message.required = 10
self.assertTrue(simple_message.is_initialized())
def testIsInitializedNestedField(self):
"""Tests is_initialized for nested fields."""
class SimpleMessage(messages.Message):
required = messages.IntegerField(1, required=True)
class NestedMessage(messages.Message):
simple = messages.MessageField(SimpleMessage, 1)
simple_message = SimpleMessage()
self.assertFalse(simple_message.is_initialized())
nested_message = NestedMessage(simple=simple_message)
self.assertFalse(nested_message.is_initialized())
simple_message.required = 10
self.assertTrue(simple_message.is_initialized())
self.assertTrue(nested_message.is_initialized())
def testInitializeNestedFieldFromDict(self):
"""Tests initializing nested fields from dict."""
class SimpleMessage(messages.Message):
required = messages.IntegerField(1, required=True)
class NestedMessage(messages.Message):
simple = messages.MessageField(SimpleMessage, 1)
class RepeatedMessage(messages.Message):
simple = messages.MessageField(SimpleMessage, 1, repeated=True)
nested_message1 = NestedMessage(simple={'required': 10})
self.assertTrue(nested_message1.is_initialized())
self.assertTrue(nested_message1.simple.is_initialized())
nested_message2 = NestedMessage()
nested_message2.simple = {'required': 10}
self.assertTrue(nested_message2.is_initialized())
self.assertTrue(nested_message2.simple.is_initialized())
repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)]
repeated_message1 = RepeatedMessage(simple=repeated_values)
self.assertEquals(3, len(repeated_message1.simple))
self.assertFalse(repeated_message1.is_initialized())
repeated_message1.simple[0].required = 0
self.assertTrue(repeated_message1.is_initialized())
repeated_message2 = RepeatedMessage()
repeated_message2.simple = repeated_values
self.assertEquals(3, len(repeated_message2.simple))
self.assertFalse(repeated_message2.is_initialized())
repeated_message2.simple[0].required = 0
self.assertTrue(repeated_message2.is_initialized())
def testNestedMethodsNotAllowed(self):
"""Test that method definitions on Message classes are not allowed."""
def action():
class WithMethods(messages.Message):
def not_allowed(self):
pass
self.assertRaises(messages.MessageDefinitionError,
action)
def testNestedAttributesNotAllowed(self):
"""Test that attribute assignment on Message classes are not allowed."""
def int_attribute():
class WithMethods(messages.Message):
not_allowed = 1
def string_attribute():
class WithMethods(messages.Message):
not_allowed = 'not allowed'
def enum_attribute():
class WithMethods(messages.Message):
not_allowed = Color.RED
for action in (int_attribute, string_attribute, enum_attribute):
self.assertRaises(messages.MessageDefinitionError,
action)
def testNameIsSetOnFields(self):
"""Make sure name is set on fields after Message class init."""
class HasNamedFields(messages.Message):
field = messages.StringField(1)
self.assertEquals('field', HasNamedFields.field_by_number(1).name)
def testSubclassingMessageDisallowed(self):
"""Not permitted to create sub-classes of message classes."""
class SuperClass(messages.Message):
pass
def action():
class SubClass(SuperClass):
pass
self.assertRaises(messages.MessageDefinitionError,
action)
def testAllFields(self):
"""Test all_fields method."""
ComplexMessage = self.CreateMessageClass()
fields = list(ComplexMessage.all_fields())
# Order does not matter, so sort now.
fields = sorted(fields, key=lambda f: f.name)
self.assertEquals(3, len(fields))
self.assertEquals('a3', fields[0].name)
self.assertEquals('b1', fields[1].name)
self.assertEquals('c2', fields[2].name)
def testFieldByName(self):
"""Test getting field by name."""
ComplexMessage = self.CreateMessageClass()
self.assertEquals(3, ComplexMessage.field_by_name('a3').number)
self.assertEquals(1, ComplexMessage.field_by_name('b1').number)
self.assertEquals(2, ComplexMessage.field_by_name('c2').number)
self.assertRaises(KeyError,
ComplexMessage.field_by_name,
'unknown')
def testFieldByNumber(self):
"""Test getting field by number."""
ComplexMessage = self.CreateMessageClass()
self.assertEquals('a3', ComplexMessage.field_by_number(3).name)
self.assertEquals('b1', ComplexMessage.field_by_number(1).name)
self.assertEquals('c2', ComplexMessage.field_by_number(2).name)
self.assertRaises(KeyError,
ComplexMessage.field_by_number,
4)
def testGetAssignedValue(self):
"""Test getting the assigned value of a field."""
class SomeMessage(messages.Message):
a_value = messages.StringField(1, default=u'a default')
message = SomeMessage()
self.assertEquals(None, message.get_assigned_value('a_value'))
message.a_value = u'a string'
self.assertEquals(u'a string', message.get_assigned_value('a_value'))
message.a_value = u'a default'
self.assertEquals(u'a default', message.get_assigned_value('a_value'))
self.assertRaisesWithRegexpMatch(
AttributeError,
'Message SomeMessage has no field no_such_field',
message.get_assigned_value,
'no_such_field')
def testReset(self):
"""Test resetting a field value."""
class SomeMessage(messages.Message):
a_value = messages.StringField(1, default=u'a default')
repeated = messages.IntegerField(2, repeated=True)
message = SomeMessage()
self.assertRaises(AttributeError, message.reset, 'unknown')
self.assertEquals(u'a default', message.a_value)
message.reset('a_value')
self.assertEquals(u'a default', message.a_value)
message.a_value = u'a new value'
self.assertEquals(u'a new value', message.a_value)
message.reset('a_value')
self.assertEquals(u'a default', message.a_value)
message.repeated = [1, 2, 3]
self.assertEquals([1, 2, 3], message.repeated)
saved = message.repeated
message.reset('repeated')
self.assertEquals([], message.repeated)
self.assertIsInstance(message.repeated, messages.FieldList)
self.assertEquals([1, 2, 3], saved)
def testAllowNestedEnums(self):
"""Test allowing nested enums in a message definition."""
class Trade(messages.Message):
class Duration(messages.Enum):
GTC = 1
DAY = 2
class Currency(messages.Enum):
USD = 1
GBP = 2
INR = 3
# Sorted by name order seems to be the only feasible option.
self.assertEquals(['Currency', 'Duration'], Trade.__enums__)
# Message definition will now be set on Enumerated objects.
self.assertEquals(Trade, Trade.Duration.message_definition())
def testAllowNestedMessages(self):
"""Test allowing nested messages in a message definition."""
class Trade(messages.Message):
class Lot(messages.Message):
pass
class Agent(messages.Message):
pass
# Sorted by name order seems to be the only feasible option.
self.assertEquals(['Agent', 'Lot'], Trade.__messages__)
self.assertEquals(Trade, Trade.Agent.message_definition())
self.assertEquals(Trade, Trade.Lot.message_definition())
# But not Message itself.
def action():
class Trade(messages.Message):
NiceTry = messages.Message
self.assertRaises(messages.MessageDefinitionError, action)
def testDisallowClassAssignments(self):
"""Test setting class attributes may not happen."""
class MyMessage(messages.Message):
pass
self.assertRaises(AttributeError,
setattr,
MyMessage,
'x',
'do not assign')
def testEquality(self):
"""Test message class equality."""
# Comparison against enums must work.
class MyEnum(messages.Enum):
val1 = 1
val2 = 2
# Comparisons against nested messages must work.
class AnotherMessage(messages.Message):
string = messages.StringField(1)
class MyMessage(messages.Message):
field1 = messages.IntegerField(1)
field2 = messages.EnumField(MyEnum, 2)
field3 = messages.MessageField(AnotherMessage, 3)
message1 = MyMessage()
self.assertNotEquals('hi', message1)
self.assertNotEquals(AnotherMessage(), message1)
self.assertEquals(message1, message1)
message2 = MyMessage()
self.assertEquals(message1, message2)
message1.field1 = 10
self.assertNotEquals(message1, message2)
message2.field1 = 20
self.assertNotEquals(message1, message2)
message2.field1 = 10
self.assertEquals(message1, message2)
message1.field2 = MyEnum.val1
self.assertNotEquals(message1, message2)
message2.field2 = MyEnum.val2
self.assertNotEquals(message1, message2)
message2.field2 = MyEnum.val1
self.assertEquals(message1, message2)
message1.field3 = AnotherMessage()
message1.field3.string = 'value1'
self.assertNotEquals(message1, message2)
message2.field3 = AnotherMessage()
message2.field3.string = 'value2'
self.assertNotEquals(message1, message2)
message2.field3.string = 'value1'
self.assertEquals(message1, message2)
def testEqualityWithUnknowns(self):
"""Test message class equality with unknown fields."""
class MyMessage(messages.Message):
field1 = messages.IntegerField(1)
message1 = MyMessage()
message2 = MyMessage()
self.assertEquals(message1, message2)
message1.set_unrecognized_field('unknown1', 'value1',
messages.Variant.STRING)
self.assertEquals(message1, message2)
message1.set_unrecognized_field('unknown2', ['asdf', 3],
messages.Variant.STRING)
message1.set_unrecognized_field('unknown3', 4.7,
messages.Variant.DOUBLE)
self.assertEquals(message1, message2)
def testUnrecognizedFieldInvalidVariant(self):
class MyMessage(messages.Message):
field1 = messages.IntegerField(1)
message1 = MyMessage()
self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
{'unhandled': 'type'}, None)
self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
{'unhandled': 'type'}, 123)
def testRepr(self):
"""Test represtation of Message object."""
class MyMessage(messages.Message):
integer_value = messages.IntegerField(1)
string_value = messages.StringField(2)
unassigned = messages.StringField(3)
unassigned_with_default = messages.StringField(4, default=u'a default')
my_message = MyMessage()
my_message.integer_value = 42
my_message.string_value = u'A string'
pat = re.compile(r"<MyMessage\n integer_value: 42\n"
" string_value: [u]?'A string'>")
self.assertTrue(pat.match(repr(my_message)) is not None)
def testValidation(self):
"""Test validation of message values."""
# Test optional.
class SubMessage(messages.Message):
pass
class Message(messages.Message):
val = messages.MessageField(SubMessage, 1)
message = Message()
message_field = messages.MessageField(Message, 1)
message_field.validate(message)
message.val = SubMessage()
message_field.validate(message)
self.assertRaises(messages.ValidationError,
setattr, message, 'val', [SubMessage()])
# Test required.
class Message(messages.Message):
val = messages.MessageField(SubMessage, 1, required=True)
message = Message()
message_field = messages.MessageField(Message, 1)
message_field.validate(message)
message.val = SubMessage()
message_field.validate(message)
self.assertRaises(messages.ValidationError,
setattr, message, 'val', [SubMessage()])
# Test repeated.
class Message(messages.Message):
val = messages.MessageField(SubMessage, 1, repeated=True)
message = Message()
message_field = messages.MessageField(Message, 1)
message_field.validate(message)
self.assertRaisesWithRegexpMatch(
messages.ValidationError,
"Field val is repeated. Found: <SubMessage>",
setattr, message, 'val', SubMessage())
message.val = [SubMessage()]
message_field.validate(message)
def testDefinitionName(self):
"""Test message name."""
class MyMessage(messages.Message):
pass
module_name = test_util.get_module_name(FieldTest)
self.assertEquals('%s.MyMessage' % module_name,
MyMessage.definition_name())
self.assertEquals(module_name, MyMessage.outer_definition_name())
self.assertEquals(module_name, MyMessage.definition_package())
self.assertEquals(six.text_type, type(MyMessage.definition_name()))
self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
self.assertEquals(six.text_type, type(MyMessage.definition_package()))
def testDefinitionName_OverrideModule(self):
"""Test message module is overriden by module package name."""
class MyMessage(messages.Message):
pass
global package
package = 'my.package'
try:
self.assertEquals('my.package.MyMessage', MyMessage.definition_name())
self.assertEquals('my.package', MyMessage.outer_definition_name())
self.assertEquals('my.package', MyMessage.definition_package())
self.assertEquals(six.text_type, type(MyMessage.definition_name()))
self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
self.assertEquals(six.text_type, type(MyMessage.definition_package()))
finally:
del package
def testDefinitionName_NoModule(self):
"""Test what happens when there is no module for message."""
class MyMessage(messages.Message):
pass
original_modules = sys.modules
sys.modules = dict(sys.modules)
try:
del sys.modules[__name__]
self.assertEquals('MyMessage', MyMessage.definition_name())
self.assertEquals(None, MyMessage.outer_definition_name())
self.assertEquals(None, MyMessage.definition_package())
self.assertEquals(six.text_type, type(MyMessage.definition_name()))
finally:
sys.modules = original_modules
def testDefinitionName_Nested(self):
"""Test nested message names."""
class MyMessage(messages.Message):
class NestedMessage(messages.Message):
class NestedMessage(messages.Message):
pass
module_name = test_util.get_module_name(MessageTest)
self.assertEquals('%s.MyMessage.NestedMessage' % module_name,
MyMessage.NestedMessage.definition_name())
self.assertEquals('%s.MyMessage' % module_name,
MyMessage.NestedMessage.outer_definition_name())
self.assertEquals(module_name,
MyMessage.NestedMessage.definition_package())
self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name,
MyMessage.NestedMessage.NestedMessage.definition_name())
self.assertEquals(
'%s.MyMessage.NestedMessage' % module_name,
MyMessage.NestedMessage.NestedMessage.outer_definition_name())
self.assertEquals(
module_name,
MyMessage.NestedMessage.NestedMessage.definition_package())
def testMessageDefinition(self):
"""Test that enumeration knows its enclosing message definition."""
class OuterMessage(messages.Message):
class InnerMessage(messages.Message):
pass
self.assertEquals(None, OuterMessage.message_definition())
self.assertEquals(OuterMessage,
OuterMessage.InnerMessage.message_definition())
def testConstructorKwargs(self):
"""Test kwargs via constructor."""
class SomeMessage(messages.Message):
name = messages.StringField(1)
number = messages.IntegerField(2)
expected = SomeMessage()
expected.name = 'my name'
expected.number = 200
self.assertEquals(expected, SomeMessage(name='my name', number=200))
def testConstructorNotAField(self):
"""Test kwargs via constructor with wrong names."""
class SomeMessage(messages.Message):
pass
self.assertRaisesWithRegexpMatch(
AttributeError,
'May not assign arbitrary value does_not_exist to message SomeMessage',
SomeMessage,
does_not_exist=10)
def testGetUnsetRepeatedValue(self):
class SomeMessage(messages.Message):
repeated = messages.IntegerField(1, repeated=True)
instance = SomeMessage()
self.assertEquals([], instance.repeated)
self.assertTrue(isinstance(instance.repeated, messages.FieldList))
def testCompareAutoInitializedRepeatedFields(self):
class SomeMessage(messages.Message):
repeated = messages.IntegerField(1, repeated=True)
message1 = SomeMessage(repeated=[])
message2 = SomeMessage()
self.assertEquals(message1, message2)
def testUnknownValues(self):
"""Test message class equality with unknown fields."""
class MyMessage(messages.Message):
field1 = messages.IntegerField(1)
message = MyMessage()
self.assertEquals([], message.all_unrecognized_fields())
self.assertEquals((None, None),
message.get_unrecognized_field_info('doesntexist'))
self.assertEquals((None, None),
message.get_unrecognized_field_info(
'doesntexist', None, None))
self.assertEquals(('defaultvalue', 'defaultwire'),
message.get_unrecognized_field_info(
'doesntexist', 'defaultvalue', 'defaultwire'))
self.assertEquals((3, None),
message.get_unrecognized_field_info(
'doesntexist', value_default=3))
message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE)
self.assertEquals(1, len(message.all_unrecognized_fields()))
self.assertTrue('exists' in message.all_unrecognized_fields())
self.assertEquals((9.5, messages.Variant.DOUBLE),
message.get_unrecognized_field_info('exists'))
self.assertEquals((9.5, messages.Variant.DOUBLE),
message.get_unrecognized_field_info('exists', 'type',
1234))
self.assertEquals((1234, None),
message.get_unrecognized_field_info('doesntexist', 1234))
message.set_unrecognized_field('another', 'value', messages.Variant.STRING)
self.assertEquals(2, len(message.all_unrecognized_fields()))
self.assertTrue('exists' in message.all_unrecognized_fields())
self.assertTrue('another' in message.all_unrecognized_fields())
self.assertEquals((9.5, messages.Variant.DOUBLE),
message.get_unrecognized_field_info('exists'))
self.assertEquals(('value', messages.Variant.STRING),
message.get_unrecognized_field_info('another'))
message.set_unrecognized_field('typetest1', ['list', 0, ('test',)],
messages.Variant.STRING)
self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
message.get_unrecognized_field_info('typetest1'))
message.set_unrecognized_field('typetest2', '', messages.Variant.STRING)
self.assertEquals(('', messages.Variant.STRING),
message.get_unrecognized_field_info('typetest2'))
def testPickle(self):
"""Testing pickling and unpickling of Message instances."""
global MyEnum
global AnotherMessage
global MyMessage
class MyEnum(messages.Enum):
val1 = 1
val2 = 2
class AnotherMessage(messages.Message):
string = messages.StringField(1, repeated=True)
class MyMessage(messages.Message):
field1 = messages.IntegerField(1)
field2 = messages.EnumField(MyEnum, 2)
field3 = messages.MessageField(AnotherMessage, 3)
message = MyMessage(field1=1, field2=MyEnum.val2,
field3=AnotherMessage(string=['a', 'b', 'c']))
message.set_unrecognized_field('exists', 'value', messages.Variant.STRING)
message.set_unrecognized_field('repeated', ['list', 0, ('test',)],
messages.Variant.STRING)
unpickled = pickle.loads(pickle.dumps(message))
self.assertEquals(message, unpickled)
self.assertTrue(AnotherMessage.string is unpickled.field3.string.field)
self.assertTrue('exists' in message.all_unrecognized_fields())
self.assertEquals(('value', messages.Variant.STRING),
message.get_unrecognized_field_info('exists'))
self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
message.get_unrecognized_field_info('repeated'))
class FindDefinitionTest(test_util.TestCase):
"""Test finding definitions relative to various definitions and modules."""
def setUp(self):
"""Set up module-space. Starts off empty."""
self.modules = {}
def DefineModule(self, name):
"""Define a module and its parents in module space.
Modules that are already defined in self.modules are not re-created.
Args:
name: Fully qualified name of modules to create.
Returns:
Deepest nested module. For example:
DefineModule('a.b.c') # Returns c.
"""
name_path = name.split('.')
full_path = []
for node in name_path:
full_path.append(node)
full_name = '.'.join(full_path)
self.modules.setdefault(full_name, types.ModuleType(full_name))
return self.modules[name]
def DefineMessage(self, module, name, children={}, add_to_module=True):
"""Define a new Message class in the context of a module.
Used for easily describing complex Message hierarchy. Message is defined
including all child definitions.
Args:
module: Fully qualified name of module to place Message class in.
name: Name of Message to define within module.
children: Define any level of nesting of children definitions. To define
a message, map the name to another dictionary. The dictionary can
itself contain additional definitions, and so on. To map to an Enum,
define the Enum class separately and map it by name.
add_to_module: If True, new Message class is added to module. If False,
new Message is not added.
"""
# Make sure module exists.
module_instance = self.DefineModule(module)
# Recursively define all child messages.
for attribute, value in children.items():
if isinstance(value, dict):
children[attribute] = self.DefineMessage(
module, attribute, value, False)
# Override default __module__ variable.
children['__module__'] = module
# Instantiate and possibly add to module.
message_class = type(name, (messages.Message,), dict(children))
if add_to_module:
setattr(module_instance, name, message_class)
return message_class
def Importer(self, module, globals='', locals='', fromlist=None):
"""Importer function.
Acts like __import__. Only loads modules from self.modules. Does not
try to load real modules defined elsewhere. Does not try to handle relative
imports.
Args:
module: Fully qualified name of module to load from self.modules.
"""
if fromlist is None:
module = module.split('.')[0]
try:
return self.modules[module]
except KeyError:
raise ImportError()
def testNoSuchModule(self):
"""Test searching for definitions that do no exist."""
self.assertRaises(messages.DefinitionNotFoundError,
messages.find_definition,
'does.not.exist',
importer=self.Importer)
def testRefersToModule(self):
"""Test that referring to a module does not return that module."""
self.DefineModule('i.am.a.module')
self.assertRaises(messages.DefinitionNotFoundError,
messages.find_definition,
'i.am.a.module',
importer=self.Importer)
def testNoDefinition(self):
"""Test not finding a definition in an existing module."""
self.DefineModule('i.am.a.module')
self.assertRaises(messages.DefinitionNotFoundError,
messages.find_definition,
'i.am.a.module.MyMessage',
importer=self.Importer)
def testNotADefinition(self):
"""Test trying to fetch something that is not a definition."""
module = self.DefineModule('i.am.a.module')
setattr(module, 'A', 'a string')
self.assertRaises(messages.DefinitionNotFoundError,
messages.find_definition,
'i.am.a.module.A',
importer=self.Importer)
def testGlobalFind(self):
"""Test finding definitions from fully qualified module names."""
A = self.DefineMessage('a.b.c', 'A', {})
self.assertEquals(A, messages.find_definition('a.b.c.A',
importer=self.Importer))
B = self.DefineMessage('a.b.c', 'B', {'C':{}})
self.assertEquals(B.C, messages.find_definition('a.b.c.B.C',
importer=self.Importer))
def testRelativeToModule(self):
"""Test finding definitions relative to modules."""
# Define modules.
a = self.DefineModule('a')
b = self.DefineModule('a.b')
c = self.DefineModule('a.b.c')
# Define messages.
A = self.DefineMessage('a', 'A')
B = self.DefineMessage('a.b', 'B')
C = self.DefineMessage('a.b.c', 'C')
D = self.DefineMessage('a.b.d', 'D')
# Find A, B, C and D relative to a.
self.assertEquals(A, messages.find_definition(
'A', a, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'b.B', a, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'b.c.C', a, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'b.d.D', a, importer=self.Importer))
# Find A, B, C and D relative to b.
self.assertEquals(A, messages.find_definition(
'A', b, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'B', b, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'c.C', b, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'd.D', b, importer=self.Importer))
# Find A, B, C and D relative to c. Module d is the same case as c.
self.assertEquals(A, messages.find_definition(
'A', c, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'B', c, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'C', c, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'd.D', c, importer=self.Importer))
def testRelativeToMessages(self):
"""Test finding definitions relative to Message definitions."""
A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}})
B = A.B
C = A.B.C
D = A.B.D
# Find relative to A.
self.assertEquals(A, messages.find_definition(
'A', A, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'B', A, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'B.C', A, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'B.D', A, importer=self.Importer))
# Find relative to B.
self.assertEquals(A, messages.find_definition(
'A', B, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'B', B, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'C', B, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'D', B, importer=self.Importer))
# Find relative to C.
self.assertEquals(A, messages.find_definition(
'A', C, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'B', C, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'C', C, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'D', C, importer=self.Importer))
# Find relative to C searching from c.
self.assertEquals(A, messages.find_definition(
'b.A', C, importer=self.Importer))
self.assertEquals(B, messages.find_definition(
'b.A.B', C, importer=self.Importer))
self.assertEquals(C, messages.find_definition(
'b.A.B.C', C, importer=self.Importer))
self.assertEquals(D, messages.find_definition(
'b.A.B.D', C, importer=self.Importer))
def testAbsoluteReference(self):
"""Test finding absolute definition names."""
# Define modules.
a = self.DefineModule('a')
b = self.DefineModule('a.a')
# Define messages.
aA = self.DefineMessage('a', 'A')
aaA = self.DefineMessage('a.a', 'A')
# Always find a.A.
self.assertEquals(aA, messages.find_definition('.a.A', None,
importer=self.Importer))
self.assertEquals(aA, messages.find_definition('.a.A', a,
importer=self.Importer))
self.assertEquals(aA, messages.find_definition('.a.A', aA,
importer=self.Importer))
self.assertEquals(aA, messages.find_definition('.a.A', aaA,
importer=self.Importer))
def testFindEnum(self):
"""Test that Enums are found."""
class Color(messages.Enum):
pass
A = self.DefineMessage('a', 'A', {'Color': Color})
self.assertEquals(
Color,
messages.find_definition('Color', A, importer=self.Importer))
def testFalseScope(self):
"""Test that Message definitions nested in strange objects are hidden."""
global X
class X(object):
class A(messages.Message):
pass
self.assertRaises(TypeError, messages.find_definition, 'A', X)
self.assertRaises(messages.DefinitionNotFoundError,
messages.find_definition,
'X.A', sys.modules[__name__])
def testSearchAttributeFirst(self):
"""Make sure not faked out by module, but continues searching."""
A = self.DefineMessage('a', 'A')
module_A = self.DefineModule('a.A')
self.assertEquals(A, messages.find_definition(
'a.A', None, importer=self.Importer))
class FindDefinitionUnicodeTests(test_util.TestCase):
# TODO(craigcitro): Fix this test and re-enable it.
def notatestUnicodeString(self):
"""Test using unicode names."""
from protorpc import registry
self.assertEquals('ServiceMapping',
messages.find_definition(
u'protorpc.registry.ServiceMapping',
None).__name__)
def main():
unittest.main()
if __name__ == '__main__':
main()