blob: 8fa39c3d464c5f6700bfc79285955ddee77d7492 [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import errno
import six.moves.http_client
import os
import socket
import unittest
from protorpc import messages
from protorpc import protobuf
from protorpc import protojson
from protorpc import remote
from protorpc import test_util
from protorpc import transport
from protorpc import webapp_test_util
from protorpc.wsgi import util as wsgi_util
import mox
package = 'transport_test'
class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
test_util.TestCase):
MODULE = transport
class Message(messages.Message):
value = messages.StringField(1)
class Service(remote.Service):
@remote.method(Message, Message)
def method(self, request):
pass
# Remove when RPC is no longer subclasses.
class TestRpc(transport.Rpc):
waited = False
def _wait_impl(self):
self.waited = True
class RpcTest(test_util.TestCase):
def setUp(self):
self.request = Message(value=u'request')
self.response = Message(value=u'response')
self.status = remote.RpcStatus(state=remote.RpcState.APPLICATION_ERROR,
error_message='an error',
error_name='blam')
self.rpc = TestRpc(self.request)
def testConstructor(self):
self.assertEquals(self.request, self.rpc.request)
self.assertEquals(remote.RpcState.RUNNING, self.rpc.state)
self.assertEquals(None, self.rpc.error_message)
self.assertEquals(None, self.rpc.error_name)
def response(self):
self.assertFalse(self.rpc.waited)
self.assertEquals(None, self.rpc.response)
self.assertTrue(self.rpc.waited)
def testSetResponse(self):
self.rpc.set_response(self.response)
self.assertEquals(self.request, self.rpc.request)
self.assertEquals(remote.RpcState.OK, self.rpc.state)
self.assertEquals(self.response, self.rpc.response)
self.assertEquals(None, self.rpc.error_message)
self.assertEquals(None, self.rpc.error_name)
def testSetResponseAlreadySet(self):
self.rpc.set_response(self.response)
self.assertRaisesWithRegexpMatch(
transport.RpcStateError,
'RPC must be in RUNNING state to change to OK',
self.rpc.set_response,
self.response)
def testSetResponseAlreadyError(self):
self.rpc.set_status(self.status)
self.assertRaisesWithRegexpMatch(
transport.RpcStateError,
'RPC must be in RUNNING state to change to OK',
self.rpc.set_response,
self.response)
def testSetStatus(self):
self.rpc.set_status(self.status)
self.assertEquals(self.request, self.rpc.request)
self.assertEquals(remote.RpcState.APPLICATION_ERROR, self.rpc.state)
self.assertEquals('an error', self.rpc.error_message)
self.assertEquals('blam', self.rpc.error_name)
self.assertRaisesWithRegexpMatch(remote.ApplicationError,
'an error',
getattr, self.rpc, 'response')
def testSetStatusAlreadySet(self):
self.rpc.set_response(self.response)
self.assertRaisesWithRegexpMatch(
transport.RpcStateError,
'RPC must be in RUNNING state to change to OK',
self.rpc.set_response,
self.response)
def testSetNonMessage(self):
self.assertRaisesWithRegexpMatch(
TypeError,
'Expected Message type, received 10',
self.rpc.set_response,
10)
def testSetStatusAlreadyError(self):
self.rpc.set_status(self.status)
self.assertRaisesWithRegexpMatch(
transport.RpcStateError,
'RPC must be in RUNNING state to change to OK',
self.rpc.set_response,
self.response)
def testSetUninitializedStatus(self):
self.assertRaises(messages.ValidationError,
self.rpc.set_status,
remote.RpcStatus())
class TransportTest(test_util.TestCase):
def setUp(self):
remote.Protocols.set_default(remote.Protocols.new_default())
def do_test(self, protocol, trans):
request = Message()
request.value = u'request'
response = Message()
response.value = u'response'
encoded_request = protocol.encode_message(request)
encoded_response = protocol.encode_message(response)
self.assertEquals(protocol, trans.protocol)
received_rpc = [None]
def transport_rpc(remote, rpc_request):
self.assertEquals(remote, Service.method.remote)
self.assertEquals(request, rpc_request)
rpc = TestRpc(request)
rpc.set_response(response)
return rpc
trans._start_rpc = transport_rpc
rpc = trans.send_rpc(Service.method.remote, request)
self.assertEquals(response, rpc.response)
def testDefaultProtocol(self):
trans = transport.Transport()
self.do_test(protobuf, trans)
self.assertEquals(protobuf, trans.protocol_config.protocol)
self.assertEquals('default', trans.protocol_config.name)
def testAlternateProtocol(self):
trans = transport.Transport(protocol=protojson)
self.do_test(protojson, trans)
self.assertEquals(protojson, trans.protocol_config.protocol)
self.assertEquals('default', trans.protocol_config.name)
def testProtocolConfig(self):
protocol_config = remote.ProtocolConfig(
protojson, 'protoconfig', 'image/png')
trans = transport.Transport(protocol=protocol_config)
self.do_test(protojson, trans)
self.assertTrue(trans.protocol_config is protocol_config)
def testProtocolByName(self):
remote.Protocols.get_default().add_protocol(
protojson, 'png', 'image/png', ())
trans = transport.Transport(protocol='png')
self.do_test(protojson, trans)
@remote.method(Message, Message)
def my_method(self, request):
self.fail('self.my_method should not be directly invoked.')
class FakeConnectionClass(object):
def __init__(self, mox):
self.request = mox.CreateMockAnything()
self.response = mox.CreateMockAnything()
class HttpTransportTest(webapp_test_util.WebServerTestBase):
def setUp(self):
# Do not need much parent construction functionality.
self.schema = 'http'
self.server = None
self.request = Message(value=u'The request value')
self.encoded_request = protojson.encode_message(self.request)
self.response = Message(value=u'The response value')
self.encoded_response = protojson.encode_message(self.response)
def testCallSucceeds(self):
self.ResetServer(wsgi_util.static_page(self.encoded_response,
content_type='application/json'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
self.assertEquals(self.response, rpc.response)
def testHttps(self):
self.schema = 'https'
self.ResetServer(wsgi_util.static_page(self.encoded_response,
content_type='application/json'))
# Create a fake https connection function that really just calls http.
self.used_https = False
def https_connection(*args, **kwargs):
self.used_https = True
return six.moves.http_client.HTTPConnection(*args, **kwargs)
original_https_connection = six.moves.http_client.HTTPSConnection
six.moves.http_client.HTTPSConnection = https_connection
try:
rpc = self.connection.send_rpc(my_method.remote, self.request)
finally:
six.moves.http_client.HTTPSConnection = original_https_connection
self.assertEquals(self.response, rpc.response)
self.assertTrue(self.used_https)
def testHttpSocketError(self):
self.ResetServer(wsgi_util.static_page(self.encoded_response,
content_type='application/json'))
bad_transport = transport.HttpTransport('http://localhost:-1/blar')
try:
bad_transport.send_rpc(my_method.remote, self.request)
except remote.NetworkError as err:
self.assertTrue(str(err).startswith('Socket error: error ('))
self.assertEquals(errno.ECONNREFUSED, err.cause.errno)
else:
self.fail('Expected error')
def testHttpRequestError(self):
self.ResetServer(wsgi_util.static_page(self.encoded_response,
content_type='application/json'))
def request_error(*args, **kwargs):
raise TypeError('Generic Error')
original_request = six.moves.http_client.HTTPConnection.request
six.moves.http_client.HTTPConnection.request = request_error
try:
try:
self.connection.send_rpc(my_method.remote, self.request)
except remote.NetworkError as err:
self.assertEquals('Error communicating with HTTP server', str(err))
self.assertEquals(TypeError, type(err.cause))
self.assertEquals('Generic Error', str(err.cause))
else:
self.fail('Expected error')
finally:
six.moves.http_client.HTTPConnection.request = original_request
def testHandleGenericServiceError(self):
self.ResetServer(wsgi_util.error(six.moves.http_client.INTERNAL_SERVER_ERROR,
'arbitrary error',
content_type='text/plain'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
try:
rpc.response
except remote.ServerError as err:
self.assertEquals('HTTP Error 500: arbitrary error', str(err).strip())
else:
self.fail('Expected ServerError')
def testHandleGenericServiceErrorNoMessage(self):
self.ResetServer(wsgi_util.error(six.moves.http_client.NOT_IMPLEMENTED,
' ',
content_type='text/plain'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
try:
rpc.response
except remote.ServerError as err:
self.assertEquals('HTTP Error 501: Not Implemented', str(err).strip())
else:
self.fail('Expected ServerError')
def testHandleStatusContent(self):
self.ResetServer(wsgi_util.static_page('{"state": "REQUEST_ERROR",'
' "error_message": "a request error"'
'}',
status=six.moves.http_client.BAD_REQUEST,
content_type='application/json'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
try:
rpc.response
except remote.RequestError as err:
self.assertEquals('a request error', str(err))
else:
self.fail('Expected RequestError')
def testHandleApplicationError(self):
self.ResetServer(wsgi_util.static_page('{"state": "APPLICATION_ERROR",'
' "error_message": "an app error",'
' "error_name": "MY_ERROR_NAME"}',
status=six.moves.http_client.BAD_REQUEST,
content_type='application/json'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
try:
rpc.response
except remote.ApplicationError as err:
self.assertEquals('an app error', str(err))
self.assertEquals('MY_ERROR_NAME', err.error_name)
else:
self.fail('Expected RequestError')
def testHandleUnparsableErrorContent(self):
self.ResetServer(wsgi_util.static_page('oops',
status=six.moves.http_client.BAD_REQUEST,
content_type='application/json'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
try:
rpc.response
except remote.ServerError as err:
self.assertEquals('HTTP Error 400: oops', str(err))
else:
self.fail('Expected ServerError')
def testHandleEmptyBadRpcStatus(self):
self.ResetServer(wsgi_util.static_page('{"error_message": "x"}',
status=six.moves.http_client.BAD_REQUEST,
content_type='application/json'))
rpc = self.connection.send_rpc(my_method.remote, self.request)
try:
rpc.response
except remote.ServerError as err:
self.assertEquals('HTTP Error 400: {"error_message": "x"}', str(err))
else:
self.fail('Expected ServerError')
def testUseProtocolConfigContentType(self):
expected_content_type = 'image/png'
def expect_content_type(environ, start_response):
self.assertEquals(expected_content_type, environ['CONTENT_TYPE'])
app = wsgi_util.static_page('', content_type=environ['CONTENT_TYPE'])
return app(environ, start_response)
self.ResetServer(expect_content_type)
protocol_config = remote.ProtocolConfig(protojson, 'json', 'image/png')
self.connection = self.CreateTransport(self.service_url, protocol_config)
rpc = self.connection.send_rpc(my_method.remote, self.request)
self.assertEquals(Message(), rpc.response)
class SimpleRequest(messages.Message):
content = messages.StringField(1)
class SimpleResponse(messages.Message):
content = messages.StringField(1)
factory_value = messages.StringField(2)
remote_host = messages.StringField(3)
remote_address = messages.StringField(4)
server_host = messages.StringField(5)
server_port = messages.IntegerField(6)
class LocalService(remote.Service):
def __init__(self, factory_value='default'):
self.factory_value = factory_value
@remote.method(SimpleRequest, SimpleResponse)
def call_method(self, request):
return SimpleResponse(content=request.content,
factory_value=self.factory_value,
remote_host=self.request_state.remote_host,
remote_address=self.request_state.remote_address,
server_host=self.request_state.server_host,
server_port=self.request_state.server_port)
@remote.method()
def raise_totally_unexpected(self, request):
raise TypeError('Kablam')
@remote.method()
def raise_unexpected(self, request):
raise remote.RequestError('Huh?')
@remote.method()
def raise_application_error(self, request):
raise remote.ApplicationError('App error', 10)
class LocalTransportTest(test_util.TestCase):
def CreateService(self, factory_value='default'):
return
def testBasicCallWithClass(self):
stub = LocalService.Stub(transport.LocalTransport(LocalService))
response = stub.call_method(content='Hello')
self.assertEquals(SimpleResponse(content='Hello',
factory_value='default',
remote_host=os.uname()[1],
remote_address='127.0.0.1',
server_host=os.uname()[1],
server_port=-1),
response)
def testBasicCallWithFactory(self):
stub = LocalService.Stub(
transport.LocalTransport(LocalService.new_factory('assigned')))
response = stub.call_method(content='Hello')
self.assertEquals(SimpleResponse(content='Hello',
factory_value='assigned',
remote_host=os.uname()[1],
remote_address='127.0.0.1',
server_host=os.uname()[1],
server_port=-1),
response)
def testTotallyUnexpectedError(self):
stub = LocalService.Stub(transport.LocalTransport(LocalService))
self.assertRaisesWithRegexpMatch(
remote.ServerError,
'Unexpected error TypeError: Kablam',
stub.raise_totally_unexpected)
def testUnexpectedError(self):
stub = LocalService.Stub(transport.LocalTransport(LocalService))
self.assertRaisesWithRegexpMatch(
remote.ServerError,
'Unexpected error RequestError: Huh?',
stub.raise_unexpected)
def testApplicationError(self):
stub = LocalService.Stub(transport.LocalTransport(LocalService))
self.assertRaisesWithRegexpMatch(
remote.ApplicationError,
'App error',
stub.raise_application_error)
def main():
unittest.main()
if __name__ == '__main__':
main()