feat: Add a little bit of typing to google.api_core.retry (#453)
* ref(typing): add a little bit of typing to google.api_core.retry
* coverage
---------
Co-authored-by: Anthonios Partheniou <partheniou@google.com>
diff --git a/.coveragerc b/.coveragerc
index d097511..34417c3 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -11,3 +11,5 @@
def __repr__
# Ignore abstract methods
raise NotImplementedError
+ # Ignore coverage for code specific to static type checkers
+ TYPE_CHECKING
diff --git a/google/api_core/retry.py b/google/api_core/retry.py
index df1e65e..84b5d0f 100644
--- a/google/api_core/retry.py
+++ b/google/api_core/retry.py
@@ -54,13 +54,15 @@
"""
-from __future__ import unicode_literals
+from __future__ import annotations
import datetime
import functools
import logging
import random
+import sys
import time
+from typing import Any, Callable, TypeVar, TYPE_CHECKING
import requests.exceptions
@@ -68,6 +70,15 @@
from google.api_core import exceptions
from google.auth import exceptions as auth_exceptions
+if TYPE_CHECKING:
+ if sys.version_info >= (3, 10):
+ from typing import ParamSpec
+ else:
+ from typing_extensions import ParamSpec
+
+ _P = ParamSpec("_P")
+ _R = TypeVar("_R")
+
_LOGGER = logging.getLogger(__name__)
_DEFAULT_INITIAL_DELAY = 1.0 # seconds
_DEFAULT_MAXIMUM_DELAY = 60.0 # seconds
@@ -75,7 +86,9 @@
_DEFAULT_DEADLINE = 60.0 * 2.0 # seconds
-def if_exception_type(*exception_types):
+def if_exception_type(
+ *exception_types: type[BaseException],
+) -> Callable[[BaseException], bool]:
"""Creates a predicate to check if the exception is of a given type.
Args:
@@ -87,7 +100,7 @@
exception is of the given type(s).
"""
- def if_exception_type_predicate(exception):
+ def if_exception_type_predicate(exception: BaseException) -> bool:
"""Bound predicate for checking an exception type."""
return isinstance(exception, exception_types)
@@ -307,14 +320,14 @@
def __init__(
self,
- predicate=if_transient_error,
- initial=_DEFAULT_INITIAL_DELAY,
- maximum=_DEFAULT_MAXIMUM_DELAY,
- multiplier=_DEFAULT_DELAY_MULTIPLIER,
- timeout=_DEFAULT_DEADLINE,
- on_error=None,
- **kwargs
- ):
+ predicate: Callable[[BaseException], bool] = if_transient_error,
+ initial: float = _DEFAULT_INITIAL_DELAY,
+ maximum: float = _DEFAULT_MAXIMUM_DELAY,
+ multiplier: float = _DEFAULT_DELAY_MULTIPLIER,
+ timeout: float = _DEFAULT_DEADLINE,
+ on_error: Callable[[BaseException], Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
self._predicate = predicate
self._initial = initial
self._multiplier = multiplier
@@ -323,7 +336,11 @@
self._deadline = self._timeout
self._on_error = on_error
- def __call__(self, func, on_error=None):
+ def __call__(
+ self,
+ func: Callable[_P, _R],
+ on_error: Callable[[BaseException], Any] | None = None,
+ ) -> Callable[_P, _R]:
"""Wrap a callable with retry behavior.
Args:
@@ -340,7 +357,7 @@
on_error = self._on_error
@functools.wraps(func)
- def retry_wrapped_func(*args, **kwargs):
+ def retry_wrapped_func(*args: _P.args, **kwargs: _P.kwargs) -> _R:
"""A wrapper that calls target function with retry."""
target = functools.partial(func, *args, **kwargs)
sleep_generator = exponential_sleep_generator(