1 #include "caffe2/mpi/mpi_common.h" 5 #include <c10/util/typeid.h> 6 #include "caffe2/utils/proto_utils.h" 10 CAFFE_KNOWN_TYPE(MPICommonWorldWrapper);
12 static std::mutex gCaffe2MPIMutex;
14 std::mutex& MPIMutex() {
15 return gCaffe2MPIMutex;
18 static MPI_Comm gCaffe2MPIComm = MPI_COMM_WORLD;
21 return gCaffe2MPIComm;
25 if (gCaffe2MPIComm != MPI_COMM_WORLD) {
26 MPI_Comm_free(&gCaffe2MPIComm);
28 gCaffe2MPIComm = new_comm;
33 MPI_CHECK(MPI_Comm_size(comm, &comm_size));
39 MPI_CHECK(MPI_Comm_rank(comm, &comm_rank));
46 static MPI_Comm AssimilateComm(MPI_Comm intra, MPI_Comm inter) {
47 MPI_Comm peer = MPI_COMM_NULL;
48 MPI_Comm newInterComm = MPI_COMM_NULL;
49 MPI_Comm newIntraComm = MPI_COMM_NULL;
52 int high = (MPI_COMM_NULL == intra) ? 1 : 0;
56 if (MPI_COMM_NULL != inter) {
57 MPI_CHECK(MPI_Intercomm_merge(inter, high, &peer));
66 if (MPI_COMM_NULL != intra) {
68 MPI_CHECK(MPI_Intercomm_create(intra, 0, peer, 1, tag, &newInterComm));
72 MPI_Intercomm_create(MPI_COMM_SELF, 0, peer, 0, tag, &newInterComm));
76 MPI_CHECK(MPI_Intercomm_merge(newInterComm, high, &newIntraComm));
79 if (MPI_COMM_NULL != peer) {
80 MPI_CHECK(MPI_Comm_free(&peer));
82 MPI_CHECK(MPI_Comm_free(&newInterComm));
85 if (MPI_COMM_NULL != intra && MPI_COMM_WORLD != intra &&
87 MPI_CHECK(MPI_Comm_free(&intra));
97 const string& job_path) {
99 MPI_Initialized(&flag);
102 MPI_CHECK(MPI_Init_thread(
nullptr,
nullptr, MPI_THREAD_MULTIPLE, &mpi_ret));
103 if (mpi_ret != MPI_THREAD_MULTIPLE && mpi_ret != MPI_THREAD_SERIALIZED) {
104 LOG(FATAL) <<
"This test requires the underlying MPI to support the " 105 <<
"MPI_THREAD_SERIALIZED or MPI_THREAD_MULTIPLE mode.";
111 LOG(ERROR) <<
"MPI_COMM_WORLD size is not 1: did you already run " 112 "MPISetupPeers? Note that if you execute your program with " 113 "mpirun to launch multiple local processes, you should not " 114 "call MPISetupPeers.";
118 if (role ==
"server") {
120 char port_name[MPI_MAX_PORT_NAME] = {
'\0'};
121 MPI_CHECK(MPI_Open_port(MPI_INFO_NULL, port_name));
122 VLOG(1) <<
"MPI server: port: " << port_name;
125 CHECK(WriteStringToFile(std::string(port_name), job_path.c_str()));
126 VLOG(1) <<
"MPI server: wrote to file: " << job_path;
129 while (comm_size < replicas) {
131 VLOG(1) <<
"MPI server: waiting for client " 132 <<
"(" << comm_size <<
"/" << replicas <<
" have connected)";
134 MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &icomm));
135 VLOG(1) <<
"MPI server: accepted client";
136 MPI_Comm new_intra_comm = AssimilateComm(
GlobalMPIComm(), icomm);
142 std::string port_name;
143 while (!ReadStringFromFile(job_path.c_str(), &port_name) ||
144 port_name.length() == 0) {
146 std::this_thread::sleep_for(std::chrono::seconds(1));
151 VLOG(1) <<
"MPI client: connecting to port: " << port_name;
152 MPI_CHECK(MPI_Comm_connect(
153 const_cast<char*>(port_name.c_str()),
159 VLOG(1) <<
"MPI client: connected";
162 MPI_Comm new_intra_comm = AssimilateComm(MPI_COMM_NULL, icomm);
167 MPI_Comm comm = AssimilateComm(
GlobalMPIComm(), MPI_COMM_NULL);
174 VLOG(1) <<
"MPI using a communicator of size: " void SetGlobalMPIComm(MPI_Comm new_comm)
Sets the global MPI communicator.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
MPI_Comm GlobalMPIComm()
Gets the global MPI communicator used by Caffe2.
void MPISetupPeers(const int replicas, const string &role, const string &job_path)
A function used to perform peer setup so one does not need to use mpirun / mpiexec to run the binary...
int MPICommSize(MPI_Comm comm)
A helper function to return the size of the given communicator.
int MPICommRank(MPI_Comm comm)
A helper function to return the rank of the given communicator.