feat: add support for asynchronous rest streaming (#686)
* duplicating file to base
* restore original file
* duplicate file to async
* restore original file
* duplicate test file for async
* restore test file
* feat: add support for asynchronous rest streaming
* 🦉 Updates from OwlBot post-processor
See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md
* fix naming issue
* fix import module name
* pull auth feature branch
* revert setup file
* address PR comments
* 🦉 Updates from OwlBot post-processor
See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md
* run black
* address PR comments
* update nox coverage
* address PR comments
* fix nox session name in workflow
* use https for remote repo
* add context manager methods
* address PR comments
* update auth error versions
* update import error
---------
Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml
index 8adc535..5980f82 100644
--- a/.github/workflows/unittest.yml
+++ b/.github/workflows/unittest.yml
@@ -11,7 +11,7 @@
runs-on: ubuntu-latest
strategy:
matrix:
- option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps"]
+ option: ["", "_grpc_gcp", "_wo_grpc", "_with_prerelease_deps", "_with_auth_aio"]
python:
- "3.7"
- "3.8"
diff --git a/google/api_core/_rest_streaming_base.py b/google/api_core/_rest_streaming_base.py
new file mode 100644
index 0000000..3bc87a9
--- /dev/null
+++ b/google/api_core/_rest_streaming_base.py
@@ -0,0 +1,118 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helpers for server-side streaming in REST."""
+
+from collections import deque
+import string
+from typing import Deque, Union
+import types
+
+import proto
+import google.protobuf.message
+from google.protobuf.json_format import Parse
+
+
+class BaseResponseIterator:
+ """Base Iterator over REST API responses. This class should not be used directly.
+
+ Args:
+ response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
+ class expected to be returned from an API.
+
+ Raises:
+ ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
+ """
+
+ def __init__(
+ self,
+ response_message_cls: Union[proto.Message, google.protobuf.message.Message],
+ ):
+ self._response_message_cls = response_message_cls
+ # Contains a list of JSON responses ready to be sent to user.
+ self._ready_objs: Deque[str] = deque()
+ # Current JSON response being built.
+ self._obj = ""
+ # Keeps track of the nesting level within a JSON object.
+ self._level = 0
+ # Keeps track whether HTTP response is currently sending values
+ # inside of a string value.
+ self._in_string = False
+ # Whether an escape symbol "\" was encountered.
+ self._escape_next = False
+
+ self._grab = types.MethodType(self._create_grab(), self)
+
+ def _process_chunk(self, chunk: str):
+ if self._level == 0:
+ if chunk[0] != "[":
+ raise ValueError(
+ "Can only parse array of JSON objects, instead got %s" % chunk
+ )
+ for char in chunk:
+ if char == "{":
+ if self._level == 1:
+ # Level 1 corresponds to the outermost JSON object
+ # (i.e. the one we care about).
+ self._obj = ""
+ if not self._in_string:
+ self._level += 1
+ self._obj += char
+ elif char == "}":
+ self._obj += char
+ if not self._in_string:
+ self._level -= 1
+ if not self._in_string and self._level == 1:
+ self._ready_objs.append(self._obj)
+ elif char == '"':
+ # Helps to deal with an escaped quotes inside of a string.
+ if not self._escape_next:
+ self._in_string = not self._in_string
+ self._obj += char
+ elif char in string.whitespace:
+ if self._in_string:
+ self._obj += char
+ elif char == "[":
+ if self._level == 0:
+ self._level += 1
+ else:
+ self._obj += char
+ elif char == "]":
+ if self._level == 1:
+ self._level -= 1
+ else:
+ self._obj += char
+ else:
+ self._obj += char
+ self._escape_next = not self._escape_next if char == "\\" else False
+
+ def _create_grab(self):
+ if issubclass(self._response_message_cls, proto.Message):
+
+ def grab(this):
+ return this._response_message_cls.from_json(
+ this._ready_objs.popleft(), ignore_unknown_fields=True
+ )
+
+ return grab
+ elif issubclass(self._response_message_cls, google.protobuf.message.Message):
+
+ def grab(this):
+ return Parse(this._ready_objs.popleft(), this._response_message_cls())
+
+ return grab
+ else:
+ raise ValueError(
+ "Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
+ )
diff --git a/google/api_core/rest_streaming.py b/google/api_core/rest_streaming.py
index 88bcb31..84aa270 100644
--- a/google/api_core/rest_streaming.py
+++ b/google/api_core/rest_streaming.py
@@ -14,17 +14,15 @@
"""Helpers for server-side streaming in REST."""
-from collections import deque
-import string
-from typing import Deque, Union
+from typing import Union
import proto
import requests
import google.protobuf.message
-from google.protobuf.json_format import Parse
+from google.api_core._rest_streaming_base import BaseResponseIterator
-class ResponseIterator:
+class ResponseIterator(BaseResponseIterator):
"""Iterator over REST API responses.
Args:
@@ -33,7 +31,8 @@
class expected to be returned from an API.
Raises:
- ValueError: If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
+ ValueError:
+ - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
"""
def __init__(
@@ -42,68 +41,16 @@
response_message_cls: Union[proto.Message, google.protobuf.message.Message],
):
self._response = response
- self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
self._response_itr = self._response.iter_content(decode_unicode=True)
- # Contains a list of JSON responses ready to be sent to user.
- self._ready_objs: Deque[str] = deque()
- # Current JSON response being built.
- self._obj = ""
- # Keeps track of the nesting level within a JSON object.
- self._level = 0
- # Keeps track whether HTTP response is currently sending values
- # inside of a string value.
- self._in_string = False
- # Whether an escape symbol "\" was encountered.
- self._escape_next = False
+ super(ResponseIterator, self).__init__(
+ response_message_cls=response_message_cls
+ )
def cancel(self):
"""Cancel existing streaming operation."""
self._response.close()
- def _process_chunk(self, chunk: str):
- if self._level == 0:
- if chunk[0] != "[":
- raise ValueError(
- "Can only parse array of JSON objects, instead got %s" % chunk
- )
- for char in chunk:
- if char == "{":
- if self._level == 1:
- # Level 1 corresponds to the outermost JSON object
- # (i.e. the one we care about).
- self._obj = ""
- if not self._in_string:
- self._level += 1
- self._obj += char
- elif char == "}":
- self._obj += char
- if not self._in_string:
- self._level -= 1
- if not self._in_string and self._level == 1:
- self._ready_objs.append(self._obj)
- elif char == '"':
- # Helps to deal with an escaped quotes inside of a string.
- if not self._escape_next:
- self._in_string = not self._in_string
- self._obj += char
- elif char in string.whitespace:
- if self._in_string:
- self._obj += char
- elif char == "[":
- if self._level == 0:
- self._level += 1
- else:
- self._obj += char
- elif char == "]":
- if self._level == 1:
- self._level -= 1
- else:
- self._obj += char
- else:
- self._obj += char
- self._escape_next = not self._escape_next if char == "\\" else False
-
def __next__(self):
while not self._ready_objs:
try:
@@ -115,18 +62,5 @@
raise e
return self._grab()
- def _grab(self):
- # Add extra quotes to make json.loads happy.
- if issubclass(self._response_message_cls, proto.Message):
- return self._response_message_cls.from_json(
- self._ready_objs.popleft(), ignore_unknown_fields=True
- )
- elif issubclass(self._response_message_cls, google.protobuf.message.Message):
- return Parse(self._ready_objs.popleft(), self._response_message_cls())
- else:
- raise ValueError(
- "Response message class must be a subclass of proto.Message or google.protobuf.message.Message."
- )
-
def __iter__(self):
return self
diff --git a/google/api_core/rest_streaming_async.py b/google/api_core/rest_streaming_async.py
new file mode 100644
index 0000000..d1f996f
--- /dev/null
+++ b/google/api_core/rest_streaming_async.py
@@ -0,0 +1,83 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helpers for asynchronous server-side streaming in REST."""
+
+from typing import Union
+
+import proto
+
+try:
+ import google.auth.aio.transport
+except ImportError as e: # pragma: NO COVER
+ raise ImportError(
+ "google-auth>=2.35.0 is required to use asynchronous rest streaming."
+ ) from e
+
+import google.protobuf.message
+from google.api_core._rest_streaming_base import BaseResponseIterator
+
+
+class AsyncResponseIterator(BaseResponseIterator):
+ """Asynchronous Iterator over REST API responses.
+
+ Args:
+ response (google.auth.aio.transport.Response): An API response object.
+ response_message_cls (Union[proto.Message, google.protobuf.message.Message]): A response
+ class expected to be returned from an API.
+
+ Raises:
+ ValueError:
+ - If `response_message_cls` is not a subclass of `proto.Message` or `google.protobuf.message.Message`.
+ """
+
+ def __init__(
+ self,
+ response: google.auth.aio.transport.Response,
+ response_message_cls: Union[proto.Message, google.protobuf.message.Message],
+ ):
+ self._response = response
+ self._chunk_size = 1024
+ self._response_itr = self._response.content().__aiter__()
+ super(AsyncResponseIterator, self).__init__(
+ response_message_cls=response_message_cls
+ )
+
+ async def __aenter__(self):
+ return self
+
+ async def cancel(self):
+ """Cancel existing streaming operation."""
+ await self._response.close()
+
+ async def __anext__(self):
+ while not self._ready_objs:
+ try:
+ chunk = await self._response_itr.__anext__()
+ chunk = chunk.decode("utf-8")
+ self._process_chunk(chunk)
+ except StopAsyncIteration as e:
+ if self._level > 0:
+ raise ValueError("i Unfinished stream: %s" % self._obj)
+ raise e
+ except ValueError as e:
+ raise e
+ return self._grab()
+
+ def __aiter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ """Cancel existing async streaming operation."""
+ await self._response.close()
diff --git a/noxfile.py b/noxfile.py
index a15795e..3fc4a72 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -38,6 +38,7 @@
"unit",
"unit_grpc_gcp",
"unit_wo_grpc",
+ "unit_with_auth_aio",
"cover",
"pytype",
"mypy",
@@ -109,7 +110,7 @@
session.install(*other_deps)
-def default(session, install_grpc=True, prerelease=False):
+def default(session, install_grpc=True, prerelease=False, install_auth_aio=False):
"""Default unit test session.
This is intended to be run **without** an interpreter set, so
@@ -144,6 +145,11 @@
f"{constraints_dir}/constraints-{session.python}.txt",
)
+ if install_auth_aio:
+ session.install(
+ "google-auth @ git+https://git@github.com/googleapis/google-auth-library-python@8833ad6f92c3300d6645355994c7db2356bd30ad"
+ )
+
# Print out package versions of dependencies
session.run(
"python", "-c", "import google.protobuf; print(google.protobuf.__version__)"
@@ -229,6 +235,12 @@
default(session, install_grpc=False)
+@nox.session(python=PYTHON_VERSIONS)
+def unit_with_auth_aio(session):
+ """Run the unit test suite with google.auth.aio installed"""
+ default(session, install_auth_aio=True)
+
+
@nox.session(python=DEFAULT_PYTHON_VERSION)
def lint_setup_py(session):
"""Verify that setup.py is valid (including RST check)."""
diff --git a/tests/asyncio/test_rest_streaming_async.py b/tests/asyncio/test_rest_streaming_async.py
new file mode 100644
index 0000000..35820de
--- /dev/null
+++ b/tests/asyncio/test_rest_streaming_async.py
@@ -0,0 +1,378 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TODO: set random.seed explicitly in each test function.
+# See related issue: https://github.com/googleapis/python-api-core/issues/689.
+
+import pytest # noqa: I202
+import mock
+
+import datetime
+import logging
+import random
+import time
+from typing import List, AsyncIterator
+
+import proto
+
+try:
+ from google.auth.aio.transport import Response
+
+ AUTH_AIO_INSTALLED = True
+except ImportError:
+ AUTH_AIO_INSTALLED = False
+
+if not AUTH_AIO_INSTALLED: # pragma: NO COVER
+ pytest.skip(
+ "google-auth>=2.35.0 is required to use asynchronous rest streaming.",
+ allow_module_level=True,
+ )
+
+from google.api_core import rest_streaming_async
+from google.api import http_pb2
+from google.api import httpbody_pb2
+
+
+from ..helpers import Composer, Song, EchoResponse, parse_responses
+
+
+__protobuf__ = proto.module(package=__name__)
+SEED = int(time.time())
+logging.info(f"Starting async rest streaming tests with random seed: {SEED}")
+random.seed(SEED)
+
+
+async def mock_async_gen(data, chunk_size=1):
+ for i in range(0, len(data)): # pragma: NO COVER
+ chunk = data[i : i + chunk_size]
+ yield chunk.encode("utf-8")
+
+
+class ResponseMock(Response):
+ class _ResponseItr(AsyncIterator[bytes]):
+ def __init__(self, _response_bytes: bytes, random_split=False):
+ self._responses_bytes = _response_bytes
+ self._idx = 0
+ self._random_split = random_split
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ if self._idx >= len(self._responses_bytes):
+ raise StopAsyncIteration
+ if self._random_split:
+ n = random.randint(1, len(self._responses_bytes[self._idx :]))
+ else:
+ n = 1
+ x = self._responses_bytes[self._idx : self._idx + n]
+ self._idx += n
+ return x
+
+ def __init__(
+ self,
+ responses: List[proto.Message],
+ response_cls,
+ random_split=False,
+ ):
+ self._responses = responses
+ self._random_split = random_split
+ self._response_message_cls = response_cls
+
+ def _parse_responses(self):
+ return parse_responses(self._response_message_cls, self._responses)
+
+ @property
+ async def headers(self):
+ raise NotImplementedError()
+
+ @property
+ async def status_code(self):
+ raise NotImplementedError()
+
+ async def close(self):
+ raise NotImplementedError()
+
+ async def content(self, chunk_size=None):
+ itr = self._ResponseItr(
+ self._parse_responses(), random_split=self._random_split
+ )
+ async for chunk in itr:
+ yield chunk
+
+ async def read(self):
+ raise NotImplementedError()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "random_split,resp_message_is_proto_plus",
+ [(False, True), (False, False)],
+)
+async def test_next_simple(random_split, resp_message_is_proto_plus):
+ if resp_message_is_proto_plus:
+ response_type = EchoResponse
+ responses = [EchoResponse(content="hello world"), EchoResponse(content="yes")]
+ else:
+ response_type = httpbody_pb2.HttpBody
+ responses = [
+ httpbody_pb2.HttpBody(content_type="hello world"),
+ httpbody_pb2.HttpBody(content_type="yes"),
+ ]
+
+ resp = ResponseMock(
+ responses=responses, random_split=random_split, response_cls=response_type
+ )
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ idx = 0
+ async for response in itr:
+ assert response == responses[idx]
+ idx += 1
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "random_split,resp_message_is_proto_plus",
+ [
+ (True, True),
+ (False, True),
+ (True, False),
+ (False, False),
+ ],
+)
+async def test_next_nested(random_split, resp_message_is_proto_plus):
+ if resp_message_is_proto_plus:
+ response_type = Song
+ responses = [
+ Song(title="some song", composer=Composer(given_name="some name")),
+ Song(title="another song", date_added=datetime.datetime(2021, 12, 17)),
+ ]
+ else:
+ # Although `http_pb2.HttpRule`` is used in the response, any response message
+ # can be used which meets this criteria for the test of having a nested field.
+ response_type = http_pb2.HttpRule
+ responses = [
+ http_pb2.HttpRule(
+ selector="some selector",
+ custom=http_pb2.CustomHttpPattern(kind="some kind"),
+ ),
+ http_pb2.HttpRule(
+ selector="another selector",
+ custom=http_pb2.CustomHttpPattern(path="some path"),
+ ),
+ ]
+ resp = ResponseMock(
+ responses=responses, random_split=random_split, response_cls=response_type
+ )
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ idx = 0
+ async for response in itr:
+ assert response == responses[idx]
+ idx += 1
+ assert idx == len(responses)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "random_split,resp_message_is_proto_plus",
+ [
+ (True, True),
+ (False, True),
+ (True, False),
+ (False, False),
+ ],
+)
+async def test_next_stress(random_split, resp_message_is_proto_plus):
+ n = 50
+ if resp_message_is_proto_plus:
+ response_type = Song
+ responses = [
+ Song(title="title_%d" % i, composer=Composer(given_name="name_%d" % i))
+ for i in range(n)
+ ]
+ else:
+ response_type = http_pb2.HttpRule
+ responses = [
+ http_pb2.HttpRule(
+ selector="selector_%d" % i,
+ custom=http_pb2.CustomHttpPattern(path="path_%d" % i),
+ )
+ for i in range(n)
+ ]
+ resp = ResponseMock(
+ responses=responses, random_split=random_split, response_cls=response_type
+ )
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ idx = 0
+ async for response in itr:
+ assert response == responses[idx]
+ idx += 1
+ assert idx == n
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "random_split,resp_message_is_proto_plus",
+ [
+ (True, True),
+ (False, True),
+ (True, False),
+ (False, False),
+ ],
+)
+async def test_next_escaped_characters_in_string(
+ random_split, resp_message_is_proto_plus
+):
+ if resp_message_is_proto_plus:
+ response_type = Song
+ composer_with_relateds = Composer()
+ relateds = ["Artist A", "Artist B"]
+ composer_with_relateds.relateds = relateds
+
+ responses = [
+ Song(
+ title='ti"tle\nfoo\tbar{}', composer=Composer(given_name="name\n\n\n")
+ ),
+ Song(
+ title='{"this is weird": "totally"}',
+ composer=Composer(given_name="\\{}\\"),
+ ),
+ Song(title='\\{"key": ["value",]}\\', composer=composer_with_relateds),
+ ]
+ else:
+ response_type = http_pb2.Http
+ responses = [
+ http_pb2.Http(
+ rules=[
+ http_pb2.HttpRule(
+ selector='ti"tle\nfoo\tbar{}',
+ custom=http_pb2.CustomHttpPattern(kind="name\n\n\n"),
+ )
+ ]
+ ),
+ http_pb2.Http(
+ rules=[
+ http_pb2.HttpRule(
+ selector='{"this is weird": "totally"}',
+ custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
+ )
+ ]
+ ),
+ http_pb2.Http(
+ rules=[
+ http_pb2.HttpRule(
+ selector='\\{"key": ["value",]}\\',
+ custom=http_pb2.CustomHttpPattern(kind="\\{}\\"),
+ )
+ ]
+ ),
+ ]
+ resp = ResponseMock(
+ responses=responses, random_split=random_split, response_cls=response_type
+ )
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ idx = 0
+ async for response in itr:
+ assert response == responses[idx]
+ idx += 1
+ assert idx == len(responses)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+async def test_next_not_array(response_type):
+
+ data = '{"hello": 0}'
+ with mock.patch.object(
+ ResponseMock, "content", return_value=mock_async_gen(data)
+ ) as mock_method:
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ with pytest.raises(ValueError):
+ await itr.__anext__()
+ mock_method.assert_called_once()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+async def test_cancel(response_type):
+ with mock.patch.object(
+ ResponseMock, "close", new_callable=mock.AsyncMock
+ ) as mock_method:
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ await itr.cancel()
+ mock_method.assert_called_once()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+async def test_iterator_as_context_manager(response_type):
+ with mock.patch.object(
+ ResponseMock, "close", new_callable=mock.AsyncMock
+ ) as mock_method:
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ async with rest_streaming_async.AsyncResponseIterator(resp, response_type):
+ pass
+ mock_method.assert_called_once()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ "response_type,return_value",
+ [
+ (EchoResponse, bytes('[{"content": "hello"}, {', "utf-8")),
+ (httpbody_pb2.HttpBody, bytes('[{"content_type": "hello"}, {', "utf-8")),
+ ],
+)
+async def test_check_buffer(response_type, return_value):
+ with mock.patch.object(
+ ResponseMock,
+ "_parse_responses",
+ return_value=return_value,
+ ):
+ resp = ResponseMock(responses=[], response_cls=response_type)
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ with pytest.raises(ValueError):
+ await itr.__anext__()
+ await itr.__anext__()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("response_type", [EchoResponse, httpbody_pb2.HttpBody])
+async def test_next_html(response_type):
+
+ data = "<!DOCTYPE html><html></html>"
+ with mock.patch.object(
+ ResponseMock, "content", return_value=mock_async_gen(data)
+ ) as mock_method:
+ resp = ResponseMock(responses=[], response_cls=response_type)
+
+ itr = rest_streaming_async.AsyncResponseIterator(resp, response_type)
+ with pytest.raises(ValueError):
+ await itr.__anext__()
+ mock_method.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_invalid_response_class():
+ class SomeClass:
+ pass
+
+ resp = ResponseMock(responses=[], response_cls=SomeClass)
+ with pytest.raises(
+ ValueError,
+ match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
+ ):
+ rest_streaming_async.AsyncResponseIterator(resp, SomeClass)
diff --git a/tests/helpers.py b/tests/helpers.py
new file mode 100644
index 0000000..3429d51
--- /dev/null
+++ b/tests/helpers.py
@@ -0,0 +1,71 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helpers for tests"""
+
+import logging
+from typing import List
+
+import proto
+
+from google.protobuf import duration_pb2
+from google.protobuf import timestamp_pb2
+from google.protobuf.json_format import MessageToJson
+
+
+class Genre(proto.Enum):
+ GENRE_UNSPECIFIED = 0
+ CLASSICAL = 1
+ JAZZ = 2
+ ROCK = 3
+
+
+class Composer(proto.Message):
+ given_name = proto.Field(proto.STRING, number=1)
+ family_name = proto.Field(proto.STRING, number=2)
+ relateds = proto.RepeatedField(proto.STRING, number=3)
+ indices = proto.MapField(proto.STRING, proto.STRING, number=4)
+
+
+class Song(proto.Message):
+ composer = proto.Field(Composer, number=1)
+ title = proto.Field(proto.STRING, number=2)
+ lyrics = proto.Field(proto.STRING, number=3)
+ year = proto.Field(proto.INT32, number=4)
+ genre = proto.Field(Genre, number=5)
+ is_five_mins_longer = proto.Field(proto.BOOL, number=6)
+ score = proto.Field(proto.DOUBLE, number=7)
+ likes = proto.Field(proto.INT64, number=8)
+ duration = proto.Field(duration_pb2.Duration, number=9)
+ date_added = proto.Field(timestamp_pb2.Timestamp, number=10)
+
+
+class EchoResponse(proto.Message):
+ content = proto.Field(proto.STRING, number=1)
+
+
+def parse_responses(response_message_cls, all_responses: List[proto.Message]) -> bytes:
+ # json.dumps returns a string surrounded with quotes that need to be stripped
+ # in order to be an actual JSON.
+ json_responses = [
+ (
+ response_message_cls.to_json(response).strip('"')
+ if issubclass(response_message_cls, proto.Message)
+ else MessageToJson(response).strip('"')
+ )
+ for response in all_responses
+ ]
+ logging.info(f"Sending JSON stream: {json_responses}")
+ ret_val = "[{}]".format(",".join(json_responses))
+ return bytes(ret_val, "utf-8")
diff --git a/tests/unit/test_rest_streaming.py b/tests/unit/test_rest_streaming.py
index 0f2b3b3..0f998df 100644
--- a/tests/unit/test_rest_streaming.py
+++ b/tests/unit/test_rest_streaming.py
@@ -26,48 +26,16 @@
from google.api_core import rest_streaming
from google.api import http_pb2
from google.api import httpbody_pb2
-from google.protobuf import duration_pb2
-from google.protobuf import timestamp_pb2
-from google.protobuf.json_format import MessageToJson
+
+from ..helpers import Composer, Song, EchoResponse, parse_responses
__protobuf__ = proto.module(package=__name__)
SEED = int(time.time())
-logging.info(f"Starting rest streaming tests with random seed: {SEED}")
+logging.info(f"Starting sync rest streaming tests with random seed: {SEED}")
random.seed(SEED)
-class Genre(proto.Enum):
- GENRE_UNSPECIFIED = 0
- CLASSICAL = 1
- JAZZ = 2
- ROCK = 3
-
-
-class Composer(proto.Message):
- given_name = proto.Field(proto.STRING, number=1)
- family_name = proto.Field(proto.STRING, number=2)
- relateds = proto.RepeatedField(proto.STRING, number=3)
- indices = proto.MapField(proto.STRING, proto.STRING, number=4)
-
-
-class Song(proto.Message):
- composer = proto.Field(Composer, number=1)
- title = proto.Field(proto.STRING, number=2)
- lyrics = proto.Field(proto.STRING, number=3)
- year = proto.Field(proto.INT32, number=4)
- genre = proto.Field(Genre, number=5)
- is_five_mins_longer = proto.Field(proto.BOOL, number=6)
- score = proto.Field(proto.DOUBLE, number=7)
- likes = proto.Field(proto.INT64, number=8)
- duration = proto.Field(duration_pb2.Duration, number=9)
- date_added = proto.Field(timestamp_pb2.Timestamp, number=10)
-
-
-class EchoResponse(proto.Message):
- content = proto.Field(proto.STRING, number=1)
-
-
class ResponseMock(requests.Response):
class _ResponseItr:
def __init__(self, _response_bytes: bytes, random_split=False):
@@ -97,27 +65,15 @@
self._random_split = random_split
self._response_message_cls = response_cls
- def _parse_responses(self, responses: List[proto.Message]) -> bytes:
- # json.dumps returns a string surrounded with quotes that need to be stripped
- # in order to be an actual JSON.
- json_responses = [
- (
- self._response_message_cls.to_json(r).strip('"')
- if issubclass(self._response_message_cls, proto.Message)
- else MessageToJson(r).strip('"')
- )
- for r in responses
- ]
- logging.info(f"Sending JSON stream: {json_responses}")
- ret_val = "[{}]".format(",".join(json_responses))
- return bytes(ret_val, "utf-8")
+ def _parse_responses(self):
+ return parse_responses(self._response_message_cls, self._responses)
def close(self):
raise NotImplementedError()
def iter_content(self, *args, **kwargs):
return self._ResponseItr(
- self._parse_responses(self._responses),
+ self._parse_responses(),
random_split=self._random_split,
)
@@ -333,9 +289,8 @@
pass
resp = ResponseMock(responses=[], response_cls=SomeClass)
- response_iterator = rest_streaming.ResponseIterator(resp, SomeClass)
with pytest.raises(
ValueError,
match="Response message class must be a subclass of proto.Message or google.protobuf.message.Message",
):
- response_iterator._grab()
+ rest_streaming.ResponseIterator(resp, SomeClass)