fix node manipulation in partition class (#48016)
Summary:
This PR fixes the add_node and remove_node in partition class and also add a unit test for node manipulation in partition
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48016
Reviewed By: gcatron
Differential Revision: D24996368
Pulled By: scottxu0730
fbshipit-source-id: 0ddffd5ed3f95e5285fffcaee8c4b671929b4df3
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 417dbce..7d11c22 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -138,6 +138,33 @@
self.assertEqual(traced(a, b), module_with_submodules(a, b))
assert dag.nodes[0].logical_device_ids == [0]
+ def test_partition_node_manipulation(self):
+ class TestModule(torch.nn.Module):
+ def forward(self, a, b):
+ add_1 = a + b
+ add_2 = add_1 + torch.rand(4)
+ add_3 = add_2 + torch.rand(4)
+ return add_3
+
+ m = TestModule()
+ traced = symbolic_trace(m)
+ a, b = torch.rand(4), torch.rand(4)
+ graph_manipulation.get_size_of_all_nodes(traced, [a, b])
+ partitioner = Partitioner()
+ devices = [Device('dev_0', 1000, 0)]
+ partitioner_config = PartitionerConfig(devices)
+ ret = partitioner.partition_graph(traced, m, partitioner_config)
+ partition = partitioner.partitions[0]
+ assert partition.used_mem_bytes == 112
+ # Select add_3 node to remove
+ selected_node = None
+ for node in partition.nodes:
+ if node.name == 'add_3':
+ selected_node = node
+ partition.remove_node(selected_node)
+ assert(partition.used_mem_bytes == 80)
+
+
def test_size_based_partition(self):
class TestModule(torch.nn.Module):
def __init__(self):
diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py
index e7f38e8..83e2144 100644
--- a/torch/fx/experimental/accelerator_partitioner.py
+++ b/torch/fx/experimental/accelerator_partitioner.py
@@ -368,7 +368,6 @@
partition.logical_device_ids.append(device.logical_id)
partition.add_node(node)
partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
- partition.used_mem_bytes += total_size_of_input_nodes
# No device left, create single node partitions
else:
self.create_single_node_partition(node)
@@ -428,7 +427,6 @@
"""
partition = self.create_partition()
partition.add_node(node)
- partition.recalculate_mem_size()
return
def sparse_nn_partition(self, available_mem_bytes: int) -> None:
@@ -551,7 +549,6 @@
if total_size_of_input_nodes > available_mem_bytes:
raise RuntimeError(node.target + 'is too large to fit into a device')
partition.add_node(node)
- partition.used_mem_bytes += total_size_of_input_nodes
reset_partition_in_sparse_nn(partition, new_partition=False)
# Set parents and children for partitions
set_parents_and_children(self.partitions)
diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py
index a2e6332..0858256 100644
--- a/torch/fx/experimental/partitioner_utils.py
+++ b/torch/fx/experimental/partitioner_utils.py
@@ -30,6 +30,7 @@
if n.op in {'placeholder', 'get_attr'}:
self.nodes.add(n)
self.nodes.add(node)
+ self.recalculate_mem_size()
def remove_node(self, node):
# Remove a node only if the node is in the partition
@@ -43,8 +44,9 @@
# and this input node is not used by some other nodes in this partition,
# the remove this input node
for input_node in input_nodes:
- if all([n not in self.nodes for n in input_node.users]):
+ if all([n not in self.nodes for n in input_node.users]) and input_node.op in {'placeholder', 'get_attr'}:
self.nodes.remove(input_node)
+ self.recalculate_mem_size()
class Device(NamedTuple):
name: str