Unlift mutated buffers (#107643)
In this PR, we extend ExportedProgram.module() functionality by also unlifting the mutated buffers. We only really care about top level buffers as we don't allow any buffer mutation inside HigherOrderOps.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107643
Approved by: https://github.com/avikchaudhuri
diff --git a/test/export/test_export.py b/test/export/test_export.py
index ea9a670..c80ae48 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -728,6 +728,126 @@
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
).run(ep.graph_module.code)
+ def test_to_module_with_mutated_buffer(self):
+
+ class Foo(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buf", torch.zeros(1))
+
+ def forward(self, x):
+ self.buf.add_(1)
+ return x.sum() + self.buf.sum()
+
+ exported = torch._export.export(Foo(), (torch.ones(5, 5),))
+ stateful_gm = exported.module()
+ export_return_val = stateful_gm(torch.ones(5, 5))
+ eager = Foo()
+ eager_return_val = eager(torch.ones(5, 5))
+ self.assertTrue(torch.allclose(eager_return_val, export_return_val))
+
+ for name, buffer in stateful_gm.named_buffers():
+ self.assertTrue(torch.allclose(torch.ones(1), buffer))
+
+ changed = stateful_gm.graph.eliminate_dead_code()
+ self.assertFalse(changed)
+ self.assertTrue(torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))))
+
+ for name, buffer in stateful_gm.named_buffers():
+ self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
+
+ def test_to_module_with_mutated_buffer_multiple(self):
+
+ class Bar(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buf", torch.ones(1))
+
+ def forward(self, x):
+ self.buf.add_(1)
+ return x.sum() + self.buf.sum()
+
+ class Foo(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buf", torch.zeros(1))
+ self.bar = Bar()
+
+ def forward(self, x):
+ self.buf.add_(1)
+ self.bar.buf.add_(2)
+ bar = self.bar(x)
+ return bar.sum() + self.buf.sum()
+
+ exported = torch._export.export(Foo(), (torch.ones(5, 5),))
+ stateful_gm = exported.module()
+ export_return_val = stateful_gm(torch.ones(5, 5))
+ eager = Foo()
+ eager_return_val = eager(torch.ones(5, 5))
+ self.assertTrue(torch.allclose(eager_return_val, export_return_val))
+
+ for name, buffer in stateful_gm.named_buffers():
+ if name == "L__self___buf":
+ self.assertTrue(torch.allclose(torch.ones(1), buffer))
+ if name == "L__self___bar_buf":
+ self.assertTrue(torch.allclose(torch.tensor(4, dtype=torch.float), buffer))
+
+ changed = stateful_gm.graph.eliminate_dead_code()
+ self.assertFalse(changed)
+ self.assertTrue(torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))))
+
+ for name, buffer in stateful_gm.named_buffers():
+ if name == "L__self___buf":
+ self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
+ if name == "L__self___bar_buf":
+ self.assertTrue(torch.allclose(torch.tensor(7, dtype=torch.float), buffer))
+
+ def test_to_module_with_mutated_buffer_multiple_update_sub_later(self):
+
+ class Bar(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buf", torch.ones(1))
+
+ def forward(self, x):
+ self.buf.add_(1)
+ return x.sum() + self.buf.sum()
+
+ class Foo(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("buf", torch.zeros(1))
+ self.bar = Bar()
+
+ def forward(self, x):
+ self.buf.add_(1)
+ bar = self.bar(x)
+ self.bar.buf.add_(2)
+ return bar.sum() + self.buf.sum()
+
+ exported = torch._export.export(Foo(), (torch.ones(5, 5),))
+ stateful_gm = exported.module()
+ export_return_val = stateful_gm(torch.ones(5, 5))
+ eager = Foo()
+ eager_return_val = eager(torch.ones(5, 5))
+ self.assertTrue(torch.allclose(eager_return_val, export_return_val))
+
+ for name, buffer in stateful_gm.named_buffers():
+ if name == "L__self___buf":
+ self.assertTrue(torch.allclose(torch.ones(1), buffer))
+ if name == "L__self___bar_buf":
+ self.assertTrue(torch.allclose(torch.tensor(4, dtype=torch.float), buffer))
+
+ changed = stateful_gm.graph.eliminate_dead_code()
+ self.assertFalse(changed)
+ self.assertTrue(torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))))
+
+ for name, buffer in stateful_gm.named_buffers():
+ if name == "L__self___buf":
+ self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
+ if name == "L__self___bar_buf":
+ self.assertTrue(torch.allclose(torch.tensor(7, dtype=torch.float), buffer))
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py
index 67a907d..9b6caf5 100644
--- a/torch/_export/exported_program.py
+++ b/torch/_export/exported_program.py
@@ -119,8 +119,10 @@
signature: Optional[ModuleCallSignature] = None
-def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
+def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict, buffers_to_mutate, user_outputs):
count = 0
+ buffer_name_to_node = {}
+ print(gm.graph)
# Step 1: make lifted params as get_attr
for node in gm.graph.nodes:
if node.op == "placeholder":
@@ -133,9 +135,32 @@
metadata = node.meta
gm.graph.erase_node(node)
getattr_node.meta = metadata
- count += 1
+ buffer_name_to_node[inp_pos_to_param_buffer_name[count]] = getattr_node
- # Step 2: Fix the input/output of the graph now that we deleted
+ count += 1
+ # Step 2: Find the all the buffers that were mutated and update them
+ if node.op == "output":
+ user_output_nodes = []
+ for return_node in node.all_input_nodes:
+ return_node_name = return_node.name
+ # we found a param/buffer mutation
+ if return_node_name in buffers_to_mutate:
+ buffer_node_name = buffers_to_mutate[return_node_name]
+ assert buffer_node_name in buffer_name_to_node
+ buffer_node = buffer_name_to_node[buffer_node_name]
+ with gm.graph.inserting_before(node):
+ buffer_update_node = gm.graph.call_function(
+ torch.ops.aten.copy_.default, (buffer_node, return_node)
+ )
+ else:
+ user_output_nodes.append(return_node)
+ with gm.graph.inserting_before(node):
+ # Only return user outputs
+ new_output = gm.graph.output(tuple(user_output_nodes))
+ node.replace_all_uses_with(new_output)
+ gm.graph.erase_node(node)
+
+ # Step 3: Fix the input/output of the graph now that we deleted
# some args.
gm.graph.lint()
names = [f"arg_{i}" for i in range(len(in_spec.children_specs))]
@@ -148,7 +173,7 @@
)
gm.recompile()
- # Step 3: Find state references in HigherOrderOps and recursively
+ # Step 4: Find state references in HigherOrderOps and recursively
# fix them.
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.cond:
@@ -174,6 +199,8 @@
in_spec,
None,
state_dict,
+ buffers_to_mutate,
+ user_outputs,
)
_unlift(
false_gm,
@@ -181,6 +208,8 @@
in_spec,
None,
state_dict,
+ buffers_to_mutate,
+ user_outputs,
)
if node.op == "call_function" and node.target.__name__ == "map_impl":
body_graph, num_mapped, *operands = node.args
@@ -198,7 +227,13 @@
_, in_spec = pytree.tree_flatten(real_operands)
_unlift(
- body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
+ body_gm,
+ inp_pos_to_buffer_name_for_submod,
+ in_spec,
+ None,
+ state_dict,
+ buffers_to_mutate,
+ user_outputs,
)
gm.graph.lint()
gm.graph.eliminate_dead_code()
@@ -254,6 +289,8 @@
ep.call_spec.in_spec,
ep.call_spec.out_spec,
ep.state_dict,
+ ep.graph_signature.buffers_to_mutate,
+ ep.graph_signature.user_outputs,
)
return new_gm