[inductor] Fix an AOTInductor missing output issue (#105496)

Summary: When an output buffer is reused instead of directly referring to the passed-in output, we need to explictly make a copy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105496
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
new file mode 100644
index 0000000..51ee123
--- /dev/null
+++ b/test/inductor/test_aot_inductor.py
@@ -0,0 +1,116 @@
+# Owner(s): ["module: inductor"]
+
+
+import torch
+
+import torch._export
+import torch._inductor
+
+import torch.fx._pytree as fx_pytree
+
+from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase
+
+from torch.testing._internal.inductor_utils import HAS_CUDA
+from torch.utils import _pytree as pytree
+
+aten = torch.ops.aten
+
+
+class AOTInductorModelCache:
+    cache = dict()
+
+    @classmethod
+    def load(cls, model, example_inputs, example_outputs):
+        key = id(model)
+        if key not in cls.cache:
+            # AOTInductorModel relies on the caller to pass in output_tensors,
+            # so we need to explicitly allocate output tensors here.
+            output_tensors = []
+            example_outputs, output_spec = pytree.tree_flatten(example_outputs)
+            for output in example_outputs:
+                output_tensors.append(torch.empty_like(output))
+
+            # The exact API is subject to change
+            exported = torch._export.export(model, example_inputs)
+            param_buffer_values = list(exported.state_dict.values())
+            flat_example_inputs = fx_pytree.tree_flatten_spec(
+                example_inputs, exported.call_spec.in_spec
+            )
+            all_args = (*param_buffer_values, *flat_example_inputs)
+            # AOT compile into a .so
+            so_path = torch._inductor.aot_compile(exported.graph_module, all_args)
+
+            # Use a utility function for easier testing
+            source = """
+            #include <torch/csrc/inductor/aot_inductor_model.h>
+
+            torch::aot_inductor::AOTInductorModel model;
+
+            void run(
+                    const std::vector<at::Tensor>& input_tensors,
+                    std::vector<at::Tensor>& output_tensors) {
+                model.run(input_tensors, output_tensors, at::cuda::getCurrentCUDAStream());
+            }
+            """
+            module = torch.utils.cpp_extension.load_inline(
+                name="aot_inductor",
+                cpp_sources=[source],
+                functions=["run"],
+                extra_ldflags=[so_path],
+                with_cuda=True,
+            ).run
+
+            value = {
+                "module": module,
+                "exported": exported,
+                "output_tensors": output_tensors,
+                "output_spec": output_spec,
+            }
+            cls.cache[key] = value
+
+        return (
+            cls.cache[key]["module"],
+            cls.cache[key]["exported"],
+            cls.cache[key]["output_tensors"],
+            cls.cache[key]["output_spec"],
+        )
+
+
+class AotInductorTests(TestCase):
+    def test_missing_output(self):
+        class Repro(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x, y):
+                a = torch.sin(x)
+                b = torch.mm(a, y)
+                c = torch.cos(b)
+                return c
+
+        model = Repro()
+        example_inputs = [
+            torch.randn(10, 10, device="cuda"),
+            torch.randn(10, 10, device="cuda"),
+        ]
+        expected = model(*example_inputs)
+
+        optimized, exported, output_tensors, output_spec = AOTInductorModelCache.load(
+            model, example_inputs, expected
+        )
+        param_buffer_values = list(exported.state_dict.values())
+        flat_example_inputs = fx_pytree.tree_flatten_spec(
+            example_inputs, exported.call_spec.in_spec
+        )
+        all_args = (*param_buffer_values, *flat_example_inputs)
+        optimized(all_args, output_tensors)
+        actual = pytree.tree_unflatten(output_tensors, output_spec)
+
+        self.assertTrue(torch.allclose(actual, expected))
+
+
+if __name__ == "__main__":
+    from torch._dynamo.test_case import run_tests
+
+    if HAS_CUDA and not TEST_WITH_ROCM:
+        run_tests(needs="filelock")
diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py
index 4cf1bc9..0357fb5 100644
--- a/torch/_inductor/codegen/wrapper.py
+++ b/torch/_inductor/codegen/wrapper.py
@@ -1058,6 +1058,12 @@
     def generate_return(self, output_refs):
         # Output tensors are allocated by the AOT runtime.
         if V.graph.aot_mode:
+            for idx, output in enumerate(V.graph.graph_outputs):
+                if output.get_name() in self.reuses:
+                    # buffer was reused, so we need to explicitly copy it to the output tensor
+                    self.wrapper_call.writeline(
+                        f"outputs[{idx}].copy_({output.get_name()});"
+                    )
             self.wrapper_call.writeline("\n}")
         else:
             self.wrapper_call.writeline(f"return {{{', '.join(output_refs)}}};\n}}")
@@ -1194,31 +1200,34 @@
         cpp_device = self.codegen_device(device)
         return f"at::TensorOptions({cpp_device}).dtype({DTYPE_TO_ATEN[dtype]}))"
 
+    def codegen_allocation(self, buffer):
+        name = buffer.get_name()
+        # outputs are passed-in in the AOT mode
+        if V.graph.aot_mode and name in set(V.graph.get_output_names()):
+            output_idx = None
+            for idx, output in enumerate(V.graph.graph_outputs):
+                if hasattr(output, "get_name") and name == output.get_name():
+                    output_idx = idx
+                    break
+
+            assert output_idx is not None, "Unkown output index"
+            self.writeline(f"auto {name} = outputs[{output_idx}];")
+            return
+
+        super().codegen_allocation(buffer)
+
     def make_buffer_allocation(self, buffer):
-        output_idx = None
-        for idx, output in enumerate(V.graph.graph_outputs):
-            if isinstance(output, (ir.NoneAsConstantBuffer, ir.ShapeAsConstantBuffer)):
-                continue
-            if buffer == output.data:
-                output_idx = idx
-                break
-        if output_idx is not None and V.graph.aot_mode:
-            # In aot_mode, output buffers are managed by the AOT runtime.
-            return (
-                f"at::Tensor {buffer.get_name()} = outputs[{output_idx}]{self.ending}"
-            )
-        else:
-            # TODO: map layout here.
-            device = buffer.get_device()
-            dtype = buffer.get_dtype()
-            shape = tuple(buffer.get_size())
-            stride = tuple(buffer.get_stride())
-            return (
-                f"{self.declare}{buffer.get_name()} = {self.namespace}empty_strided("
-                f"{self.codegen_shape_tuple(shape)}, "
-                f"{self.codegen_shape_tuple(stride)}, "
-                f"{self.codegen_tensor_option(device, dtype)}{self.ending}"
-            )
+        # TODO: map layout here.
+        device = buffer.get_device()
+        dtype = buffer.get_dtype()
+        shape = tuple(buffer.get_size())
+        stride = tuple(buffer.get_stride())
+        return (
+            f"{self.declare}{buffer.get_name()} = {self.namespace}empty_strided("
+            f"{self.codegen_shape_tuple(shape)}, "
+            f"{self.codegen_shape_tuple(stride)}, "
+            f"{self.codegen_tensor_option(device, dtype)}{self.ending}"
+        )
 
     def generate_extern_kernel_alloc_and_find_schema_if_needed(
         self,