[dynamo][compile-time] Remove unnecessary tree_map_only (#121052)

Reduces the torch.compile(backend="eager") for this code by 1-2 seconds.

~~~
def fn(x):
    for _ in range(10000):
        # x = torch.sin(x)
        x = torch.ops.aten.sin(x)
        # x = sin(x)

    return x
~~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121052
Approved by: https://github.com/jansel
ghstack dependencies: #121053
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index b228efd..43c3c18 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -1593,14 +1593,16 @@
 def get_fake_values_from_nodes(tx, nodes, allow_non_graph_fake):
     def visit(n: torch.fx.Node):
         if n.op == "call_function" and "example_value" not in n.meta:
+            # fake tensor validity is checked inside get_fake_value using
+            # ensure_graph_fake
             return get_fake_value(n, tx, allow_non_graph_fake)
 
-        return n.meta["example_value"]
+        out = n.meta["example_value"]
+        if not allow_non_graph_fake and isinstance(out, torch.Tensor):
+            return ensure_graph_fake(out, tx)
+        return out
 
-    args_kwargs = torch.fx.node.map_arg(nodes, visit)
-    return tree_map_only(
-        torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), args_kwargs
-    )
+    return torch.fx.node.map_arg(nodes, visit)
 
 
 def get_fake_value(node, tx, allow_non_graph_fake=False):