Reorder statements for readability (#11764)
Summary:
I was reading this a couple times before figuring out it's also the entry point for the MPI_COMM_WORLD.
Reordered statements and added comment to clarify.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11764
Differential Revision: D9882834
Pulled By: pietern
fbshipit-source-id: a9282d55368815925fd695a2541354e5aec599da
diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp
index 0369996..63846b4 100644
--- a/torch/lib/c10d/ProcessGroupMPI.cpp
+++ b/torch/lib/c10d/ProcessGroupMPI.cpp
@@ -253,33 +253,30 @@
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
- globalLock.unlock();
-
if (rank < 0 || size < 0) {
throw std::runtime_error("Failed to get the world_size / rank");
}
+ // If no ranks are specified, assume we're creating the root group
if (ranks.empty()) {
- return std::make_shared<ProcessGroupMPI>(rank, size, MPI_COMM_WORLD);
- } else {
- std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
-
- MPI_Group worldGroup;
- MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
-
- MPI_Group ranksGroup;
- MPI_CHECK(
- MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
-
- MPI_Comm groupComm;
- MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm));
-
- MPI_CHECK(MPI_Group_free(&worldGroup));
- MPI_CHECK(MPI_Group_free(&ranksGroup));
-
globalLock.unlock();
- return std::make_shared<ProcessGroupMPI>(rank, size, groupComm);
+ return std::make_shared<ProcessGroupMPI>(rank, size, MPI_COMM_WORLD);
}
+
+ MPI_Group worldGroup;
+ MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
+
+ MPI_Group ranksGroup;
+ MPI_CHECK(MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
+
+ MPI_Comm groupComm;
+ MPI_CHECK(MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm));
+
+ MPI_CHECK(MPI_Group_free(&worldGroup));
+ MPI_CHECK(MPI_Group_free(&ranksGroup));
+
+ globalLock.unlock();
+ return std::make_shared<ProcessGroupMPI>(rank, size, groupComm);
}
ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)