Caffe2 - C++ API
A deep learning, cross platform ML framework
zmq_helper.h
1 
17 #ifndef CAFFE2_UTILS_ZMQ_HELPER_H_
18 #define CAFFE2_UTILS_ZMQ_HELPER_H_
19 
20 #include <zmq.h>
21 
22 #include "caffe2/core/logging.h"
23 
24 namespace caffe2 {
25 
26 class ZmqContext {
27  public:
28  explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) {
29  CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context.");
30  int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads);
31  CAFFE_ENFORCE_EQ(rc, 0);
32  rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT);
33  CAFFE_ENFORCE_EQ(rc, 0);
34  }
35  ~ZmqContext() {
36  int rc = zmq_ctx_destroy(ptr_);
37  CAFFE_ENFORCE_EQ(rc, 0);
38  }
39 
40  void* ptr() { return ptr_; }
41 
42  private:
43  void* ptr_;
44 
45  DISABLE_COPY_AND_ASSIGN(ZmqContext);
46 };
47 
48 class ZmqMessage {
49  public:
50  ZmqMessage() {
51  int rc = zmq_msg_init(&msg_);
52  CAFFE_ENFORCE_EQ(rc, 0);
53  }
54 
55  ~ZmqMessage() {
56  int rc = zmq_msg_close(&msg_);
57  CAFFE_ENFORCE_EQ(rc, 0);
58  }
59 
60  zmq_msg_t* msg() { return &msg_; }
61 
62  void* data() { return zmq_msg_data(&msg_); }
63  size_t size() { return zmq_msg_size(&msg_); }
64 
65  private:
66  zmq_msg_t msg_;
67  DISABLE_COPY_AND_ASSIGN(ZmqMessage);
68 };
69 
70 class ZmqSocket {
71  public:
72  explicit ZmqSocket(int type)
73  : context_(1), ptr_(zmq_socket(context_.ptr(), type)) {
74  CAFFE_ENFORCE(ptr_ != nullptr, "Faild to create zmq socket.");
75  }
76 
77  ~ZmqSocket() {
78  int rc = zmq_close(ptr_);
79  CAFFE_ENFORCE_EQ(rc, 0);
80  }
81 
82  void Bind(const string& addr) {
83  int rc = zmq_bind(ptr_, addr.c_str());
84  CAFFE_ENFORCE_EQ(rc, 0);
85  }
86 
87  void Unbind(const string& addr) {
88  int rc = zmq_unbind(ptr_, addr.c_str());
89  CAFFE_ENFORCE_EQ(rc, 0);
90  }
91 
92  void Connect(const string& addr) {
93  int rc = zmq_connect(ptr_, addr.c_str());
94  CAFFE_ENFORCE_EQ(rc, 0);
95  }
96 
97  void Disconnect(const string& addr) {
98  int rc = zmq_disconnect(ptr_, addr.c_str());
99  CAFFE_ENFORCE_EQ(rc, 0);
100  }
101 
102  int Send(const string& msg, int flags) {
103  int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags);
104  if (nbytes) {
105  return nbytes;
106  } else if (zmq_errno() == EAGAIN) {
107  return 0;
108  } else {
109  LOG(FATAL) << "Cannot send zmq message. Error number: "
110  << zmq_errno();
111  return 0;
112  }
113  }
114 
115  int SendTillSuccess(const string& msg, int flags) {
116  CAFFE_ENFORCE(msg.size(), "You cannot send an empty message.");
117  int nbytes = 0;
118  do {
119  nbytes = Send(msg, flags);
120  } while (nbytes == 0);
121  return nbytes;
122  }
123 
124  int Recv(ZmqMessage* msg) {
125  int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0);
126  if (nbytes >= 0) {
127  return nbytes;
128  } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
129  return 0;
130  } else {
131  LOG(FATAL) << "Cannot receive zmq message. Error number: "
132  << zmq_errno();
133  return 0;
134  }
135  }
136 
137  int RecvTillSuccess(ZmqMessage* msg) {
138  int nbytes = 0;
139  do {
140  nbytes = Recv(msg);
141  } while (nbytes == 0);
142  return nbytes;
143  }
144 
145  private:
146  ZmqContext context_;
147  void* ptr_;
148 };
149 
150 } // namespace caffe2
151 
152 
153 #endif // CAFFE2_UTILS_ZMQ_HELPER_H_
Copyright (c) 2016-present, Facebook, Inc.