[AOTAutograd] add export entrypoints (#100587)

The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used @SherlockNoMad and @suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc @zou3519 @Chillee

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100587
Approved by: https://github.com/ezyang, https://github.com/SherlockNoMad
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index ccf4733..0cc4ed0 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -31,7 +31,7 @@
     grad, vjp, vmap, jacrev,
     make_fx
 )
-from torch._functorch.aot_autograd import aot_module_simplified
+from torch._functorch.aot_autograd import aot_module_simplified, aot_export_module, aot_export_joint_simple
 from functorch.compile import (
     nnc_jit, compiled_function, compiled_module,
     min_cut_rematerialization_partition, aot_function, aot_module,
@@ -1946,6 +1946,237 @@
                  dynamic=dynamic)(*inps).sum().backward()
     return (fw_graph_cell[0], bw_graph_cell[0])
 
+class TestMod(torch.nn.Module):
+    def __init__(self, fn):
+        super().__init__()
+        self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True))
+        self.fn = fn
+
+    def forward(self, *args):
+        return self.fn(self.p, *args)
+
+class TestAOTExport(AOTTestCase):
+
+    def test_aot_export_module_joint(self):
+        class ConvBatchnormRelu(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.conv = torch.nn.Conv2d(1, 3, 1, 1)
+                self.bn = torch.nn.BatchNorm2d(3)
+
+            def forward(self, x):
+                x = self.conv(x)
+                x = self.bn(x)
+                user_out = torch.nn.functional.relu(x)
+                loss = user_out.sum()
+                return loss, user_out.detach()
+
+        mod = ConvBatchnormRelu()
+        mod.train()
+        inp = torch.randn(1, 1, 3, 3)
+        o_ref = mod(inp)
+        fx_g, signature = aot_export_module(mod, [inp], trace_joint=True, output_loss_index=0)
+        # Some important characteristics of the exported graph below:
+        # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
+        # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
+        self.assertExpectedInline(fx_g.print_readable(print_output=False), """\
+class <lambda>(torch.nn.Module):
+    def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]):
+        # No stacktrace found for following nodes
+        convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  arg1_1 = None
+        add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
+        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  arg3_1 = arg4_1 = arg5_1 = None
+        getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0]
+        getitem_1: f32[3] = _native_batch_norm_legit_functional[1]
+        getitem_2: f32[3] = _native_batch_norm_legit_functional[2]
+        getitem_3: f32[3] = _native_batch_norm_legit_functional[3]
+        getitem_4: f32[3] = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
+        relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem);  getitem = None
+        detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
+        sum_1: f32[] = torch.ops.aten.sum.default(relu)
+        detach_1: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
+        detach_2: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_1);  detach_1 = None
+        ones_like: f32[] = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
+        expand: f32[1, 3, 3, 3] = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]);  ones_like = None
+        threshold_backward: f32[1, 3, 3, 3] = torch.ops.aten.threshold_backward.default(expand, relu, 0);  expand = relu = None
+        native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]);  threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
+        getitem_5: f32[1, 3, 3, 3] = native_batch_norm_backward[0]
+        getitem_6: f32[3] = native_batch_norm_backward[1]
+        getitem_7: f32[3] = native_batch_norm_backward[2];  native_batch_norm_backward = None
+        convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]);  getitem_5 = arg7_1 = arg0_1 = None
+        getitem_8 = convolution_backward[0]
+        getitem_9: f32[3, 1, 1, 1] = convolution_backward[1]
+        getitem_10: f32[3] = convolution_backward[2];  convolution_backward = None
+        return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
+        """)  # noqa: B950
+
+
+        self.assertExpectedInline(str(signature.parameters), """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""")
+        self.assertExpectedInline(str(signature.buffers), """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""")
+        self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""")
+        self.assertExpectedInline(str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""")  # noqa: B950
+        self.assertExpectedInline(str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""")  # noqa: B950
+        self.assertExpectedInline(str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""")  # noqa: B950
+        self.assertExpectedInline(str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""")  # noqa: B950
+        self.assertExpectedInline(str(signature.backward_signature.gradients_to_user_inputs), """{}""")
+        self.assertExpectedInline(str(signature.backward_signature.loss_output), """getitem_3""")
+
+        # Also check the inference graph
+        # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs.
+        fx_g_inference, signature_inference = aot_export_module(mod, [inp], trace_joint=False)
+        self.assertExpectedInline(fx_g_inference.print_readable(print_output=False), """\
+class <lambda>(torch.nn.Module):
+    def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]):
+        # No stacktrace found for following nodes
+        convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  arg7_1 = arg0_1 = arg1_1 = None
+        add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
+        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05);  convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
+        getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0]
+        getitem_3: f32[3] = _native_batch_norm_legit_functional[3]
+        getitem_4: f32[3] = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
+        relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem);  getitem = None
+        sum_1: f32[] = torch.ops.aten.sum.default(relu)
+        detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu);  relu = None
+        return (getitem_3, getitem_4, add, sum_1, detach)
+        """)  # noqa: B950
+        # Some important characteristics of the exported graph below:
+        # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
+        # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
+
+    def test_aot_export_simplified_basic(self):
+        def f(x, y):
+            return x * y, y * y.detach()
+
+        x = torch.randn(2, requires_grad=True)
+        y = torch.randn(2, requires_grad=True)
+
+        f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False)
+        out_ref = f(x, y)
+        # No calling convention changes necessary to invoke the traced graph
+        out_test = f_graph_fw(x, y)
+        self.assertEqual(out_ref, out_test)
+
+        # Now test the backward
+        x = torch.randn(2, requires_grad=True)
+        y = torch.randn(2, requires_grad=True)
+        x2 = x.clone().detach().requires_grad_(True)
+        y2 = y.clone().detach().requires_grad_(True)
+        x3 = x.clone().detach().requires_grad_(True)
+        y3 = y.clone().detach().requires_grad_(True)
+        f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True)
+        num_fw_outputs = 2
+        fw_g, bw_g = default_partition(f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs)
+        out_ref2 = f(x2, y2)
+        fw_outs = fw_g(x3, y3)
+        out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:]
+        self.assertEqual(out_ref2, out_test2)
+
+        # Test running the traced backward graph with a mocked-up grad_output
+        grad_outs = [torch.ones_like(x) for x in out_ref2]
+        grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs)
+        grads_test = bw_g(*activations, *grad_outs)
+        for g_ref, g_test in zip(grads_ref, grads_test):
+            self.assertEqual(g_ref, g_test)
+
+    def test_aot_export_metadata_mutation_banned(self):
+        def fn(p, x):
+            x.t_()
+            return (x * 2,)
+        mod = TestMod(fn)
+        inp = torch.randn(2)
+        with self.assertRaisesRegex(
+            RuntimeError, "Found an input that received a metadata mutation"
+        ):
+            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
+            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
+            aot_export_module(mod, [inp], trace_joint=False)
+
+    def test_aot_export_input_mutation_on_parameter_banned(self):
+        def fn(p, x):
+            p.mul_(2)
+            return (p + x,)
+        mod = TestMod(fn)
+        inp = torch.randn(2)
+        with self.assertRaisesRegex(
+            RuntimeError, "Found a graph input that requires gradients, and received a mutation"
+        ):
+            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
+            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
+            aot_export_module(mod, [inp], trace_joint=False)
+
+    def test_aot_export_synthetic_bases_banned(self):
+        def fn(p, x, y):
+            x.mul_(2)
+            return (x + y,)
+        mod = TestMod(fn)
+        inp = torch.randn(2)
+        inp2 = inp.view(-1)
+        with self.assertRaisesRegex(
+            RuntimeError, "Encountered aliased inputs that are mutated"
+        ):
+            aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False)
+            aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True)
+            aot_export_module(mod, [inp, inp2], trace_joint=False)
+
+    def test_aot_export_input_dupes_banned(self):
+        def fn(p, x, y):
+            x.mul_(2)
+            return (x + y,)
+        mod = TestMod(fn)
+        inp = torch.randn(2)
+        with self.assertRaisesRegex(
+            RuntimeError, "Encountered duplicated inputs that are mutated in the graph"
+        ):
+            aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False)
+            aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True)
+            aot_export_module(mod, [inp, inp], trace_joint=False)
+
+    def test_aot_export_multiple_outputs_require_grad_banned(self):
+        def fn(p, x):
+            out = p * x
+            return out, out.sum()
+        mod = TestMod(fn)
+        inp = torch.randn(2)
+        with self.assertRaisesRegex(
+            RuntimeError, "Found an output of the forward that requires gradients, that was not"
+        ):
+            aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)
+
+    def test_aot_export_simplified_input_mutations_banned(self):
+        def fn(x):
+            x.mul_(2)
+            return (x + x,)
+        inp = torch.randn(2)
+        with self.assertRaisesRegex(
+            RuntimeError, "aot_export_joint_simple does not support input mutations"
+        ):
+            aot_export_joint_simple(fn, [inp], trace_joint=False)
+            aot_export_joint_simple(fn, [inp], trace_joint=True)
+
+    def test_aot_export_simplified_pytrees_banned(self):
+        def fn(inps):
+            return (inps[0] + inps[1],)
+        inp1 = torch.randn(2)
+        inp2 = torch.randn(2)
+        inps = [inp1, inp2]
+        with self.assertRaisesRegex(
+            RuntimeError, "aot_export_joint_simple requires individual inputs not to be pytrees"
+        ):
+            aot_export_joint_simple(fn, [inps], trace_joint=False)
+            aot_export_joint_simple(fn, [inps], trace_joint=True)
+
+    def test_aot_export_functionalized_rng_banned(self):
+        def fn(p, x):
+            return (p + x,)
+        mod = TestMod(fn)
+        inp = torch.randn(2)
+        with patch("functorch.compile.config.functionalize_rng_ops", True), self.assertRaisesRegex(
+            RuntimeError, "Functionalized RNG is not currently supported in the aot_export"
+        ):
+            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
+            aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
+            aot_export_module(mod, [inp], trace_joint=False)
+
 
 class TestPartitioning(AOTTestCase):
     @unittest.skipIf(not USE_NETWORKX, "networkx not available")
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 888423a..61ab5b2 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -922,6 +922,140 @@
 
     return inner
 
+
+@dataclass
+class BackwardSignature:
+    """
+    Provides information about the backward section of an exported
+    joint forward-backward graph.
+    For a particular fx GraphModule, this class contains information on:
+    (1) A mapping from each gradient (backwards output) to the parameter
+        it corresponds to (forward input)
+    (2) A mapping from each gradient (backwards output) to the user input
+        it corresponds to (forward input)
+    (3) Which of the forward outputs corresponds to the loss, that we backprop on.
+
+    Each string name is the `node.name` of the corresponding node in the fx graph.
+    """
+    gradients_to_parameters: Dict[str, str]
+    gradients_to_user_inputs: Dict[str, str]
+    loss_output: str
+
+GraphOutputName = NewType('GraphOutputName', str)
+GraphInputName = NewType('GraphInputName', str)
+FQN = NewType('FQN', str)
+
+@dataclass
+class GraphSignature:
+    """
+    Provides information about an exported module.
+    For a particular fx GraphModule, this class contains information on:
+    (1) Which graph inputs are parameters, buffers, or user inputs
+    (2) (for params/buffers) a mapping from the name of each graph argument
+        to its parameter/buffer FQN in the original nn.Module.
+    (3) If there are input mutations, these are represented as extra outputs
+        in the fx GraphModule. We provide a mapping from these
+        extra output names to the names of the the actual inputs.
+    (4) The pytree metadata on how to flatten/unflatten inputs and outputs.
+        The corresponding FX GraphModule only accepts and returns
+        pytree-flattened inputs/outputs.
+    (5) (Optionally) if the FX is a joint forward-backward graph, we provide
+        a signature on the backward section of the joint graph.
+    """
+
+    parameters: List[FQN]
+    buffers: List[FQN]
+
+    user_inputs: List[GraphInputName]
+    user_outputs: List[GraphOutputName]
+    inputs_to_parameters: Dict[GraphInputName, FQN]
+    inputs_to_buffers: Dict[GraphInputName, FQN]
+
+    # If the user's module mutates a buffer,
+    # it's represented in the graph as an extra graph output.
+    # This dict is a mapping from
+    # "graph outputs that correspond to updated buffers"
+    # to the FQN names of those mutated buffers.
+    buffers_to_mutate: Dict[GraphOutputName, FQN]
+
+    in_spec: pytree.TreeSpec
+    out_spec: pytree.TreeSpec
+
+    backward_signature: Optional[BackwardSignature]
+
+    @classmethod
+    def from_tracing_metadata(
+        cls,
+        *,
+        in_spec: pytree.TreeSpec,
+        out_spec: pytree.TreeSpec,
+        graph_input_names: List[str],
+        graph_output_names: List[str],
+        view_mutation_metadata: ViewAndMutationMeta,
+        named_parameters: List[str],
+        named_buffers: List[str],
+        num_user_inputs: int,
+        num_user_outputs: int,
+        loss_index: Optional[int],
+        backward_signature: Optional[BackwardSignature],
+    ) -> "GraphSignature":
+        graph_inputs = graph_input_names
+        graph_outputs = graph_output_names
+        parameters = list(named_parameters)
+        buffers = list(named_buffers)
+
+        # Calling convention assumptions:
+        # (1) graph inputs = (params, buffers, user_inputs)
+        # (2) graph outputs = (mutated_inputs, user_outs, param_gradients)
+        # (If we are capturing an inference graph, this convention is identical
+        #  except that param_gradients is empty)
+        user_inputs = graph_inputs[len(parameters) + len(buffers) :]
+        assert num_user_inputs == len(user_inputs)
+        assert len(graph_inputs) == (len(parameters) + len(buffers) + len(user_inputs))
+
+        inputs_to_parameters = dict(zip(graph_inputs[: len(parameters)], parameters))
+        inputs_to_buffers = dict(zip(
+            graph_inputs[len(parameters) : len(parameters) + len(buffers)],
+            buffers,
+        ))
+
+        state_names = [*parameters, *buffers]
+        mutated_buffers = []
+        for idx, input_info in enumerate(view_mutation_metadata.input_info):
+            if input_info.mutates_data:
+                # Only buffers can be mutated, not parameters
+                assert idx >= len(parameters)
+                buffer_name = state_names[idx]
+                mutated_buffers.append(buffer_name)
+
+        assert len(mutated_buffers) == view_mutation_metadata.num_mutated_inputs
+
+        start, stop = 0, view_mutation_metadata.num_mutated_inputs
+        buffers_to_mutate = dict(zip(graph_outputs[start:stop], mutated_buffers))
+
+        start, stop = stop, stop + num_user_outputs
+        user_outputs = graph_outputs[start:stop]
+
+        unused_outputs = len(graph_outputs) - stop
+        if backward_signature is not None:
+            unused_outputs -= len(backward_signature.gradients_to_parameters) + len(
+                backward_signature.gradients_to_user_inputs
+            )
+        assert unused_outputs == 0
+
+        return GraphSignature(
+            parameters=parameters,
+            buffers=buffers,
+            user_inputs=user_inputs,
+            user_outputs=user_outputs,
+            inputs_to_buffers=inputs_to_buffers,
+            inputs_to_parameters=inputs_to_parameters,
+            buffers_to_mutate=buffers_to_mutate,
+            in_spec=in_spec,
+            out_spec=out_spec,
+            backward_signature=backward_signature,
+        )
+
 @dataclasses.dataclass
 class AOTConfig:
     """
@@ -935,6 +1069,8 @@
     num_params_buffers: int
     aot_id: int
     keep_inference_input_mutations: bool
+    is_export: bool = False
+    no_tangents: bool = False
     dynamic_shapes: bool = False
     aot_autograd_arg_pos_to_source : Optional[List[Source]] = None
     inference_compiler: Optional[Callable] = None
@@ -1084,7 +1220,7 @@
 #     otherwise, when we compute autograd.grad(), we will not take those input mutations into account
 #     (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
 def create_joint(
-    fn: Callable,
+    fn: Callable, *, aot_config: AOTConfig
 ) -> Any:
     def inner_fn(primals: List[Any], tangents: List[Any]):
         outs, tangent_mask = fn(*primals)
@@ -1124,12 +1260,21 @@
         # Call the backwards pass
         if grad_primals:
             with fx_traceback.preserve_node_meta():
-                backward_out = torch.autograd.grad(
-                    needed_outs,
-                    grad_primals,
-                    grad_outputs=needed_tangents,
-                    allow_unused=True,
-                )
+                # for full graph export, we always export a joint graph where we assume no tangents are needed.
+                if aot_config.no_tangents:
+                    assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
+                    backward_out = torch.autograd.grad(
+                        needed_outs,
+                        grad_primals,
+                        allow_unused=True,
+                    )
+                else:
+                    backward_out = torch.autograd.grad(
+                        needed_outs,
+                        grad_primals,
+                        grad_outputs=needed_tangents,
+                        allow_unused=True,
+                    )
         backward_out_iter = iter(backward_out)
         return outs, [
             next(backward_out_iter) if i else None for i in inputs_needs_grads
@@ -1344,6 +1489,7 @@
         fw_metadata,
         keep_data_input_mutations=aot_config.keep_inference_input_mutations,
     )
+
     fw_module = create_functionalized_graph(
         fn_to_trace,
         flat_args,
@@ -1943,6 +2089,15 @@
     if ok:
         return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
 
+    # export path: ban duplicate inputs for now, add later if requested.
+    if aot_config.is_export:
+        raise RuntimeError(f"""\
+Encountered duplicated inputs that are mutated in the graph you are trying to export.
+This functionality is currently not supported. If needed, please file a github issue.
+
+fw_metadata={str(fw_metadata)}
+        """)
+
     # Strategy 2: Duplicate specialize.
     #
     # In Haskell types, suppose you have:
@@ -2114,6 +2269,17 @@
     if synthetic_base_info is None:
         return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
 
+    # export path: ban synthetic bases for now, add later if requested.
+    if aot_config.is_export:
+        raise RuntimeError(f"""\
+Encountered aliased inputs that are mutated in the graph you are trying to export.
+This functionality is currently not supported. If needed, please file a github issue.
+
+synthetic_base_info={str(synthetic_base_info)}
+
+fw_metadata={str(fw_metadata)}
+        """)
+
     assert len(fw_metadata.input_info) == len(synthetic_base_info)
 
     # Update our forward metadata to take synthetic bases into account
@@ -2467,7 +2633,7 @@
 # are no duplicate arguments in flat_args (e.g., the same Tensor
 # object never shows up twice.  However, two tensor inputs MAY alias
 # the same storage, so long as they have separate TensorImpls.)
-def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
+def aot_dispatch_autograd_graph(flat_fn, flat_args: List[Any], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
     # traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
     # It includes outputs of the original forward, *and* any updated inputs due to input mutations.
     # However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
@@ -2478,13 +2644,12 @@
 
     assert len(fw_metadata.requires_grad_info) == fw_metadata.num_mutated_inputs + fw_metadata.num_outputs
     joint_inputs = (flat_args, traced_tangents)
-    disable_amp = torch._C._is_any_autocast_enabled()
 
     fn_prepared_for_autograd = fn_prepped_for_autograd(
         flat_fn,
         fw_metadata,
     )
-    joint_fn_to_trace = create_joint(fn_prepared_for_autograd)
+    joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
 
     fx_g = create_functionalized_graph(
         joint_fn_to_trace,
@@ -2503,6 +2668,21 @@
     torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
     fx_g.graph.eliminate_dead_code()
     fx_g.recompile()
+    # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect
+    # when we need to manually detach() some inputs in the forward.
+    # Higher order ops might eventually need to do the same.
+    return fx_g
+
+def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
+    fx_g = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
+
+    # Copied from aot_dispatch_autograd_graph.
+    traced_tangents = pytree.tree_map(
+        lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
+        fw_metadata.traced_tangents,
+    )
+    joint_inputs = (flat_args, traced_tangents)
+    disable_amp = torch._C._is_any_autocast_enabled()
 
     if aot_config.enable_log:
         aot_joint_log.info("%s", lazy_format_graph_code("Joint graph", fx_g, aot_config.aot_id))
@@ -2913,6 +3093,10 @@
     inputs in flat_args are parameters and buffers, and the rest are inputs.
 
     We use this to assume that parameters/buffer's shapes don't change.
+
+    Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export)
+        When aot_config.is_export is True, we return an FX graph + metadata
+        When aot_config.is_export is False, we return an ordinary runtime function
     """
 
     # This is the main entry point.
@@ -2994,18 +3178,56 @@
                     keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd,
                 )(*fake_flat_args)
 
+        if aot_config.is_export:
+            # aot_export: ban input metadata mutations for now to keep shared code paths simpler.
+            # Keeping .resize_() in the graph will require some work
+            # Allowing it but keeping the graph functional will require some calling convention changes.
+            if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0:
+                raise RuntimeError(f"""\
+Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`.
+This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
+
+fw_metadata={str(fw_metadata)}""")
+            # In export, banning data mutations on inputs that require grad for now.
+            # This should be rare, and is tricky to get right. When we trace the backward,
+            # we currently trace with autograd.grad instead of .backward(), which makes it difficult
+            # to ensure that we run autograd all the way through the input **before** it saw the mutation.
+            if len([x for x in fw_metadata.requires_grad_info[:fw_metadata.num_mutated_inputs] if x]) != 0:
+                raise RuntimeError(f"""\
+Found a graph input that requires gradients, and received a mutation.
+This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
+
+fw_metadata={str(fw_metadata)}""")
+            # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad,
+            # and turning it on will require a non-trivial calling convention change for any export runtime.
+            if config.functionalize_rng_ops:
+                raise RuntimeError("""\
+Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue,
+or otherwise set torch._functorch.config.functionalize_rng_ops = False.""")
+
         # crappy version of dispatcher
         # TODO: Do this properly
         if needs_autograd:
-            compiler_fn = aot_dispatch_autograd
+            # For now, aot_dispatch_autograd knows to explicitly return a graph
+            # when run with export, and an opaque callable otherwise.
+            # In theory we could factor these out, but I wanted to let the dust
+            # settle on how functionalized rng fits into export first.
+            compiler_fn = aot_dispatch_autograd_graph if aot_config.is_export else aot_dispatch_autograd
         else:
-            compiler_fn = aot_dispatch_base
+            # aot_dispatch_base_graph contains only the "graph bits", while aot_dispatch_base
+            # includes some extra work around handling a runtime epilogue.
+            compiler_fn = aot_dispatch_base_graph if aot_config.is_export else aot_dispatch_base
 
         compiler_fn = partial(aot_wrapper_synthetic_base, compiler_fn=compiler_fn, needs_autograd=needs_autograd)
         compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)
         # You can put more passes here
 
         compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
+        if aot_config.is_export:
+            # During export, we don't get back a callable - we get back the raw fx graph
+            # (either a joint or an inference-only graph)
+            assert isinstance(compiled_fn, torch.fx.GraphModule)
+            return compiled_fn, fw_metadata
 
         if not hasattr(compiled_fn, "_boxed_call"):
             compiled_fn = make_boxed_func(compiled_fn)
@@ -3040,6 +3262,168 @@
         return pytree.tree_unflatten(x, self.spec)
 
 
+def create_functional_call(mod, params_spec, params_len):
+    # Redudant with dynamo, but worth having in case this gets invoked elsewhere.
+
+    # Note [Fake Modules and AOTAutograd]
+    #
+    # A simple heuristic for when to use fake versus real tensors is that fake tensors are for compile time
+    # (when we don't want to actually run the compute, but we do want to know about metadata),
+    # and real tensors are for runtime (when we actually want to do the compute.) However, in AOTAutograd,
+    # modules are the exception: we always pass AOTAutograd modules with real tensors.
+    # This is because AOTAutograd will produce a compiled function which needs to directly access any
+    # parameters the compiled function may need, but these parameters will NOT be passed in by the caller (aka Dynamo).
+    # So at compile time, the compiled function we produce must close over any parameters, and those parameters must be
+    # real parameters, and we cannot do this unless at compile time we get a module with real tensors.
+
+    # Even if Dynamo did pass all parameters explicitly at runtime, which would eliminate the need to close over
+    # the parameters, it would still be profitable to pass real tensor parameters to the compiler at compile time,
+    # because some compilation strategies like CUDA graphs want to burn in the pointer addresses where the parameter data live,
+    # and of course we can't do that unless we give the backend a real tensor.
+    torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
+
+    def functional_call(*args, **kwargs):
+        with stateless._reparametrize_module(
+            mod, pytree.tree_unflatten(args[:params_len], params_spec)
+        ):
+            if isinstance(mod, torch.fx.GraphModule):
+                with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
+                    warnings.filterwarnings(
+                        "ignore", "Anomaly Detection has been enabled."
+                    )
+                    with torch.autograd.detect_anomaly(check_nan=False):
+                        out = Interpreter(mod).run(*args[params_len:], **kwargs)
+            else:
+                out = mod(*args[params_len:], **kwargs)
+
+        if not isinstance(out, (tuple, list)):
+            raise RuntimeError(
+                "Graph output must be a tuple(). This is so that we can avoid "
+                "pytree processing of the ouputs. Please change the module to "
+                "have tuple outputs or use aot_module instead."
+            )
+        return out
+    return functional_call
+
+# Creates a function that returns flattened inputs and outputs
+# Also returns the output tree spec, which is needed to recover the "unflattened"
+# output tree structure later.
+def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]:
+    if kwargs is None:
+        kwargs = {}
+    # Save the args_spec for flat_tensor_args to unflatten while tracing
+    _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
+    out_spec = PytreeThunk()
+
+    def flat_fn(*flat_args):
+        # The input are flattened tensor args. Prepare the args in the
+        # order that original function expects. Add static args as well.
+        # They will appear as tensor constants in the traced graph.
+        nonlocal out_spec
+        args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
+        tree_out = fn(*args, **kwargs)
+        flat_out, spec = pytree.tree_flatten(tree_out)
+        for i in flat_out:
+            is_known_type = False
+            for j in KNOWN_TYPES:
+                if isinstance(i, j):
+                    is_known_type = True
+                    break
+            if not is_known_type:
+                raise RuntimeError(
+                    f"Found {type(i)} in output, which is not a known type. "
+                    "If this type holds tensors, you need to register a pytree for it. "
+                    "See https://github.com/pytorch/functorch/issues/475 for a brief "
+                    "explanation why. If you don't need to register a pytree, please "
+                    "leave a comment explaining your use case and we'll make this more "
+                    "ergonomic to deal with"
+                )
+        out_spec.set(spec)
+        return flat_out
+    return flat_fn, out_spec
+
+def _graph_input_names(gm):
+    return [node.name for node in gm.graph.nodes if node.op == "placeholder"]
+
+
+def _graph_output_names(gm):
+    output_node = next(iter(reversed(gm.graph.nodes)))
+    assert output_node.op == "output" and len(output_node.args) == 1
+    return_args = output_node.args[0]
+    return [getattr(return_arg, "name", None) for return_arg in return_args]
+
+
+def create_graph_signature(
+    fx_g: torch.fx.GraphModule,
+    fw_metadata: ViewAndMutationMeta,
+    in_spec: pytree.TreeSpec,
+    out_spec: pytree.TreeSpec,
+    *,
+    user_args_flat: List[torch.Tensor],
+    params_and_buffers_flat: List[torch.Tensor],
+    param_names: List[str],
+    buffer_names: List[str],
+    trace_joint: bool,
+    num_user_fw_outs: Optional[int],
+    loss_index: Optional[int],
+) -> GraphSignature:
+
+    # Retrieve graph input names
+    graph_input_names = _graph_input_names(fx_g)
+    # Retrieve graph output names
+    graph_output_names = _graph_output_names(fx_g)
+
+    num_params_buffers = len(param_names) + len(buffer_names)
+    # We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
+    # Such that # graph inps = # user inps + # params + # buffers
+    num_user_args = len(graph_input_names) - num_params_buffers
+
+    if trace_joint:
+        assert num_user_fw_outs is not None
+        num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inputs
+        backward_output_names = graph_output_names[num_fw_outs:]
+
+        grad_index = itertools.count(0)
+        gradients_to_parameters = {
+            backward_output_names[next(grad_index)]: param_names[i]
+            for i, param in enumerate(params_and_buffers_flat)
+            if param.requires_grad
+        }
+
+        gradients_to_user_inputs = {
+            backward_output_names[next(grad_index)]: graph_input_names[i + len(params_and_buffers_flat)]
+            for i, user_input in enumerate(user_args_flat)
+            if user_input.requires_grad
+        }
+
+        assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len(
+            backward_output_names
+        )
+
+        # Check that we have fully accounted for all graph outputs
+        backward_signature = BackwardSignature(
+            gradients_to_parameters,
+            gradients_to_user_inputs,
+            graph_output_names[loss_index],
+        )
+    else:
+        backward_signature = None
+        num_user_fw_outs = len(graph_output_names) - fw_metadata.num_mutated_inputs
+
+    return GraphSignature.from_tracing_metadata(
+        in_spec=in_spec,
+        out_spec=out_spec,
+        graph_input_names=graph_input_names,
+        graph_output_names=graph_output_names,
+        view_mutation_metadata=fw_metadata,
+        named_parameters=param_names,
+        named_buffers=buffer_names,
+        num_user_inputs=num_user_args,
+        num_user_outputs=num_user_fw_outs,
+        loss_index=loss_index,
+        backward_signature=backward_signature,
+    )
+
 def aot_function(
     fn: Callable,
     fw_compiler: Callable,
@@ -3121,6 +3505,8 @@
         keep_inference_input_mutations=keep_inference_input_mutations,
         dynamic_shapes=dynamic,
         aot_autograd_arg_pos_to_source=None,
+        is_export=False,
+        no_tangents=False,
         enable_log=enable_log,
     )
     cached_res = None
@@ -3133,35 +3519,7 @@
 
         # Compile the function and save it in the cache
         if cached_res is None:
-            # Save the args_spec for flat_tensor_args to unflatten while tracing
-            _, tensor_args_spec = pytree.tree_flatten((args, kwargs))
-            out_spec = PytreeThunk()
-
-            def flat_fn(*flat_args):
-                # The input are flattened tensor args. Prepare the args in the
-                # order that original function expects. Add static args as well.
-                # They will appear as tensor constants in the traced graph.
-                nonlocal out_spec
-                args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
-                tree_out = fn(*args, **kwargs)
-                flat_out, spec = pytree.tree_flatten(tree_out)
-                for i in flat_out:
-                    is_known_type = False
-                    for j in KNOWN_TYPES:
-                        if isinstance(i, j):
-                            is_known_type = True
-                            break
-                    if not is_known_type:
-                        raise RuntimeError(
-                            f"Found {type(i)} in output, which is not a known type. "
-                            "If this type holds tensors, you need to register a pytree for it. "
-                            "See https://github.com/pytorch/functorch/issues/475 for a brief "
-                            "explanation why. If you don't need to register a pytree, please "
-                            "leave a comment explaining your use case and we'll make this more "
-                            "ergonomic to deal with"
-                        )
-                out_spec.set(spec)
-                return flat_out
+            flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs)
 
             compiled_fn = create_aot_dispatcher_function(
                 flat_fn,
@@ -3249,26 +3607,6 @@
 
     :func:`aot_module_simplified` removes these overheads.
     """
-    #########################################################
-
-    # Redudant with dynamo, but worth having in case this gets invoked elsewhere.
-
-    # Note [Fake Modules and AOTAutograd]
-    #
-    # A simple heuristic for when to use fake versus real tensors is that fake tensors are for compile time
-    # (when we don't want to actually run the compute, but we do want to know about metadata),
-    # and real tensors are for runtime (when we actually want to do the compute.) However, in AOTAutograd,
-    # modules are the exception: we always pass AOTAutograd modules with real tensors.
-    # This is because AOTAutograd will produce a compiled function which needs to directly access any
-    # parameters the compiled function may need, but these parameters will NOT be passed in by the caller (aka Dynamo).
-    # So at compile time, the compiled function we produce must close over any parameters, and those parameters must be
-    # real parameters, and we cannot do this unless at compile time we get a module with real tensors.
-
-    # Even if Dynamo did pass all parameters explicitly at runtime, which would eliminate the need to close over
-    # the parameters, it would still be profitable to pass real tensor parameters to the compiler at compile time,
-    # because some compilation strategies like CUDA graphs want to burn in the pointer addresses where the parameter data live,
-    # and of course we can't do that unless we give the backend a real tensor.
-    torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
 
     params = {
         **dict(mod.named_parameters(remove_duplicate=False)),
@@ -3278,28 +3616,7 @@
     params_flat = tuple(params_flat)
     params_len = len(params_flat)
 
-    def functional_call(*args, **kwargs):
-        with stateless._reparametrize_module(
-            mod, pytree.tree_unflatten(args[:params_len], params_spec)
-        ):
-
-            if isinstance(mod, torch.fx.GraphModule):
-                with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
-                    warnings.filterwarnings(
-                        "ignore", "Anomaly Detection has been enabled."
-                    )
-                    with torch.autograd.detect_anomaly(check_nan=False):
-                        out = Interpreter(mod).run(*args[params_len:], **kwargs)
-            else:
-                out = mod(*args[params_len:], **kwargs)
-
-        if not isinstance(out, (tuple, list)):
-            raise RuntimeError(
-                "Graph output must be a tuple(). This is so that we can avoid "
-                "pytree processing of the ouputs. Please change the module to "
-                "have tuple outputs or use aot_module instead."
-            )
-        return out
+    functional_call = create_functional_call(mod, params_spec, params_len)
 
     if bw_compiler is None:
         bw_compiler = fw_compiler
@@ -3361,7 +3678,9 @@
         aot_id=next(AOT_COUNTER),
         keep_inference_input_mutations=keep_inference_input_mutations,
         dynamic_shapes=dynamic_shapes,
-        aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source
+        aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
+        is_export=False,
+        no_tangents=False,
     )
 
     compiled_fn = create_aot_dispatcher_function(
@@ -3387,6 +3706,304 @@
 
     return forward
 
+def aot_export_module(
+    mod: nn.Module,
+    args,
+    *,
+    decompositions: Optional[Dict] = None,
+    # If true, we'll return a joint forward-backward graph,
+    # As well as metadata on the loss + gradients in the backward.
+    trace_joint: bool,
+    # If trace_joint is True, we expect your module to return a scalar loss.
+    # Your module can return multiple outputs, so you must specify which output the loss is.
+    output_loss_index: Optional[int] = None,
+) -> Tuple[torch.fx.GraphModule, GraphSignature]:
+    """
+    This function takes in a module, and returns:
+    (1) an FX graph that can be exported
+    (2) some metadata about the graph
+
+    If `trace_joint=True` we will return a joint graph of the forward + backward.
+
+    The traced FX graph will have the following properties compared to the original module:
+    (1) Inputs and outputs to the module will be pytree-flattened
+    (2) Parameters and buffers on the module will be lifted into graph inputs,
+        graph_inputs = (*parameters, *buffers, *user_inputs)
+    (3) The graph will be fully functionalized
+    (4) Any input mutations will be converted into additional outputs in the graph,
+        meaning whoever calls this graph is responsible for applying the mutations
+        back to the original inputs.
+    (5) If is_joint is provided the graph will return parameter gradients in addition to user outputs.
+        The graph output will look like:
+        graph_outputs = (*updated_inputs, *user_outputs, *param_gradients)
+
+    There are also several restrictions on what modules can use this API. In particular:
+    (1) If trace_joint is specified, we expect the loss function to be **fused**
+        into the module forward. One of the outputs to the forward must be a scalar loss,
+        which is specified with `output_loss_index`.
+        All other outputs to the forward are presumed to not require gradients.
+    (2) This API cannot capture optimizers (although in theory we could build an API for this).
+    (3) Metadata mutations on params/buffers/inputs are banned.
+    (4) Data mutations on anything that requires gradients are banned (parameters)
+    (5) If an input is mutated, it is not allowed to alias any other inputs.
+    (6) Parameters must not be duplicated.
+    """
+    named_parameters = dict(mod.named_parameters(remove_duplicate=False))
+    named_buffers = dict(mod.named_buffers(remove_duplicate=False))
+    params_and_buffers = {
+        **dict(named_parameters),
+        **dict(named_buffers),
+    }
+    params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
+    params_and_buffers_flat = tuple(params_and_buffers_flat)
+    params_len = len(params_and_buffers_flat)
+
+    functional_call = create_functional_call(mod, params_spec, params_len)
+
+    num_fw_outs = None
+
+    if trace_joint:
+        # This helper effectively just adds some extra asserts about what the backward will look like:
+        # Outputs must include a scalar loss, that we compute gradients w.r.t.
+        # We don't compute gradients w.r.t. anything else: so just in case we detach()
+        # and other output tensors.
+        def fn_to_trace(*args):
+            nonlocal num_fw_outs
+            out = functional_call(*args)
+            if output_loss_index is None:
+                raise RuntimeError("""\
+If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss.
+You must specify the which (index) output is the loss with output_loss_index.""")
+            if isinstance(out, (torch.Tensor)):
+                out = (out,)
+            if not isinstance(out, (tuple, list)):
+                raise RuntimeError(f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}")
+
+            for i, o in enumerate(out):
+                # We only want to create a backward graph w.r.t. the loss that the user passed in.
+                # This implies that every other output should not require gradients.
+                # Instead of making this an error (and forcing the user to detach all other outputs
+                # of their forward),
+                # we'll automatically detach them here.
+                if o.requires_grad and i != output_loss_index:
+                    raise RuntimeError(f"""\
+Found an output of the forward that requires gradients, that was not the scalar loss.
+We require all outputs to the forward that are not the scalar loss to not require gradient,
+because we will only compute a backward graph against the scalar loss.
+You can fix this by calling .detach() on each of your forward outputs that is not the loss.
+You specified that output index {output_loss_index} is the loss, but we found that
+the output at index {i} requires gradients.""")
+            out_loss = out[output_loss_index]
+            num_fw_outs = len(out)
+            if not out_loss.requires_grad:
+                raise RuntimeError(f"""\
+The output at index {output_loss_index} was marked as the loss, but it does not require gradients""")
+            if out_loss.numel() != 1:
+                raise RuntimeError(f"""\
+We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""")
+            return out
+        ctx = nullcontext
+    else:
+        # Run under no_grad, so our tracing machinery only traces an inference graph.
+        ctx = torch.no_grad
+        fn_to_trace = functional_call
+
+    full_args = []
+    # First, the params
+    full_args.extend(params_and_buffers_flat)
+    # Next, the input args
+    full_args.extend(args)
+
+    with ctx():
+        fx_g, metadata, in_spec, out_spec = _aot_export_function(
+            fn_to_trace,
+            full_args,
+            decompositions=decompositions,
+            num_params_buffers=len(params_and_buffers_flat),
+            no_tangents=True,
+        )
+    if trace_joint:
+        def flattened_joint(*args):
+            # The idea here is that the joint graph that AOTAutograd creates has some strict properties:
+            # (1) It accepts two arguments (primals, tangents), and pytree_flattens them
+            # (2) It returns a tuple of (fw_outs, gradients)
+            # This is a very useful convention for anyone who wants to partition the joint graph
+            # into a separate forward and backward graph.
+            # However,
+            # (1) for people exporting a single joint graph, it would be preferable not to have
+            #     any pytrees in the graph.
+            # (2) We are guaranteed in the aot_export_module case that the forward outputs a loss,
+            #     and there are therefore no tangents that are needed to run the joint graph.
+            # (3) AOTAutograd creates a grad_input for every input in the forward,
+            #     including None's for inputs that are not grad-requiring tensors.
+            #     we don't want these in our export graph.
+            #     and there are therefore no tangents that are needed to run the joint graph.
+            # This function "fixes" both of the above by removing any tangent inputs,
+            # and removing pytrees from the original FX graph.
+            fake_tangents = [None for _ in range(metadata.num_outputs + metadata.num_mutated_inputs)]
+            fw_outs, gradients = fx_g(args, fake_tangents)
+            assert len(gradients) == len(args)
+            output_gradients = []
+            for i, (a, grad) in enumerate(zip(args, gradients)):
+                if isinstance(a, torch.Tensor) and a.requires_grad:
+                    assert grad is not None, """\
+Found a parameter that did not receive a gradient.
+"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
+https://github.com/pytorch/pytorch/issues/101192
+"""
+                    output_gradients.append(grad)
+                else:
+                    assert grad is None
+            return *fw_outs, *output_gradients
+        fx_g = make_fx(flattened_joint)(*full_args)
+
+    user_args_flat, _ = pytree.tree_flatten(args)
+    return fx_g, create_graph_signature(
+        fx_g,
+        metadata,
+        in_spec,
+        out_spec,
+        user_args_flat=user_args_flat,
+        params_and_buffers_flat=params_and_buffers_flat,
+        param_names=list(named_parameters.keys()),
+        buffer_names=list(named_buffers.keys()),
+        trace_joint=trace_joint,
+        num_user_fw_outs=num_fw_outs,
+        loss_index=output_loss_index,
+    )
+
+def aot_export_joint_simple(
+    func: Callable,
+    args,
+    *,
+    trace_joint: bool,
+    # It looks like the main consequence of this API is that for dynamic shapes,
+    # it will assume that parms/buffers are static.
+    # With the new inferred dynamic shapes API, maybe this doesn't matter?
+    num_params_buffers: int = 0,
+    decompositions: Optional[Dict] = None,
+) -> torch.fx.GraphModule:
+    """
+    A simplified version of export. Used by higher order operators.
+
+    This function makes a high-level "no calling convention changes" guarantee:
+    - If no inputs require grad (so we export an inference graph),
+      there are *no* calling convention change between the exported graph, and "func".
+    - If at least one input requires grad (so we trace out and expot a joint fw-bw graph),
+      Then if you were partition the graph into a separate forward and backward graph,
+      The forward graph will have no calling convention changes compared to "func".
+
+    The above also relies on some strong restrictions around which functions this API accepts:
+    (1) `args` cannot contain any pytrees (they must have been pytree_flattened already)
+    (2) `func` cannot mutate any inputs
+    (3) The outputs of `func` cannot alias any inputs.
+
+    Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops.
+    """
+    if trace_joint:
+        ctx = nullcontext
+    else:
+        # Run under no_grad, so our tracing machinery only traces an inference graph.
+        ctx = torch.no_grad
+
+    with ctx():
+        fx_g, metadata, in_spec, out_spec = _aot_export_function(
+            func,
+            args,
+            decompositions=decompositions,
+        )
+    # At this point, we can just directly return the (joint or inference graph) that we traced.
+    # First though: a bunch of assertions to make sure that our graph doesn't require
+    # any calling convention changes compared to the original function.
+    # These restrictions are *in addition to* the general restrictions on export.
+
+    # No input mutations
+    if len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) != 0:
+        raise RuntimeError(f"aot_export_joint_simple does not support input mutations. {str(metadata)}")
+    # No output aliasing
+    if len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) != 0:
+        raise RuntimeError(f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}")
+    # No pytrees
+    if type(in_spec) == pytree.LeafSpec:
+        raise RuntimeError(f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}")
+    if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
+        raise RuntimeError(f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}")
+    if type(out_spec) == pytree.LeafSpec:
+        raise RuntimeError(f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}")
+    if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
+        raise RuntimeError(f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}")
+    # TODO: we might have to temporarily patch config.functionalize_rng
+    # so that it doesn't run when we're exporting a higher order op.
+
+    if config.debug_assert:
+        # Smoke test that after partitioning, we can run the forward without any calling convention changes.
+        fw_module, bw_module = aot_config.default_partition(
+            fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)
+        )
+        # Attempt to run the fw_module with the original user inputs
+        fake_mode = detect_fake_mode(args)
+        if fake_mode is None:
+            fake_mode = FakeTensorMode()
+        with fake_mode:
+            fw_module(*args)
+    return fx_g
+
+# Private for now because we aren't providing a contract on what to return
+# for joint graphs (we could when there's a clearer use case)
+# In the future, we may need to add more export API's that provide their own strong guarantees.
+# This is meant as a general helper function for handling various export-y use cases.
+def _aot_export_function(
+    func: Callable,
+    args,
+    *,
+    num_params_buffers: int = 0,
+    decompositions: Optional[Dict] = None,
+    # If we're exporting a joint graph and we don't want any tangent inputs in the graph
+    # (because we are backpropping through a scalar 1 loss),
+    # we need to explicitly specify not to include tangents in the graph.
+    # It's not enough just to check that our tangent is a scalar, since we also
+    # need to know if it is a 1 (no need to make it a graph input), or something else
+    # (requiring it to be a graph input).
+    # We don't know this info at trace time though, so we need to make it an explicit config.
+    no_tangents: bool = False,
+) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
+    dynamic_shapes = False
+    for x in args:
+        if isinstance(x, FakeTensor):
+            dynamic_shapes = x.fake_mode.shape_env is not None
+            break
+
+    flat_fn, out_spec = create_tree_flattened_fn(func, args)
+    flat_args, in_spec = pytree.tree_flatten(args)
+
+    # The export use case doesn't care about several bits of AOTConfig
+    # (1) compilers (we just export the graph)
+    # (2) partitioners (export is only full graph, user can partition themselves)
+    aot_config = AOTConfig(
+        fw_compiler=None,
+        bw_compiler=None,
+        inference_compiler=None,
+        partition_fn=None,
+        decompositions=decompositions,
+        num_params_buffers=num_params_buffers,
+        aot_id=next(AOT_COUNTER),
+        # For now there's no use case involving keeping input mutations in the graph
+        # (which we can only do in the inference case anyway).
+        # We can add this later if we need to.
+        keep_inference_input_mutations=False,
+        dynamic_shapes=dynamic_shapes,
+        aot_autograd_arg_pos_to_source=None,
+        is_export=True,
+        no_tangents=no_tangents,
+    )
+
+    fx_g, meta = create_aot_dispatcher_function(
+        flat_fn,
+        flat_args,
+        aot_config,
+    )
+    return fx_g, meta, in_spec, out_spec.spec
+
 
 compiled_function = aot_function
 compiled_module = aot_module