fix node manipulation in partition class (#48016)

Summary:
This PR fixes the add_node and remove_node in partition class and also add a unit test for node manipulation in partition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48016

Reviewed By: gcatron

Differential Revision: D24996368

Pulled By: scottxu0730

fbshipit-source-id: 0ddffd5ed3f95e5285fffcaee8c4b671929b4df3
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 417dbce..7d11c22 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -138,6 +138,33 @@
         self.assertEqual(traced(a, b), module_with_submodules(a, b))
         assert dag.nodes[0].logical_device_ids == [0]
 
+    def test_partition_node_manipulation(self):
+        class TestModule(torch.nn.Module):
+            def forward(self, a, b):
+                add_1 = a + b
+                add_2 = add_1 + torch.rand(4)
+                add_3 = add_2 + torch.rand(4)
+                return add_3
+
+        m = TestModule()
+        traced = symbolic_trace(m)
+        a, b = torch.rand(4), torch.rand(4)
+        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
+        partitioner = Partitioner()
+        devices = [Device('dev_0', 1000, 0)]
+        partitioner_config = PartitionerConfig(devices)
+        ret = partitioner.partition_graph(traced, m, partitioner_config)
+        partition = partitioner.partitions[0]
+        assert partition.used_mem_bytes == 112
+        # Select add_3 node to remove
+        selected_node = None
+        for node in partition.nodes:
+            if node.name == 'add_3':
+                selected_node = node
+        partition.remove_node(selected_node)
+        assert(partition.used_mem_bytes == 80)
+
+
     def test_size_based_partition(self):
         class TestModule(torch.nn.Module):
             def __init__(self):
diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py
index e7f38e8..83e2144 100644
--- a/torch/fx/experimental/accelerator_partitioner.py
+++ b/torch/fx/experimental/accelerator_partitioner.py
@@ -368,7 +368,6 @@
                             partition.logical_device_ids.append(device.logical_id)
                     partition.add_node(node)
                     partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
-                    partition.used_mem_bytes += total_size_of_input_nodes
                 # No device left, create single node partitions
                 else:
                     self.create_single_node_partition(node)
@@ -428,7 +427,6 @@
         """
         partition = self.create_partition()
         partition.add_node(node)
-        partition.recalculate_mem_size()
         return
 
     def sparse_nn_partition(self, available_mem_bytes: int) -> None:
@@ -551,7 +549,6 @@
                     if total_size_of_input_nodes > available_mem_bytes:
                         raise RuntimeError(node.target + 'is too large to fit into a device')
                 partition.add_node(node)
-                partition.used_mem_bytes += total_size_of_input_nodes
         reset_partition_in_sparse_nn(partition, new_partition=False)
         # Set parents and children for partitions
         set_parents_and_children(self.partitions)
diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py
index a2e6332..0858256 100644
--- a/torch/fx/experimental/partitioner_utils.py
+++ b/torch/fx/experimental/partitioner_utils.py
@@ -30,6 +30,7 @@
             if n.op in {'placeholder', 'get_attr'}:
                 self.nodes.add(n)
         self.nodes.add(node)
+        self.recalculate_mem_size()
 
     def remove_node(self, node):
         # Remove a node only if the node is in the partition
@@ -43,8 +44,9 @@
             # and this input node is not used by some other nodes in this partition,
             # the remove this input node
             for input_node in input_nodes:
-                if all([n not in self.nodes for n in input_node.users]):
+                if all([n not in self.nodes for n in input_node.users]) and input_node.op in {'placeholder', 'get_attr'}:
                     self.nodes.remove(input_node)
+            self.recalculate_mem_size()
 
 class Device(NamedTuple):
     name: str