feat: add type annotations to wrapped grpc calls (#554)

* add types to grpc call wrappers

* fixed tests

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* changed type

* changed async types

* added tests

* fixed lint issues

* Update tests/asyncio/test_grpc_helpers_async.py

Co-authored-by: Anthonios Partheniou <partheniou@google.com>

* turned GrpcStream into a type alias

* added test for GrpcStream

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* added comment

* reordered types

* changed type var to P

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
Co-authored-by: Anthonios Partheniou <partheniou@google.com>
diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py
index f52e180..793c884 100644
--- a/google/api_core/grpc_helpers.py
+++ b/google/api_core/grpc_helpers.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 """Helpers for :mod:`grpc`."""
+from typing import Generic, TypeVar, Iterator
 
 import collections
 import functools
@@ -54,6 +55,9 @@
 
 _LOGGER = logging.getLogger(__name__)
 
+# denotes the proto response type for grpc calls
+P = TypeVar("P")
+
 
 def _patch_callable_name(callable_):
     """Fix-up gRPC callable attributes.
@@ -79,7 +83,7 @@
     return error_remapped_callable
 
 
-class _StreamingResponseIterator(grpc.Call):
+class _StreamingResponseIterator(Generic[P], grpc.Call):
     def __init__(self, wrapped, prefetch_first_result=True):
         self._wrapped = wrapped
 
@@ -97,11 +101,11 @@
             # ignore stop iteration at this time. This should be handled outside of retry.
             pass
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[P]:
         """This iterator is also an iterable that returns itself."""
         return self
 
-    def __next__(self):
+    def __next__(self) -> P:
         """Get the next response from the stream.
 
         Returns:
@@ -144,6 +148,10 @@
         return self._wrapped.trailing_metadata()
 
 
+# public type alias denoting the return type of streaming gapic calls
+GrpcStream = _StreamingResponseIterator[P]
+
+
 def _wrap_stream_errors(callable_):
     """Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
 
diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py
index d1f69d9..5685e6f 100644
--- a/google/api_core/grpc_helpers_async.py
+++ b/google/api_core/grpc_helpers_async.py
@@ -21,11 +21,15 @@
 import asyncio
 import functools
 
+from typing import Generic, Iterator, AsyncGenerator, TypeVar
+
 import grpc
 from grpc import aio
 
 from google.api_core import exceptions, grpc_helpers
 
+# denotes the proto response type for grpc calls
+P = TypeVar("P")
 
 # NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform
 # automatic patching for us. But that means the overhead of creating an
@@ -75,8 +79,8 @@
             raise exceptions.from_grpc_error(rpc_error) from rpc_error
 
 
-class _WrappedUnaryResponseMixin(_WrappedCall):
-    def __await__(self):
+class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall):
+    def __await__(self) -> Iterator[P]:
         try:
             response = yield from self._call.__await__()
             return response
@@ -84,17 +88,17 @@
             raise exceptions.from_grpc_error(rpc_error) from rpc_error
 
 
-class _WrappedStreamResponseMixin(_WrappedCall):
+class _WrappedStreamResponseMixin(Generic[P], _WrappedCall):
     def __init__(self):
         self._wrapped_async_generator = None
 
-    async def read(self):
+    async def read(self) -> P:
         try:
             return await self._call.read()
         except grpc.RpcError as rpc_error:
             raise exceptions.from_grpc_error(rpc_error) from rpc_error
 
-    async def _wrapped_aiter(self):
+    async def _wrapped_aiter(self) -> AsyncGenerator[P, None]:
         try:
             # NOTE(lidiz) coverage doesn't understand the exception raised from
             # __anext__ method. It is covered by test case:
@@ -104,7 +108,7 @@
         except grpc.RpcError as rpc_error:
             raise exceptions.from_grpc_error(rpc_error) from rpc_error
 
-    def __aiter__(self):
+    def __aiter__(self) -> AsyncGenerator[P, None]:
         if not self._wrapped_async_generator:
             self._wrapped_async_generator = self._wrapped_aiter()
         return self._wrapped_async_generator
@@ -127,26 +131,32 @@
 # NOTE(lidiz) Implementing each individual class separately, so we don't
 # expose any API that should not be seen. E.g., __aiter__ in unary-unary
 # RPC, or __await__ in stream-stream RPC.
-class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall):
+class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall):
     """Wrapped UnaryUnaryCall to map exceptions."""
 
 
-class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall):
+class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall):
     """Wrapped UnaryStreamCall to map exceptions."""
 
 
 class _WrappedStreamUnaryCall(
-    _WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall
+    _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall
 ):
     """Wrapped StreamUnaryCall to map exceptions."""
 
 
 class _WrappedStreamStreamCall(
-    _WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall
+    _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall
 ):
     """Wrapped StreamStreamCall to map exceptions."""
 
 
+# public type alias denoting the return type of async streaming gapic calls
+GrpcAsyncStream = _WrappedStreamResponseMixin[P]
+# public type alias denoting the return type of unary gapic calls
+AwaitableGrpcCall = _WrappedUnaryResponseMixin[P]
+
+
 def _wrap_unary_errors(callable_):
     """Map errors for Unary-Unary async callables."""
     grpc_helpers._patch_callable_name(callable_)
diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py
index 95242f6..67c9b33 100644
--- a/tests/asyncio/test_grpc_helpers_async.py
+++ b/tests/asyncio/test_grpc_helpers_async.py
@@ -266,6 +266,28 @@
     wrap_unary_errors.assert_called_once_with(callable_)
 
 
+def test_grpc_async_stream():
+    """
+    GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call.
+    """
+    instance = grpc_helpers_async.GrpcAsyncStream[int]()
+    assert isinstance(instance, grpc.aio.Call)
+    # should implement __aiter__ and __anext__
+    assert hasattr(instance, "__aiter__")
+    it = instance.__aiter__()
+    assert hasattr(it, "__anext__")
+
+
+def test_awaitable_grpc_call():
+    """
+    AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call.
+    """
+    instance = grpc_helpers_async.AwaitableGrpcCall[int]()
+    assert isinstance(instance, grpc.aio.Call)
+    # should implement __await__
+    assert hasattr(instance, "__await__")
+
+
 @mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors")
 def test_wrap_errors_streaming(wrap_stream_errors):
     callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable)
diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py
index 4eccbca..58a6a32 100644
--- a/tests/unit/test_grpc_helpers.py
+++ b/tests/unit/test_grpc_helpers.py
@@ -195,6 +195,23 @@
         wrapped.trailing_metadata.assert_called_once_with()
 
 
+class TestGrpcStream(Test_StreamingResponseIterator):
+    @staticmethod
+    def _make_one(wrapped, **kw):
+        return grpc_helpers.GrpcStream(wrapped, **kw)
+
+    def test_grpc_stream_attributes(self):
+        """
+        Should be both a grpc.Call and an iterable
+        """
+        call = self._make_one(None)
+        assert isinstance(call, grpc.Call)
+        # should implement __iter__
+        assert hasattr(call, "__iter__")
+        it = call.__iter__()
+        assert hasattr(it, "__next__")
+
+
 def test_wrap_stream_okay():
     expected_responses = [1, 2, 3]
     callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))