Added inference to context when only compiling forwards (#83783)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83783
Approved by: https://github.com/pyjhzwh, https://github.com/jansel
diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py
index 743ac58..97a8691 100644
--- a/functorch/functorch/_src/aot_autograd.py
+++ b/functorch/functorch/_src/aot_autograd.py
@@ -184,7 +184,8 @@
     return aten.view(x, shape)
 
 
-graph_being_compiled: str = None
+# This is a list since looking forward, we can have this arbitrarily nested.
+graph_being_compiled: List[str] = []
 nth_graph: int = 0
 model_name: str = "model"
 
@@ -194,23 +195,30 @@
     model_name = name
 
 
-def get_graph_being_compiled() -> str:
+def get_aot_compilation_context() -> Tuple[List[str], str, int]:
+    return list(graph_being_compiled), model_name, nth_graph
+
+
+def get_aot_graph_name() -> str:
     """
     Returns the name of the graph being compiled.
     """
     global model_name, graph_being_compiled, nth_graph
-    return f"{model_name}_{graph_being_compiled}_{nth_graph}"
+    return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}"
+
+
+get_graph_being_compiled = get_aot_graph_name
 
 
 @contextmanager
 def track_graph_compiling(graph_name, increment_index=False):
     global graph_being_compiled
-    graph_being_compiled = graph_name
+    graph_being_compiled = [graph_name]
     yield
     if increment_index:
         global nth_graph
         nth_graph += 1
-    graph_being_compiled = None
+    graph_being_compiled = []
 
 
 def make_boxed_func(f):
@@ -264,7 +272,7 @@
 
 def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
     fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
-    with track_graph_compiling("forward"):
+    with track_graph_compiling("inference"):
         compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
 
     @wraps(compiled_fw)
diff --git a/functorch/functorch/compile/__init__.py b/functorch/functorch/compile/__init__.py
index 1568d56..99e0456 100644
--- a/functorch/functorch/compile/__init__.py
+++ b/functorch/functorch/compile/__init__.py
@@ -10,6 +10,8 @@
     clear_compile_cache,
     aot_module_simplified,
     get_graph_being_compiled,
+    get_aot_graph_name,
+    get_aot_compilation_context,
     make_boxed_func,
     make_boxed_compiler
 )
diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py
index 5deeac1..e1d4b3c 100644
--- a/functorch/test/test_pythonkey.py
+++ b/functorch/test/test_pythonkey.py
@@ -23,8 +23,10 @@
 from functorch._src.aot_autograd import aot_module_simplified
 from functorch.compile import (
     nnc_jit, compiled_function, compiled_module,
-    min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
-    num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, clear_compile_cache
+    min_cut_rematerialization_partition, aot_function, aot_module,
+    decomposition_table, nop,
+    num_of_recompilations, default_partition, default_decompositions,
+    memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context
 )
 
 from torch.testing._internal.common_device_type import ops
@@ -330,6 +332,22 @@
         inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
         f(*inp).sum().backward()
 
+    def test_compilation_context(self):
+        def f(x):
+            return x.sin().sin()
+        count = []
+
+        def compiler(fx_g, _):
+            context = get_aot_compilation_context()
+            count.append((context[0], len(fx_g.graph.nodes)))
+            return fx_g
+
+        f = aot_function(f, compiler)
+        out = f(torch.randn(5, requires_grad=True))
+        f(torch.randn(5))
+        out.sum().backward()
+        self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
+
 
 
 class TestEagerFusionOpInfo(AOTTestCase):