Added inference to context when only compiling forwards (#83783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83783
Approved by: https://github.com/pyjhzwh, https://github.com/jansel
diff --git a/functorch/functorch/_src/aot_autograd.py b/functorch/functorch/_src/aot_autograd.py
index 743ac58..97a8691 100644
--- a/functorch/functorch/_src/aot_autograd.py
+++ b/functorch/functorch/_src/aot_autograd.py
@@ -184,7 +184,8 @@
return aten.view(x, shape)
-graph_being_compiled: str = None
+# This is a list since looking forward, we can have this arbitrarily nested.
+graph_being_compiled: List[str] = []
nth_graph: int = 0
model_name: str = "model"
@@ -194,23 +195,30 @@
model_name = name
-def get_graph_being_compiled() -> str:
+def get_aot_compilation_context() -> Tuple[List[str], str, int]:
+ return list(graph_being_compiled), model_name, nth_graph
+
+
+def get_aot_graph_name() -> str:
"""
Returns the name of the graph being compiled.
"""
global model_name, graph_being_compiled, nth_graph
- return f"{model_name}_{graph_being_compiled}_{nth_graph}"
+ return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}"
+
+
+get_graph_being_compiled = get_aot_graph_name
@contextmanager
def track_graph_compiling(graph_name, increment_index=False):
global graph_being_compiled
- graph_being_compiled = graph_name
+ graph_being_compiled = [graph_name]
yield
if increment_index:
global nth_graph
nth_graph += 1
- graph_being_compiled = None
+ graph_being_compiled = []
def make_boxed_func(f):
@@ -264,7 +272,7 @@
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
- with track_graph_compiling("forward"):
+ with track_graph_compiling("inference"):
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
@wraps(compiled_fw)
diff --git a/functorch/functorch/compile/__init__.py b/functorch/functorch/compile/__init__.py
index 1568d56..99e0456 100644
--- a/functorch/functorch/compile/__init__.py
+++ b/functorch/functorch/compile/__init__.py
@@ -10,6 +10,8 @@
clear_compile_cache,
aot_module_simplified,
get_graph_being_compiled,
+ get_aot_graph_name,
+ get_aot_compilation_context,
make_boxed_func,
make_boxed_compiler
)
diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py
index 5deeac1..e1d4b3c 100644
--- a/functorch/test/test_pythonkey.py
+++ b/functorch/test/test_pythonkey.py
@@ -23,8 +23,10 @@
from functorch._src.aot_autograd import aot_module_simplified
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
- min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
- num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, clear_compile_cache
+ min_cut_rematerialization_partition, aot_function, aot_module,
+ decomposition_table, nop,
+ num_of_recompilations, default_partition, default_decompositions,
+ memory_efficient_fusion, clear_compile_cache, get_aot_compilation_context
)
from torch.testing._internal.common_device_type import ops
@@ -330,6 +332,22 @@
inp = [torch.randn(5, requires_grad=True) for _ in range(3)]
f(*inp).sum().backward()
+ def test_compilation_context(self):
+ def f(x):
+ return x.sin().sin()
+ count = []
+
+ def compiler(fx_g, _):
+ context = get_aot_compilation_context()
+ count.append((context[0], len(fx_g.graph.nodes)))
+ return fx_g
+
+ f = aot_function(f, compiler)
+ out = f(torch.randn(5, requires_grad=True))
+ f(torch.randn(5))
+ out.sum().backward()
+ self.assertEqual(count, [(['forward'], 4), (['inference'], 4), (['backward'], 8)])
+
class TestEagerFusionOpInfo(AOTTestCase):