blob: 354763326913dd739af2536d69e617c6ab71daff [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for protorpc.remote."""
__author__ = 'rafek@google.com (Rafe Kaplan)'
import sys
import types
import unittest
from wsgiref import headers
from protorpc import descriptor
from protorpc import message_types
from protorpc import messages
from protorpc import protobuf
from protorpc import protojson
from protorpc import remote
from protorpc import test_util
from protorpc import transport
import mox
class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
test_util.TestCase):
MODULE = remote
class Request(messages.Message):
"""Test request message."""
value = messages.StringField(1)
class Response(messages.Message):
"""Test response message."""
value = messages.StringField(1)
class MyService(remote.Service):
@remote.method(Request, Response)
def remote_method(self, request):
response = Response()
response.value = request.value
return response
class SimpleRequest(messages.Message):
"""Simple request message type used for tests."""
param1 = messages.StringField(1)
param2 = messages.StringField(2)
class SimpleResponse(messages.Message):
"""Simple response message type used for tests."""
class BasicService(remote.Service):
"""A basic service with decorated remote method."""
def __init__(self):
self.request_ids = []
@remote.method(SimpleRequest, SimpleResponse)
def remote_method(self, request):
self.request_ids.append(id(request))
return SimpleResponse()
class RpcErrorTest(test_util.TestCase):
def testFromStatus(self):
for state in remote.RpcState:
exception = remote.RpcError.from_state
self.assertEquals(remote.ServerError,
remote.RpcError.from_state('SERVER_ERROR'))
class ApplicationErrorTest(test_util.TestCase):
def testErrorCode(self):
self.assertEquals('blam',
remote.ApplicationError('an error', 'blam').error_name)
def testStr(self):
self.assertEquals('an error', str(remote.ApplicationError('an error', 1)))
def testRepr(self):
self.assertEquals("ApplicationError('an error', 1)",
repr(remote.ApplicationError('an error', 1)))
self.assertEquals("ApplicationError('an error')",
repr(remote.ApplicationError('an error')))
class MethodTest(test_util.TestCase):
"""Test remote method decorator."""
def testMethod(self):
"""Test use of remote decorator."""
self.assertEquals(SimpleRequest,
BasicService.remote_method.remote.request_type)
self.assertEquals(SimpleResponse,
BasicService.remote_method.remote.response_type)
self.assertTrue(isinstance(BasicService.remote_method.remote.method,
types.FunctionType))
def testMethodMessageResolution(self):
"""Test use of remote decorator to resolve message types by name."""
class OtherService(remote.Service):
@remote.method('SimpleRequest', 'SimpleResponse')
def remote_method(self, request):
pass
self.assertEquals(SimpleRequest,
OtherService.remote_method.remote.request_type)
self.assertEquals(SimpleResponse,
OtherService.remote_method.remote.response_type)
def testMethodMessageResolution_NotFound(self):
"""Test failure to find message types."""
class OtherService(remote.Service):
@remote.method('NoSuchRequest', 'NoSuchResponse')
def remote_method(self, request):
pass
self.assertRaisesWithRegexpMatch(
messages.DefinitionNotFoundError,
'Could not find definition for NoSuchRequest',
getattr,
OtherService.remote_method.remote,
'request_type')
self.assertRaisesWithRegexpMatch(
messages.DefinitionNotFoundError,
'Could not find definition for NoSuchResponse',
getattr,
OtherService.remote_method.remote,
'response_type')
def testInvocation(self):
"""Test that invocation passes request through properly."""
service = BasicService()
request = SimpleRequest()
self.assertEquals(SimpleResponse(), service.remote_method(request))
self.assertEquals([id(request)], service.request_ids)
def testInvocation_WrongRequestType(self):
"""Wrong request type passed to remote method."""
service = BasicService()
self.assertRaises(remote.RequestError,
service.remote_method,
'wrong')
self.assertRaises(remote.RequestError,
service.remote_method,
None)
self.assertRaises(remote.RequestError,
service.remote_method,
SimpleResponse())
def testInvocation_WrongResponseType(self):
"""Wrong response type returned from remote method."""
class AnotherService(object):
@remote.method(SimpleRequest, SimpleResponse)
def remote_method(self, unused_request):
return self.return_this
service = AnotherService()
service.return_this = 'wrong'
self.assertRaises(remote.ServerError,
service.remote_method,
SimpleRequest())
service.return_this = None
self.assertRaises(remote.ServerError,
service.remote_method,
SimpleRequest())
service.return_this = SimpleRequest()
self.assertRaises(remote.ServerError,
service.remote_method,
SimpleRequest())
def testBadRequestType(self):
"""Test bad request types used in remote definition."""
for request_type in (None, 1020, messages.Message, str):
def declare():
class BadService(object):
@remote.method(request_type, SimpleResponse)
def remote_method(self, request):
pass
self.assertRaises(TypeError, declare)
def testBadResponseType(self):
"""Test bad response types used in remote definition."""
for response_type in (None, 1020, messages.Message, str):
def declare():
class BadService(object):
@remote.method(SimpleRequest, response_type)
def remote_method(self, request):
pass
self.assertRaises(TypeError, declare)
class GetRemoteMethodTest(test_util.TestCase):
"""Test for is_remote_method."""
def testGetRemoteMethod(self):
"""Test valid remote method detection."""
class Service(object):
@remote.method(Request, Response)
def remote_method(self, request):
pass
self.assertEquals(Service.remote_method.remote,
remote.get_remote_method_info(Service.remote_method))
self.assertTrue(Service.remote_method.remote,
remote.get_remote_method_info(Service().remote_method))
def testGetNotRemoteMethod(self):
"""Test positive result on a remote method."""
class NotService(object):
def not_remote_method(self, request):
pass
def fn(self):
pass
class NotReallyRemote(object):
"""Test negative result on many bad values for remote methods."""
def not_really(self, request):
pass
not_really.remote = 'something else'
for not_remote in [NotService.not_remote_method,
NotService().not_remote_method,
NotReallyRemote.not_really,
NotReallyRemote().not_really,
None,
1,
'a string',
fn]:
self.assertEquals(None, remote.get_remote_method_info(not_remote))
class RequestStateTest(test_util.TestCase):
"""Test request state."""
STATE_CLASS = remote.RequestState
def testConstructor(self):
"""Test constructor."""
state = self.STATE_CLASS(remote_host='remote-host',
remote_address='remote-address',
server_host='server-host',
server_port=10)
self.assertEquals('remote-host', state.remote_host)
self.assertEquals('remote-address', state.remote_address)
self.assertEquals('server-host', state.server_host)
self.assertEquals(10, state.server_port)
state = self.STATE_CLASS()
self.assertEquals(None, state.remote_host)
self.assertEquals(None, state.remote_address)
self.assertEquals(None, state.server_host)
self.assertEquals(None, state.server_port)
def testConstructorError(self):
"""Test unexpected keyword argument."""
self.assertRaises(TypeError,
self.STATE_CLASS,
x=10)
def testRepr(self):
"""Test string representation."""
self.assertEquals('<%s>' % self.STATE_CLASS.__name__,
repr(self.STATE_CLASS()))
self.assertEquals("<%s remote_host='abc'>" % self.STATE_CLASS.__name__,
repr(self.STATE_CLASS(remote_host='abc')))
self.assertEquals("<%s remote_host='abc' "
"remote_address='def'>" % self.STATE_CLASS.__name__,
repr(self.STATE_CLASS(remote_host='abc',
remote_address='def')))
self.assertEquals("<%s remote_host='abc' "
"remote_address='def' "
"server_host='ghi'>" % self.STATE_CLASS.__name__,
repr(self.STATE_CLASS(remote_host='abc',
remote_address='def',
server_host='ghi')))
self.assertEquals("<%s remote_host='abc' "
"remote_address='def' "
"server_host='ghi' "
'server_port=102>' % self.STATE_CLASS.__name__,
repr(self.STATE_CLASS(remote_host='abc',
remote_address='def',
server_host='ghi',
server_port=102)))
class HttpRequestStateTest(RequestStateTest):
STATE_CLASS = remote.HttpRequestState
def testHttpMethod(self):
state = remote.HttpRequestState(http_method='GET')
self.assertEquals('GET', state.http_method)
def testHttpMethod(self):
state = remote.HttpRequestState(service_path='/bar')
self.assertEquals('/bar', state.service_path)
def testHeadersList(self):
state = remote.HttpRequestState(
headers=[('a', 'b'), ('c', 'd'), ('c', 'e')])
self.assertEquals(['a', 'c', 'c'], list(state.headers.keys()))
self.assertEquals(['b'], state.headers.get_all('a'))
self.assertEquals(['d', 'e'], state.headers.get_all('c'))
def testHeadersDict(self):
state = remote.HttpRequestState(headers={'a': 'b', 'c': ['d', 'e']})
self.assertEquals(['a', 'c', 'c'], sorted(state.headers.keys()))
self.assertEquals(['b'], state.headers.get_all('a'))
self.assertEquals(['d', 'e'], state.headers.get_all('c'))
def testRepr(self):
super(HttpRequestStateTest, self).testRepr()
self.assertEquals("<%s remote_host='abc' "
"remote_address='def' "
"server_host='ghi' "
'server_port=102 '
"http_method='POST' "
"service_path='/bar' "
"headers=[('a', 'b'), ('c', 'd')]>" %
self.STATE_CLASS.__name__,
repr(self.STATE_CLASS(remote_host='abc',
remote_address='def',
server_host='ghi',
server_port=102,
http_method='POST',
service_path='/bar',
headers={'a': 'b', 'c': 'd'},
)))
class ServiceTest(test_util.TestCase):
"""Test Service class."""
def testServiceBase_AllRemoteMethods(self):
"""Test that service base class has no remote methods."""
self.assertEquals({}, remote.Service.all_remote_methods())
def testAllRemoteMethods(self):
"""Test all_remote_methods with properly Service subclass."""
self.assertEquals({'remote_method': MyService.remote_method},
MyService.all_remote_methods())
def testAllRemoteMethods_SubClass(self):
"""Test all_remote_methods on a sub-class of a service."""
class SubClass(MyService):
@remote.method(Request, Response)
def sub_class_method(self, request):
pass
self.assertEquals({'remote_method': SubClass.remote_method,
'sub_class_method': SubClass.sub_class_method,
},
SubClass.all_remote_methods())
def testOverrideMethod(self):
"""Test that trying to override a remote method with remote decorator."""
class SubClass(MyService):
def remote_method(self, request):
response = super(SubClass, self).remote_method(request)
response.value = '(%s)' % response.value
return response
self.assertEquals({'remote_method': SubClass.remote_method,
},
SubClass.all_remote_methods())
instance = SubClass()
self.assertEquals('(Hello)',
instance.remote_method(Request(value='Hello')).value)
self.assertEquals(Request, SubClass.remote_method.remote.request_type)
self.assertEquals(Response, SubClass.remote_method.remote.response_type)
def testOverrideMethodWithRemote(self):
"""Test trying to override a remote method with remote decorator."""
def do_override():
class SubClass(MyService):
@remote.method(Request, Response)
def remote_method(self, request):
pass
self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
'Do not use method decorator when '
'overloading remote method remote_method '
'on service SubClass',
do_override)
def testOverrideMethodWithInvalidValue(self):
"""Test trying to override a remote method with remote decorator."""
def do_override(bad_value):
class SubClass(MyService):
remote_method = bad_value
for bad_value in [None, 1, 'string', {}]:
self.assertRaisesWithRegexpMatch(remote.ServiceDefinitionError,
'Must override remote_method in '
'SubClass with a method',
do_override, bad_value)
def testCallingRemoteMethod(self):
"""Test invoking a remote method."""
expected = Response()
expected.value = 'what was passed in'
request = Request()
request.value = 'what was passed in'
service = MyService()
self.assertEquals(expected, service.remote_method(request))
def testFactory(self):
"""Test using factory to pass in state."""
class StatefulService(remote.Service):
def __init__(self, a, b, c=None):
self.a = a
self.b = b
self.c = c
state = [1, 2, 3]
factory = StatefulService.new_factory(1, state)
module_name = ServiceTest.__module__
pattern = ('Creates new instances of service StatefulService.\n\n'
'Returns:\n'
' New instance of %s.StatefulService.' % module_name)
self.assertEqual(pattern, factory.__doc__)
self.assertEquals('StatefulService_service_factory', factory.__name__)
self.assertEquals(StatefulService, factory.service_class)
service = factory()
self.assertEquals(1, service.a)
self.assertEquals(id(state), id(service.b))
self.assertEquals(None, service.c)
factory = StatefulService.new_factory(2, b=3, c=4)
service = factory()
self.assertEquals(2, service.a)
self.assertEquals(3, service.b)
self.assertEquals(4, service.c)
def testFactoryError(self):
"""Test misusing a factory."""
# Passing positional argument that is not accepted by class.
self.assertRaises(TypeError, remote.Service.new_factory(1))
# Passing keyword argument that is not accepted by class.
self.assertRaises(TypeError, remote.Service.new_factory(x=1))
class StatefulService(remote.Service):
def __init__(self, a):
pass
# Missing required parameter.
self.assertRaises(TypeError, StatefulService.new_factory())
def testDefinitionName(self):
"""Test getting service definition name."""
class TheService(remote.Service):
pass
module_name = test_util.get_module_name(ServiceTest)
self.assertEqual(TheService.definition_name(),
'%s.TheService' % module_name)
self.assertTrue(TheService.outer_definition_name(),
module_name)
self.assertTrue(TheService.definition_package(),
module_name)
def testDefinitionNameWithPackage(self):
"""Test getting service definition name when package defined."""
global package
package = 'my.package'
try:
class TheService(remote.Service):
pass
self.assertEquals('my.package.TheService', TheService.definition_name())
self.assertEquals('my.package', TheService.outer_definition_name())
self.assertEquals('my.package', TheService.definition_package())
finally:
del package
def testDefinitionNameWithNoModule(self):
"""Test getting service definition name when package defined."""
module = sys.modules[__name__]
try:
del sys.modules[__name__]
class TheService(remote.Service):
pass
self.assertEquals('TheService', TheService.definition_name())
self.assertEquals(None, TheService.outer_definition_name())
self.assertEquals(None, TheService.definition_package())
finally:
sys.modules[__name__] = module
class StubTest(test_util.TestCase):
def setUp(self):
self.mox = mox.Mox()
self.transport = self.mox.CreateMockAnything()
def testDefinitionName(self):
self.assertEquals(BasicService.definition_name(),
BasicService.Stub.definition_name())
self.assertEquals(BasicService.outer_definition_name(),
BasicService.Stub.outer_definition_name())
self.assertEquals(BasicService.definition_package(),
BasicService.Stub.definition_package())
def testRemoteMethods(self):
self.assertEquals(BasicService.all_remote_methods(),
BasicService.Stub.all_remote_methods())
def testSync_WithRequest(self):
stub = BasicService.Stub(self.transport)
request = SimpleRequest()
request.param1 = 'val1'
request.param2 = 'val2'
response = SimpleResponse()
rpc = transport.Rpc(request)
rpc.set_response(response)
self.transport.send_rpc(BasicService.remote_method.remote,
request).AndReturn(rpc)
self.mox.ReplayAll()
self.assertEquals(SimpleResponse(), stub.remote_method(request))
self.mox.VerifyAll()
def testSync_WithKwargs(self):
stub = BasicService.Stub(self.transport)
request = SimpleRequest()
request.param1 = 'val1'
request.param2 = 'val2'
response = SimpleResponse()
rpc = transport.Rpc(request)
rpc.set_response(response)
self.transport.send_rpc(BasicService.remote_method.remote,
request).AndReturn(rpc)
self.mox.ReplayAll()
self.assertEquals(SimpleResponse(), stub.remote_method(param1='val1',
param2='val2'))
self.mox.VerifyAll()
def testAsync_WithRequest(self):
stub = BasicService.Stub(self.transport)
request = SimpleRequest()
request.param1 = 'val1'
request.param2 = 'val2'
response = SimpleResponse()
rpc = transport.Rpc(request)
self.transport.send_rpc(BasicService.remote_method.remote,
request).AndReturn(rpc)
self.mox.ReplayAll()
self.assertEquals(rpc, stub.async.remote_method(request))
self.mox.VerifyAll()
def testAsync_WithKwargs(self):
stub = BasicService.Stub(self.transport)
request = SimpleRequest()
request.param1 = 'val1'
request.param2 = 'val2'
response = SimpleResponse()
rpc = transport.Rpc(request)
self.transport.send_rpc(BasicService.remote_method.remote,
request).AndReturn(rpc)
self.mox.ReplayAll()
self.assertEquals(rpc, stub.async.remote_method(param1='val1',
param2='val2'))
self.mox.VerifyAll()
def testAsync_WithRequestAndKwargs(self):
stub = BasicService.Stub(self.transport)
request = SimpleRequest()
request.param1 = 'val1'
request.param2 = 'val2'
response = SimpleResponse()
self.mox.ReplayAll()
self.assertRaisesWithRegexpMatch(
TypeError,
r'May not provide both args and kwargs',
stub.async.remote_method,
request,
param1='val1',
param2='val2')
self.mox.VerifyAll()
def testAsync_WithTooManyPositionals(self):
stub = BasicService.Stub(self.transport)
request = SimpleRequest()
request.param1 = 'val1'
request.param2 = 'val2'
response = SimpleResponse()
self.mox.ReplayAll()
self.assertRaisesWithRegexpMatch(
TypeError,
r'remote_method\(\) takes at most 2 positional arguments \(3 given\)',
stub.async.remote_method,
request, 'another value')
self.mox.VerifyAll()
class IsErrorStatusTest(test_util.TestCase):
def testIsError(self):
for state in (s for s in remote.RpcState if s > remote.RpcState.RUNNING):
status = remote.RpcStatus(state=state)
self.assertTrue(remote.is_error_status(status))
def testIsNotError(self):
for state in (s for s in remote.RpcState if s <= remote.RpcState.RUNNING):
status = remote.RpcStatus(state=state)
self.assertFalse(remote.is_error_status(status))
def testStateNone(self):
self.assertRaises(messages.ValidationError,
remote.is_error_status, remote.RpcStatus())
class CheckRpcStatusTest(test_util.TestCase):
def testStateNone(self):
self.assertRaises(messages.ValidationError,
remote.check_rpc_status, remote.RpcStatus())
def testNoError(self):
for state in (remote.RpcState.OK, remote.RpcState.RUNNING):
remote.check_rpc_status(remote.RpcStatus(state=state))
def testErrorState(self):
status = remote.RpcStatus(state=remote.RpcState.REQUEST_ERROR,
error_message='a request error')
self.assertRaisesWithRegexpMatch(remote.RequestError,
'a request error',
remote.check_rpc_status, status)
def testApplicationErrorState(self):
status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
error_message='an application error',
error_name='blam')
try:
remote.check_rpc_status(status)
self.fail('Should have raised application error.')
except remote.ApplicationError as err:
self.assertEquals('an application error', str(err))
self.assertEquals('blam', err.error_name)
class ProtocolConfigTest(test_util.TestCase):
def testConstructor(self):
config = remote.ProtocolConfig(
protojson,
'proto1',
'application/X-Json',
iter(['text/Json', 'text/JavaScript']))
self.assertEquals(protojson, config.protocol)
self.assertEquals('proto1', config.name)
self.assertEquals('application/x-json', config.default_content_type)
self.assertEquals(('text/json', 'text/javascript'),
config.alternate_content_types)
self.assertEquals(('application/x-json', 'text/json', 'text/javascript'),
config.content_types)
def testConstructorDefaults(self):
config = remote.ProtocolConfig(protojson, 'proto2')
self.assertEquals(protojson, config.protocol)
self.assertEquals('proto2', config.name)
self.assertEquals('application/json', config.default_content_type)
self.assertEquals(('application/x-javascript',
'text/javascript',
'text/x-javascript',
'text/x-json',
'text/json'),
config.alternate_content_types)
self.assertEquals(('application/json',
'application/x-javascript',
'text/javascript',
'text/x-javascript',
'text/x-json',
'text/json'), config.content_types)
def testEmptyAlternativeTypes(self):
config = remote.ProtocolConfig(protojson, 'proto2',
alternative_content_types=())
self.assertEquals(protojson, config.protocol)
self.assertEquals('proto2', config.name)
self.assertEquals('application/json', config.default_content_type)
self.assertEquals((), config.alternate_content_types)
self.assertEquals(('application/json',), config.content_types)
def testDuplicateContentTypes(self):
self.assertRaises(remote.ServiceConfigurationError,
remote.ProtocolConfig,
protojson,
'json',
'text/plain',
('text/plain',))
self.assertRaises(remote.ServiceConfigurationError,
remote.ProtocolConfig,
protojson,
'json',
'text/plain',
('text/html', 'text/html'))
def testEncodeMessage(self):
config = remote.ProtocolConfig(protojson, 'proto2')
encoded_message = config.encode_message(
remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
error_message='bad error'))
# Convert back to a dictionary from JSON.
dict_message = protojson.json.loads(encoded_message)
self.assertEquals({'state': 'SERVER_ERROR', 'error_message': 'bad error'},
dict_message)
def testDecodeMessage(self):
config = remote.ProtocolConfig(protojson, 'proto2')
self.assertEquals(
remote.RpcStatus(state=remote.RpcState.SERVER_ERROR,
error_message="bad error"),
config.decode_message(
remote.RpcStatus,
'{"state": "SERVER_ERROR", "error_message": "bad error"}'))
class ProtocolsTest(test_util.TestCase):
def setUp(self):
self.protocols = remote.Protocols()
def testEmpty(self):
self.assertEquals((), self.protocols.names)
self.assertEquals((), self.protocols.content_types)
def testAddProtocolAllDefaults(self):
self.protocols.add_protocol(protojson, 'json')
self.assertEquals(('json',), self.protocols.names)
self.assertEquals(('application/json',
'application/x-javascript',
'text/javascript',
'text/json',
'text/x-javascript',
'text/x-json'),
self.protocols.content_types)
def testAddProtocolNoDefaultAlternatives(self):
class Protocol(object):
CONTENT_TYPE = 'text/plain'
self.protocols.add_protocol(Protocol, 'text')
self.assertEquals(('text',), self.protocols.names)
self.assertEquals(('text/plain',), self.protocols.content_types)
def testAddProtocolOverrideDefaults(self):
self.protocols.add_protocol(protojson, 'json',
default_content_type='text/blar',
alternative_content_types=('text/blam',
'text/blim'))
self.assertEquals(('json',), self.protocols.names)
self.assertEquals(('text/blam', 'text/blar', 'text/blim'),
self.protocols.content_types)
def testLookupByName(self):
self.protocols.add_protocol(protojson, 'json')
self.protocols.add_protocol(protojson, 'json2',
default_content_type='text/plain',
alternative_content_types=())
self.assertEquals('json', self.protocols.lookup_by_name('JsOn').name)
self.assertEquals('json2', self.protocols.lookup_by_name('Json2').name)
def testLookupByContentType(self):
self.protocols.add_protocol(protojson, 'json')
self.protocols.add_protocol(protojson, 'json2',
default_content_type='text/plain',
alternative_content_types=())
self.assertEquals(
'json',
self.protocols.lookup_by_content_type('AppliCation/Json').name)
self.assertEquals(
'json',
self.protocols.lookup_by_content_type('text/x-Json').name)
self.assertEquals(
'json2',
self.protocols.lookup_by_content_type('text/Plain').name)
def testNewDefault(self):
protocols = remote.Protocols.new_default()
self.assertEquals(('protobuf', 'protojson'), protocols.names)
protobuf_protocol = protocols.lookup_by_name('protobuf')
self.assertEquals(protobuf, protobuf_protocol.protocol)
protojson_protocol = protocols.lookup_by_name('protojson')
self.assertEquals(protojson.ProtoJson.get_default(),
protojson_protocol.protocol)
def testGetDefaultProtocols(self):
protocols = remote.Protocols.get_default()
self.assertEquals(('protobuf', 'protojson'), protocols.names)
protobuf_protocol = protocols.lookup_by_name('protobuf')
self.assertEquals(protobuf, protobuf_protocol.protocol)
protojson_protocol = protocols.lookup_by_name('protojson')
self.assertEquals(protojson.ProtoJson.get_default(),
protojson_protocol.protocol)
self.assertTrue(protocols is remote.Protocols.get_default())
def testSetDefaultProtocols(self):
protocols = remote.Protocols()
remote.Protocols.set_default(protocols)
self.assertTrue(protocols is remote.Protocols.get_default())
def testSetDefaultWithoutProtocols(self):
self.assertRaises(TypeError, remote.Protocols.set_default, None)
self.assertRaises(TypeError, remote.Protocols.set_default, 'hi protocols')
self.assertRaises(TypeError, remote.Protocols.set_default, {})
def main():
unittest.main()
if __name__ == '__main__':
main()