[AOTInductor] Remove call to aot_autograd when receiving ExportedProgram (#105977)

https://github.com/pytorch/pytorch/issues/105555

Existing flow first exports and then calls torch._inductor.aot_compile. However, export calls aot_autograd with the core aten decomposition table, and then torch._inductor.aot_compile calls aot_autograd again with the inductor decomposition table. The 2nd calling of aot_autograd is supposedly causing some problems, and seems excessive, so instead we will create a new function, torch._export.aot_compiler which will export using the inductor decomposition table, pass it to inductor's compile_fx_aot, and because it has already been exported, avoid recalling aot_autograd.

```
def aot_compile(
    f: Callable,
    args: Tuple[Any],
    kwargs: Optional[Dict[str, Any]] = None,
    constraints: Optional[List[Constraint]] = None,
) -> Tuple[str, ExportedProgram]:
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105977
Approved by: https://github.com/desertfire, https://github.com/zhxchen17, https://github.com/eellison
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index e014860..a354299 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -1132,14 +1132,9 @@
                 example_args, example_kwargs
             )
 
-            exported = torch._export.export(model, example_args, example_kwargs)
-            param_buffer_values = list(exported.state_dict.values())
-            flat_example_inputs = fx_pytree.tree_flatten_spec(
-                example_inputs, exported.call_spec.in_spec
+            so_path, exported = torch._export.aot_compile(
+                model, example_args, example_kwargs
             )
-            all_args = (*param_buffer_values, *flat_example_inputs)
-            # AOT compile into a .so
-            so_path = torch._inductor.aot_compile(exported.graph_module, all_args)
 
             output_node = list(exported.graph.nodes)[-1]
             output_tensors = [
diff --git a/test/cpp/aot_inductor/test.py b/test/cpp/aot_inductor/test.py
index 20fbe23..2de3899 100644
--- a/test/cpp/aot_inductor/test.py
+++ b/test/cpp/aot_inductor/test.py
@@ -1,8 +1,7 @@
 import shutil
 
 import torch
-import torch._dynamo
-import torch._inductor
+import torch._export
 
 
 class Net(torch.nn.Module):
@@ -23,7 +22,6 @@
     torch._dynamo.reset()
 
     with torch.no_grad():
-        module, _ = torch._dynamo.export(Net().cuda())(x, y)
-        lib_path = torch._inductor.aot_compile(module, [x, y])
+        lib_path, module = torch._export.aot_compile(Net().cuda(), (x, y))
 
 shutil.copy(lib_path, "libaot_inductor_output.so")
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 67f95af..e8faefa 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -32,14 +32,10 @@
             output_tensors.append(torch.empty_like(output))
 
         # The exact API is subject to change
-        exported = torch._export.export(model, example_inputs)
-        param_buffer_values = list(exported.state_dict.values())
-        flat_example_inputs = fx_pytree.tree_flatten_spec(
-            example_inputs, exported.call_spec.in_spec
+        so_path, exported = torch._export.aot_compile(
+            model,
+            example_inputs,
         )
-        all_args = (*param_buffer_values, *flat_example_inputs)
-        # AOT compile into a .so
-        so_path = torch._inductor.aot_compile(exported.graph_module, all_args)
 
         # Use a utility function for easier testing
         source = """
@@ -79,6 +75,24 @@
 
 
 class AotInductorTests(TestCase):
+    def test_simple(self):
+        class Repro(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.weight = torch.randn(10, 10, device="cuda")
+
+            def forward(self, x, y):
+                return x + torch.nn.functional.linear(y, self.weight)
+
+        model = Repro()
+        example_inputs = (
+            torch.randn(10, 10, device="cuda"),
+            torch.randn(10, 10, device="cuda"),
+        )
+        expected = model(*example_inputs)
+        actual = AOTInductorModelRunner.run(model, example_inputs, expected)
+        self.assertTrue(same(actual, expected))
+
     def test_missing_output(self):
         class Repro(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 284444e..30802df 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -10,6 +10,7 @@
 import torch
 import torch._dynamo
 import torch.fx
+import torch.fx._pytree as fx_pytree
 
 import torch.utils._pytree as pytree
 from torch._decomp import core_aten_decompositions, get_decompositions
@@ -344,3 +345,55 @@
         f"but got {len(args)} positional args, {len(kwargs)} kwargs."
     )
     return OrderedDict({kw_name: kwargs[kw_name] for kw_name in arg_names[len(args):]})
+
+
+def aot_compile(
+    f: Callable,
+    args: Tuple[Any],
+    kwargs: Optional[Dict[str, Any]] = None,
+    constraints: Optional[List[Constraint]] = None,
+    options: Optional[Dict[str, Any]] = None,
+) -> Tuple[str, ExportedProgram]:
+    """
+    Note: this function is not stable yet
+
+    Traces either an nn.Module's forward function or just a callable with PyTorch
+    operations inside, generates executable cpp code from the program, and returns
+    the path to the generated shared library
+
+    Args:
+        f: the `nn.Module` or callable to trace.
+
+        args: example positional inputs.
+
+        kwargs: optional example keyword inputs.
+
+        constraints: A optional list of constraints on the dynamic arguments specifying
+            their possible range of their shapes
+
+        options: A dictionary of options to control inductor
+
+    Returns:
+        Path to the generated shared library, and the exported program
+    """
+    from torch._inductor.compile_fx import compile_fx_aot
+    from torch._inductor.decomposition import select_decomp_table
+
+    global DECOMP_TABLE
+    DECOMP_TABLE = select_decomp_table()
+    ep = export(f, args, kwargs, constraints)
+    # Reset the global value
+    DECOMP_TABLE = core_aten_decompositions()
+
+    param_buffer_values = list(ep.state_dict.values())
+    flat_example_inputs = fx_pytree.tree_flatten_spec(
+        combine_args_kwargs(args, kwargs), ep.call_spec.in_spec  # type: ignore[arg-type]
+    )
+    all_args = (*param_buffer_values, *flat_example_inputs)
+
+    so_path = compile_fx_aot(
+        ep.graph_module,
+        all_args,  # type: ignore[arg-type]
+        config_patches=options,
+    )
+    return so_path, ep
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 234965c..8adb985 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -894,11 +894,8 @@
 
                 cls.cache[key] = output_so
 
-        def wrapper_call(*args):
-            assert len(graph.graph_outputs) > 0
-            return cls.cache[key], *(None for i in range(len(graph.graph_outputs) - 1))
-
-        return wrapper_call
+            return cls.cache[key]
+        return None
 
 
 # Putting this fn in cpp.py (unfortunately) causes a deadlock, which is why it's in codecache.py.
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index c58d717..29c7f9e 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -206,7 +206,10 @@
                 compiled = inner_compile(
                     clone_graph(gm), example_inputs, **kwargs_patched
                 )
-                if torch._guards.TracingContext.get().output_strides:
+                if (
+                    torch._guards.TracingContext.get()
+                    and torch._guards.TracingContext.get().output_strides
+                ):
                     torch._guards.TracingContext.get().output_strides.clear()
 
                 def materialize(x):
@@ -217,18 +220,17 @@
                         assert not isinstance(x, FakeTensor)
                         return x
 
-                assert torch._guards.TracingContext.get()
-                real_inputs = [
-                    materialize(x)
-                    for x in [
-                        *[
-                            param
-                            for param in torch._guards.TracingContext.get().params_flat
-                            if param is not None
-                        ],
-                        *V.real_inputs,
+                if torch._guards.TracingContext.get():
+                    params_flat = [
+                        param
+                        for param in torch._guards.TracingContext.get().params_flat
+                        if param is not None
                     ]
-                ]
+                    real_inputs = [
+                        materialize(x) for x in [*params_flat, *V.real_inputs]
+                    ]
+                else:
+                    real_inputs = [materialize(x) for x in V.real_inputs]
 
                 with torch.utils._python_dispatch._disable_current_modes():
                     compiled(real_inputs)
@@ -531,6 +533,9 @@
                         context.output_strides.append(None)
             compiled_fn = graph.compile_to_fn()
 
+            if _in_aot_compilation:
+                return compiled_fn
+
             if graph.disable_cudagraphs:
                 BoxedBool.disable(cudagraphs)
 
@@ -1079,6 +1084,10 @@
     tracing_context = (
         torch._guards.TracingContext.get() or torch._guards.TracingContext(fake_mode)
     )
+    if _in_aot_compilation:
+        with V.set_fake_mode(fake_mode), compiled_autograd.disable():
+            return fw_compiler(model_, example_inputs_)
+
     with V.set_fake_mode(fake_mode), torch._guards.tracing(
         tracing_context
     ), compiled_autograd.disable():
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 4705ccf..51fb2fa 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -984,6 +984,7 @@
             code, linemap = self.codegen()
             output_code_log.debug("Output code: \n%s", code)
 
+            # Directly return the file path with the compiled code
             return AotCodeCache.compile(self, code, cuda=self.cuda)
         else:
             return self.compile_to_module().call