Updating Types in torch/_dynamo/utils.py (#131001)
Adds some type annotations to the torch/_dynamo/utils.py file.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131001
Approved by: https://github.com/aorenste
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index ebcbb02..73e9daa4 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -1331,7 +1331,7 @@
if isinstance(compiled_fn, _LazyGraphModule) or (
isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
- and compiled_fn.__name__ == "_lazy_forward"
+ and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
):
# Since dynamo will run the forward method for the GraphModule shortly
# anyways, it does not hurt to do the real recompilation here if
@@ -1341,7 +1341,7 @@
lazy_gm = (
compiled_fn
if isinstance(compiled_fn, _LazyGraphModule)
- else compiled_fn.__self__
+ else compiled_fn.__self__ # type: ignore[attr-defined]
)
_LazyGraphModule.force_recompile(lazy_gm)
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 987d531..b6295ac 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -50,11 +50,12 @@
Union,
ValuesView,
)
-from typing_extensions import TypeGuard
+from typing_extensions import Literal, ParamSpec, TypeGuard
from ..utils.hooks import RemovableHandle
T = TypeVar("T")
+_P = ParamSpec("_P")
try:
import numpy as np
@@ -131,7 +132,10 @@
timer_counter = itertools.count()
-def tabulate(rows, headers):
+def tabulate(
+ rows: Union[List[Tuple[str, object]], List[List[object]]],
+ headers: Union[Tuple[str, ...], List[str]],
+) -> str:
try:
import tabulate
@@ -146,13 +150,13 @@
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
-def increment_frame():
+def increment_frame() -> None:
global curr_frame
curr_frame = curr_frame + 1
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
-def reset_frame_count():
+def reset_frame_count() -> None:
global curr_frame
frame_phase_timing.clear()
compilation_time_metrics.clear()
@@ -162,14 +166,14 @@
op_count = 0
-def increment_op_count(cnt):
+def increment_op_count(cnt: int) -> None:
global op_count
op_count += cnt
# Calculate total time spent so far for each phase
# For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
-def calculate_time_spent():
+def calculate_time_spent() -> Dict[str, float]:
total_wall_time = 0.0
total_by_key = {}
for timings in frame_phase_timing.values():
@@ -194,7 +198,7 @@
# TIMING:
# entire_frame_compile:8.574629999999999
# backend_compile:5.26806
-def print_time_report():
+def print_time_report() -> None:
total_by_key = calculate_time_spent()
out = "TIMING:"
@@ -204,7 +208,7 @@
print(out)
-def _add_time_spent(key, phase_name, time_spent):
+def _add_time_spent(key: str, phase_name: str, time_spent: float) -> None:
frame_phase_timing[key][phase_name] += time_spent
@@ -229,10 +233,32 @@
# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
-def dynamo_timed(original_function=None, phase_name=None, fwd_only=True):
- def dynamo_timed_inner(func):
+@overload
+def dynamo_timed(
+ original_function: Callable[_P, T],
+ phase_name: Optional[str] = None,
+ fwd_only: bool = True,
+) -> Callable[_P, T]:
+ ...
+
+
+@overload
+def dynamo_timed(
+ original_function: Literal[None] = None,
+ phase_name: Optional[str] = None,
+ fwd_only: bool = True,
+) -> Callable[[Callable[_P, T]], Callable[_P, T]]:
+ ...
+
+
+def dynamo_timed(
+ original_function: Optional[Callable[_P, T]] = None,
+ phase_name: Optional[str] = None,
+ fwd_only: bool = True,
+):
+ def dynamo_timed_inner(func: Callable[_P, T]) -> Callable[_P, T]:
@wraps(func)
- def time_wrapper(*args, **kwargs):
+ def time_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> T:
key = func.__qualname__
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []
@@ -309,7 +335,19 @@
return dynamo_timed_inner
-def compile_times(repr="str", aggregate=False):
+@overload
+def compile_times(repr: Literal["str"], aggregate: bool = False) -> str:
+ ...
+
+
+@overload
+def compile_times(
+ repr: Literal["csv"], aggregate: bool = False
+) -> Tuple[List[str], List[object]]:
+ ...
+
+
+def compile_times(repr="str", aggregate: bool = False):
"""
Get metrics about torchdynamo frontend/backend compilation times.
@@ -343,10 +381,11 @@
]
headers = list(compilation_time_metrics.keys())
return headers, values
+ return None
@atexit.register
-def dump_compile_times():
+def dump_compile_times() -> None:
log.info(compile_times(repr="str", aggregate=True))
@@ -365,14 +404,14 @@
class DuplicateWarningChecker:
- def __init__(self, maxsize=4096):
+ def __init__(self, maxsize: int = 4096) -> None:
self.maxsize = maxsize
self.reset()
def reset(self):
self.set = collections.OrderedDict()
- def add(self, key):
+ def add(self, key: Union[str, Tuple[object, object]]) -> bool:
if key in self.set:
self.set.move_to_end(key, last=True)
if not config.verbose:
@@ -396,7 +435,7 @@
return contextlib.ExitStack()
-def reset_graph_break_dup_checker():
+def reset_graph_break_dup_checker() -> None:
graph_break_dup_warning_checker.reset()
@@ -425,12 +464,12 @@
return exitstack
-def gen_record_file_name(exc, code):
+def gen_record_file_name(exc, code) -> str:
return f"{get_debug_dir()}/error_recordings/\
{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
-def write_record_to_file(filename, exec_record):
+def write_record_to_file(filename: str, exec_record) -> None:
try:
if os.path.exists(filename):
log.warning(
@@ -444,7 +483,7 @@
log.exception("Unable to write execution record %s", filename)
-def count_calls(g: fx.Graph):
+def count_calls(g: fx.Graph) -> int:
c = 0
for n in g.nodes:
if "call" in n.op:
diff --git a/torch/serialization.py b/torch/serialization.py
index dd73d1d..02f8b50 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -794,7 +794,7 @@
import torch.nn as nn
serialized_container_types = {}
- serialized_storages = {}
+ serialized_storages: Dict[str, Tuple[torch.UntypedStorage, torch.dtype]] = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each