Change direct uses of MutationOutput to `mark_node_as_mutating` (#127149)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127149
Approved by: https://github.com/oulgen
ghstack dependencies: #127148
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index b51aa49..807c2c0 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -4749,18 +4749,19 @@
         return [i.get_name() for i in self.mutable_args]
 
 
-def mark_node_as_mutating(cur_buffer, *mutated_ops: IRNode):
+def mark_node_as_mutating(cur_buffer, *mutated_nodes: IRNode):
     """
-    Allows ops in mutated_ops to be marked as being mutated as well as
+    Allows ops in mutated_nodes to be marked as being mutated as well as
     indicates to the scheduler that these ops depend on cur_buffer.
+
+    NB: Use this instead of directly constructing MutationOutput
     """
-    for op in mutated_ops:
+    for node in mutated_nodes:
         assert isinstance(
-            op, IRNode
-        ), f"{op} op is type {type(op)} and is not an IRNode"
-        V.graph.mark_buffer_mutated(op.get_name())
-        assert hasattr(op, "layout")
-        MutationOutput(op.layout, op, cur_buffer)
+            node, IRNode
+        ), f"{node} node is type {type(node)} and is not an IRNode"
+        V.graph.mark_buffer_mutated(node.get_name())
+        MutationOutput(node.get_layout(), node, cur_buffer)
 
 
 class MutationOutput(ExternKernel):
@@ -4768,6 +4769,7 @@
         return [self.inputs[0].get_name()]
 
     def __init__(self, layout, input, parent):
+        # NB: Do not directly construct this - use `mark_node_as_mutating`
         super().__init__(None, layout, [input, parent], ())
         self.name = V.graph.register_buffer(self)
 
@@ -7424,7 +7426,7 @@
 
     @property
     def layout(self):
-        return self.data.layout  # type: ignore[attr-defined]
+        return self.data.get_layout()
 
     def get_layout(self):
         return self.layout
@@ -8245,12 +8247,7 @@
         packed.cpp_kernel_name = cpp_kernel_name
         packed.python_kernel_name = python_kernel_name
 
-        def mark_mutation(x):
-            if isinstance(x.data, BaseView):
-                x = x.data.unwrap_view()
-            MutationOutput(x.layout, x, packed)
-
-        pytree.tree_map(lambda inp: mark_mutation(inp), inputs)
+        mark_node_as_mutating(packed, *pytree.tree_leaves(inputs))
 
     # NOTE: [Out-of-Place Collective Safety]
     # Between the initiation and completion of an out-of-place collective:
@@ -8366,9 +8363,8 @@
             non_tensor_args,
             unflatten_args,
         )
-        if isinstance(inp.data, BaseView):
-            inp = inp.data.unwrap_view()
-        MutationOutput(inp.layout, inp, packed)
+
+        mark_node_as_mutating(packed, inp)
 
     def get_read_writes(self):
         read_writes = super().get_read_writes()