| #!/usr/bin/env python |
| # |
| # Copyright 2010 Google Inc. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Tests for protorpc.stub.""" |
| |
| __author__ = 'rafek@google.com (Rafe Kaplan)' |
| |
| |
| import StringIO |
| import sys |
| import types |
| import unittest |
| |
| from protorpc import definition |
| from protorpc import descriptor |
| from protorpc import message_types |
| from protorpc import messages |
| from protorpc import protobuf |
| from protorpc import remote |
| from protorpc import test_util |
| |
| import mox |
| |
| |
| class ModuleInterfaceTest(test_util.ModuleInterfaceTest, |
| test_util.TestCase): |
| |
| MODULE = definition |
| |
| |
| class DefineEnumTest(test_util.TestCase): |
| """Test for define_enum.""" |
| |
| def testDefineEnum_Empty(self): |
| """Test defining an empty enum.""" |
| enum_descriptor = descriptor.EnumDescriptor() |
| enum_descriptor.name = 'Empty' |
| |
| enum_class = definition.define_enum(enum_descriptor, 'whatever') |
| |
| self.assertEquals('Empty', enum_class.__name__) |
| self.assertEquals('whatever', enum_class.__module__) |
| |
| self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) |
| |
| def testDefineEnum(self): |
| """Test defining an enum.""" |
| red = descriptor.EnumValueDescriptor() |
| green = descriptor.EnumValueDescriptor() |
| blue = descriptor.EnumValueDescriptor() |
| |
| red.name = 'RED' |
| red.number = 1 |
| green.name = 'GREEN' |
| green.number = 2 |
| blue.name = 'BLUE' |
| blue.number = 3 |
| |
| enum_descriptor = descriptor.EnumDescriptor() |
| enum_descriptor.name = 'Colors' |
| enum_descriptor.values = [red, green, blue] |
| |
| enum_class = definition.define_enum(enum_descriptor, 'whatever') |
| |
| self.assertEquals('Colors', enum_class.__name__) |
| self.assertEquals('whatever', enum_class.__module__) |
| |
| self.assertEquals(enum_descriptor, descriptor.describe_enum(enum_class)) |
| |
| |
| class DefineFieldTest(test_util.TestCase): |
| """Test for define_field.""" |
| |
| def testDefineField_Optional(self): |
| """Test defining an optional field instance from a method descriptor.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT32 |
| field_descriptor.label = descriptor.FieldDescriptor.Label.OPTIONAL |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.IntegerField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.INT32, field.variant) |
| self.assertFalse(field.required) |
| self.assertFalse(field.repeated) |
| |
| def testDefineField_Required(self): |
| """Test defining a required field instance from a method descriptor.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING |
| field_descriptor.label = descriptor.FieldDescriptor.Label.REQUIRED |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.StringField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) |
| self.assertTrue(field.required) |
| self.assertFalse(field.repeated) |
| |
| def testDefineField_Repeated(self): |
| """Test defining a repeated field instance from a method descriptor.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.DOUBLE |
| field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.FloatField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.DOUBLE, field.variant) |
| self.assertFalse(field.required) |
| self.assertTrue(field.repeated) |
| |
| def testDefineField_Message(self): |
| """Test defining a message field.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE |
| field_descriptor.type_name = 'something.yet.to.be.Defined' |
| field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.MessageField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) |
| self.assertFalse(field.required) |
| self.assertTrue(field.repeated) |
| self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, |
| 'Could not find definition for ' |
| 'something.yet.to.be.Defined', |
| getattr, field, 'type') |
| |
| def testDefineField_DateTime(self): |
| """Test defining a date time field.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_timestamp' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.MESSAGE |
| field_descriptor.type_name = 'protorpc.message_types.DateTimeMessage' |
| field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, message_types.DateTimeField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.MESSAGE, field.variant) |
| self.assertFalse(field.required) |
| self.assertTrue(field.repeated) |
| |
| def testDefineField_Enum(self): |
| """Test defining an enum field.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.ENUM |
| field_descriptor.type_name = 'something.yet.to.be.Defined' |
| field_descriptor.label = descriptor.FieldDescriptor.Label.REPEATED |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.EnumField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.ENUM, field.variant) |
| self.assertFalse(field.required) |
| self.assertTrue(field.repeated) |
| self.assertRaisesWithRegexpMatch(messages.DefinitionNotFoundError, |
| 'Could not find definition for ' |
| 'something.yet.to.be.Defined', |
| getattr, field, 'type') |
| |
| def testDefineField_Default_Bool(self): |
| """Test defining a default value for a bool.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.BOOL |
| field_descriptor.default_value = u'true' |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.BooleanField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.BOOL, field.variant) |
| self.assertFalse(field.required) |
| self.assertFalse(field.repeated) |
| self.assertEqual(field.default, True) |
| |
| field_descriptor.default_value = u'false' |
| |
| field = definition.define_field(field_descriptor) |
| |
| self.assertEqual(field.default, False) |
| |
| def testDefineField_Default_Float(self): |
| """Test defining a default value for a float.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.FLOAT |
| field_descriptor.default_value = u'34.567' |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.FloatField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.FLOAT, field.variant) |
| self.assertFalse(field.required) |
| self.assertFalse(field.repeated) |
| self.assertEqual(field.default, 34.567) |
| |
| def testDefineField_Default_Int(self): |
| """Test defining a default value for an int.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 |
| field_descriptor.default_value = u'34' |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.IntegerField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.INT64, field.variant) |
| self.assertFalse(field.required) |
| self.assertFalse(field.repeated) |
| self.assertEqual(field.default, 34) |
| |
| def testDefineField_Default_Str(self): |
| """Test defining a default value for a str.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.STRING |
| field_descriptor.default_value = u'Test' |
| |
| field = definition.define_field(field_descriptor) |
| |
| # Name will not be set from the original descriptor. |
| self.assertFalse(hasattr(field, 'name')) |
| |
| self.assertTrue(isinstance(field, messages.StringField)) |
| self.assertEquals(1, field.number) |
| self.assertEquals(descriptor.FieldDescriptor.Variant.STRING, field.variant) |
| self.assertFalse(field.required) |
| self.assertFalse(field.repeated) |
| self.assertEqual(field.default, u'Test') |
| |
| def testDefineField_Default_Invalid(self): |
| """Test defining a default value that is not valid.""" |
| field_descriptor = descriptor.FieldDescriptor() |
| |
| field_descriptor.name = 'a_field' |
| field_descriptor.number = 1 |
| field_descriptor.variant = descriptor.FieldDescriptor.Variant.INT64 |
| field_descriptor.default_value = u'Test' |
| |
| # Verify that the string is passed to the Constructor. |
| mock = mox.Mox() |
| mock.StubOutWithMock(messages.IntegerField, '__init__') |
| messages.IntegerField.__init__( |
| default=u'Test', |
| number=1, |
| variant=messages.Variant.INT64 |
| ).AndRaise(messages.InvalidDefaultError) |
| |
| mock.ReplayAll() |
| self.assertRaises(messages.InvalidDefaultError, |
| definition.define_field, field_descriptor) |
| mock.VerifyAll() |
| |
| mock.ResetAll() |
| mock.UnsetStubs() |
| |
| |
| class DefineMessageTest(test_util.TestCase): |
| """Test for define_message.""" |
| |
| def testDefineMessageEmpty(self): |
| """Test definition a message with no fields or enums.""" |
| |
| class AMessage(messages.Message): |
| pass |
| |
| message_descriptor = descriptor.describe_message(AMessage) |
| |
| message_class = definition.define_message(message_descriptor, '__main__') |
| |
| self.assertEquals('AMessage', message_class.__name__) |
| self.assertEquals('__main__', message_class.__module__) |
| |
| self.assertEquals(message_descriptor, |
| descriptor.describe_message(message_class)) |
| |
| def testDefineMessageEnumOnly(self): |
| """Test definition a message with only enums.""" |
| |
| class AMessage(messages.Message): |
| class NestedEnum(messages.Enum): |
| pass |
| |
| message_descriptor = descriptor.describe_message(AMessage) |
| |
| message_class = definition.define_message(message_descriptor, '__main__') |
| |
| self.assertEquals('AMessage', message_class.__name__) |
| self.assertEquals('__main__', message_class.__module__) |
| |
| self.assertEquals(message_descriptor, |
| descriptor.describe_message(message_class)) |
| |
| def testDefineMessageFieldsOnly(self): |
| """Test definition a message with only fields.""" |
| |
| class AMessage(messages.Message): |
| |
| field1 = messages.IntegerField(1) |
| field2 = messages.StringField(2) |
| |
| message_descriptor = descriptor.describe_message(AMessage) |
| |
| message_class = definition.define_message(message_descriptor, '__main__') |
| |
| self.assertEquals('AMessage', message_class.__name__) |
| self.assertEquals('__main__', message_class.__module__) |
| |
| self.assertEquals(message_descriptor, |
| descriptor.describe_message(message_class)) |
| |
| def testDefineMessage(self): |
| """Test defining Message class from descriptor.""" |
| |
| class AMessage(messages.Message): |
| class NestedEnum(messages.Enum): |
| pass |
| |
| field1 = messages.IntegerField(1) |
| field2 = messages.StringField(2) |
| |
| message_descriptor = descriptor.describe_message(AMessage) |
| |
| message_class = definition.define_message(message_descriptor, '__main__') |
| |
| self.assertEquals('AMessage', message_class.__name__) |
| self.assertEquals('__main__', message_class.__module__) |
| |
| self.assertEquals(message_descriptor, |
| descriptor.describe_message(message_class)) |
| |
| |
| class DefineServiceTest(test_util.TestCase): |
| """Test service proxy definition.""" |
| |
| def setUp(self): |
| """Set up mock and request classes.""" |
| self.module = types.ModuleType('stocks') |
| |
| class GetQuoteRequest(messages.Message): |
| __module__ = 'stocks' |
| |
| symbols = messages.StringField(1, repeated=True) |
| |
| class GetQuoteResponse(messages.Message): |
| __module__ = 'stocks' |
| |
| prices = messages.IntegerField(1, repeated=True) |
| |
| self.module.GetQuoteRequest = GetQuoteRequest |
| self.module.GetQuoteResponse = GetQuoteResponse |
| |
| def testDefineService(self): |
| """Test service definition from descriptor.""" |
| method_descriptor = descriptor.MethodDescriptor() |
| method_descriptor.name = 'get_quote' |
| method_descriptor.request_type = 'GetQuoteRequest' |
| method_descriptor.response_type = 'GetQuoteResponse' |
| |
| service_descriptor = descriptor.ServiceDescriptor() |
| service_descriptor.name = 'Stocks' |
| service_descriptor.methods = [method_descriptor] |
| |
| StockService = definition.define_service(service_descriptor, self.module) |
| |
| self.assertTrue(issubclass(StockService, remote.Service)) |
| self.assertTrue(issubclass(StockService.Stub, remote.StubBase)) |
| |
| request = self.module.GetQuoteRequest() |
| service = StockService() |
| self.assertRaises(NotImplementedError, |
| service.get_quote, request) |
| |
| self.assertEquals(self.module.GetQuoteRequest, |
| service.get_quote.remote.request_type) |
| self.assertEquals(self.module.GetQuoteResponse, |
| service.get_quote.remote.response_type) |
| |
| |
| class ModuleTest(test_util.TestCase): |
| """Test for module creation and importation functions.""" |
| |
| def MakeFileDescriptor(self, package): |
| """Helper method to construct FileDescriptors. |
| |
| Creates FileDescriptor with a MessageDescriptor and an EnumDescriptor. |
| |
| Args: |
| package: Package name to give new file descriptors. |
| |
| Returns: |
| New FileDescriptor instance. |
| """ |
| enum_descriptor = descriptor.EnumDescriptor() |
| enum_descriptor.name = u'MyEnum' |
| |
| message_descriptor = descriptor.MessageDescriptor() |
| message_descriptor.name = u'MyMessage' |
| |
| service_descriptor = descriptor.ServiceDescriptor() |
| service_descriptor.name = u'MyService' |
| |
| file_descriptor = descriptor.FileDescriptor() |
| file_descriptor.package = package |
| file_descriptor.enum_types = [enum_descriptor] |
| file_descriptor.message_types = [message_descriptor] |
| file_descriptor.service_types = [service_descriptor] |
| |
| return file_descriptor |
| |
| def testDefineModule(self): |
| """Test define_module function.""" |
| file_descriptor = self.MakeFileDescriptor('my.package') |
| |
| module = definition.define_file(file_descriptor) |
| |
| self.assertEquals('my.package', module.__name__) |
| self.assertEquals('my.package', module.MyEnum.__module__) |
| self.assertEquals('my.package', module.MyMessage.__module__) |
| self.assertEquals('my.package', module.MyService.__module__) |
| |
| self.assertEquals(file_descriptor, descriptor.describe_file(module)) |
| |
| def testDefineModule_ReuseModule(self): |
| """Test updating module with additional definitions.""" |
| file_descriptor = self.MakeFileDescriptor('my.package') |
| |
| module = types.ModuleType('override') |
| self.assertEquals(module, definition.define_file(file_descriptor, module)) |
| |
| self.assertEquals('override', module.MyEnum.__module__) |
| self.assertEquals('override', module.MyMessage.__module__) |
| self.assertEquals('override', module.MyService.__module__) |
| |
| # One thing is different between original descriptor and new. |
| file_descriptor.package = 'override' |
| self.assertEquals(file_descriptor, descriptor.describe_file(module)) |
| |
| def testImportFile(self): |
| """Test importing FileDescriptor in to module space.""" |
| modules = {} |
| file_descriptor = self.MakeFileDescriptor('standalone') |
| definition.import_file(file_descriptor, modules=modules) |
| self.assertEquals(file_descriptor, |
| descriptor.describe_file(modules['standalone'])) |
| |
| def testImportFile_InToExisting(self): |
| """Test importing FileDescriptor in to existing module.""" |
| module = types.ModuleType('standalone') |
| modules = {'standalone': module} |
| file_descriptor = self.MakeFileDescriptor('standalone') |
| definition.import_file(file_descriptor, modules=modules) |
| self.assertEquals(module, modules['standalone']) |
| self.assertEquals(file_descriptor, |
| descriptor.describe_file(modules['standalone'])) |
| |
| def testImportFile_InToGlobalModules(self): |
| """Test importing FileDescriptor in to global modules.""" |
| original_modules = sys.modules |
| try: |
| sys.modules = dict(sys.modules) |
| if 'standalone' in sys.modules: |
| del sys.modules['standalone'] |
| file_descriptor = self.MakeFileDescriptor('standalone') |
| definition.import_file(file_descriptor) |
| self.assertEquals(file_descriptor, |
| descriptor.describe_file(sys.modules['standalone'])) |
| finally: |
| sys.modules = original_modules |
| |
| def testImportFile_Nested(self): |
| """Test importing FileDescriptor in to existing nested module.""" |
| modules = {} |
| file_descriptor = self.MakeFileDescriptor('root.nested') |
| definition.import_file(file_descriptor, modules=modules) |
| self.assertEquals(modules['root'].nested, modules['root.nested']) |
| self.assertEquals(file_descriptor, |
| descriptor.describe_file(modules['root.nested'])) |
| |
| def testImportFile_NoPackage(self): |
| """Test importing FileDescriptor with no package.""" |
| file_descriptor = self.MakeFileDescriptor('does not matter') |
| file_descriptor.reset('package') |
| self.assertRaisesWithRegexpMatch(ValueError, |
| 'File descriptor must have package name', |
| definition.import_file, |
| file_descriptor) |
| |
| def testImportFileSet(self): |
| """Test importing a whole file set.""" |
| file_set = descriptor.FileSet() |
| file_set.files = [self.MakeFileDescriptor(u'standalone'), |
| self.MakeFileDescriptor(u'root.nested'), |
| self.MakeFileDescriptor(u'root.nested.nested'), |
| ] |
| |
| root = types.ModuleType('root') |
| nested = types.ModuleType('root.nested') |
| root.nested = nested |
| modules = { |
| 'root': root, |
| 'root.nested': nested, |
| } |
| |
| definition.import_file_set(file_set, modules=modules) |
| |
| self.assertEquals(root, modules['root']) |
| self.assertEquals(nested, modules['root.nested']) |
| self.assertEquals(nested.nested, modules['root.nested.nested']) |
| |
| self.assertEquals(file_set, |
| descriptor.describe_file_set( |
| [modules['standalone'], |
| modules['root.nested'], |
| modules['root.nested.nested'], |
| ])) |
| |
| def testImportFileSetFromFile(self): |
| """Test importing a whole file set from a file.""" |
| file_set = descriptor.FileSet() |
| file_set.files = [self.MakeFileDescriptor(u'standalone'), |
| self.MakeFileDescriptor(u'root.nested'), |
| self.MakeFileDescriptor(u'root.nested.nested'), |
| ] |
| |
| stream = StringIO.StringIO(protobuf.encode_message(file_set)) |
| |
| self.mox = mox.Mox() |
| opener = self.mox.CreateMockAnything() |
| opener('my-file.dat', 'rb').AndReturn(stream) |
| |
| self.mox.ReplayAll() |
| |
| modules = {} |
| definition.import_file_set('my-file.dat', modules=modules, _open=opener) |
| |
| self.assertEquals(file_set, |
| descriptor.describe_file_set( |
| [modules['standalone'], |
| modules['root.nested'], |
| modules['root.nested.nested'], |
| ])) |
| |
| def testImportBuiltInProtorpcClasses(self): |
| """Test that built in Protorpc classes are skipped.""" |
| file_set = descriptor.FileSet() |
| file_set.files = [self.MakeFileDescriptor(u'standalone'), |
| self.MakeFileDescriptor(u'root.nested'), |
| self.MakeFileDescriptor(u'root.nested.nested'), |
| descriptor.describe_file(descriptor), |
| ] |
| |
| root = types.ModuleType('root') |
| nested = types.ModuleType('root.nested') |
| root.nested = nested |
| modules = { |
| 'root': root, |
| 'root.nested': nested, |
| 'protorpc.descriptor': descriptor, |
| } |
| |
| definition.import_file_set(file_set, modules=modules) |
| |
| self.assertEquals(root, modules['root']) |
| self.assertEquals(nested, modules['root.nested']) |
| self.assertEquals(nested.nested, modules['root.nested.nested']) |
| self.assertEquals(descriptor, modules['protorpc.descriptor']) |
| |
| self.assertEquals(file_set, |
| descriptor.describe_file_set( |
| [modules['standalone'], |
| modules['root.nested'], |
| modules['root.nested.nested'], |
| modules['protorpc.descriptor'], |
| ])) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |