Caffe2 - C++ API
A deep learning, cross platform ML framework
mpi_common.h
1 
17 #ifndef CAFFE2_MPI_MPI_COMMON_H_
18 #define CAFFE2_MPI_MPI_COMMON_H_
19 
20 #include <mpi.h>
21 #include <mutex>
22 
23 #include "caffe2/core/logging.h"
24 
25 namespace caffe2 {
26 
27 inline void CheckInitializedMPI() {
28  int flag;
29  MPI_Initialized(&flag);
30  CAFFE_ENFORCE(flag, "MPI does not seem to have been initialized.");
31 }
32 
33 template <typename T> class MPIDataTypeWrapper;
34 
35 #define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \
36  template<> class MPIDataTypeWrapper<c_type> { \
37  public: \
38  inline static MPI_Datatype type() { return mpi_type; } \
39  };
40 
41 MPI_DATATYPE_WRAPPER(char, MPI_CHAR)
42 MPI_DATATYPE_WRAPPER(float, MPI_FLOAT)
43 MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE)
44 // Note(Yangqing): as necessary, add more specializations.
45 #undef MPI_DATATYPE_WRAPPER
46 
47 // For all Caffe MPI calls, we will wrap it inside an MPI mutex lock guard.
48 std::mutex& MPIMutex();
49 
50 #define MPI_CHECK(condition) \
51  do { \
52  std::lock_guard<std::mutex> guard(::caffe2::MPIMutex()); \
53  int error = (condition); \
54  CAFFE_ENFORCE( \
55  error == MPI_SUCCESS, \
56  "Caffe2 MPI Error at: ", \
57  __FILE__, \
58  ":", \
59  __LINE__, \
60  ": ", \
61  error); \
62  } while (0)
63 
68 MPI_Comm GlobalMPIComm();
69 
74 void SetGlobalMPIComm(MPI_Comm new_comm);
75 
79 int MPICommSize(MPI_Comm comm);
80 
84 int MPICommRank(MPI_Comm comm);
85 
90  public:
102  MPI_Comm src_comm = MPI_COMM_NULL,
103  int color = 0,
104  int rank = -1) {
105  if (src_comm == MPI_COMM_NULL) {
106  src_comm = GlobalMPIComm();
107  }
108  if (rank == -1) {
109  MPI_CHECK(MPI_Comm_rank(src_comm, &rank));
110  }
111  MPI_CHECK(MPI_Comm_split(src_comm, color, rank, &comm_));
112  MPI_CHECK(MPI_Comm_size(comm_, &size_));
113  MPI_CHECK(MPI_Comm_rank(comm_, &rank_));
114  }
115 
117  int ret;
118  MPI_CHECK(MPI_Finalized(&ret));
119  if (!ret) {
120  MPI_Comm_free(&comm_);
121  }
122  }
123 
127  inline MPI_Comm comm() const {
128  return comm_;
129  }
133  inline int size() const {
134  return size_;
135  }
139  inline int rank() const {
140  return rank_;
141  }
142 
143  private:
144  MPI_Comm comm_;
145  int size_;
146  int rank_;
147 };
148 
164 void MPISetupPeers(
165  const int replicas,
166  const string& role,
167  const string& job_path);
168 } // namespace caffe2
169 
170 #endif // CAFFE2_MPI_MPI_COMMON_H_
MPI_Comm comm() const
Returns the common world held by the wrapper.
Definition: mpi_common.h:127
MPI_Comm GlobalMPIComm()
Gets the global MPI communicator used by Caffe2.
Definition: mpi_common.cc:36
void SetGlobalMPIComm(MPI_Comm new_comm)
Sets the global MPI communicator.
Definition: mpi_common.cc:40
int rank() const
Returns the rank of this process in the world.
Definition: mpi_common.h:139
int MPICommRank(MPI_Comm comm)
A helper function to return the rank of the given communicator.
Definition: mpi_common.cc:53
Copyright (c) 2016-present, Facebook, Inc.
int size() const
Returns the size of the world.
Definition: mpi_common.h:133
int MPICommSize(MPI_Comm comm)
A helper function to return the size of the given communicator.
Definition: mpi_common.cc:47
MPICommonWorldWrapper(MPI_Comm src_comm=MPI_COMM_NULL, int color=0, int rank=-1)
Creates a common world wrapper.
Definition: mpi_common.h:101
A simple wrapper over an MPI common world.
Definition: mpi_common.h:89
void MPISetupPeers(const int replicas, const string &role, const string &job_path)
A function used to perform peer setup so one does not need to use mpirun / mpiexec to run the binary...
Definition: mpi_common.cc:110