[Dynamo] Log backward graph compilation metrics (#126629)

Fixes #125313

Compilation metric logs for the code example at #125313:
```
%s CompilationMetrics(compile_id='0/0', frame_key='1', co_name='forward', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=10, cache_size=0, accumulated_cache_size=0, guard_count=11, shape_env_guard_count=0, graph_op_count=1, graph_node_count=3, graph_input_count=1, start_time=1716247236.6165977, entire_frame_compile_time_s=7.926939964294434, backend_compile_time_s=7.887059926986694, inductor_compile_time_s=4.108498811721802, code_gen_time_s=3.97833514213562, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons={"'skip function graph_break in file /home/ybliang/local/pytorch/torch/_dynamo/decorators.py'"}, dynamo_time_before_restart_s=0.025330543518066406, has_guarded_code=True, is_fwd=True)
%s CompilationMetrics(compile_id='1/0', frame_key='2', co_name='torch_dynamo_resume_in_forward_at_12', co_filename='/data/users/ybliang/debug/debug2.py', co_firstlineno=12, cache_size=0, accumulated_cache_size=0, guard_count=10, shape_env_guard_count=0, graph_op_count=2, graph_node_count=5, graph_input_count=1, start_time=1716247244.544928, entire_frame_compile_time_s=0.10148310661315918, backend_compile_time_s=0.08753013610839844, inductor_compile_time_s=0.03691983222961426, code_gen_time_s=0.022417306900024414, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons=set(), dynamo_time_before_restart_s=0.0, has_guarded_code=True, is_fwd=True)
tensor([[-0.1622, -0.0000, -0.0000,  0.5643, -0.0000,  0.0000, -0.5087,  0.0914,
         -0.0000, -0.0421]], grad_fn=<CompiledFunctionBackward>)
%s CompilationMetrics(compile_id='1/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.026738643646240234, code_gen_time_s=0.016446352005004883, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False)
%s CompilationMetrics(compile_id='0/0', frame_key=None, co_name=None, co_filename=None, co_firstlineno=None, cache_size=None, accumulated_cache_size=None, guard_count=None, shape_env_guard_count=None, graph_op_count=None, graph_node_count=None, graph_input_count=None, start_time=None, entire_frame_compile_time_s=None, backend_compile_time_s=None, inductor_compile_time_s=0.14563536643981934, code_gen_time_s=0.08652091026306152, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=None, compliant_custom_ops=None, restart_reasons=None, dynamo_time_before_restart_s=None, has_guarded_code=None, is_fwd=False)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126629
Approved by: https://github.com/ezyang
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 6dcb84f..e779cce 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -876,6 +876,7 @@
                 dynamo_time_before_restart = time.time() - start_time
 
             metrics = CompilationMetrics(
+                str(compile_id),
                 frame_key,
                 code.co_name,
                 code.co_filename,
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 9d43ba5..2b42c8d 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -113,7 +113,9 @@
 compilation_time_metrics: Dict[str, List[float]] = {}
 
 # profiling compilation time by frame phase
-frame_phase_timing: Dict[str, Dict[str, float]] = {}
+frame_phase_timing: Dict[str, Dict[str, float]] = collections.defaultdict(
+    lambda: collections.defaultdict(float)
+)
 
 timer_counter = itertools.count()
 
@@ -185,6 +187,10 @@
     print(out)
 
 
+def _add_time_spent(key, phase_name, time_spent):
+    frame_phase_timing[key][phase_name] += time_spent
+
+
 # dynamo_timed API works as a function decorator
 # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
 # where the key is the functions name.
@@ -201,9 +207,12 @@
 # phase_names record an extra record into a separate compilation timing structure,
 # one keyed on frame+name rather than function.
 # The frame is incremented outside of this function, in def increment_frame() above.
+# `fwd_only` is used to identify if this phase or function is only called
+# during compiling fwd graphs, e.g, `entire_frame_compile` and `backend_compile`.
+# The other phases (`inductor_compile` and `code_gen`) are called for both fwd and bwd graphs.
 
 
-def dynamo_timed(original_function=None, phase_name=None):
+def dynamo_timed(original_function=None, phase_name=None, fwd_only=True):
     def dynamo_timed_inner(func):
         if config.cprofile:
             return func
@@ -213,19 +222,70 @@
             key = func.__qualname__
             if key not in compilation_time_metrics:
                 compilation_time_metrics[key] = []
-            with torch.profiler.record_function(f"{key} (dynamo_timed)"):
-                t0 = time.time()
-                r = func(*args, **kwargs)
-                time_spent = time.time() - t0
-            compilation_time_metrics[key].append(time_spent)
-            if phase_name:
-                frame_key = str(curr_frame)
-                if frame_key not in frame_phase_timing:
-                    frame_phase_timing[frame_key] = {}
-                if phase_name not in frame_phase_timing[frame_key]:
-                    frame_phase_timing[frame_key][phase_name] = time_spent
-                else:
-                    frame_phase_timing[frame_key][phase_name] += time_spent
+
+            fail_type: Optional[str] = None
+            fail_reason: Optional[str] = None
+            time_spent = float("-inf")
+            try:
+                with torch.profiler.record_function(f"{key} (dynamo_timed)"):
+                    t0 = time.time()
+                    r = func(*args, **kwargs)
+                    time_spent = time.time() - t0
+                compilation_time_metrics[key].append(time_spent)
+            except Exception as e:
+                fail_type = str(type(e))
+                fail_reason = str(e)
+                raise
+            finally:
+                # Only record backward compilation metrics if phase_name is not None!
+                if phase_name:
+                    frame_key = str(curr_frame)
+                    # fwd only compilation stages: entire_frame_compile, backend_compile.
+                    # use frame_key as time aggregation key.
+                    if fwd_only and fail_type is None:
+                        _add_time_spent(frame_key, phase_name, time_spent)
+                    else:
+                        # fwd + bwd compilation stages: inductor_compile, code_gen.
+                        # use frame_key as time aggregation key for fwd graphs;
+                        # use compile_id as time aggregation key for bwd graphs.
+                        if torch._guards.TracingContext.try_get() is not None:
+                            aot_graph_name = str(
+                                torch._guards.TracingContext.get().aot_graph_name
+                            )
+                            if (
+                                "forward" in aot_graph_name
+                                or "inference" in aot_graph_name
+                            ) and fail_type is None:
+                                _add_time_spent(frame_key, phase_name, time_spent)
+                            elif "backward" in aot_graph_name:
+                                compile_id = str(
+                                    torch._guards.CompileContext.current_compile_id()
+                                )
+                                if fail_type is None:
+                                    _add_time_spent(compile_id, phase_name, time_spent)
+
+                                # log backward compilation metrics at the end of `inductor_compile` of bwd graph,
+                                # one record for one bwd graph.
+                                if phase_name == "inductor_compile":
+                                    if fail_type is None:
+                                        inductor_compile_time = frame_phase_timing[
+                                            compile_id
+                                        ].get("inductor_compile", None)
+                                        code_gen_time = frame_phase_timing[
+                                            compile_id
+                                        ].get("code_gen", None)
+                                    else:
+                                        inductor_compile_time = None
+                                        code_gen_time = None
+                                    metrics = BwdCompilationMetrics(
+                                        compile_id,
+                                        inductor_compile_time,
+                                        code_gen_time,
+                                        fail_type,
+                                        fail_reason,
+                                    )
+                                    record_compilation_metrics(metrics)
+
             return r
 
         return time_wrapper
@@ -598,6 +658,7 @@
 
 @dataclasses.dataclass
 class CompilationMetrics:
+    compile_id: str
     frame_key: str
     co_name: str
     co_filename: str
@@ -628,26 +689,44 @@
     has_guarded_code: bool
 
 
+@dataclasses.dataclass
+class BwdCompilationMetrics:
+    compile_id: str
+    inductor_compile_time_s: Optional[float]
+    code_gen_time_s: Optional[float]
+    fail_type: Optional[str]
+    fail_reason: Optional[str]
+
+
 DEFAULT_COMPILATION_METRICS_LIMIT = 64
 
 
-_compilation_metrics: Deque[CompilationMetrics] = collections.deque(
-    maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
-)
+_compilation_metrics: Deque[
+    Union[CompilationMetrics, BwdCompilationMetrics]
+] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT)
 
 
-def record_compilation_metrics(compilation_metrics: CompilationMetrics):
+def record_compilation_metrics(
+    compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics]
+):
     global _compilation_metrics
     _compilation_metrics.append(compilation_metrics)
-    torch._logging.trace_structured(
-        "compilation_metrics",
-        lambda: {
-            k: list(v) if isinstance(v, set) else v
-            for k, v in dataclasses.asdict(compilation_metrics).items()
-        },
-    )
-    if config.log_compilation_metrics:
-        log_compilation_event(compilation_metrics)
+    if isinstance(compilation_metrics, CompilationMetrics):
+        name = "compilation_metrics"
+    else:
+        name = "bwd_compilation_metrics"
+    # Currently only record fwd compilation metrics, will add bwd compilation metrics
+    # after the internal Scuba logging changes finish.
+    if isinstance(compilation_metrics, CompilationMetrics):
+        torch._logging.trace_structured(
+            name,
+            lambda: {
+                k: list(v) if isinstance(v, set) else v
+                for k, v in dataclasses.asdict(compilation_metrics).items()
+            },
+        )
+        if config.log_compilation_metrics:
+            log_compilation_event(compilation_metrics)
 
 
 def set_compilation_metrics_limit(new_size: int) -> None:
@@ -663,7 +742,7 @@
     _compilation_metrics.clear()
 
 
-def get_compilation_metrics() -> List[CompilationMetrics]:
+def get_compilation_metrics() -> List[Union[CompilationMetrics, BwdCompilationMetrics]]:
     return list(_compilation_metrics)
 
 
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index a450f40..fd188eb 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -7,7 +7,6 @@
 """
 import collections
 import pprint
-import time
 from contextlib import nullcontext
 from dataclasses import dataclass, field
 from functools import wraps
@@ -24,7 +23,6 @@
     tracing,
     TracingContext,
 )
-from torch._logging import trace_structured
 
 from torch._prims_common import CUDARngStateHelper
 from torch._subclasses import FakeTensor
@@ -1801,41 +1799,9 @@
                         with tracing(saved_context), compile_context(
                             saved_compile_context
                         ), context(), track_graph_compiling(aot_config, "backward"):
-                            fail_type: Optional[str] = None
-                            fail_reason: Optional[str] = None
-                            start_time = time.time()
-                            try:
-                                CompiledFunction.compiled_bw = aot_config.bw_compiler(
-                                    bw_module, placeholder_list
-                                )
-                            except Exception as e:
-                                fail_type = str(type(e))
-                                fail_reason = str(e)
-                                if saved_compile_context is not None:
-                                    e.compile_id = saved_compile_context.compile_id  # type: ignore[attr-defined]
-                                raise
-                            finally:
-                                # TODO: Similar to CompilationMetrics, we would
-                                # like to report inductor_compile_time, but we
-                                # cannot conveniently do so because these are
-                                # keyed on utils.frame, and frame key is not
-                                # incremented on backwards compilations.  Maybe
-                                # should just bump the frame key here too?
-                                end_time = time.time()
-                                # TODO: Put this in scuba?  But CompilationMetrics
-                                # is kind of not a great match, because there's no
-                                # interaction with Dynamo, so a lot of Dynamo only
-                                # events don't exist anymore.  So we need a new
-                                # scuba table. Lazy lazy...
-                                trace_structured(
-                                    "aot_autograd_backward_compilation_metrics",
-                                    lambda: {
-                                        "start_time": start_time,
-                                        "elapsed_time": time.time() - start_time,
-                                        "fail_type": fail_type,
-                                        "fail_reason": fail_reason,
-                                    },
-                                )
+                            CompiledFunction.compiled_bw = aot_config.bw_compiler(
+                                bw_module, placeholder_list
+                            )
 
                     out = call_func_at_runtime_with_args(
                         CompiledFunction.compiled_bw,
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index 77b8925..60db8b7 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -408,7 +408,7 @@
 # the backward graph as well.
 @_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)
 @with_fresh_cache_if_config
-@dynamo_utils.dynamo_timed(phase_name="inductor_compile")
+@dynamo_utils.dynamo_timed(phase_name="inductor_compile", fwd_only=False)
 def compile_fx_inner(
     gm: torch.fx.GraphModule,
     example_inputs: List[torch.Tensor],
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 412caf5..7430180 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -1681,7 +1681,7 @@
             node_runtimes.append((node, node.get_estimated_runtime()))
         return total_bytes, node_counts, node_runtimes
 
-    @dynamo_timed(phase_name="code_gen")
+    @dynamo_timed(phase_name="code_gen", fwd_only=False)
     def compile_to_module(self):
         from .codecache import PyCodeCache
 
diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py
index 7d24be0..bc3a3d0 100644
--- a/torch/_inductor/runtime/runtime_utils.py
+++ b/torch/_inductor/runtime/runtime_utils.py
@@ -187,7 +187,7 @@
     dynamo_timed = torch._dynamo.utils.dynamo_timed
 except AttributeError:  # Compile workers only have a mock version of torch
 
-    def dynamo_timed(original_function=None, phase_name=None):
+    def dynamo_timed(original_function=None, phase_name=None, fwd_only=True):
         if original_function:
             return original_function
         return dynamo_timed