chore: avoid checking instance on each stream call (#529)
* chore: avoid checking instance on each stream call
* fixed indentation
* added check for unary call
* fixed type check
* fixed tests
* fixed coverage
* added exception to test class
* added comment to test
* 🦉 Updates from OwlBot post-processor
See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md
---------
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py
index 9423d2b..718b5f0 100644
--- a/google/api_core/grpc_helpers_async.py
+++ b/google/api_core/grpc_helpers_async.py
@@ -159,7 +159,6 @@
def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
- grpc_helpers._patch_callable_name(callable_)
@functools.wraps(callable_)
def error_remapped_callable(*args, **kwargs):
@@ -169,23 +168,13 @@
return error_remapped_callable
-def _wrap_stream_errors(callable_):
+def _wrap_stream_errors(callable_, wrapper_type):
"""Map errors for streaming RPC async callables."""
- grpc_helpers._patch_callable_name(callable_)
@functools.wraps(callable_)
async def error_remapped_callable(*args, **kwargs):
call = callable_(*args, **kwargs)
-
- if isinstance(call, aio.UnaryStreamCall):
- call = _WrappedUnaryStreamCall().with_call(call)
- elif isinstance(call, aio.StreamUnaryCall):
- call = _WrappedStreamUnaryCall().with_call(call)
- elif isinstance(call, aio.StreamStreamCall):
- call = _WrappedStreamStreamCall().with_call(call)
- else:
- raise TypeError("Unexpected type of call %s" % type(call))
-
+ call = wrapper_type().with_call(call)
await call.wait_for_connection()
return call
@@ -207,10 +196,17 @@
Returns: Callable: The wrapped gRPC callable.
"""
+ grpc_helpers._patch_callable_name(callable_)
if isinstance(callable_, aio.UnaryUnaryMultiCallable):
return _wrap_unary_errors(callable_)
+ elif isinstance(callable_, aio.UnaryStreamMultiCallable):
+ return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall)
+ elif isinstance(callable_, aio.StreamUnaryMultiCallable):
+ return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall)
+ elif isinstance(callable_, aio.StreamStreamMultiCallable):
+ return _wrap_stream_errors(callable_, _WrappedStreamStreamCall)
else:
- return _wrap_stream_errors(callable_)
+ raise TypeError("Unexpected type of callable: {}".format(type(callable_)))
def create_channel(
diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py
index 6bde59c..6e08f10 100644
--- a/tests/asyncio/test_grpc_helpers_async.py
+++ b/tests/asyncio/test_grpc_helpers_async.py
@@ -98,11 +98,39 @@
@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "callable_type,expected_wrapper_type",
+ [
+ (grpc.aio.UnaryStreamMultiCallable, grpc_helpers_async._WrappedUnaryStreamCall),
+ (grpc.aio.StreamUnaryMultiCallable, grpc_helpers_async._WrappedStreamUnaryCall),
+ (
+ grpc.aio.StreamStreamMultiCallable,
+ grpc_helpers_async._WrappedStreamStreamCall,
+ ),
+ ],
+)
+async def test_wrap_errors_w_stream_type(callable_type, expected_wrapper_type):
+ class ConcreteMulticallable(callable_type):
+ def __call__(self, *args, **kwargs):
+ raise NotImplementedError("Should not be called")
+
+ with mock.patch.object(
+ grpc_helpers_async, "_wrap_stream_errors"
+ ) as wrap_stream_errors:
+ callable_ = ConcreteMulticallable()
+ grpc_helpers_async.wrap_errors(callable_)
+ assert wrap_stream_errors.call_count == 1
+ wrap_stream_errors.assert_called_once_with(callable_, expected_wrapper_type)
+
+
+@pytest.mark.asyncio
async def test_wrap_stream_errors_unary_stream():
mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedUnaryStreamCall
+ )
await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
@@ -114,7 +142,9 @@
mock_call = mock.Mock(aio.StreamUnaryCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamUnaryCall
+ )
await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
@@ -126,7 +156,9 @@
mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
@@ -134,14 +166,16 @@
@pytest.mark.asyncio
-async def test_wrap_stream_errors_type_error():
+async def test_wrap_errors_type_error():
+ """
+ If wrap_errors is called with an unexpected type, it should raise a TypeError.
+ """
mock_call = mock.Mock()
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
-
- with pytest.raises(TypeError):
- await wrapped_callable()
+ with pytest.raises(TypeError) as exc:
+ grpc_helpers_async.wrap_errors(multicallable)
+ assert "Unexpected type" in str(exc.value)
@pytest.mark.asyncio
@@ -151,7 +185,9 @@
mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error])
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
with pytest.raises(exceptions.InvalidArgument):
await wrapped_callable()
@@ -166,7 +202,9 @@
mock_call.read = mock.AsyncMock(side_effect=grpc_error)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
wrapped_call = await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
@@ -189,7 +227,9 @@
mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
wrapped_call = await wrapped_callable()
with pytest.raises(exceptions.InvalidArgument) as exc_info:
@@ -210,7 +250,9 @@
mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
wrapped_call = await wrapped_callable()
with pytest.raises(TypeError) as exc_info:
@@ -224,7 +266,9 @@
mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
wrapped_call = await wrapped_callable()
assert wrapped_call.__aiter__() == wrapped_call.__aiter__()
@@ -239,7 +283,9 @@
mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error])
multicallable = mock.Mock(return_value=mock_call)
- wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
+ wrapped_callable = grpc_helpers_async._wrap_stream_errors(
+ multicallable, grpc_helpers_async._WrappedStreamStreamCall
+ )
wrapped_call = await wrapped_callable()
@@ -295,7 +341,9 @@
result = grpc_helpers_async.wrap_errors(callable_)
assert result == wrap_stream_errors.return_value
- wrap_stream_errors.assert_called_once_with(callable_)
+ wrap_stream_errors.assert_called_once_with(
+ callable_, grpc_helpers_async._WrappedUnaryStreamCall
+ )
@pytest.mark.parametrize(