Add torch.cond support to AOT Inductor (#121120)

Summary: In this PR, `torch.cond` support and the necessary codegening infrastructure is added to C++ wrapper (AOTInductor and friends).

Notable additions:

- A new mechanism in the Python wrapper codegen to precompile and save the Triton kernels (generated and user-defined) which haven't been covered by the active path through the control flow given the sample inputs. As we can't do the runtime autotuning of the kernels outside the active path, we precompile and save them with the `launchers[0]` (corresponding to the first config).

- Codegen infra for `torch.cond` in the C++ wrapper (ABI- and non-ABI-compatible). The `torch.cond` codegen has been slightly refactored to avoid duplication across the Python and C++ wrappers.

- More extensions of the caching sites in the wrapper code to cache per codegened graph (e.g., `codegen_int_array_var`) + some infra for tracking the current codegened graph in the wrapper (both during codegen-ing in the `Scheduler.codegen` and in the `WrapperCodeGen.generate` functions).

- New unit tests to cover the added AOT Inductor + `torch.cond` functionality.

Codegen examples from the new unit tests:

- [`test_cond_simple_abi_compatible_cpu`](https://gist.github.com/aakhundov/862d5de9aa460f5df399e1387f7b342e)
- [`test_cond_simple_abi_compatible_cuda`](https://gist.github.com/aakhundov/d70b81f95fa8cc768cedef9acacb25bb)
- [`test_cond_simple_non_abi_compatible_cpu`](https://gist.github.com/aakhundov/c0ae7a8cbb6fa311c838e1b580f9a3f6)
- [`test_cond_simple_non_abi_compatible_cuda`](https://gist.github.com/aakhundov/08b945d4e8a32c97b7f9ff6272f4a223)
- [`test_cond_nested_abi_compatible_cuda`](https://gist.github.com/aakhundov/ce664f433c53e010ce4c0d96a6c13711)
- [`test_cond_with_parameters_abi_compatible_cuda`](https://gist.github.com/aakhundov/77afbeb8eaab5c5b930a3f922a7baf12)
- [`test_cond_with_multiple_outputs_abi_compatible_cuda`](https://gist.github.com/aakhundov/8cc06105ec8a3fe88be09b3f6e32c690)

Test Plan:

```
$ python test/inductor/test_aot_inductor.py -k test_cond
...
----------------------------------------------------------------------
Ran 42 tests in 170.619s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121120
Approved by: https://github.com/jansel, https://github.com/chenyang78
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index 2be31cb..ffac33e 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -53,9 +53,11 @@
 try:
     try:
         from .test_aot_inductor_utils import AOTIRunnerUtil
+        from .test_control_flow import CondModels, prepend_predicates
         from .test_torchinductor import copy_tests, requires_multigpu, TestFailure
     except ImportError:
         from test_aot_inductor_utils import AOTIRunnerUtil
+        from test_control_flow import CondModels, prepend_predicates
         from test_torchinductor import copy_tests, requires_multigpu, TestFailure
 except (unittest.SkipTest, ImportError) as e:
     if __name__ == "__main__":
@@ -755,6 +757,130 @@
         )
         self.check_model(Repro(), example_inputs)
 
+    def test_cond_simple(self):
+        inputs = (
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+        )
+        dim0_ab = Dim("s0", min=2, max=1024)
+        dynamic_shapes = {
+            "p": {},
+            "a": {0: dim0_ab, 1: None},
+            "b": {0: dim0_ab, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.Simple(),
+            prepend_predicates(inputs),
+            dynamic_shapes=dynamic_shapes,
+        )
+
+    def test_cond_nested(self):
+        inputs = (
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+        )
+        dim0_abc = Dim("s0", min=2, max=1024)
+        dynamic_shapes = {
+            "p0": {},
+            "p1": {},
+            "p2": {},
+            "a": {0: dim0_abc, 1: None},
+            "b": {0: dim0_abc, 1: None},
+            "c": {0: dim0_abc, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.Nested(),
+            prepend_predicates(inputs, num_predicates=3),
+            dynamic_shapes=dynamic_shapes,
+        )
+
+    def test_cond_with_parameters(self):
+        inputs = (torch.randn((10, 20), device=self.device),)
+        dim0_abc = Dim("s0", min=2, max=1024)
+        dynamic_shapes = {
+            "p": {},
+            "a": {0: dim0_abc, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.Parameters(self.device),
+            prepend_predicates(inputs),
+            dynamic_shapes=dynamic_shapes,
+        )
+
+    def test_cond_with_reinterpret_view_inputs_outputs(self):
+        inputs = (
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+        )
+        dim0_ab = Dim("s0", min=3, max=1024)
+        dynamic_shapes = {
+            "p": {},
+            "a": {0: dim0_ab, 1: None},
+            "b": {0: dim0_ab, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.ReinterpretView(),
+            prepend_predicates(inputs),
+            dynamic_shapes=dynamic_shapes,
+        )
+
+    def test_cond_with_multiple_outputs(self):
+        inputs = (
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+            torch.randn((30, 40), device=self.device),
+        )
+        dim0_ab = Dim("s0", min=2, max=1024)
+        dim0_c = Dim("s1", min=2, max=1024)
+        dynamic_shapes = {
+            "p": {},
+            "a": {0: dim0_ab, 1: None},
+            "b": {0: dim0_ab, 1: None},
+            "c": {0: dim0_c, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.MultipleOutputs(),
+            prepend_predicates(inputs),
+            dynamic_shapes=dynamic_shapes,
+        )
+
+    def test_cond_with_outer_code_before_after(self):
+        inputs = (
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+        )
+        dim0_ab = Dim("s0", min=2, max=1024)
+        dynamic_shapes = {
+            "p": {},
+            "a": {0: dim0_ab, 1: None},
+            "b": {0: dim0_ab, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.OuterCode(),
+            prepend_predicates(inputs),
+            dynamic_shapes=dynamic_shapes,
+        )
+
+    def test_cond_use_buffers_from_outer_scope(self):
+        inputs = (
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+            torch.randn((10, 20), device=self.device),
+        )
+        dim0_abc = Dim("s0", min=2, max=1024)
+        dynamic_shapes = {
+            "p": {},
+            "a": {0: dim0_abc, 1: None},
+            "b": {0: dim0_abc, 1: None},
+            "c": {0: dim0_abc, 1: None},
+        }
+        self.check_model_with_multiple_inputs(
+            CondModels.OuterBuffers(),
+            prepend_predicates(inputs),
+            dynamic_shapes=dynamic_shapes,
+        )
+
     def test_zero_grid_with_backed_symbols(self):
         class Repro(torch.nn.Module):
             def __init__(self):
diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py
index 543a672..f79ca1b 100644
--- a/test/inductor/test_control_flow.py
+++ b/test/inductor/test_control_flow.py
@@ -12,8 +12,18 @@
 from torch.testing._internal.triton_utils import requires_cuda
 
 
-class CondTests(TestCase):
-    class SimpleCondModel(torch.nn.Module):
+def prepend_predicates(inputs, num_predicates=1):
+    result = []
+    device = inputs[0].device
+    # iterate over the cartesian product of predicate values
+    for p_values in itertools.product(*([[False, True]] * num_predicates)):
+        predicates = [torch.tensor(v, device=device) for v in p_values]
+        result.append((*predicates, *inputs))
+    return result
+
+
+class CondModels:
+    class Simple(torch.nn.Module):
         def forward(self, p, a, b):
             def true_fn(x, y):
                 return x + y
@@ -23,7 +33,7 @@
 
             return torch.cond(p, true_fn, false_fn, [a, b])
 
-    class NestedCondModel(torch.nn.Module):
+    class Nested(torch.nn.Module):
         def forward(self, p0, p1, p2, a, b, c):
             def true_fn(x0, y0, z0):
                 def true_true_fn(x1, y1, z1):
@@ -61,6 +71,86 @@
 
             return torch.cond(p0, true_fn, false_fn, [a, b, c])
 
+    class Parameters(torch.nn.Module):
+        class InnerModel1(torch.nn.Module):
+            def __init__(self, device):
+                super().__init__()
+                self.layer = torch.nn.Linear(20, 30, device=device)
+
+            def forward(self, x):
+                return self.layer(x + 1) * 3.14
+
+        class InnerModel2(torch.nn.Module):
+            def __init__(self, device):
+                super().__init__()
+                self.layer1 = torch.nn.Linear(20, 10, device=device)
+                self.layer2 = torch.nn.Linear(10, 30, device=device)
+
+            def forward(self, x):
+                return self.layer2(self.layer1(x - 2)) * 3.14
+
+        def __init__(self, device):
+            super().__init__()
+            self.true_fn = self.InnerModel1(device)
+            self.false_fn = self.InnerModel2(device)
+
+        def forward(self, p, a):
+            return torch.cond(p, self.true_fn, self.false_fn, [a])
+
+    class ReinterpretView(torch.nn.Module):
+        def forward(self, p, a, b):
+            def true_fn(x, y):
+                z1 = x + y
+                z2 = x - y
+                return z1[2:], z2[:, 4:]
+
+            def false_fn(x, y):
+                z1 = x - y
+                z2 = x + y
+                return z1[2:], z2[:, 4:]
+
+            return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
+
+    class MultipleOutputs(torch.nn.Module):
+        def forward(self, p, a, b, c):
+            def true_fn(x, y, z):
+                return x * y, z / 2.71, (y - x).sum(dim=1)
+
+            def false_fn(x, y, z):
+                return y / x, z * 3.14, (x + y).mean(dim=1)
+
+            return torch.cond(p, true_fn, false_fn, [a, b, c])
+
+    class OuterCode(torch.nn.Module):
+        def forward(self, p, a, b):
+            c = a * b + 3.14
+            d = a / b - 2.71
+
+            def true_fn(x, y):
+                return x + y
+
+            def false_fn(x, y):
+                return x - y
+
+            e = torch.cond(p, true_fn, false_fn, [c, d])
+
+            return e * e / 1.41
+
+    class OuterBuffers(torch.nn.Module):
+        def forward(self, p, a, b, c):
+            d = a * 2
+            e = b / 2
+
+            def true_fn(x):
+                return x + d
+
+            def false_fn(x):
+                return x - e
+
+            return torch.cond(p, true_fn, false_fn, [c])
+
+
+class CondTests(TestCase):
     def _run_test(
         self,
         model,
@@ -87,11 +177,9 @@
                     torch._dynamo.mark_dynamic(inp, 0)
 
         for inputs in input_sets:
-            # iterate over the cartesian product of predicate values
-            for p_values in itertools.product(*([[False, True]] * num_predicates)):
-                predicates = [torch.tensor(v, device=device) for v in p_values]
-                result = model(*predicates, *inputs)
-                result_compiled = compiled_model(*predicates, *inputs)
+            for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
+                result = model(*inputs_with_predicates)
+                result_compiled = compiled_model(*inputs_with_predicates)
                 self.assertEqual(result, result_compiled)
 
         self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
@@ -102,7 +190,7 @@
     def test_simple_control_flow(self, device, dynamic):
         # cond control flow without nesting
         self._run_test(
-            model=self.SimpleCondModel(),
+            model=CondModels.Simple(),
             inputs=(
                 torch.randn(10, 20),
                 torch.randn(10, 20),
@@ -117,7 +205,7 @@
     def test_nested_control_flow(self, device, dynamic):
         # cond control flow with nesting
         self._run_test(
-            model=self.NestedCondModel(),
+            model=CondModels.Nested(),
             inputs=(
                 torch.randn(10, 20),
                 torch.randn(10, 20),
@@ -133,23 +221,8 @@
     @parametrize("dynamic", [False, True])
     def test_outer_code_before_after(self, device, dynamic):
         # some code before and after the conditional
-        class Model(torch.nn.Module):
-            def forward(self, p, a, b):
-                c = a * b + 3.14
-                d = a / b - 2.71
-
-                def true_fn(x, y):
-                    return x + y
-
-                def false_fn(x, y):
-                    return x - y
-
-                e = torch.cond(p, true_fn, false_fn, [c, d])
-
-                return e * e / 1.41
-
         self._run_test(
-            model=Model(),
+            model=CondModels.OuterCode(),
             inputs=(
                 torch.randn(10, 20),
                 torch.randn(10, 20),
@@ -163,18 +236,8 @@
     @parametrize("dynamic", [False, True])
     def test_multiple_outputs(self, device, dynamic):
         # multiple outputs with different shapes
-        class Model(torch.nn.Module):
-            def forward(self, p, a, b, c):
-                def true_fn(x, y, z):
-                    return x * y, z / 2.71, (y - x).sum(dim=1)
-
-                def false_fn(x, y, z):
-                    return y / x, z * 3.14, (x + y).mean(dim=1)
-
-                return torch.cond(p, true_fn, false_fn, [a, b, c])
-
         self._run_test(
-            model=Model(),
+            model=CondModels.MultipleOutputs(),
             inputs=(
                 torch.randn(10, 20),
                 torch.randn(10, 20),
@@ -215,21 +278,8 @@
     @requires_cuda
     def test_use_buffers_from_outer_scope(self):
         # subgraphs input shapes include symbolic expressions
-        class Model(torch.nn.Module):
-            def forward(self, p, a, b, c):
-                d = a * 2
-                e = b / 2
-
-                def true_fn(x):
-                    return x + d
-
-                def false_fn(x):
-                    return x - e
-
-                return torch.cond(p, true_fn, false_fn, [c])
-
         self._run_test(
-            model=Model(),
+            model=CondModels.OuterBuffers(),
             inputs=(
                 torch.randn(10, 20),
                 torch.randn(10, 20),
@@ -242,22 +292,8 @@
     @requires_cuda
     def test_reintepret_view_inputs_outputs(self):
         # ReinterpretView in inputs and outputs of the subgraphs
-        class Model(torch.nn.Module):
-            def forward(self, p, a, b):
-                def true_fn(x, y):
-                    z1 = x + y
-                    z2 = x - y
-                    return z1[2:], z2[:, 4:]
-
-                def false_fn(x, y):
-                    z1 = x - y
-                    z2 = x + y
-                    return z1[2:], z2[:, 4:]
-
-                return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
-
         self._run_test(
-            model=Model(),
+            model=CondModels.ReinterpretView(),
             inputs=(
                 torch.randn(10, 20),
                 torch.randn(10, 20),
@@ -271,34 +307,8 @@
     @parametrize("dynamic", [False, True])
     def test_subgraphs_with_parameters(self, device, dynamic):
         # nested Modules with parameters
-        class InnerModel1(torch.nn.Module):
-            def __init__(self):
-                super().__init__()
-                self.layer = torch.nn.Linear(20, 30, device=device)
-
-            def forward(self, x):
-                return self.layer(x + 1) * 3.14
-
-        class InnerModel2(torch.nn.Module):
-            def __init__(self):
-                super().__init__()
-                self.layer1 = torch.nn.Linear(20, 10, device=device)
-                self.layer2 = torch.nn.Linear(10, 30, device=device)
-
-            def forward(self, x):
-                return self.layer2(self.layer1(x - 2)) * 3.14
-
-        class Model(torch.nn.Module):
-            def __init__(self):
-                super().__init__()
-                self.true_fn = InnerModel1()
-                self.false_fn = InnerModel2()
-
-            def forward(self, p, a):
-                return torch.cond(p, self.true_fn, self.false_fn, [a])
-
         self._run_test(
-            model=Model(),
+            model=CondModels.Parameters(device),
             inputs=(torch.randn(10, 20),),
             device=device,
             dynamic=dynamic,
@@ -392,7 +402,7 @@
             }
         ):
             self._run_test(
-                model=self.NestedCondModel(),
+                model=CondModels.Nested(),
                 inputs=(
                     torch.randn(10, 20),
                     torch.randn(10, 20),
diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py
index d669990..9d421a6 100644
--- a/torch/_inductor/codegen/cpp_wrapper_cpu.py
+++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py
@@ -15,7 +15,7 @@
 from ..utils import cache_on_self, sympy_product
 from ..virtualized import V
 from .common import IndentedBuffer
-from .wrapper import pexpr, WrapperCodeGen
+from .wrapper import EnterSubgraphLine, ExitSubgraphLine, pexpr, WrapperCodeGen
 
 
 class CppWrapperCpu(WrapperCodeGen):
@@ -1230,7 +1230,11 @@
 
     @functools.lru_cache(None)
     def codegen_int_array_var(
-        self, int_array: str, writer=None, known_statically=False
+        self,
+        int_array: str,
+        writer=None,
+        known_statically=False,
+        graph=None,  # for per-graph caching
     ):
         # Because the memory planning is done in two passes (see the implementation
         # of self.generate), the writeline behavior is different in the two passes.
@@ -1273,11 +1277,13 @@
                 size,
                 self.wrapper_call,
                 known_statically=self.is_statically_known_list_of_ints(shape),
+                graph=self.get_codegened_graph(),
             )
             stride_array_var = self.codegen_int_array_var(
                 stride,
                 self.wrapper_call,
                 known_statically=self.is_statically_known_list_of_ints(orig_stride),
+                graph=self.get_codegened_graph(),
             )
             device_type, device_id = device_str.split(",")
             device_idx = "this->device_idx_" if V.graph.aot_mode else device_id
@@ -1342,8 +1348,12 @@
                 pexpr(offset),  # bytes not numel
                 self.codegen_dtype(dtype),
                 str(len(shape)),
-                self.codegen_int_array_var(size, self.wrapper_call),
-                self.codegen_int_array_var(stride, self.wrapper_call),
+                self.codegen_int_array_var(
+                    size, self.wrapper_call, graph=self.get_codegened_graph()
+                ),
+                self.codegen_int_array_var(
+                    stride, self.wrapper_call, graph=self.get_codegened_graph()
+                ),
                 f"&{tmp_name}",
             ]
             self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};")
@@ -1386,11 +1396,13 @@
                     size,
                     writer,
                     known_statically=self.is_statically_known_list_of_ints(size_list),
+                    graph=self.get_codegened_graph(),
                 ),
                 self.codegen_int_array_var(
                     stride,
                     writer,
                     known_statically=self.is_statically_known_list_of_ints(stride_list),
+                    graph=self.get_codegened_graph(),
                 ),
                 offset,
             ]
@@ -1459,11 +1471,71 @@
         if not config.abi_compatible:
             super().codegen_multi_output(name, value)
 
-    def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
-        raise NotImplementedError("Control flow NYI in C++ wrapper codegen.")
+    def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
+        for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
+            if config.abi_compatible:
+                # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional
+                # input (outer_input) into another at::Tensor to be used as a subgraph input
+                # (inner_input) in the nested scope. we can't std::move here, as the codegened
+                # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we
+                # can't necessarily std::move it back to the origin (x).
+                self.writeline(f"AtenTensorHandle {inner_input}_handle;")
+                self.writeline(
+                    f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));"
+                )
+                self.writeline(
+                    f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);"
+                )
+            else:
+                self.writeline(
+                    f"{self.declare}{inner_input} = {outer_input}{self.ending}"
+                )
+
+    def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
+        for inner_output, outer_output in zip(
+            subgraph.graph.graph_outputs, outer_outputs
+        ):
+            src = inner_output.codegen_reference()
+            if config.abi_compatible:
+                # in ABI-compatible mode, we need to std::move subgraph output (inner_output)
+                # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy
+                # constructor is deleted.
+                src = f"std::move({src})"
+            self.writeline(f"{outer_output} = {src}{self.ending}")
 
     def codegen_conditional(self, conditional):
-        raise NotImplementedError("Control flow NYI in C++ wrapper codegen.")
+        name = conditional.get_name()
+        outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands]
+        if config.abi_compatible:
+            outer_outputs = []
+            for out in conditional.outputs:
+                # in ABI-compatible mode, ir.MultiOutput is not codegened,
+                # hence pre-declare output variables directly and separately
+                self.writeline(f"RAIIAtenTensorHandle {out.get_name()};")
+                outer_outputs.append(out.get_name())
+            predicate = f"{conditional.predicate.get_name()}_scalar"
+            self.writeline(f"bool {predicate};")
+            # in ABI-compatible mode, we need to use the ABI shim function
+            # to extract a C++ bool from the unrelying scalar bool Tensor
+            self.writeline(
+                f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));"
+            )
+        else:
+            # in non-ABI-compatible mode, we can codegen the conditional outputs
+            # as array of at::Tensor instances, as the ir.MultiOutput is codegened
+            outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
+            self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];")
+            predicate = f"{conditional.predicate.codegen_reference()}.item<bool>()"
+
+        self.writeline(f"if ({predicate}) {{")
+        self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
+        self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline("} else {")
+        self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
+        self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
+        self.writeline(ExitSubgraphLine(self))
+        self.writeline("}")
 
     def generate_extern_kernel_args_decl_if_needed(
         self, op_overload, raw_args, output_args
@@ -1696,6 +1768,12 @@
 
         self.extern_call_ops.add(cpp_kernel_key)
 
+    def generate_reset_kernel_saved_flags(self):
+        pass
+
+    def generate_save_uncompiled_kernels(self):
+        pass
+
     def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
         if (
             config.abi_compatible
@@ -1760,7 +1838,12 @@
             if config.abi_compatible:
                 static = self.is_statically_known_list_of_ints(val)
                 # Need to pass the array length because we can't use std::vector
-                return f"{self.codegen_int_array_var(result, known_statically=static)}, {len(val)}"
+                int_var_array = self.codegen_int_array_var(
+                    result,
+                    known_statically=static,
+                    graph=self.get_codegened_graph(),
+                )
+                return f"{int_var_array}, {len(val)}"
             else:
                 return result
         else:
diff --git a/torch/_inductor/codegen/cpp_wrapper_cuda.py b/torch/_inductor/codegen/cpp_wrapper_cuda.py
index 3fc207f..caa3549 100644
--- a/torch/_inductor/codegen/cpp_wrapper_cuda.py
+++ b/torch/_inductor/codegen/cpp_wrapper_cuda.py
@@ -1,7 +1,7 @@
 import functools
 import os
 from itertools import chain, count
-from typing import Any, List, Optional
+from typing import Any, List, Optional, TYPE_CHECKING
 
 import sympy
 
@@ -14,6 +14,9 @@
 from .cpp_wrapper_cpu import CppWrapperCpu
 from .wrapper import SymbolicCallArg
 
+if TYPE_CHECKING:
+    from ..graph import GraphLowering
+
 
 def is_int(s: str) -> bool:
     # Cpp code gen adds L at the end of ints
@@ -167,7 +170,12 @@
 
     @functools.lru_cache(None)
     def generate_load_kernel_once(
-        self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
+        self,
+        name: str,
+        mangled_name: str,
+        cubin_path: str,
+        shared_mem: int,
+        graph: "GraphLowering",  # for per-graph caching
     ):
         if V.graph.aot_mode:
             self.writeline(f"if (kernels.{name} == nullptr) {{")
@@ -267,7 +275,9 @@
         ), f"cubin file should already exist at this moment: {cubin_path}"
         shared_mem = params.get("shared_mem", 0)
 
-        self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)
+        self.generate_load_kernel_once(
+            name, mangled_name, cubin_path, shared_mem, V.graph
+        )
 
         # args with value 1 are added into equal_to_1 and constants
         # in triton_meta (in the Python codegen) which makes them
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 71a1635..c0f6864 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -47,6 +47,8 @@
 if TYPE_CHECKING:
     import triton
 
+    from ..graph import GraphLowering
+
 
 pexpr = PythonPrinter().doprint
 
@@ -223,13 +225,22 @@
     pass
 
 
-class EnterScopeLine(WrapperLine):
+@dataclasses.dataclass
+class EnterSubgraphLine(WrapperLine):
+    wrapper: "WrapperCodeGen"
+    graph: "GraphLowering"
+
     def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper.push_codegened_graph(self.graph)
         code.do_indent()
 
 
-class ExitScopeLine(WrapperLine):
+@dataclasses.dataclass
+class ExitSubgraphLine(WrapperLine):
+    wrapper: "WrapperCodeGen"
+
     def codegen(self, code: IndentedBuffer) -> None:
+        self.wrapper.pop_codegened_graph()
         code.do_unindent()
 
 
@@ -403,7 +414,7 @@
         # If the generated source code is exactly the same, reuse the
         # pre-existing kernel for it
         self.src_to_kernel: Dict[str, str] = {}
-        self.kernel_numel_expr: Set[str] = set()
+        self.kernel_numel_expr: Set[Tuple[str, "GraphLowering"]] = set()
         self.lines: List[Union[MemoryPlanningLine, LineContext]] = []
         self.declare = ""
         self.declare_maybe_reference = ""
@@ -424,6 +435,12 @@
         self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {}
         self.computed_sizes: Set[sympy.Symbol] = set()
 
+        # this is used for tracking which GraphLowering instance---parent graph
+        # or (nested) subgraph---is currently codegened; the primary use case is
+        # including the graph instance into a cache key to avoid cross-graph
+        # caching during lowering of nested subgraphs
+        self.codegened_graph_stack = [V.graph]
+
         self.write_header()
         self.write_prefix()
 
@@ -570,6 +587,15 @@
         self.writeline(f"{name} = get_raw_stream({device_idx})")
         return name
 
+    def get_codegened_graph(self):
+        return self.codegened_graph_stack[-1]
+
+    def push_codegened_graph(self, graph):
+        self.codegened_graph_stack.append(graph)
+
+    def pop_codegened_graph(self):
+        return self.codegened_graph_stack.pop()
+
     def next_kernel_suffix(self) -> str:
         return f"{next(self._names_iter)}"
 
@@ -701,6 +727,9 @@
             else:
                 self.memory_plan_reuse()
 
+            if config.triton.store_cubin:
+                self.generate_reset_kernel_saved_flags()
+
             for line in self.lines:
                 if isinstance(line, WrapperLine):
                     line.codegen(self.wrapper_call)
@@ -715,6 +744,9 @@
             if config.profile_bandwidth:
                 self.generate_end_graph()
 
+            if config.triton.store_cubin:
+                self.generate_save_uncompiled_kernels()
+
             self.generate_return(output_refs)
 
         self.finalize_prefix()
@@ -756,9 +788,9 @@
             line = self.lines[i]
             if isinstance(line, MemoryPlanningLine):
                 self.lines[i] = line.plan(planning_states[-1])
-            elif isinstance(line, EnterScopeLine):
+            elif isinstance(line, EnterSubgraphLine):
                 planning_states.append(MemoryPlanningState())
-            elif isinstance(line, ExitScopeLine):
+            elif isinstance(line, ExitSubgraphLine):
                 past_planning_states.append(planning_states.pop())
         past_planning_states.append(planning_states.pop())
         assert len(planning_states) == 0
@@ -1152,8 +1184,9 @@
 
     def generate_numel_expr(self, kernel_name: str, tree):
         expr = f"{kernel_name}_{tree.prefix}numel"
-        if expr not in self.kernel_numel_expr:
-            self.kernel_numel_expr.add(expr)
+        if (expr, V.graph) not in self.kernel_numel_expr:
+            # declare expr once in each graph (scope)
+            self.kernel_numel_expr.add((expr, V.graph))
             self.writeline(
                 f"{self.declare}{expr} = {self.expr_printer(tree.numel)}{self.ending}"
             )
@@ -1192,6 +1225,41 @@
     def generate_end_graph(self):
         self.wrapper_call.writeline("end_graph()")
 
+    def generate_reset_kernel_saved_flags(self):
+        self.wrapper_call.splice(
+            """
+            for kernel in globals().values():
+                if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner):
+                    kernel.cuda_kernel_saved = False
+            """
+        )
+
+    def generate_save_uncompiled_kernels(self):
+        """
+        Precompile and save the CUBINs of the Triton kernels that haven't
+        been precompiled and saved as a side effect of running the generated
+        JIT model (Python wrapper). This can happen when the model contains
+        control flow: only one pass through the control flow operators covers
+        the kernels that are saved, the remaining kernels are not launched,
+        hence not saved. The main purpose of this codegen is to compile and
+        save the Triton kernels outside the active control flow path for
+        subsequent AOTInductor code generation and compilation.
+        """
+        self.wrapper_call.splice(
+            """
+            for kernel in globals().values():
+                if isinstance(kernel, torch._inductor.triton_heuristics.CachingAutotuner):
+                    if not kernel.cuda_kernel_saved:
+                        if len(kernel.launchers) == 0:
+                            kernel.precompile()
+                        kernel.save_cuda_kernel(
+                            grid=(0, 0, 0),   # use dummy grid
+                            stream="stream",  # use dummy stream
+                            launcher=kernel.launchers[0],
+                        )
+            """
+        )
+
     def generate_default_grid(self, name: str, grid_args: List[Any]):
         return grid_args
 
@@ -1415,38 +1483,46 @@
             self.unbacked_symbol_decls.add(name)
             return self.declare + name
 
-    def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
-        self.writeline(f"# subgraph: {subgraph.name}")
+    def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
         for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs):
             self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}")
-        parent_graph = V.graph
-        with V.set_graph_handler(subgraph.graph):
-            subgraph.graph.codegen_subgraph(
-                parent_graph=parent_graph,
-            )
+
+    def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs):
         for inner_output, outer_output in zip(
             subgraph.graph.graph_outputs, outer_outputs
         ):
             self.writeline(
-                f"{self.declare}{outer_output} = {inner_output.codegen_reference()}{self.ending}"
+                f"{outer_output} = {inner_output.codegen_reference()}{self.ending}"
             )
 
+    def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
+        try:
+            self.push_codegened_graph(subgraph.graph)
+            self.writeline(f"{self.comment} subgraph: {subgraph.name}")
+            self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs)
+            parent_graph = V.graph
+            with V.set_graph_handler(subgraph.graph):
+                subgraph.graph.codegen_subgraph(
+                    parent_graph=parent_graph,
+                )
+            self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs)
+        finally:
+            self.pop_codegened_graph()
+
     def codegen_conditional(self, conditional):
         name = conditional.get_name()
         outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
         outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
 
-        # predefine the list of outer outputs before entering the conditional
-        # TODO(aakhundov): make this work for C++ wrapper codegen (and ABI mode)
         self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
         self.writeline(f"if {conditional.predicate.codegen_reference()}.item():")
-        self.writeline(EnterScopeLine())
+        self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
         self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
-        self.writeline(ExitScopeLine())
+        self.writeline(ExitSubgraphLine(self))
         self.writeline("else:")
-        self.writeline(EnterScopeLine())
+        self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
         self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
-        self.writeline(ExitScopeLine())
+        self.writeline(ExitSubgraphLine(self))
 
     @staticmethod
     def statically_known_int_or_none(x):
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index cc49910..e66e025 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -1195,6 +1195,7 @@
 
         self.scheduler = Scheduler(self.buffers)
         V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
+
         self.scheduler.codegen()
         return self.wrapper_code.generate(self.is_inference)
 
@@ -1212,6 +1213,7 @@
 
         self.wrapper_code = parent_graph.wrapper_code
         self.device_ops = parent_graph.device_ops
+        self.cpp_wrapper = parent_graph.cpp_wrapper
 
         self.scheduler = Scheduler(self.buffers)
         self.scheduler.codegen()
diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py
index 9f675ec..27cbacf 100644
--- a/torch/_inductor/triton_heuristics.py
+++ b/torch/_inductor/triton_heuristics.py
@@ -157,6 +157,7 @@
         self.configs = configs
         self.heuristic_type = heuristic_type
         self.custom_kernel = custom_kernel
+        self.cuda_kernel_saved = False
 
         # Align the default design that default as cuda
         self.device_type = (
@@ -564,6 +565,8 @@
             ).read_bytes()
             CudaKernelParamCache.set(key, params, launcher.bin.asm["hsaco"])
 
+        self.cuda_kernel_saved = True
+
     def coordinate_descent_tuning(self, launcher, *args, **kwargs):
         """
         Coordinate descent tuning can be run with or without max-autotune.
diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h
index beab68a..2995e81 100644
--- a/torch/csrc/inductor/aoti_torch/c/shim.h
+++ b/torch/csrc/inductor/aoti_torch/c/shim.h
@@ -387,6 +387,13 @@
 AOTI_TORCH_EXPORT AOTITorchError
 aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst);
 
+// Make a shallow copy of the tensor referred to by src and assign
+// it to the handle in the ret_dst. This is similar to the above
+// aoti_torch_assign_tensors function, but creates and sets the
+// ret_dst from within.
+AOTI_TORCH_EXPORT AOTITorchError
+aoti_torch_assign_tensors_out(AtenTensorHandle src, AtenTensorHandle* ret_dst);
+
 // This function will create a new tensor object and its pointer is returned
 // through *ret. The caller is responsible for wrapping the tensor pointer
 // with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp
index 4e8d3ca..f8b3062 100644
--- a/torch/csrc/inductor/aoti_torch/shim_common.cpp
+++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp
@@ -594,6 +594,16 @@
   });
 }
 
+AOTITorchError aoti_torch_assign_tensors_out(
+    AtenTensorHandle src,
+    AtenTensorHandle* ret_dst) {
+  AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
+    at::Tensor* src_tensor_ptr = tensor_handle_to_tensor_pointer(src);
+    at::Tensor dst_tensor = *src_tensor_ptr;
+    *ret_dst = new_tensor_handle(std::move(dst_tensor));
+  });
+}
+
 AOTITorchError aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret) {
   AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
     at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);