[fx] Bypass custom __setattr__ in Node.__init__ (#135079)
Before:

After:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135079
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 62daf41..cb3fbeb 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -816,10 +816,10 @@
def find_nodes(self, *, op: str, target: Optional['Target'] = None):
if op == "call_function":
assert target is not None
- return dict(self.table[(op, target)]).keys()
+ return [*self.table[(op, target)].keys()]
if target is None:
- return dict(self.table[(op, None)]).keys()
+ return [*self.table[(op, None)].keys()]
# op is call_method, get_attr, call_module
return [node for node in self.table[(op, None)].keys() if node.target == target]
diff --git a/torch/fx/node.py b/torch/fx/node.py
index f84b23e..323f954 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -168,6 +168,16 @@
"""
_args: Tuple['Argument', ...]
_kwargs: Dict[str, 'Argument']
+ graph: 'Graph'
+ name: str
+ op: str
+ target: 'Target'
+ _input_nodes: Dict['Node', None]
+ users: Dict['Node', None]
+ type: Optional[Any]
+ _sort_key: Any
+ _repr_fn: Optional[Callable[['Node'], str]]
+ meta: Dict[str, Any]
@compatibility(is_backward_compatible=True)
def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
@@ -199,11 +209,7 @@
annotation of values in the generated code or for other types
of analyses.
"""
- super().__init__()
- self.graph = graph
- self.name = name # unique name of value being created
assert op in _legal_ops
- self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
if op == 'call_function':
if not callable(target):
raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
@@ -212,13 +218,22 @@
if not isinstance(target, str):
raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
'but a str is expected')
- self.target = target # for method/module/function, the name of the method/module/function/attr
+ super().__init__()
+
+ # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
+ assign = object.__setattr__
+
+ assign(self, "graph", graph)
+ assign(self, "name", name) # unique name of value being created
+ assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
+
+ assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr
# being invoked, e.g add, layer1, or torch.add
# All `Node`-valued inputs. Key is the Node, value is don't-care.
# The public API for this is `all_input_nodes`, this private attribute
# should not be accessed directly.
- self._input_nodes : Dict[Node, None] = {}
+ assign(self, "_input_nodes", {})
self.__update_args_kwargs(args, kwargs)
# All of the nodes that use the value produced by this Node
@@ -226,7 +241,8 @@
# would appear once here, but represents two uses.
#
# Is a dict to act as an "ordered set". Keys are significant, value dont-care
- self.users : Dict[Node, None] = {}
+ assign(self, "users", {})
+
# Type expression representing the output value of this node.
# This should contain the same class of Type objects that would appear
# as type annotations for function inputs/outputs.
@@ -237,15 +253,15 @@
# generated function return type. (Note this is a special case. ``return``
# does not produce a value, it's more of a notation. Thus, this value
# describes the type of args[0] in the ``return`` node.
- self.type : Optional[Any] = return_type
- self._sort_key: Any = ()
+ assign(self, "type", return_type)
+ assign(self, "_sort_key", ())
# If set, use this fn to print this node
- self._repr_fn : Optional[Callable[[Node], str]] = None
+ assign(self, "_repr_fn", None)
# Dictionary to store metadata passes need to do their
# transformations. This metadata is preserved across node copies
- self.meta : Dict[str, Any] = {}
+ assign(self, "meta", {})
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
@@ -490,14 +506,14 @@
# Clear prior users and input_nodes
for old_use in self._input_nodes.keys():
old_use.users.pop(self)
- self._input_nodes = {}
+ object.__setattr__(self, "_input_nodes", {}) # bypass Node.__setattr__
# We do three things in a single pass of the args
# - Normalize list->immutable_list, dict->immutable_dict, etc
# - Populate self._input_nodes
# - Populate arg.users[self] for each arg
- self._args = map_aggregate(new_args, update_users_and_input_nodes) # type: ignore[assignment]
- self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes) # type: ignore[assignment]
+ object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes))
+ object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes))
def __repr__(self) -> str:
if self._repr_fn:
@@ -740,23 +756,24 @@
self.graph._graph_namespace._rename_object(self, name)
def __setattr__(self, name: str, value: Any) -> None:
- if name == 'name' and hasattr(self, "name"):
- m = self.graph.owning_module
- if getattr(m, "_replace_hook", None):
- assert isinstance(value, str)
- for user in self.users:
- m._replace_hook(old=self, new=value, user=user)
- update = False
- if (
- hasattr(self, name) and
- hasattr(self.graph, "_find_nodes_lookup_table") and
- self in self.graph._find_nodes_lookup_table
- ):
- update = True
- self.graph._find_nodes_lookup_table.remove(self)
+ if name in _setattr_custom_handling:
+ if name == "name":
+ m = self.graph.owning_module
+ if getattr(m, "_replace_hook", None):
+ assert isinstance(value, str)
+ for user in self.users:
+ m._replace_hook(old=self, new=value, user=user)
+ elif name in ("op", "target"):
+ table = getattr(self.graph, "_find_nodes_lookup_table", None)
+ if table and self in table:
+ self.graph._find_nodes_lookup_table.remove(self)
+ object.__setattr__(self, name, value)
+ self.graph._find_nodes_lookup_table.insert(self)
+ return
+
object.__setattr__(self, name, value)
- if update:
- self.graph._find_nodes_lookup_table.insert(self)
+
+_setattr_custom_handling = dict.fromkeys(["name", "op", "target"])
@compatibility(is_backward_compatible=True)
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: