Revert "DCE inference graphs too (#97275)"
This reverts commit aa3a57b80d39fc803f3f85e6a84a49926d99b4ba.
Reverted https://github.com/pytorch/pytorch/pull/97275 on behalf of https://github.com/ezyang due to this broke a test
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index e7704d6..512a046 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -1258,16 +1258,12 @@
with enable_python_dispatcher():
fw_module = make_fx(trace_fn, decomposition_table=aot_config.decompositions)(*flat_args)
- # As long as we opted to remove input mutations, then
- # there should be *NO* mutating ops in the graph at this point.
- copy_count = assert_functional_graph(fw_module.graph, allow_input_mutations=aot_config.keep_inference_input_mutations)
-
- fw_module.graph.eliminate_dead_code()
- fw_module.recompile()
-
- copy_count2 = assert_functional_graph(fw_module.graph, allow_input_mutations=aot_config.keep_inference_input_mutations)
-
- assert copy_count == copy_count2
+ if not aot_config.keep_inference_input_mutations:
+ # As long as we opted to remove input mutations, then
+ # there should be *NO* mutating ops in the graph at this point.
+ assert_functional_graph(fw_module.graph)
+ fw_module.graph.eliminate_dead_code()
+ fw_module.recompile()
aot_graphs_log.info(format_graph_code(f"====== Forward graph {aot_config.aot_id} ======\n", fw_module))
@@ -1288,27 +1284,11 @@
return compiled_fn
-# Returns the number of detected copy_
-def assert_functional_graph(fx_g: torch.fx.Graph, *, allow_input_mutations: bool = False) -> int:
- placeholders = set()
- copy_count = 0
- # NB: It would also be nice to verify that the mutations all happen at the
- # end, but we also do some administrative views after mutations so this
- # isn't actually true. (TODO: Could this cause problems for Inductor?)
+def assert_functional_graph(fx_g: torch.fx.Graph):
for n in fx_g.nodes:
- if n.op == "placeholder":
- placeholders.add(n)
if isinstance(n.target, torch._ops.OpOverload):
- if n.target is aten.copy_.default and allow_input_mutations:
- suffix = True
- # Can only copy_ into an input, and can only do so once
- assert n.args[0] in placeholders
- placeholders.remove(n.args[0])
- copy_count += 1
- else:
- assert not n.target._schema.is_mutable, \
- f'aot_autograd expected to have an entirely functional graph, but found {n.format_node()}'
- return copy_count
+ assert not n.target._schema.is_mutable, \
+ f'aot_autograd expected to have an entirely functional graph, but found {n.format_node()}'
@contextmanager
diff --git a/torch/fx/node.py b/torch/fx/node.py
index fb847df..6745667 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -32,7 +32,6 @@
_side_effectful_functions: Set[Callable] = {
torch._assert,
- _ops.aten.copy_.default,
_ops.profiler._record_function_enter,
_ops.profiler._record_function_enter_new,
_ops.profiler._record_function_exit}