[fx] Bypass custom __setattr__ in Node.__init__ (#135079)

Before:
![image](https://github.com/user-attachments/assets/5f0a6ae6-6049-44d0-b5f2-a549a23ad97f)

After:
![image](https://github.com/user-attachments/assets/51c9f91b-f8a0-4043-8362-65813feec823)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135079
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 62daf41..cb3fbeb 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -816,10 +816,10 @@
     def find_nodes(self, *, op: str, target: Optional['Target'] = None):
         if op == "call_function":
             assert target is not None
-            return dict(self.table[(op, target)]).keys()
+            return [*self.table[(op, target)].keys()]
 
         if target is None:
-            return dict(self.table[(op, None)]).keys()
+            return [*self.table[(op, None)].keys()]
 
         # op is call_method, get_attr, call_module
         return [node for node in self.table[(op, None)].keys() if node.target == target]
diff --git a/torch/fx/node.py b/torch/fx/node.py
index f84b23e..323f954 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -168,6 +168,16 @@
     """
     _args: Tuple['Argument', ...]
     _kwargs: Dict[str, 'Argument']
+    graph: 'Graph'
+    name: str
+    op: str
+    target: 'Target'
+    _input_nodes: Dict['Node', None]
+    users: Dict['Node', None]
+    type: Optional[Any]
+    _sort_key: Any
+    _repr_fn: Optional[Callable[['Node'], str]]
+    meta: Dict[str, Any]
 
     @compatibility(is_backward_compatible=True)
     def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target',
@@ -199,11 +209,7 @@
                 annotation of values in the generated code or for other types
                 of analyses.
         """
-        super().__init__()
-        self.graph = graph
-        self.name = name  # unique name of value being created
         assert op in _legal_ops
-        self.op = op  # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
         if op == 'call_function':
             if not callable(target):
                 raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
@@ -212,13 +218,22 @@
             if not isinstance(target, str):
                 raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} '
                                  'but a str is expected')
-        self.target = target  # for method/module/function, the name of the method/module/function/attr
+        super().__init__()
+
+        # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects
+        assign = object.__setattr__
+
+        assign(self, "graph", graph)
+        assign(self, "name", name)  # unique name of value being created
+        assign(self, "op", op)  # the kind of operation = placeholder|call_method|call_module|call_function|get_attr
+
+        assign(self, "target", target)  # for method/module/function, the name of the method/module/function/attr
         # being invoked, e.g add, layer1, or torch.add
 
         # All `Node`-valued inputs. Key is the Node, value is don't-care.
         # The public API for this is `all_input_nodes`, this private attribute
         # should not be accessed directly.
-        self._input_nodes : Dict[Node, None] = {}
+        assign(self, "_input_nodes", {})
         self.__update_args_kwargs(args, kwargs)
 
         # All of the nodes that use the value produced by this Node
@@ -226,7 +241,8 @@
         # would appear once here, but represents two uses.
         #
         # Is a dict to act as an "ordered set". Keys are significant, value dont-care
-        self.users : Dict[Node, None] = {}
+        assign(self, "users", {})
+
         # Type expression representing the output value of this node.
         # This should contain the same class of Type objects that would appear
         # as type annotations for function inputs/outputs.
@@ -237,15 +253,15 @@
         # generated function return type. (Note this is a special case. ``return``
         # does not produce a value, it's more of a notation. Thus, this value
         # describes the type of args[0] in the ``return`` node.
-        self.type : Optional[Any] = return_type
-        self._sort_key: Any = ()
+        assign(self, "type", return_type)
+        assign(self, "_sort_key", ())
 
         # If set, use this fn to print this node
-        self._repr_fn : Optional[Callable[[Node], str]] = None
+        assign(self, "_repr_fn", None)
 
         # Dictionary to store metadata passes need to do their
         # transformations. This metadata is preserved across node copies
-        self.meta : Dict[str, Any] = {}
+        assign(self, "meta", {})
 
     def __getstate__(self) -> Dict[str, Any]:
         state = self.__dict__.copy()
@@ -490,14 +506,14 @@
         # Clear prior users and input_nodes
         for old_use in self._input_nodes.keys():
             old_use.users.pop(self)
-        self._input_nodes = {}
+        object.__setattr__(self, "_input_nodes", {})  # bypass Node.__setattr__
 
         # We do three things in a single pass of the args
         # - Normalize list->immutable_list, dict->immutable_dict, etc
         # - Populate self._input_nodes
         # - Populate arg.users[self] for each arg
-        self._args = map_aggregate(new_args, update_users_and_input_nodes)  # type: ignore[assignment]
-        self._kwargs = map_aggregate(new_kwargs, update_users_and_input_nodes)  # type: ignore[assignment]
+        object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes))
+        object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes))
 
     def __repr__(self) -> str:
         if self._repr_fn:
@@ -740,23 +756,24 @@
         self.graph._graph_namespace._rename_object(self, name)
 
     def __setattr__(self, name: str, value: Any) -> None:
-        if name == 'name' and hasattr(self, "name"):
-            m = self.graph.owning_module
-            if getattr(m, "_replace_hook", None):
-                assert isinstance(value, str)
-                for user in self.users:
-                    m._replace_hook(old=self, new=value, user=user)
-        update = False
-        if (
-                hasattr(self, name) and
-                hasattr(self.graph, "_find_nodes_lookup_table") and
-                self in self.graph._find_nodes_lookup_table
-        ):
-            update = True
-            self.graph._find_nodes_lookup_table.remove(self)
+        if name in _setattr_custom_handling:
+            if name == "name":
+                m = self.graph.owning_module
+                if getattr(m, "_replace_hook", None):
+                    assert isinstance(value, str)
+                    for user in self.users:
+                        m._replace_hook(old=self, new=value, user=user)
+            elif name in ("op", "target"):
+                table = getattr(self.graph, "_find_nodes_lookup_table", None)
+                if table and self in table:
+                    self.graph._find_nodes_lookup_table.remove(self)
+                    object.__setattr__(self, name, value)
+                    self.graph._find_nodes_lookup_table.insert(self)
+                    return
+
         object.__setattr__(self, name, value)
-        if update:
-            self.graph._find_nodes_lookup_table.insert(self)
+
+_setattr_custom_handling = dict.fromkeys(["name", "op", "target"])
 
 @compatibility(is_backward_compatible=True)
 def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: