Caffe2 - C++ API
A deep learning, cross platform ML framework
mpi_common.cc
1 #include "caffe2/mpi/mpi_common.h"
2 
3 #include <thread>
4 
5 #include <c10/util/typeid.h>
6 #include "caffe2/utils/proto_utils.h"
7 
8 namespace caffe2 {
9 
10 CAFFE_KNOWN_TYPE(MPICommonWorldWrapper);
11 
12 static std::mutex gCaffe2MPIMutex;
13 
14 std::mutex& MPIMutex() {
15  return gCaffe2MPIMutex;
16 }
17 
18 static MPI_Comm gCaffe2MPIComm = MPI_COMM_WORLD;
19 
20 MPI_Comm GlobalMPIComm() {
21  return gCaffe2MPIComm;
22 }
23 
24 void SetGlobalMPIComm(MPI_Comm new_comm) {
25  if (gCaffe2MPIComm != MPI_COMM_WORLD) {
26  MPI_Comm_free(&gCaffe2MPIComm);
27  }
28  gCaffe2MPIComm = new_comm;
29 }
30 
31 int MPICommSize(MPI_Comm comm) {
32  int comm_size;
33  MPI_CHECK(MPI_Comm_size(comm, &comm_size));
34  return comm_size;
35 }
36 
37 int MPICommRank(MPI_Comm comm) {
38  int comm_rank;
39  MPI_CHECK(MPI_Comm_rank(comm, &comm_rank));
40  return comm_rank;
41 }
42 
46 static MPI_Comm AssimilateComm(MPI_Comm intra, MPI_Comm inter) {
47  MPI_Comm peer = MPI_COMM_NULL;
48  MPI_Comm newInterComm = MPI_COMM_NULL;
49  MPI_Comm newIntraComm = MPI_COMM_NULL;
50 
51  // The spawned rank will be the "high" rank in the new intra-comm
52  int high = (MPI_COMM_NULL == intra) ? 1 : 0;
53 
54  // If this is one of the (two) ranks in the inter-comm,
55  // create a new intra-comm from the inter-comm
56  if (MPI_COMM_NULL != inter) {
57  MPI_CHECK(MPI_Intercomm_merge(inter, high, &peer));
58  } else {
59  peer = MPI_COMM_NULL;
60  }
61 
62  // Create a new inter-comm between the pre-existing intra-comm
63  // (all of it, not only rank zero), and the remote (spawned) rank,
64  // using the just-created intra-comm as the peer communicator.
65  int tag = 12345;
66  if (MPI_COMM_NULL != intra) {
67  // This task is a member of the pre-existing intra-comm
68  MPI_CHECK(MPI_Intercomm_create(intra, 0, peer, 1, tag, &newInterComm));
69  } else {
70  // This is the remote (spawned) task
71  MPI_CHECK(
72  MPI_Intercomm_create(MPI_COMM_SELF, 0, peer, 0, tag, &newInterComm));
73  }
74 
75  // Now convert this inter-comm into an intra-comm
76  MPI_CHECK(MPI_Intercomm_merge(newInterComm, high, &newIntraComm));
77 
78  // Clean up the intermediaries
79  if (MPI_COMM_NULL != peer) {
80  MPI_CHECK(MPI_Comm_free(&peer));
81  }
82  MPI_CHECK(MPI_Comm_free(&newInterComm));
83 
84  // Delete the original intra-comm
85  if (MPI_COMM_NULL != intra && MPI_COMM_WORLD != intra &&
86  GlobalMPIComm() != intra) {
87  MPI_CHECK(MPI_Comm_free(&intra));
88  }
89 
90  // Return the new intra-comm
91  return newIntraComm;
92 }
93 
95  const int replicas,
96  const string& role,
97  const string& job_path) {
98  int flag;
99  MPI_Initialized(&flag);
100  if (!flag) {
101  int mpi_ret;
102  MPI_CHECK(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpi_ret));
103  if (mpi_ret != MPI_THREAD_MULTIPLE && mpi_ret != MPI_THREAD_SERIALIZED) {
104  LOG(FATAL) << "This test requires the underlying MPI to support the "
105  << "MPI_THREAD_SERIALIZED or MPI_THREAD_MULTIPLE mode.";
106  return;
107  }
108  }
109 
110  if (MPICommSize(MPI_COMM_WORLD) != 1) {
111  LOG(ERROR) << "MPI_COMM_WORLD size is not 1: did you already run "
112  "MPISetupPeers? Note that if you execute your program with "
113  "mpirun to launch multiple local processes, you should not "
114  "call MPISetupPeers.";
115  return;
116  }
117 
118  if (role == "server") {
119  // Open a port to accept connections.
120  char port_name[MPI_MAX_PORT_NAME] = {'\0'};
121  MPI_CHECK(MPI_Open_port(MPI_INFO_NULL, port_name));
122  VLOG(1) << "MPI server: port: " << port_name;
123 
124  // Writes the port name to the file.
125  CHECK(WriteStringToFile(std::string(port_name), job_path.c_str()));
126  VLOG(1) << "MPI server: wrote to file: " << job_path;
127 
128  int comm_size = MPICommSize(GlobalMPIComm());
129  while (comm_size < replicas) {
130  MPI_Comm icomm;
131  VLOG(1) << "MPI server: waiting for client "
132  << "(" << comm_size << "/" << replicas << " have connected)";
133  MPI_CHECK(
134  MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &icomm));
135  VLOG(1) << "MPI server: accepted client";
136  MPI_Comm new_intra_comm = AssimilateComm(GlobalMPIComm(), icomm);
137  SetGlobalMPIComm(new_intra_comm);
138  comm_size = MPICommSize(new_intra_comm);
139  }
140  } else {
141  // Opens the job path file to obtain server address.
142  std::string port_name;
143  while (!ReadStringFromFile(job_path.c_str(), &port_name) ||
144  port_name.length() == 0) {
145  /* sleep override */
146  std::this_thread::sleep_for(std::chrono::seconds(1));
147  }
148 
149  // Connect to server.
150  MPI_Comm icomm;
151  VLOG(1) << "MPI client: connecting to port: " << port_name;
152  MPI_CHECK(MPI_Comm_connect(
153  const_cast<char*>(port_name.c_str()),
154  MPI_INFO_NULL,
155  0,
156  GlobalMPIComm(),
157  &icomm));
158 
159  VLOG(1) << "MPI client: connected";
160 
161  // Join the server's reference intracommunicator.
162  MPI_Comm new_intra_comm = AssimilateComm(MPI_COMM_NULL, icomm);
163  SetGlobalMPIComm(new_intra_comm);
164 
165  // Let other clients join the intracommunicator we're now a part of.
166  while (MPICommSize(GlobalMPIComm()) < replicas) {
167  MPI_Comm comm = AssimilateComm(GlobalMPIComm(), MPI_COMM_NULL);
168  SetGlobalMPIComm(comm);
169  }
170  }
171 
172  // After all peers have assimilated, do a barrier.
173  MPI_Barrier(GlobalMPIComm());
174  VLOG(1) << "MPI using a communicator of size: "
176 }
177 
178 } // namespace caffe2
void SetGlobalMPIComm(MPI_Comm new_comm)
Sets the global MPI communicator.
Definition: mpi_common.cc:24
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
MPI_Comm GlobalMPIComm()
Gets the global MPI communicator used by Caffe2.
Definition: mpi_common.cc:20
void MPISetupPeers(const int replicas, const string &role, const string &job_path)
A function used to perform peer setup so one does not need to use mpirun / mpiexec to run the binary...
Definition: mpi_common.cc:94
int MPICommSize(MPI_Comm comm)
A helper function to return the size of the given communicator.
Definition: mpi_common.cc:31
int MPICommRank(MPI_Comm comm)
A helper function to return the rank of the given communicator.
Definition: mpi_common.cc:37