Caffe2 - C++ API
A deep learning, cross platform ML framework
mpi_common.h
1 #ifndef CAFFE2_MPI_MPI_COMMON_H_
2 #define CAFFE2_MPI_MPI_COMMON_H_
3 
4 #include <mpi.h>
5 #include <mutex>
6 
7 #include "caffe2/core/common.h"
8 #include "caffe2/core/logging.h"
9 
10 namespace caffe2 {
11 
12 inline void CheckInitializedMPI() {
13  int flag;
14  MPI_Initialized(&flag);
15  CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
16 }
17 
18 template <typename T> class MPIDataTypeWrapper;
19 
20 #define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
21  template<> class MPIDataTypeWrapper<c_type> { \
22  public: \
23  inline static MPI_Datatype type() { return mpi_type; } \
24  };
25 
26 MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
27 MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
28 MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
29 // Note(Yangqing): as necessary, add more specializations.
30 #undef MPI_DATATYPE_WRAPPER
31 
32 // For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
33 CAFFE2_API std::mutex& MPIMutex();
34 
35 #define MPI_CHECK(condition) \
36  do { \
37  std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
38  int error = (condition); \
39  CAFFE_ENFORCE( \
40  error == MPI_SUCCESS, \
41  "Caffe2 MPI Error at: ", \
42  __FILE__, \
43  ":", \
44  __LINE__, \
45  ": ", \
46  error); \
47  } while (0)
48 
53 CAFFE2_API MPI_Comm GlobalMPIComm();
54 
59 CAFFE2_API void SetGlobalMPIComm(MPI_Comm new_comm);
60 
64 CAFFE2_API int MPICommSize(MPI_Comm comm);
65 
69 CAFFE2_API int MPICommRank(MPI_Comm comm);
70 
75  public:
87  MPI_Comm src_comm = MPI_COMM_NULL,
88  int color = 0,
89  int rank = -1) {
90  if (src_comm == MPI_COMM_NULL) {
91  src_comm = GlobalMPIComm();
92  }
93  if (rank == -1) {
94  MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
95  }
96  MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
97  MPI_CHECK(MPI_Comm_size(comm_, &size_));
98  MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
99  }
100 
102  int ret;
103  MPI_CHECK(MPI_Finalized(&ret));
104  if (!ret) {
105  MPI_Comm_free(&comm_);
106  }
107  }
108 
112  inline MPI_Comm comm() const {
113  return comm_;
114  }
118  inline int size() const {
119  return size_;
120  }
124  inline int rank() const {
125  return rank_;
126  }
127 
128  private:
129  MPI_Comm comm_;
130  int size_;
131  int rank_;
132 };
133 
149 void MPISetupPeers(
150  const int replicas,
151  const string& role,
152  const string& job_path);
153 } // namespace caffe2
154 
155 #endif // CAFFE2_MPI_MPI_COMMON_H_
void SetGlobalMPIComm(MPI_Comm new_comm)
Sets the global MPI communicator.
Definition: mpi_common.cc:24
MPI_Comm comm() const
Returns the common world held by the wrapper.
Definition: mpi_common.h:112
int rank() const
Returns the rank of this process in the world.
Definition: mpi_common.h:124
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
int size() const
Returns the size of the world.
Definition: mpi_common.h:118
MPI_Comm GlobalMPIComm()
Gets the global MPI communicator used by Caffe2.
Definition: mpi_common.cc:20
MPICommonWorldWrapper(MPI_Comm src_comm=MPI_COMM_NULL, int color=0, int rank=-1)
Creates a common world wrapper.
Definition: mpi_common.h:86
A simple wrapper over an MPI common world.
Definition: mpi_common.h:74
void MPISetupPeers(const int replicas, const string &role, const string &job_path)
A function used to perform peer setup so one does not need to use mpirun / mpiexec to run the binary...
Definition: mpi_common.cc:94
int MPICommSize(MPI_Comm comm)
A helper function to return the size of the given communicator.
Definition: mpi_common.cc:31
int MPICommRank(MPI_Comm comm)
A helper function to return the rank of the given communicator.
Definition: mpi_common.cc:37