[fx] Don't use generators in map_aggregate (#135082)
While the generators avoid a copy, they are slow.
Before:

After:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135082
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 456ab37..f84b23e 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -772,13 +772,16 @@
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
"""
if isinstance(a, tuple):
- t = tuple(map_aggregate(elem, fn) for elem in a)
+ t = tuple([map_aggregate(elem, fn) for elem in a])
# Support NamedTuple (if it has `_fields`) by repacking into original type.
return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type]
elif isinstance(a, list):
- return immutable_list(map_aggregate(elem, fn) for elem in a)
+ return immutable_list([map_aggregate(elem, fn) for elem in a])
elif isinstance(a, dict):
- return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
+ rv = immutable_dict()
+ for k, v in a.items():
+ dict.__setitem__(rv, k, map_aggregate(v, fn))
+ return rv
elif isinstance(a, slice):
return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn))
else: