|  | #ifndef CAFFE2_MPI_MPI_COMMON_H_ | 
|  | #define CAFFE2_MPI_MPI_COMMON_H_ | 
|  |  | 
|  | #include <mpi.h> | 
|  | #include <mutex> | 
|  |  | 
|  | #include "caffe2/core/common.h" | 
|  | #include "caffe2/core/logging.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | inline void CheckInitializedMPI() { | 
|  | int flag; | 
|  | MPI_Initialized(&flag); | 
|  | CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized."); | 
|  | } | 
|  |  | 
|  | template <typename T> | 
|  | class MPIDataTypeWrapper; | 
|  |  | 
|  | #define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \ | 
|  | template <>                                  \ | 
|  | class MPIDataTypeWrapper<c_type> {           \ | 
|  | public:                                     \ | 
|  | inline static MPI_Datatype type() {        \ | 
|  | return mpi_type;                         \ | 
|  | }                                          \ | 
|  | }; | 
|  |  | 
|  | MPI_DATATYPE_WRAPPER(char, MPI_CHAR) | 
|  | MPI_DATATYPE_WRAPPER(float, MPI_FLOAT) | 
|  | MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE) | 
|  | // Note(Yangqing): as necessary, add more specializations. | 
|  | #undef MPI_DATATYPE_WRAPPER | 
|  |  | 
|  | // For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard. | 
|  | TORCH_API std::mutex& MPIMutex(); | 
|  |  | 
|  | #define MPI_CHECK(condition)                                 \ | 
|  | do {                                                       \ | 
|  | std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \ | 
|  | int error = (condition);                                 \ | 
|  | CAFFE_ENFORCE(                                           \ | 
|  | error == MPI_SUCCESS,                                \ | 
|  | "Caffe2 MPI Error at: ",                             \ | 
|  | __FILE__,                                            \ | 
|  | ":",                                                 \ | 
|  | __LINE__,                                            \ | 
|  | ": ",                                                \ | 
|  | error);                                              \ | 
|  | } while (0) | 
|  |  | 
|  | /** | 
|  | * @brief Gets the global MPI communicator used by Caffe2. In default, this | 
|  | * is MPI_COMM_WORLD unless you call SetGlobalMPIComm(). | 
|  | */ | 
|  | TORCH_API MPI_Comm GlobalMPIComm(); | 
|  |  | 
|  | /** | 
|  | * @brief Sets the global MPI communicator. Caffe2 takes over the ownership | 
|  | * of the passed in communicator. | 
|  | */ | 
|  | TORCH_API void SetGlobalMPIComm(MPI_Comm new_comm); | 
|  |  | 
|  | /** | 
|  | * @brief A helper function to return the size of the given communicator. | 
|  | */ | 
|  | TORCH_API int MPICommSize(MPI_Comm comm); | 
|  |  | 
|  | /** | 
|  | * @brief A helper function to return the rank of the given communicator. | 
|  | */ | 
|  | TORCH_API int MPICommRank(MPI_Comm comm); | 
|  |  | 
|  | /** | 
|  | * @brief A simple wrapper over an MPI common world. | 
|  | */ | 
|  | class MPICommonWorldWrapper { | 
|  | public: | 
|  | /** | 
|  | * @brief Creates a common world wrapper. | 
|  | * | 
|  | * The new common world is created by taking the existing communicator | 
|  | * passed in as src_comm, and splitting it using the color and the rank | 
|  | * specified. In default, we will split from Caffe2's global communicator, | 
|  | * and use color 0 as well as rank implicitly given by src_comm. As a result, | 
|  | * the default constructor basically creates a comm identical to the source | 
|  | * comm world. | 
|  | */ | 
|  | explicit MPICommonWorldWrapper( | 
|  | MPI_Comm src_comm = MPI_COMM_NULL, | 
|  | int color = 0, | 
|  | int rank = -1) { | 
|  | if (src_comm == MPI_COMM_NULL) { | 
|  | src_comm = GlobalMPIComm(); | 
|  | } | 
|  | if (rank == -1) { | 
|  | MPI_CHECK(MPI_Comm_rank(src_comm, &rank)); | 
|  | } | 
|  | MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_)); | 
|  | MPI_CHECK(MPI_Comm_size(comm_, &size_)); | 
|  | MPI_CHECK(MPI_Comm_rank(comm_, &rank_)); | 
|  | } | 
|  |  | 
|  | ~MPICommonWorldWrapper() { | 
|  | int ret; | 
|  | MPI_CHECK(MPI_Finalized(&ret)); | 
|  | if (!ret) { | 
|  | MPI_Comm_free(&comm_); | 
|  | } | 
|  | } | 
|  |  | 
|  | /** | 
|  | * @brief Returns the common world held by the wrapper. | 
|  | */ | 
|  | inline MPI_Comm comm() const { | 
|  | return comm_; | 
|  | } | 
|  | /** | 
|  | * @brief Returns the size of the world. | 
|  | */ | 
|  | inline int size() const { | 
|  | return size_; | 
|  | } | 
|  | /** | 
|  | * @brief Returns the rank of this process in the world. | 
|  | */ | 
|  | inline int rank() const { | 
|  | return rank_; | 
|  | } | 
|  |  | 
|  | private: | 
|  | MPI_Comm comm_; | 
|  | int size_; | 
|  | int rank_; | 
|  | }; | 
|  |  | 
|  | /** | 
|  | * A function used to perform peer setup so one does not need to use | 
|  | * mpirun / mpiexec to run the binary. Note that if you use mpirun or mpiexec | 
|  | * to set up the common world, do not use this function - MPI_Init would have | 
|  | * already set that up. | 
|  | * | 
|  | * This also assumes that you have a common path (like NFS) that multiple | 
|  | * instances can read from. | 
|  | * | 
|  | * Inputs: | 
|  | *   replicas (int): the number of replicas that mpi will run with. | 
|  | *   role (string): the role of this process, "server" or "client". | 
|  | *   job_path (string): a file name that the server will write its port into | 
|  | *       and the clients will read the server's port from. | 
|  | */ | 
|  | void MPISetupPeers( | 
|  | const int replicas, | 
|  | const string& role, | 
|  | const string& job_path); | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_MPI_MPI_COMMON_H_ |