[AOTInductor] Remove call to aot_autograd when receiving ExportedProgram (#105977)
https://github.com/pytorch/pytorch/issues/105555
Existing flow first exports and then calls torch._inductor.aot_compile. However, export calls aot_autograd with the core aten decomposition table, and then torch._inductor.aot_compile calls aot_autograd again with the inductor decomposition table. The 2nd calling of aot_autograd is supposedly causing some problems, and seems excessive, so instead we will create a new function, torch._export.aot_compiler which will export using the inductor decomposition table, pass it to inductor's compile_fx_aot, and because it has already been exported, avoid recalling aot_autograd.
```
def aot_compile(
f: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
constraints: Optional[List[Constraint]] = None,
) -> Tuple[str, ExportedProgram]:
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105977
Approved by: https://github.com/desertfire, https://github.com/zhxchen17, https://github.com/eellison
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index e014860..a354299 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -1132,14 +1132,9 @@
example_args, example_kwargs
)
- exported = torch._export.export(model, example_args, example_kwargs)
- param_buffer_values = list(exported.state_dict.values())
- flat_example_inputs = fx_pytree.tree_flatten_spec(
- example_inputs, exported.call_spec.in_spec
+ so_path, exported = torch._export.aot_compile(
+ model, example_args, example_kwargs
)
- all_args = (*param_buffer_values, *flat_example_inputs)
- # AOT compile into a .so
- so_path = torch._inductor.aot_compile(exported.graph_module, all_args)
output_node = list(exported.graph.nodes)[-1]
output_tensors = [
diff --git a/test/cpp/aot_inductor/test.py b/test/cpp/aot_inductor/test.py
index 20fbe23..2de3899 100644
--- a/test/cpp/aot_inductor/test.py
+++ b/test/cpp/aot_inductor/test.py
@@ -1,8 +1,7 @@
import shutil
import torch
-import torch._dynamo
-import torch._inductor
+import torch._export
class Net(torch.nn.Module):
@@ -23,7 +22,6 @@
torch._dynamo.reset()
with torch.no_grad():
- module, _ = torch._dynamo.export(Net().cuda())(x, y)
- lib_path = torch._inductor.aot_compile(module, [x, y])
+ lib_path, module = torch._export.aot_compile(Net().cuda(), (x, y))
shutil.copy(lib_path, "libaot_inductor_output.so")
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 67f95af..e8faefa 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -32,14 +32,10 @@
output_tensors.append(torch.empty_like(output))
# The exact API is subject to change
- exported = torch._export.export(model, example_inputs)
- param_buffer_values = list(exported.state_dict.values())
- flat_example_inputs = fx_pytree.tree_flatten_spec(
- example_inputs, exported.call_spec.in_spec
+ so_path, exported = torch._export.aot_compile(
+ model,
+ example_inputs,
)
- all_args = (*param_buffer_values, *flat_example_inputs)
- # AOT compile into a .so
- so_path = torch._inductor.aot_compile(exported.graph_module, all_args)
# Use a utility function for easier testing
source = """
@@ -79,6 +75,24 @@
class AotInductorTests(TestCase):
+ def test_simple(self):
+ class Repro(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(10, 10, device="cuda")
+
+ def forward(self, x, y):
+ return x + torch.nn.functional.linear(y, self.weight)
+
+ model = Repro()
+ example_inputs = (
+ torch.randn(10, 10, device="cuda"),
+ torch.randn(10, 10, device="cuda"),
+ )
+ expected = model(*example_inputs)
+ actual = AOTInductorModelRunner.run(model, example_inputs, expected)
+ self.assertTrue(same(actual, expected))
+
def test_missing_output(self):
class Repro(torch.nn.Module):
def __init__(self):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 284444e..30802df 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -10,6 +10,7 @@
import torch
import torch._dynamo
import torch.fx
+import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch._decomp import core_aten_decompositions, get_decompositions
@@ -344,3 +345,55 @@
f"but got {len(args)} positional args, {len(kwargs)} kwargs."
)
return OrderedDict({kw_name: kwargs[kw_name] for kw_name in arg_names[len(args):]})
+
+
+def aot_compile(
+ f: Callable,
+ args: Tuple[Any],
+ kwargs: Optional[Dict[str, Any]] = None,
+ constraints: Optional[List[Constraint]] = None,
+ options: Optional[Dict[str, Any]] = None,
+) -> Tuple[str, ExportedProgram]:
+ """
+ Note: this function is not stable yet
+
+ Traces either an nn.Module's forward function or just a callable with PyTorch
+ operations inside, generates executable cpp code from the program, and returns
+ the path to the generated shared library
+
+ Args:
+ f: the `nn.Module` or callable to trace.
+
+ args: example positional inputs.
+
+ kwargs: optional example keyword inputs.
+
+ constraints: A optional list of constraints on the dynamic arguments specifying
+ their possible range of their shapes
+
+ options: A dictionary of options to control inductor
+
+ Returns:
+ Path to the generated shared library, and the exported program
+ """
+ from torch._inductor.compile_fx import compile_fx_aot
+ from torch._inductor.decomposition import select_decomp_table
+
+ global DECOMP_TABLE
+ DECOMP_TABLE = select_decomp_table()
+ ep = export(f, args, kwargs, constraints)
+ # Reset the global value
+ DECOMP_TABLE = core_aten_decompositions()
+
+ param_buffer_values = list(ep.state_dict.values())
+ flat_example_inputs = fx_pytree.tree_flatten_spec(
+ combine_args_kwargs(args, kwargs), ep.call_spec.in_spec # type: ignore[arg-type]
+ )
+ all_args = (*param_buffer_values, *flat_example_inputs)
+
+ so_path = compile_fx_aot(
+ ep.graph_module,
+ all_args, # type: ignore[arg-type]
+ config_patches=options,
+ )
+ return so_path, ep
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 234965c..8adb985 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -894,11 +894,8 @@
cls.cache[key] = output_so
- def wrapper_call(*args):
- assert len(graph.graph_outputs) > 0
- return cls.cache[key], *(None for i in range(len(graph.graph_outputs) - 1))
-
- return wrapper_call
+ return cls.cache[key]
+ return None
# Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py.
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index c58d717..29c7f9e 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -206,7 +206,10 @@
compiled = inner_compile(
clone_graph(gm), example_inputs, **kwargs_patched
)
- if torch._guards.TracingContext.get().output_strides:
+ if (
+ torch._guards.TracingContext.get()
+ and torch._guards.TracingContext.get().output_strides
+ ):
torch._guards.TracingContext.get().output_strides.clear()
def materialize(x):
@@ -217,18 +220,17 @@
assert not isinstance(x, FakeTensor)
return x
- assert torch._guards.TracingContext.get()
- real_inputs = [
- materialize(x)
- for x in [
- *[
- param
- for param in torch._guards.TracingContext.get().params_flat
- if param is not None
- ],
- *V.real_inputs,
+ if torch._guards.TracingContext.get():
+ params_flat = [
+ param
+ for param in torch._guards.TracingContext.get().params_flat
+ if param is not None
]
- ]
+ real_inputs = [
+ materialize(x) for x in [*params_flat, *V.real_inputs]
+ ]
+ else:
+ real_inputs = [materialize(x) for x in V.real_inputs]
with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)
@@ -531,6 +533,9 @@
context.output_strides.append(None)
compiled_fn = graph.compile_to_fn()
+ if _in_aot_compilation:
+ return compiled_fn
+
if graph.disable_cudagraphs:
BoxedBool.disable(cudagraphs)
@@ -1079,6 +1084,10 @@
tracing_context = (
torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode)
)
+ if _in_aot_compilation:
+ with V.set_fake_mode(fake_mode), compiled_autograd.disable():
+ return fw_compiler(model_, example_inputs_)
+
with V.set_fake_mode(fake_mode), torch._guards.tracing(
tracing_context
), compiled_autograd.disable():
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 4705ccf..51fb2fa 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -984,6 +984,7 @@
code, linemap = self.codegen()
output_code_log.debug("Output code: \n%s", code)
+ # Directly return the file path with the compiled code
return AotCodeCache.compile(self, code, cuda=self.cuda)
else:
return self.compile_to_module().call