[AOTI] Improve the two-pass wrapper codegen (#114067)

Summary: For the second-pass, we don't have to rerun the whole inductor flow again. This PR moves that second-pass to the codegen time. This change not only speeds up the compilation, but also removes kernel scheduling inconsistency between the two passes. Another future improvement is to make the second-pass reuse the scheduler and do the wrapper codegen only.

This is a copy of https://github.com/pytorch/pytorch/pull/113762 to land in github first.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114067
Approved by: https://github.com/chenyang78
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index b1dc743..fa7411d 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -18,6 +18,7 @@
 
 from torch.testing import FileCheck
 from torch.testing._internal import common_utils
+from torch.testing._internal.common_quantization import skip_if_no_torchvision
 
 from torch.testing._internal.common_utils import (
     IS_CI,
@@ -1060,6 +1061,42 @@
         example_inputs = (torch.randn(3, 10, device=self.device),)
         self.check_model(Model(), example_inputs)
 
+    @skip_if_no_torchvision
+    def test_missing_cubin(self):
+        from torchvision.models.resnet import Bottleneck, ResNet
+
+        class Model(ResNet):
+            def __init__(self):
+                super().__init__(
+                    block=Bottleneck,
+                    layers=[3, 4, 6, 3],
+                    replace_stride_with_dilation=[False, False, True],
+                    norm_layer=None,
+                )
+
+            def forward(self, x):
+                x = self.conv1(x)
+                x = self.bn1(x)
+                x = self.relu(x)
+                f1 = x
+                x = self.maxpool(x)
+                x = self.layer1(x)
+                f2 = x
+                x = self.layer2(x)
+                f3 = x
+                x = self.layer3(x)
+                x = self.layer4(x)
+                f4 = x
+                return [f1, f2, f3, f4]
+
+        # Call eval() here so that batch_norm won't update the running stats
+        # Use float64 to avoid numeric difference failure
+        model = Model().to(device=self.device, dtype=torch.float64).eval()
+        example_inputs = (
+            torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
+        )
+        self.check_model(model, example_inputs)
+
     @common_utils.parametrize("grid_type", [1, 2, 3])
     @common_utils.parametrize("num_dims", [1, 2])
     @common_utils.parametrize("dynamic", [False, True])
@@ -1194,6 +1231,8 @@
         # TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
         #   NotImplementedError: Cannot access storage of OpaqueTensorImpl
         "test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
+        # Need to support convolution
+        "test_missing_cubin": TestFailure(("abi_compatible_cpu",)),
         "test_normal_functional": TestFailure(("abi_compatible_cpu",)),
         "test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
         # There is a double-free issue which will be fixed in another PR
@@ -1219,6 +1258,8 @@
     # test_failures, xfail by default, set is_skip=True to skip
     {
         "test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
+        # Need to support convolution
+        "test_missing_cubin": TestFailure(("abi_compatible_cuda",)),
         "test_normal_functional": TestFailure(("abi_compatible_cuda",)),
         # There is a double-free issue which will be fixed in another PR
         "test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py
index 726cfce..6de2d05 100644
--- a/test/inductor/test_pattern_matcher.py
+++ b/test/inductor/test_pattern_matcher.py
@@ -50,10 +50,6 @@
         if len(codes) == 1:
             codes = codes[0]
         torch.testing.assert_close(actual, expected)
-        if inductor_config.cpp_wrapper:
-            # CPP wrapper runs everything twice, so we'll match the pattern twice
-            expected_matches *= 2
-            expected_nodes *= 2
 
         self.assertEqual(
             counters["inductor"]["pattern_matcher_count"], expected_matches
@@ -519,13 +515,6 @@
         self.common(fn, args, 2, 5)
 
     def test_cat_slice_cat(self):
-        def check_counter(counter, expected):
-            if not inductor_config.cpp_wrapper:
-                self.assertEqual(counter, expected)
-            else:
-                # cpp_wrapper for the CUDA backend runs two passes
-                self.assertEqual(counter, 2 * expected)
-
         def fn(a, b):
             cat_1 = torch.ops.aten.cat.default([a, b], 1)
             slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
@@ -548,8 +537,8 @@
         torch.testing.assert_close(actual, expected)
         # We don't recompile for dynamic-shape cases.
         if dynamo_config.assume_static_by_default:
-            check_counter(counters["inductor"]["pattern_matcher_count"], 1)
-            check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
+            self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
+            self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)
 
         # Verify we fallback to non-optimal path for negative `end`.
         def fn(a, b):
diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py
index 28c1761..c1ee714 100644
--- a/test/inductor/test_select_algorithm.py
+++ b/test/inductor/test_select_algorithm.py
@@ -44,13 +44,6 @@
 
 
 class TestSelectAlgorithm(TestCase):
-    def check_counter(self, counter, expected):
-        if not inductor_config.cpp_wrapper:
-            self.assertEqual(counter, expected)
-        else:
-            # cpp_wrapper for the CUDA backend runs two passes
-            self.assertEqual(counter, 2 * expected)
-
     @expectedFailureDynamicWrapper
     @patches
     def test_linear_relu(self):
@@ -64,7 +57,7 @@
             torch.randn(1, 16, device="cuda"),
         )
         # Autotuning checks correctness of each version
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
         # It would be nice to assert this got fused into a single kernel, but that
         # only happens if we select a triton template (and not aten).
 
@@ -82,7 +75,7 @@
         )
 
         foo(*inps)
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
 
     @patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
     @patches
@@ -112,7 +105,7 @@
             torch.randn(8, 32, device="cuda"),
             torch.randn(32, 8, device="cuda"),
         )
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
 
     @patches
     def test__int_mm(self):
@@ -206,7 +199,7 @@
             torch.randn(512, 512, device="cuda"),
         )
         # Autotuning checks correctness of each version
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
 
     @patches
     def test_mm_dup_args(self):
@@ -215,7 +208,7 @@
             return torch.mm(a, a)
 
         foo(torch.randn(32, 32, device="cuda"))
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
 
     @patches
     def test_mm_dup_args_view(self):
@@ -226,7 +219,7 @@
             return torch.mm(q, k.transpose(0, 1))
 
         foo(torch.randn(64, 64, device="cuda"))
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
 
     @skipIfRocm
     @expectedFailureDynamicWrapper
@@ -252,7 +245,7 @@
             torch.randn(34, device="cuda"),
         )
         # Autotuning checks correctness of each version
-        self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
+        self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
 
     @skipIfRocm
     @patches
diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py
index b81088e..740c80d 100644
--- a/torch/_inductor/compile_fx.py
+++ b/torch/_inductor/compile_fx.py
@@ -51,7 +51,6 @@
 from .fx_passes.pre_grad import pre_grad_passes
 from .graph import GraphLowering
 from .ir import ExternKernelNode
-from .pattern_matcher import clone_graph
 from .utils import get_dtype_size, has_incompatible_cudagraph_ops
 from .virtualized import V
 
@@ -217,79 +216,6 @@
     return make_boxed_func(gm.forward)
 
 
-def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
-    @functools.wraps(inner_compile)
-    def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
-        """
-        Compile into cpp wrapper:
-        For CPU, this is currently done in one pass.
-        For GPU, this is done in two passes: JIT-compile the model with python wrapper code
-        and run it to generate autotuned kernel binaries in the first pass; and then generate
-        cpp wrapper code and compile it to a dynamic library in the second pass.
-        """
-        devices = (
-            {t.device.type for t in gm.parameters()}
-            | {t.device.type for t in gm.buffers()}
-            | {t.device.type for t in example_inputs if isinstance(t, torch.Tensor)}
-        )
-
-        if "cuda" not in devices:
-            kwargs_patched = {**kwargs, "cpp_wrapper": True}
-            return inner_compile(gm, example_inputs, **kwargs_patched)
-        else:
-            with config.patch(
-                {
-                    "triton.store_cubin": True,
-                }
-            ):
-                # first pass with regular python wrapper code
-                kwargs_patched = {
-                    **kwargs,
-                    "cpp_wrapper": False,
-                }
-                # clone_graph(gm) makes sure no graph modification from the first pass will
-                # leak to the second pass. It does increase memory pressure, but the problem
-                # can be alleviated once we have parameters as FakeTensor.
-
-                compiled = inner_compile(
-                    clone_graph(gm), example_inputs, **kwargs_patched
-                )
-
-                def materialize(x):
-                    if isinstance(x, (torch.SymInt, torch.SymFloat)):
-                        # Need concrete value to run dynamic shapes and tune the result
-                        return x.node.hint
-                    else:
-                        assert not isinstance(x, FakeTensor)
-                        return x
-
-                if tracing_context := torch._guards.TracingContext.try_get():
-                    if tracing_context.output_strides:
-                        tracing_context.output_strides.clear()
-
-                    params_flat = [
-                        param
-                        for param in tracing_context.params_flat  # type: ignore[union-attr]
-                        if param is not None
-                    ]
-                    real_inputs = [
-                        materialize(x) for x in (params_flat + V.real_inputs)
-                    ]
-                else:
-                    real_inputs = [materialize(x) for x in V.real_inputs]
-
-                with torch.utils._python_dispatch._disable_current_modes():
-                    compiled(real_inputs)
-
-                del real_inputs
-
-                # second pass
-                kwargs_patched = {**kwargs, "cpp_wrapper": True}
-                return inner_compile(gm, example_inputs, **kwargs_patched)
-
-    return wrapper
-
-
 def fake_tensor_prop(
     gm: torch.fx.GraphModule,
     example_inputs: List[torch.Tensor],
@@ -592,6 +518,10 @@
     with V.set_fake_mode(fake_mode):
         graph = GraphLowering(
             gm,
+            # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
+            # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
+            # we currently use fake tensors and defake them later.
+            example_inputs=V.real_inputs if is_inference else example_inputs,
             shape_env=shape_env,
             num_static_inputs=num_fixed,
             graph_id=graph_id,
@@ -1033,6 +963,7 @@
                 "cpp_wrapper": False,
                 "triton.autotune_cublasLt": False,
                 "triton.cudagraphs": False,
+                "triton.store_cubin": True,
             }
         ), V.set_real_inputs(example_inputs_):
             inputs_ = example_inputs_
@@ -1055,7 +986,7 @@
             return compile_fx(
                 model_,
                 inputs_,
-                inner_compile=inner_compile_with_cpp_wrapper(inner_compile),
+                inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
                 decompositions=decompositions,
             )
 
diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py
index 3c32076..1027fbf 100644
--- a/torch/_inductor/graph.py
+++ b/torch/_inductor/graph.py
@@ -15,8 +15,9 @@
 import torch._logging
 import torch.fx
 from torch._decomp import get_decompositions
-from torch._dynamo.utils import dynamo_timed
+from torch._dynamo.utils import defake, dynamo_timed
 from torch._logging import LazyString
+from torch._subclasses.fake_tensor import FakeTensor
 from torch.fx.experimental.sym_node import magic_methods, method_to_operator
 from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
 from torch.utils._mode_utils import no_dispatch
@@ -164,6 +165,7 @@
     def __init__(
         self,
         gm: torch.fx.GraphModule,
+        example_inputs: Optional[List[torch.Tensor]] = None,
         shape_env=None,
         num_static_inputs=None,
         graph_id=None,
@@ -176,6 +178,7 @@
     ):
         super().__init__(gm)
 
+        self.example_inputs = example_inputs
         self.layout_opt = (
             layout_opt if layout_opt is not None else self.decide_layout_opt(gm)
         )
@@ -921,6 +924,46 @@
         assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
         self.wrapper_code = wrapper_code_gen_cls()
 
+    def codegen_with_cpp_wrapper(self):
+        """
+        For CPU, the cpp wrapper codegen is done in one pass.
+        For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
+        wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
+        generate cpp wrapper code and compile it to a dynamic library in the second pass.
+        """
+        if "cuda" in self.device_types:
+            # first pass
+            self.cpp_wrapper = False
+            compiled = self.compile_to_module().call
+
+            def materialize(x):
+                if isinstance(x, (torch.SymInt, torch.SymFloat)):
+                    # Need concrete value to run dynamic shapes and tune the result
+                    return x.node.hint
+                elif isinstance(x, FakeTensor):
+                    return defake(x)
+                else:
+                    assert isinstance(
+                        x, torch.Tensor
+                    ), "Unknown type when creating real inputs"
+                    return x
+
+            with torch.utils._python_dispatch._disable_current_modes():
+                assert self.example_inputs is not None
+                real_inputs = [materialize(x) for x in self.example_inputs]
+                compiled(real_inputs)
+            del real_inputs
+
+            # second pass
+            # TODO: reuse self.scheduler from the first pass to speed up the second pass
+            self.cpp_wrapper = True
+            self.removed_buffers.clear()
+            self.inplaced_to_remove.clear()
+            return self.codegen()
+        else:
+            # cpu
+            return self.codegen()
+
     def codegen(self):
         from .scheduler import Scheduler
 
@@ -952,7 +995,9 @@
     def compile_to_module(self):
         from .codecache import PyCodeCache
 
-        code, linemap = self.codegen()
+        code, linemap = (
+            self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
+        )
         linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
         key, path = PyCodeCache.write(code)
         mod = PyCodeCache.load_by_key_path(
@@ -975,10 +1020,11 @@
         return mod
 
     def compile_to_fn(self):
-        if self.aot_mode and self.cpp_wrapper:
+        if self.aot_mode:
             from .codecache import AotCodeCache
 
-            code, linemap = self.codegen()
+            assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
+            code, linemap = self.codegen_with_cpp_wrapper()
             output_code_log.debug("Output code: \n%s", code)
 
             serialized_extern_kernel_nodes = None