Don't iterate over graph when adding graph input (#89084)
helps with https://github.com/pytorch/torchdynamo/issues/1803
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89084
Approved by: https://github.com/jansel
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index ee50795..4578fb9 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -110,6 +110,10 @@
self.tensor_id_to_sym_shape_ref = {}
self.intermediary_symbols = {}
+ # Enables creating unique node names by tracking
+ # all current placeholder node names
+ self.name_to_input = collections.OrderedDict()
+
@property
def output(self):
return self
@@ -147,6 +151,7 @@
del node.meta["example_value"]
self.graph.erase_node(node)
self.real_value_cache.pop(node, None)
+ self.name_to_input.pop(node.name, None)
def count_calls(self):
return count_calls(self.graph)
@@ -162,22 +167,22 @@
return obj
def create_graph_input(self, name, type_expr=None):
- placeholders = [n for n in self.graph.nodes if n.op == "placeholder"]
-
# unique
- used_names = {n.target for n in placeholders}
- if name in used_names:
+ if name in self.name_to_input:
for i in itertools.count():
- if f"{name}_{i}" not in used_names:
+ if f"{name}_{i}" not in self.name_to_input:
name = f"{name}_{i}"
break
- if placeholders:
- ctx = self.graph.inserting_after(placeholders[-1])
+ if self.name_to_input:
+ prev_name = next(reversed(self.name_to_input))
+ ctx = self.graph.inserting_after(self.name_to_input[prev_name])
else:
ctx = self.graph.inserting_before(None)
with ctx:
- return self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
+ proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
+ self.name_to_input[name] = proxy.node
+ return proxy
def new_var(self, name="tmp"):
existing = set(self.code_options["co_varnames"])
@@ -490,6 +495,7 @@
del node.meta["example_value"]
self.graph.erase_node(node)
self.real_value_cache.pop(node, None)
+ self.name_to_input.pop(node.name, None)
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
@@ -525,6 +531,7 @@
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()
+ self.name_to_input.clear()
def create_proxy(
self,