Change direct uses of MutationOutput to `mark_node_as_mutating` (#127149)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127149
Approved by: https://github.com/oulgen
ghstack dependencies: #127148
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index b51aa49..807c2c0 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -4749,18 +4749,19 @@
return [i.get_name() for i in self.mutable_args]
-def mark_node_as_mutating(cur_buffer, *mutated_ops: IRNode):
+def mark_node_as_mutating(cur_buffer, *mutated_nodes: IRNode):
"""
- Allows ops in mutated_ops to be marked as being mutated as well as
+ Allows ops in mutated_nodes to be marked as being mutated as well as
indicates to the scheduler that these ops depend on cur_buffer.
+
+ NB: Use this instead of directly constructing MutationOutput
"""
- for op in mutated_ops:
+ for node in mutated_nodes:
assert isinstance(
- op, IRNode
- ), f"{op} op is type {type(op)} and is not an IRNode"
- V.graph.mark_buffer_mutated(op.get_name())
- assert hasattr(op, "layout")
- MutationOutput(op.layout, op, cur_buffer)
+ node, IRNode
+ ), f"{node} node is type {type(node)} and is not an IRNode"
+ V.graph.mark_buffer_mutated(node.get_name())
+ MutationOutput(node.get_layout(), node, cur_buffer)
class MutationOutput(ExternKernel):
@@ -4768,6 +4769,7 @@
return [self.inputs[0].get_name()]
def __init__(self, layout, input, parent):
+ # NB: Do not directly construct this - use `mark_node_as_mutating`
super().__init__(None, layout, [input, parent], ())
self.name = V.graph.register_buffer(self)
@@ -7424,7 +7426,7 @@
@property
def layout(self):
- return self.data.layout # type: ignore[attr-defined]
+ return self.data.get_layout()
def get_layout(self):
return self.layout
@@ -8245,12 +8247,7 @@
packed.cpp_kernel_name = cpp_kernel_name
packed.python_kernel_name = python_kernel_name
- def mark_mutation(x):
- if isinstance(x.data, BaseView):
- x = x.data.unwrap_view()
- MutationOutput(x.layout, x, packed)
-
- pytree.tree_map(lambda inp: mark_mutation(inp), inputs)
+ mark_node_as_mutating(packed, *pytree.tree_leaves(inputs))
# NOTE: [Out-of-Place Collective Safety]
# Between the initiation and completion of an out-of-place collective:
@@ -8366,9 +8363,8 @@
non_tensor_args,
unflatten_args,
)
- if isinstance(inp.data, BaseView):
- inp = inp.data.unwrap_view()
- MutationOutput(inp.layout, inp, packed)
+
+ mark_node_as_mutating(packed, inp)
def get_read_writes(self):
read_writes = super().get_read_writes()