Implement connectivity state related APIs
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
index 6e187d6..1e9f634 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pxd.pxi
@@ -17,3 +17,4 @@
grpc_channel * channel
CallbackCompletionQueue cq
bytes _target
+ object _loop
diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
index 850fb77..e6ac6ec 100644
--- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
+++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
@@ -13,14 +13,19 @@
# limitations under the License.
+class _WatchConnectivityFailed(Exception): pass
+cdef CallbackFailureHandler _WATCH_CONNECTIVITY_FAILURE_HANDLER = CallbackFailureHandler(
+ 'watch_connectivity_state',
+ 'Maybe timed out.',
+ _WatchConnectivityFailed)
+
+
cdef class AioChannel:
def __cinit__(self, bytes target, tuple options, ChannelCredentials credentials):
if options is None:
options = ()
cdef _ChannelArgs channel_args = _ChannelArgs(options)
self._target = target
- self.cq = CallbackCompletionQueue()
-
if credentials is None:
self.channel = grpc_insecure_channel_create(
<char *>target,
@@ -32,12 +37,42 @@
<char *> target,
channel_args.c_args(),
NULL)
+ self._loop = asyncio.get_event_loop()
def __repr__(self):
class_name = self.__class__.__name__
id_ = id(self)
return f"<{class_name} {id_}>"
+ def check_connectivity_state(self, bint try_to_connect):
+ return grpc_channel_check_connectivity_state(
+ self.channel,
+ try_to_connect,
+ )
+
+ async def watch_connectivity_state(self,
+ grpc_connectivity_state last_observed_state,
+ object deadline):
+ cdef gpr_timespec c_deadline = _timespec_from_time(deadline)
+
+ cdef object future = self._loop.create_future()
+ cdef CallbackWrapper wrapper = CallbackWrapper(
+ future,
+ _WATCH_CONNECTIVITY_FAILURE_HANDLER)
+ grpc_channel_watch_connectivity_state(
+ self.channel,
+ last_observed_state,
+ c_deadline,
+ self.cq.c_ptr(),
+ wrapper.c_functor())
+
+ try:
+ await future
+ except _WatchConnectivityFailed:
+ return None
+ else:
+ return self.check_connectivity_state(False)
+
def close(self):
grpc_channel_destroy(self.channel)
diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py
index 11405a8..4014e82 100644
--- a/src/python/grpcio/grpc/experimental/aio/_channel.py
+++ b/src/python/grpcio/grpc/experimental/aio/_channel.py
@@ -13,7 +13,8 @@
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""
import asyncio
-from typing import Any, Optional, Sequence, Text
+import time
+from typing import Any, Optional, Sequence, Text, Tuple
import grpc
from grpc import _common
@@ -224,6 +225,51 @@
self._channel = cygrpc.AioChannel(_common.encode(target), options,
credentials)
+ def check_connectivity_state(self, try_to_connect: bool = False
+ ) -> grpc.ChannelConnectivity:
+ """Check the connectivity state of a channel.
+
+ This is an EXPERIMENTAL API.
+
+ Args:
+ try_to_connect: a bool indicate whether the Channel should try to connect to peer or not.
+
+ Returns:
+ A ChannelConnectivity object.
+ """
+ result = self._channel.check_connectivity_state(try_to_connect)
+ return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result]
+
+ async def watch_connectivity_state(
+ self,
+ last_observed_state: grpc.ChannelConnectivity,
+ timeout_seconds: Optional[float] = None,
+ ) -> Optional[grpc.ChannelConnectivity]:
+ """Watch for a change in connectivity state.
+
+ This is an EXPERIMENTAL API.
+
+ Once the channel connectivity state is different from
+ last_observed_state, the function will return the new connectivity
+ state. If deadline expires BEFORE the state is changed, None will be
+ returned.
+
+ Args:
+ try_to_connect: a bool indicate whether the Channel should try to connect to peer or not.
+
+ Returns:
+ A ChannelConnectivity object or None.
+ """
+ deadline = time.time(
+ ) + timeout_seconds if timeout_seconds is not None else None
+ result = await self._channel.watch_connectivity_state(
+ last_observed_state.value[0], deadline)
+ if result is None:
+ return None
+ else:
+ return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[
+ result]
+
def unary_unary(
self,
method: Text,
diff --git a/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
new file mode 100644
index 0000000..eb7ee35
--- /dev/null
+++ b/src/python/grpcio_tests/tests_aio/unit/connectivity_test.py
@@ -0,0 +1,98 @@
+# Copyright 2019 The gRPC Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests behavior of the connectivity state."""
+
+import logging
+import threading
+import unittest
+import time
+import grpc
+
+from grpc.experimental import aio
+from src.proto.grpc.testing import messages_pb2
+from src.proto.grpc.testing import test_pb2_grpc
+from tests.unit.framework.common import test_constants
+from tests_aio.unit._test_server import start_test_server
+from tests_aio.unit._test_base import AioTestBase
+
+_INVALID_BACKEND_ADDRESS = '0.0.0.1:2'
+
+
+class TestChannel(AioTestBase):
+
+ async def setUp(self):
+ self._server_address, self._server = await start_test_server()
+
+ async def tearDown(self):
+ await self._server.stop(None)
+
+ async def test_unavailable_backend(self):
+ channel = aio.insecure_channel(_INVALID_BACKEND_ADDRESS)
+
+ self.assertEqual(grpc.ChannelConnectivity.IDLE,
+ channel.check_connectivity_state(False))
+ self.assertEqual(grpc.ChannelConnectivity.IDLE,
+ channel.check_connectivity_state(True))
+ self.assertEqual(
+ grpc.ChannelConnectivity.CONNECTING, await
+ channel.watch_connectivity_state(grpc.ChannelConnectivity.IDLE))
+ self.assertEqual(
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE, await
+ channel.watch_connectivity_state(grpc.ChannelConnectivity.CONNECTING
+ ))
+
+ await channel.close()
+
+ async def test_normal_backend(self):
+ channel = aio.insecure_channel(self._server_address)
+
+ current_state = channel.check_connectivity_state(True)
+ self.assertEqual(grpc.ChannelConnectivity.IDLE, current_state)
+
+ deadline = time.time() + test_constants.SHORT_TIMEOUT
+
+ while current_state != grpc.ChannelConnectivity.READY:
+ current_state = await channel.watch_connectivity_state(
+ current_state, deadline - time.time())
+ self.assertIsNotNone(current_state)
+
+ await channel.close()
+
+ async def test_timeout(self):
+ channel = aio.insecure_channel(self._server_address)
+
+ self.assertEqual(grpc.ChannelConnectivity.IDLE,
+ channel.check_connectivity_state(False))
+
+ # If timed out, the function should return None.
+ self.assertIsNone(await channel.watch_connectivity_state(
+ grpc.ChannelConnectivity.IDLE, test_constants.SHORT_TIMEOUT))
+
+ await channel.close()
+
+ async def test_shutdown(self):
+ channel = aio.insecure_channel(self._server_address)
+
+ self.assertEqual(grpc.ChannelConnectivity.IDLE,
+ channel.check_connectivity_state(False))
+
+ await channel.close()
+
+ self.assertEqual(grpc.ChannelConnectivity.SHUTDOWN,
+ channel.check_connectivity_state(False))
+
+
+if __name__ == '__main__':
+ logging.basicConfig(level=logging.DEBUG)
+ unittest.main(verbosity=2)