Revert "[aotinductor] Solves a problem where a tensor is returned more than once (#112177)"
This reverts commit a91baaf314999abaaf93260f87b1ee109bb36541.
Reverted https://github.com/pytorch/pytorch/pull/112177 on behalf of https://github.com/PaliC due to breaking internal tests (refer to internal diff) ([comment](https://github.com/pytorch/pytorch/pull/112177#issuecomment-1794153272))
diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py
index aa53ce6..bf20379 100644
--- a/test/inductor/test_aot_inductor.py
+++ b/test/inductor/test_aot_inductor.py
@@ -1011,18 +1011,6 @@
x = torch.randn(5, device=self.device)
self.check_model(Model(self.device), (x,))
- def test_repeat_output(self):
- class Model(torch.nn.Module):
- def __init__(self):
- super().__init__()
-
- def forward(self, x):
- y = torch.sin(x)
- return y, y
-
- example_inputs = (torch.randn(3, 10, device=self.device),)
- self.check_model(Model(), example_inputs)
-
class AOTInductorTestABICompatibleCpu(TestCase):
device = "cpu"
@@ -1048,8 +1036,6 @@
"test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
"test_normal_functional": TestFailure(("abi_compatible_cpu",)),
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
- # There is a double-free issue which will be fixed in another PR
- "test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True),
"test_sdpa": TestFailure(("abi_compatible_cpu",)),
"test_sdpa_2": TestFailure(("abi_compatible_cpu",)),
"test_simple_dynamic": TestFailure(("abi_compatible_cpu",)),
@@ -1072,8 +1058,6 @@
{
"test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
"test_normal_functional": TestFailure(("abi_compatible_cuda",)),
- # There is a double-free issue which will be fixed in another PR
- "test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
},
)
diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py
index 5f8780d..cd99efc 100644
--- a/torch/_export/exported_program.py
+++ b/torch/_export/exported_program.py
@@ -71,9 +71,7 @@
# Step 2: Find the all the buffers that were mutated and update them
if node.op == "output":
user_output_nodes = []
- # In the case that the same node is returned multiple times,
- # node.all_input_nodes will only iterate that node once
- for return_node in pytree.tree_flatten(node.args)[0]:
+ for return_node in node.all_input_nodes:
return_node_name = return_node.name
# we found a param/buffer mutation
if return_node_name in buffers_to_mutate: