[dtensor][debug] added c10d allgather, allgather_coalesced, and allgather_into_tensor_coalesced tracing to CommDebugMode (#127334)
**Summary**
Added c10d allgather, allgather_coalesced, and allgather_into_tensor_coalesced tracing to CommDebugMode and edited test case in test_comm_mode to include added features.
**Test Plan**
pytest test/distributed/_tensor/debug/test_comm_mode.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127334
Approved by: https://github.com/XilunWu, https://github.com/yifuwang
ghstack dependencies: #127025, #127029, #127040, #127134
diff --git a/test/distributed/_tensor/debug/test_comm_mode.py b/test/distributed/_tensor/debug/test_comm_mode.py
index 4143da2..0962d9c 100644
--- a/test/distributed/_tensor/debug/test_comm_mode.py
+++ b/test/distributed/_tensor/debug/test_comm_mode.py
@@ -160,6 +160,34 @@
comm_counts = comm_mode.get_comm_counts()
self.assertEqual(comm_counts[c10d_ops.scatter_], 1)
+ # tests c10d all_gather tracing
+ output_list = []
+
+ with comm_mode:
+ dist.all_gather(output_list, inp, None)
+
+ comm_counts = comm_mode.get_comm_counts()
+ self.assertEqual(comm_counts[c10d_ops.allgather_], 1)
+
+ # tests c10d allgather_coalesced_ tracing
+ output_list = []
+
+ with comm_mode:
+ dist.all_gather_coalesced(output_list, [inp], None)
+
+ comm_counts = comm_mode.get_comm_counts()
+ self.assertEqual(comm_counts[c10d_ops.allgather_coalesced_], 1)
+
+ # tests c10d allgather_into_tensor_coalesced_ tracing
+ comm_mode = CommDebugMode()
+ with comm_mode as A, dist._coalescing_manager() as B:
+ # dist.all_reduce_coalesced(inp)
+ dist.all_gather_into_tensor(all_gather_out, inp)
+
+ comm_counts = comm_mode.get_comm_counts()
+ self.assertEqual(comm_mode.get_total_counts(), 1)
+ self.assertEqual(comm_counts[c10d_ops.allgather_into_tensor_coalesced_], 1)
+
@requires_nccl()
def test_comm_mode_with_c10d_allreduce_coalesced(self):
world_pg = self.world_pg
diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py
index 62e10a1..d566da5 100644
--- a/torch/distributed/_tensor/debug/comm_mode.py
+++ b/torch/distributed/_tensor/debug/comm_mode.py
@@ -25,9 +25,12 @@
}
c10d_collective_ops = {
- c10d_ops.allreduce_,
c10d_ops._allgather_base_,
c10d_ops._reduce_scatter_base_,
+ c10d_ops.allgather_,
+ c10d_ops.allgather_coalesced_,
+ c10d_ops.allgather_into_tensor_coalesced_,
+ c10d_ops.allreduce_,
c10d_ops.allreduce_coalesced_,
c10d_ops.broadcast_,
c10d_ops.gather_,