Caffe2 - C++ API
A deep learning, cross platform ML framework
mpi_common.cc
1 
17 #include "caffe2/mpi/mpi_common.h"
18 
19 #include <thread>
20 
21 #include "caffe2/core/typeid.h"
22 #include "caffe2/utils/proto_utils.h"
23 
24 namespace caffe2 {
25 
26 CAFFE_KNOWN_TYPE(MPICommonWorldWrapper);
27 
28 static std::mutex gCaffe2MPIMutex;
29 
30 std::mutex& MPIMutex() {
31  return gCaffe2MPIMutex;
32 }
33 
34 static MPI_Comm gCaffe2MPIComm = MPI_COMM_WORLD;
35 
36 MPI_Comm GlobalMPIComm() {
37  return gCaffe2MPIComm;
38 }
39 
40 void SetGlobalMPIComm(MPI_Comm new_comm) {
41  if (gCaffe2MPIComm != MPI_COMM_WORLD) {
42  MPI_Comm_free(&gCaffe2MPIComm);
43  }
44  gCaffe2MPIComm = new_comm;
45 }
46 
47 int MPICommSize(MPI_Comm comm) {
48  int comm_size;
49  MPI_CHECK(MPI_Comm_size(comm, &comm_size));
50  return comm_size;
51 }
52 
53 int MPICommRank(MPI_Comm comm) {
54  int comm_rank;
55  MPI_CHECK(MPI_Comm_rank(comm, &comm_rank));
56  return comm_rank;
57 }
58 
62 static MPI_Comm AssimilateComm(MPI_Comm intra, MPI_Comm inter) {
63  MPI_Comm peer = MPI_COMM_NULL;
64  MPI_Comm newInterComm = MPI_COMM_NULL;
65  MPI_Comm newIntraComm = MPI_COMM_NULL;
66 
67  // The spawned rank will be the "high" rank in the new intra-comm
68  int high = (MPI_COMM_NULL == intra) ? 1 : 0;
69 
70  // If this is one of the (two) ranks in the inter-comm,
71  // create a new intra-comm from the inter-comm
72  if (MPI_COMM_NULL != inter) {
73  MPI_CHECK(MPI_Intercomm_merge(inter, high, &peer));
74  } else {
75  peer = MPI_COMM_NULL;
76  }
77 
78  // Create a new inter-comm between the pre-existing intra-comm
79  // (all of it, not only rank zero), and the remote (spawned) rank,
80  // using the just-created intra-comm as the peer communicator.
81  int tag = 12345;
82  if (MPI_COMM_NULL != intra) {
83  // This task is a member of the pre-existing intra-comm
84  MPI_CHECK(MPI_Intercomm_create(intra, 0, peer, 1, tag, &newInterComm));
85  } else {
86  // This is the remote (spawned) task
87  MPI_CHECK(
88  MPI_Intercomm_create(MPI_COMM_SELF, 0, peer, 0, tag, &newInterComm));
89  }
90 
91  // Now convert this inter-comm into an intra-comm
92  MPI_CHECK(MPI_Intercomm_merge(newInterComm, high, &newIntraComm));
93 
94  // Clean up the intermediaries
95  if (MPI_COMM_NULL != peer) {
96  MPI_CHECK(MPI_Comm_free(&peer));
97  }
98  MPI_CHECK(MPI_Comm_free(&newInterComm));
99 
100  // Delete the original intra-comm
101  if (MPI_COMM_NULL != intra && MPI_COMM_WORLD != intra &&
102  GlobalMPIComm() != intra) {
103  MPI_CHECK(MPI_Comm_free(&intra));
104  }
105 
106  // Return the new intra-comm
107  return newIntraComm;
108 }
109 
111  const int replicas,
112  const string& role,
113  const string& job_path) {
114  int flag;
115  MPI_Initialized(&flag);
116  if (!flag) {
117  int mpi_ret;
118  MPI_CHECK(MPI_Init_thread(nullptr, nullptr, MPI_THREAD_MULTIPLE, &mpi_ret));
119  if (mpi_ret != MPI_THREAD_MULTIPLE && mpi_ret != MPI_THREAD_SERIALIZED) {
120  LOG(FATAL) << "This test requires the underlying MPI to support the "
121  << "MPI_THREAD_SERIALIZED or MPI_THREAD_MULTIPLE mode.";
122  return;
123  }
124  }
125 
126  if (MPICommSize(MPI_COMM_WORLD) != 1) {
127  LOG(ERROR) << "MPI_COMM_WORLD size is not 1: did you already run "
128  "MPISetupPeers? Note that if you execute your program with "
129  "mpirun to launch multiple local processes, you should not "
130  "call MPISetupPeers.";
131  return;
132  }
133 
134  if (role == "server") {
135  // Open a port to accept connections.
136  char port_name[MPI_MAX_PORT_NAME] = {'\0'};
137  MPI_CHECK(MPI_Open_port(MPI_INFO_NULL, port_name));
138  VLOG(1) << "MPI server: port: " << port_name;
139 
140  // Writes the port name to the file.
141  CHECK(WriteStringToFile(std::string(port_name), job_path.c_str()));
142  VLOG(1) << "MPI server: wrote to file: " << job_path;
143 
144  int comm_size = MPICommSize(GlobalMPIComm());
145  while (comm_size < replicas) {
146  MPI_Comm icomm;
147  VLOG(1) << "MPI server: waiting for client "
148  << "(" << comm_size << "/" << replicas << " have connected)";
149  MPI_CHECK(
150  MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &icomm));
151  VLOG(1) << "MPI server: accepted client";
152  MPI_Comm new_intra_comm = AssimilateComm(GlobalMPIComm(), icomm);
153  SetGlobalMPIComm(new_intra_comm);
154  comm_size = MPICommSize(new_intra_comm);
155  }
156  } else {
157  // Opens the job path file to obtain server address.
158  std::string port_name;
159  while (!ReadStringFromFile(job_path.c_str(), &port_name) ||
160  port_name.length() == 0) {
161  /* sleep override */
162  std::this_thread::sleep_for(std::chrono::seconds(1));
163  }
164 
165  // Connect to server.
166  MPI_Comm icomm;
167  VLOG(1) << "MPI client: connecting to port: " << port_name;
168  MPI_CHECK(MPI_Comm_connect(
169  const_cast<char*>(port_name.c_str()),
170  MPI_INFO_NULL,
171  0,
172  GlobalMPIComm(),
173  &icomm));
174 
175  VLOG(1) << "MPI client: connected";
176 
177  // Join the server's reference intracommunicator.
178  MPI_Comm new_intra_comm = AssimilateComm(MPI_COMM_NULL, icomm);
179  SetGlobalMPIComm(new_intra_comm);
180 
181  // Let other clients join the intracommunicator we're now a part of.
182  while (MPICommSize(GlobalMPIComm()) < replicas) {
183  MPI_Comm comm = AssimilateComm(GlobalMPIComm(), MPI_COMM_NULL);
184  SetGlobalMPIComm(comm);
185  }
186  }
187 
188  // After all peers have assimilated, do a barrier.
189  MPI_Barrier(GlobalMPIComm());
190  VLOG(1) << "MPI using a communicator of size: "
192 }
193 
194 } // namespace caffe2
MPI_Comm GlobalMPIComm()
Gets the global MPI communicator used by Caffe2.
Definition: mpi_common.cc:36
void SetGlobalMPIComm(MPI_Comm new_comm)
Sets the global MPI communicator.
Definition: mpi_common.cc:40
int MPICommRank(MPI_Comm comm)
A helper function to return the rank of the given communicator.
Definition: mpi_common.cc:53
Copyright (c) 2016-present, Facebook, Inc.
int MPICommSize(MPI_Comm comm)
A helper function to return the size of the given communicator.
Definition: mpi_common.cc:47
void MPISetupPeers(const int replicas, const string &role, const string &job_path)
A function used to perform peer setup so one does not need to use mpirun / mpiexec to run the binary...
Definition: mpi_common.cc:110