[Dynamo] Log backward graph compilation metrics (#126629)
Fixes #125313
Compilation metric logs for the code example at #125313:
```
%s CompilationMetrics(compile_id='0/0', frame_key='1', co_name='forward', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=10, cache_size=0, accumulated_cache_size=0, guard_count=11, shape_env_guard_count=0, graph_op_count=1, graph_node_count=3, graph_input_count=1, start_time=1716247236.6165977, entire_frame_compile_time_s=7.926939964294434, backend_compile_time_s=7.887059926986694, inductor_compile_time_s=4.108498811721802, code_gen_time_s=3.97833514213562, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons={"'skip function graph_break in file /home/ybliang/local/pytorch/torch/_dynamo/decorators.py'"}, dynamo_time_before_restart_s=0.025330543518066406, has_guarded_code=True, is_fwd=True)
%s CompilationMetrics(compile_id='1/0', frame_key='2', co_name='torch_dynamo_resume_in_forward_at_12', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=12, cache_size=0, accumulated_cache_size=0, guard_count=10, shape_env_guard_count=0, graph_op_count=2, graph_node_count=5, graph_input_count=1, start_time=1716247244.544928, entire_frame_compile_time_s=0.10148310661315918, backend_compile_time_s=0.08753013610839844, inductor_compile_time_s=0.03691983222961426, code_gen_time_s=0.022417306900024414, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons=set(), dynamo_time_before_restart_s=0.0, has_guarded_code=True, is_fwd=True)
tensor([[-0.1622, -0.0000, -0.0000, 0.5643, -0.0000, 0.0000, -0.5087, 0.0914,
-0.0000, -0.0421]], grad_fn=<CompiledFunctionBackward>)
%s CompilationMetrics(compile_id='1/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.026738643646240234, code_gen_time_s=0.016446352005004883, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False)
%s CompilationMetrics(compile_id='0/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.14563536643981934, code_gen_time_s=0.08652091026306152, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126629
Approved by: https://github.com/ezyang
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 6dcb84f..e779cce 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -876,6 +876,7 @@
dynamo_time_before_restart = time.time() - start_time
metrics = CompilationMetrics(
+ str(compile_id),
frame_key,
code.co_name,
code.co_filename,
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 9d43ba5..2b42c8d 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -113,7 +113,9 @@
compilation_time_metrics: Dict[str, List[float]] = {}
# profiling compilation time by frame phase
-frame_phase_timing: Dict[str, Dict[str, float]] = {}
+frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
+ lambda: collections.defaultdict(float)
+)
timer_counter = itertools.count()
@@ -185,6 +187,10 @@
print(out)
+def _add_time_spent(key, phase_name, time_spent):
+ frame_phase_timing[key][phase_name] += time_spent
+
+
# dynamo_timed API works as a function decorator
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
@@ -201,9 +207,12 @@
# phase_names record an extra record into a separate compilation timing structure,
# one keyed on frame+name rather than function.
# The frame is incremented outside of this function, in def increment_frame() above.
+# `fwd_only` is used to identify if this phase or function is only called
+# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
+# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
-def dynamo_timed(original_function=None, phase_name=None):
+def dynamo_timed(original_function=None, phase_name=None, fwd_only=True):
def dynamo_timed_inner(func):
if config.cprofile:
return func
@@ -213,19 +222,70 @@
key = func.__qualname__
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []
- with torch.profiler.record_function(f"{key} (dynamo_timed)"):
- t0 = time.time()
- r = func(*args, **kwargs)
- time_spent = time.time() - t0
- compilation_time_metrics[key].append(time_spent)
- if phase_name:
- frame_key = str(curr_frame)
- if frame_key not in frame_phase_timing:
- frame_phase_timing[frame_key] = {}
- if phase_name not in frame_phase_timing[frame_key]:
- frame_phase_timing[frame_key][phase_name] = time_spent
- else:
- frame_phase_timing[frame_key][phase_name] += time_spent
+
+ fail_type: Optional[str] = None
+ fail_reason: Optional[str] = None
+ time_spent = float("-inf")
+ try:
+ with torch.profiler.record_function(f"{key} (dynamo_timed)"):
+ t0 = time.time()
+ r = func(*args, **kwargs)
+ time_spent = time.time() - t0
+ compilation_time_metrics[key].append(time_spent)
+ except Exception as e:
+ fail_type = str(type(e))
+ fail_reason = str(e)
+ raise
+ finally:
+ # Only record backward compilation metrics if phase_name is not None!
+ if phase_name:
+ frame_key = str(curr_frame)
+ # fwd only compilation stages: entire_frame_compile, backend_compile.
+ # use frame_key as time aggregation key.
+ if fwd_only and fail_type is None:
+ _add_time_spent(frame_key, phase_name, time_spent)
+ else:
+ # fwd + bwd compilation stages: inductor_compile, code_gen.
+ # use frame_key as time aggregation key for fwd graphs;
+ # use compile_id as time aggregation key for bwd graphs.
+ if torch._guards.TracingContext.try_get() is not None:
+ aot_graph_name = str(
+ torch._guards.TracingContext.get().aot_graph_name
+ )
+ if (
+ "forward" in aot_graph_name
+ or "inference" in aot_graph_name
+ ) and fail_type is None:
+ _add_time_spent(frame_key, phase_name, time_spent)
+ elif "backward" in aot_graph_name:
+ compile_id = str(
+ torch._guards.CompileContext.current_compile_id()
+ )
+ if fail_type is None:
+ _add_time_spent(compile_id, phase_name, time_spent)
+
+ # log backward compilation metrics at the end of `inductor_compile` of bwd graph,
+ # one record for one bwd graph.
+ if phase_name == "inductor_compile":
+ if fail_type is None:
+ inductor_compile_time = frame_phase_timing[
+ compile_id
+ ].get("inductor_compile", None)
+ code_gen_time = frame_phase_timing[
+ compile_id
+ ].get("code_gen", None)
+ else:
+ inductor_compile_time = None
+ code_gen_time = None
+ metrics = BwdCompilationMetrics(
+ compile_id,
+ inductor_compile_time,
+ code_gen_time,
+ fail_type,
+ fail_reason,
+ )
+ record_compilation_metrics(metrics)
+
return r
return time_wrapper
@@ -598,6 +658,7 @@
@dataclasses.dataclass
class CompilationMetrics:
+ compile_id: str
frame_key: str
co_name: str
co_filename: str
@@ -628,26 +689,44 @@
has_guarded_code: bool
+@dataclasses.dataclass
+class BwdCompilationMetrics:
+ compile_id: str
+ inductor_compile_time_s: Optional[float]
+ code_gen_time_s: Optional[float]
+ fail_type: Optional[str]
+ fail_reason: Optional[str]
+
+
DEFAULT_COMPILATION_METRICS_LIMIT = 64
-_compilation_metrics: Deque[CompilationMetrics] = collections.deque(
- maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
-)
+_compilation_metrics: Deque[
+ Union[CompilationMetrics, BwdCompilationMetrics]
+] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT)
-def record_compilation_metrics(compilation_metrics: CompilationMetrics):
+def record_compilation_metrics(
+ compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics]
+):
global _compilation_metrics
_compilation_metrics.append(compilation_metrics)
- torch._logging.trace_structured(
- "compilation_metrics",
- lambda: {
- k: list(v) if isinstance(v, set) else v
- for k, v in dataclasses.asdict(compilation_metrics).items()
- },
- )
- if config.log_compilation_metrics:
- log_compilation_event(compilation_metrics)
+ if isinstance(compilation_metrics, CompilationMetrics):
+ name = "compilation_metrics"
+ else:
+ name = "bwd_compilation_metrics"
+ # Currently only record fwd compilation metrics, will add bwd compilation metrics
+ # after the internal Scuba logging changes finish.
+ if isinstance(compilation_metrics, CompilationMetrics):
+ torch._logging.trace_structured(
+ name,
+ lambda: {
+ k: list(v) if isinstance(v, set) else v
+ for k, v in dataclasses.asdict(compilation_metrics).items()
+ },
+ )
+ if config.log_compilation_metrics:
+ log_compilation_event(compilation_metrics)
def set_compilation_metrics_limit(new_size: int) -> None:
@@ -663,7 +742,7 @@
_compilation_metrics.clear()
-def get_compilation_metrics() -> List[CompilationMetrics]:
+def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]:
return list(_compilation_metrics)
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index a450f40..fd188eb 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -7,7 +7,6 @@
"""
import collections
import pprint
-import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
@@ -24,7 +23,6 @@
tracing,
TracingContext,
)
-from torch._logging import trace_structured
from torch._prims_common import CUDARngStateHelper
from torch._subclasses import FakeTensor
@@ -1801,41 +1799,9 @@
with tracing(saved_context), compile_context(
saved_compile_context
), context(), track_graph_compiling(aot_config, "backward"):
- fail_type: Optional[str] = None
- fail_reason: Optional[str] = None
- start_time = time.time()
- try:
- CompiledFunction.compiled_bw = aot_config.bw_compiler(
- bw_module, placeholder_list
- )
- except Exception as e:
- fail_type = str(type(e))
- fail_reason = str(e)
- if saved_compile_context is not None:
- e.compile_id = saved_compile_context.compile_id # type: ignore[attr-defined]
- raise
- finally:
- # TODO: Similar to CompilationMetrics, we would
- # like to report inductor_compile_time, but we
- # cannot conveniently do so because these are
- # keyed on utils.frame, and frame key is not
- # incremented on backwards compilations. Maybe
- # should just bump the frame key here too?
- end_time = time.time()
- # TODO: Put this in scuba? But CompilationMetrics
- # is kind of not a great match, because there's no
- # interaction with Dynamo, so a lot of Dynamo only
- # events don't exist anymore. So we need a new
- # scuba table. Lazy lazy...
- trace_structured(
- "aot_autograd_backward_compilation_metrics",
- lambda: {
- "start_time": start_time,
- "elapsed_time": time.time() - start_time,
- "fail_type": fail_type,
- "fail_reason": fail_reason,
- },
- )
+ CompiledFunction.compiled_bw = aot_config.bw_compiler(
+ bw_module, placeholder_list
+ )
out = call_func_at_runtime_with_args(
CompiledFunction.compiled_bw,
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 77b8925..60db8b7 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -408,7 +408,7 @@
# the backward graph as well.
@_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
@with_fresh_cache_if_config
-@dynamo_utils.dynamo_timed(phase_name="inductor_compile")
+@dynamo_utils.dynamo_timed(phase_name="inductor_compile", fwd_only=False)
def compile_fx_inner(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 412caf5..7430180 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -1681,7 +1681,7 @@
node_runtimes.append((node, node.get_estimated_runtime()))
return total_bytes, node_counts, node_runtimes
- @dynamo_timed(phase_name="code_gen")
+ @dynamo_timed(phase_name="code_gen", fwd_only=False)
def compile_to_module(self):
from .codecache import PyCodeCache
diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py
index 7d24be0..bc3a3d0 100644
--- a/torch/_inductor/runtime/runtime_utils.py
+++ b/torch/_inductor/runtime/runtime_utils.py
@@ -187,7 +187,7 @@
dynamo_timed = torch._dynamo.utils.dynamo_timed
except AttributeError: # Compile workers only have a mock version of torch
- def dynamo_timed(original_function=None, phase_name=None):
+ def dynamo_timed(original_function=None, phase_name=None, fwd_only=True):
if original_function:
return original_function
return dynamo_timed