blob: 0d615923e04915a8429252317025ac8e79f9bb4e [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Base class for RpcOp tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.contrib.proto.python.ops import encode_proto_op
from tensorflow.contrib.rpc.python.kernel_tests import test_example_pb2
from tensorflow.contrib.rpc.python.ops import rpc_op
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
__all__ = ['I_WARNED_YOU', 'RpcOpTestBase']
I_WARNED_YOU = 'I warned you!'
class RpcOpTestBase(object):
# pylint: disable=missing-docstring,invalid-name
"""Base class for RpcOp tests."""
def get_method_name(self, suffix):
raise NotImplementedError
def rpc(self, *args, **kwargs):
return rpc_op.rpc(*args, protocol=self._protocol, **kwargs)
def try_rpc(self, *args, **kwargs):
return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
def testScalarHostPortRpc(self):
with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, ())
response_values = sess.run(response_tensors)
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(status_code.shape, ())
self.assertEqual(status_message.shape, ())
self.assertEqual(response_tensors.shape, ())
response_values, status_code_values, status_message_values = (
sess.run((response_tensors, status_code, status_message)))
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values))
self.assertAllEqual([2, 3, 4], response_message.values)
# For the base Rpc op, don't expect to get error status back.
self.assertEqual(errors.OK, status_code_values)
self.assertEqual(b'', status_message_values)
def testEmptyHostPortRpc(self):
with self.cached_session() as sess:
request_tensors = []
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertAllEqual(response_tensors.shape, [0])
response_values = sess.run(response_tensors)
self.assertAllEqual(response_values.shape, [0])
def testInvalidMethod(self):
for method in [
'/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
with self.cached_session() as sess:
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(self.rpc(method=method, address=self._address, request=''))
_, status_code_value, status_message_value = sess.run(
self.try_rpc(method=method, address=self._address, request=''))
self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
self.assertTrue(
self.invalid_method_string in status_message_value.decode('ascii'))
def testInvalidAddress(self):
# This covers the case of address='' and address='localhost:293874293874'
address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
with self.cached_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
method=self.get_method_name('Increment'),
address=address,
request=''))
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
method=self.get_method_name('Increment'),
address=address,
request=''))
self.assertEqual(errors.UNAVAILABLE, status_code_value)
self.assertTrue(
self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
with self.cached_session() as sess:
response_tensors = self.rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
request='')
self.assertEqual(response_tensors.shape, ())
with self.assertRaisesOpError(I_WARNED_YOU):
sess.run(response_tensors)
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
request='')
self.assertEqual(response_tensors.shape, ())
self.assertEqual(status_code.shape, ())
self.assertEqual(status_message.shape, ())
status_code_value, status_message_value = sess.run((status_code,
status_message))
self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)
self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
def testSometimesFailingMethodWithManyRequests(self):
with self.cached_session() as sess:
# Fail hard by default.
response_tensors = self.rpc(
method=self.get_method_name('SometimesFailWithInvalidArgument'),
address=self._address,
request=[''] * 20)
self.assertEqual(response_tensors.shape, (20,))
with self.assertRaisesOpError(I_WARNED_YOU):
sess.run(response_tensors)
# Don't fail hard, use TryRpc - return the failing status instead.
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('SometimesFailWithInvalidArgument'),
address=self._address,
request=[''] * 20)
self.assertEqual(response_tensors.shape, (20,))
self.assertEqual(status_code.shape, (20,))
self.assertEqual(status_message.shape, (20,))
status_code_values, status_message_values = sess.run((status_code,
status_message))
self.assertTrue([
x in (errors.OK, errors.INVALID_ARGUMENT) for x in status_code_values
])
expected_message_values = np.where(
status_code_values == errors.INVALID_ARGUMENT,
I_WARNED_YOU.encode('ascii'), b'')
self.assertAllEqual(expected_message_values, status_message_values)
def testVecHostPortRpc(self):
with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
self.assertEqual(response_tensors.shape, (20,))
response_values = sess.run(response_tensors)
self.assertEqual(response_values.shape, (20,))
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
many_response_tensors = [
self.rpc(
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors) for _ in range(10)
]
# Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
many_response_values = sess.run(many_response_tensors)
self.assertEqual(10, len(many_response_values))
for response_values in many_response_values:
self.assertEqual(response_values.shape, (20,))
for i in range(20):
response_message = test_example_pb2.TestCase()
self.assertTrue(response_message.ParseFromString(response_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
with self.cached_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['values'],
sizes=[[3]] * 20,
values=[
[[i, i + 1, i + 2] for i in range(20)],
])
response_tensor_strings = self.rpc(
method=self.get_method_name('Increment'),
address=self._address,
request=request_tensors)
_, (response_shape,) = decode_proto_op.decode_proto(
bytes=response_tensor_strings,
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['values'],
output_types=[dtypes.int32])
response_shape_values = sess.run(response_shape)
self.assertAllEqual([[i + 1, i + 2, i + 3]
for i in range(20)], response_shape_values)
def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
address=self._address,
request=request_tensors)
for timeout_ms in [1, 500, 1000]:
options = config_pb2.RunOptions(timeout_in_ms=timeout_ms)
with self.assertRaises((errors.UnavailableError,
errors.DeadlineExceededError)):
sess.run(response_tensors, options=options)
def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
address=self._address,
timeout_in_ms=1000,
request=request_tensors)
with self.assertRaises(errors.DeadlineExceededError):
sess.run(response_tensors)
def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
with self.cached_session() as sess:
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('SometimesSleepForever'),
timeout_in_ms=1000,
address=self._address,
request=[''] * 20)
self.assertEqual(response_tensors.shape, (20,))
self.assertEqual(status_code.shape, (20,))
self.assertEqual(status_message.shape, (20,))
status_code_values = sess.run(status_code)
self.assertTrue([
x in (errors.OK, errors.DEADLINE_EXCEEDED) for x in status_code_values
])
def testTryRpcWithMultipleAddressesSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
method=self.get_method_name('Increment'),
address=addresses,
request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
status_code))
self.assertAllEqual(
flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
status_code_values)
for i in range(10):
self.assertTrue(response_tensors_values[2 * i])
self.assertFalse(response_tensors_values[2 * i + 1])
def testTryRpcWithMultipleMethodsSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.cached_session() as sess:
methods = flatten(
[[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
request = test_example_pb2.TestCase(values=[0, 1, 2]).SerializeToString()
response_tensors, status_code, _ = self.try_rpc(
method=methods, address=self._address, request=request)
response_tensors_values, status_code_values = sess.run((response_tensors,
status_code))
self.assertAllEqual(
flatten([errors.OK, errors.UNIMPLEMENTED] for _ in range(10)),
status_code_values)
for i in range(10):
self.assertTrue(response_tensors_values[2 * i])
self.assertFalse(response_tensors_values[2 * i + 1])
def testTryRpcWithMultipleAddressesAndRequests(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
requests = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
]
response_tensors, status_code, _ = self.try_rpc(
method=self.get_method_name('Increment'),
address=addresses,
request=requests)
response_tensors_values, status_code_values = sess.run((response_tensors,
status_code))
self.assertAllEqual(
flatten([errors.OK, errors.UNAVAILABLE] for _ in range(10)),
status_code_values)
for i in range(20):
if i % 2 == 1:
self.assertFalse(response_tensors_values[i])
else:
response_message = test_example_pb2.TestCase()
self.assertTrue(
response_message.ParseFromString(response_tensors_values[i]))
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)