Caffe2 - C++ API
A deep learning, cross platform ML framework
Utils.cpp
1 #include <c10d/Utils.hpp>
2 
3 #include <netdb.h>
4 #include <sys/poll.h>
5 
6 #include <arpa/inet.h>
7 #include <netinet/in.h>
8 #include <netinet/tcp.h>
9 
10 #include <fcntl.h>
11 #include <unistd.h>
12 
13 #include <algorithm>
14 #include <cstring>
15 #include <memory>
16 #include <string>
17 #include <thread>
18 
19 namespace c10d {
20 namespace tcputil {
21 
22 namespace {
23 
24 constexpr int LISTEN_QUEUE_SIZE = 64;
25 
26 void setSocketNoDelay(int socket) {
27  int flag = 1;
28  socklen_t optlen = sizeof(flag);
29  SYSCHECK_ERR_RETURN_NEG1(setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen));
30 }
31 
32 PortType getSocketPort(int fd) {
33  PortType listenPort;
34  struct ::sockaddr_storage addrStorage;
35  socklen_t addrLen = sizeof(addrStorage);
36  SYSCHECK_ERR_RETURN_NEG1(getsockname(
37  fd, reinterpret_cast<struct ::sockaddr*>(&addrStorage), &addrLen));
38 
39  if (addrStorage.ss_family == AF_INET) {
40  struct ::sockaddr_in* addr =
41  reinterpret_cast<struct ::sockaddr_in*>(&addrStorage);
42  listenPort = ntohs(addr->sin_port);
43 
44  } else if (addrStorage.ss_family == AF_INET6) { // AF_INET6
45  struct ::sockaddr_in6* addr =
46  reinterpret_cast<struct ::sockaddr_in6*>(&addrStorage);
47  listenPort = ntohs(addr->sin6_port);
48 
49  } else {
50  throw std::runtime_error("unsupported protocol");
51  }
52  return listenPort;
53 }
54 
55 } // namespace
56 
57 std::string sockaddrToString(struct ::sockaddr* addr) {
58  char address[INET6_ADDRSTRLEN + 1];
59  if (addr->sa_family == AF_INET) {
60  struct ::sockaddr_in* s = reinterpret_cast<struct ::sockaddr_in*>(addr);
61  SYSCHECK(::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN), __output != nullptr)
62  address[INET_ADDRSTRLEN] = '\0';
63  } else if (addr->sa_family == AF_INET6) {
64  struct ::sockaddr_in6* s = reinterpret_cast<struct ::sockaddr_in6*>(addr);
65  SYSCHECK(::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN), __output != nullptr)
66  address[INET6_ADDRSTRLEN] = '\0';
67  } else {
68  throw std::runtime_error("unsupported protocol");
69  }
70  return address;
71 }
72 
73 // listen, connect and accept
74 std::pair<int, PortType> listen(PortType port) {
75  struct ::addrinfo hints, *res = NULL;
76  std::memset(&hints, 0x00, sizeof(hints));
77  hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
78  hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
79  hints.ai_socktype = SOCK_STREAM; // TCP
80 
81  // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
82  // by editing `/etc/gai.conf`. so there is no need to manual sorting
83  // or protocol preference.
84  int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res);
85  if (err != 0 || !res) {
86  throw std::invalid_argument(
87  "cannot find host to listen on: " + std::string(gai_strerror(err)));
88  }
89 
90  std::shared_ptr<struct ::addrinfo> addresses(
91  res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); });
92 
93  struct ::addrinfo* nextAddr = addresses.get();
94  int socket;
95  while (true) {
96  try {
97  SYSCHECK_ERR_RETURN_NEG1(
98  socket = ::socket(
99  nextAddr->ai_family,
100  nextAddr->ai_socktype,
101  nextAddr->ai_protocol))
102 
103  int optval = 1;
104  SYSCHECK_ERR_RETURN_NEG1(
105  ::setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(int)))
106 
107  SYSCHECK_ERR_RETURN_NEG1(::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen))
108  SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE))
109  break;
110 
111  } catch (const std::system_error& e) {
112  ::close(socket);
113  nextAddr = nextAddr->ai_next;
114 
115  // we have tried all addresses but could not start
116  // listening on any of them
117  if (!nextAddr) {
118  throw;
119  }
120  }
121  }
122 
123  // get listen port and address
124  return {socket, getSocketPort(socket)};
125 }
126 
127 int connect(
128  const std::string& address,
129  PortType port,
130  bool wait,
131  const std::chrono::milliseconds& timeout) {
132  struct ::addrinfo hints, *res = NULL;
133  std::memset(&hints, 0x00, sizeof(hints));
134  hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric
135  hints.ai_family = AF_UNSPEC; // either IPv4 or IPv6
136  hints.ai_socktype = SOCK_STREAM; // TCP
137 
138  // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked
139  // by editing `/etc/gai.conf`. so there is no need to manual sorting
140  // or protcol preference.
141  int err =
142  ::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res);
143  if (err != 0 || !res) {
144  throw std::invalid_argument(
145  "host not found: " + std::string(gai_strerror(err)));
146  }
147 
148  std::shared_ptr<struct ::addrinfo> addresses(
149  res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); });
150 
151  struct ::addrinfo* nextAddr = addresses.get();
152  int socket;
153  // we'll loop over the addresses only if at least of them gave us ECONNREFUSED
154  // Maybe the host was up, but the server wasn't running.
155  bool anyRefused = false;
156  while (true) {
157  try {
158  SYSCHECK_ERR_RETURN_NEG1(
159  socket = ::socket(
160  nextAddr->ai_family,
161  nextAddr->ai_socktype,
162  nextAddr->ai_protocol))
163 
164  ResourceGuard socketGuard([socket]() { ::close(socket); });
165 
166  // We need to connect in non-blocking mode, so we can use a timeout
167  SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, O_NONBLOCK));
168 
169  int ret = ::connect(socket, nextAddr->ai_addr, nextAddr->ai_addrlen);
170 
171  if (ret != 0 && errno != EINPROGRESS) {
172  throw std::system_error(errno, std::system_category());
173  }
174 
175  struct ::pollfd pfd;
176  pfd.fd = socket;
177  pfd.events = POLLOUT;
178 
179  int numReady = ::poll(&pfd, 1, timeout.count());
180  if (numReady < 0) {
181  throw std::system_error(errno, std::system_category());
182  } else if (numReady == 0) {
183  errno = 0;
184  throw std::runtime_error("connect() timed out");
185  }
186 
187  socklen_t errLen = sizeof(errno);
188  errno = 0;
189  ::getsockopt(socket, SOL_SOCKET, SO_ERROR, &errno, &errLen);
190 
191  // `errno` is set when:
192  // 1. `getsockopt` has failed
193  // 2. there is awaiting error in the socket
194  // (the error is saved to the `errno` variable)
195  if (errno != 0) {
196  throw std::system_error(errno, std::system_category());
197  }
198 
199  // Disable non-blocking mode
200  int flags;
201  SYSCHECK_ERR_RETURN_NEG1(flags = ::fcntl(socket, F_GETFL));
202  SYSCHECK_ERR_RETURN_NEG1(::fcntl(socket, F_SETFL, flags & (~O_NONBLOCK)));
203  socketGuard.release();
204  break;
205 
206  } catch (std::exception& e) {
207  if (errno == ECONNREFUSED) {
208  anyRefused = true;
209  }
210 
211  // We need to move to the next address because this was not available
212  // to connect or to create a socket.
213  nextAddr = nextAddr->ai_next;
214 
215  // We have tried all addresses but could not connect to any of them.
216  if (!nextAddr) {
217  if (!wait || !anyRefused) {
218  throw;
219  }
220  std::this_thread::sleep_for(std::chrono::seconds(1));
221  anyRefused = false;
222  nextAddr = addresses.get();
223  }
224  }
225  }
226 
227  setSocketNoDelay(socket);
228 
229  return socket;
230 }
231 
232 std::tuple<int, std::string> accept(
233  int listenSocket,
234  const std::chrono::milliseconds& timeout) {
235  // poll on listen socket, it allows to make timeout
236  std::unique_ptr<struct ::pollfd[]> events(new struct ::pollfd[1]);
237  events[0] = {.fd = listenSocket, .events = POLLIN};
238 
239  while (true) {
240  int res = ::poll(events.get(), 1, timeout.count());
241  if (res == 0) {
242  throw std::runtime_error(
243  "waiting for processes to "
244  "connect has timed out");
245  } else if (res == -1) {
246  if (errno == EINTR) {
247  continue;
248  }
249  throw std::system_error(errno, std::system_category());
250  } else {
251  if (!(events[0].revents & POLLIN))
252  throw std::system_error(ECONNABORTED, std::system_category());
253  break;
254  }
255  }
256 
257  int socket;
258  SYSCHECK_ERR_RETURN_NEG1(socket = ::accept(listenSocket, NULL, NULL))
259 
260  // Get address of the connecting process
261  struct ::sockaddr_storage addr;
262  socklen_t addrLen = sizeof(addr);
263  SYSCHECK_ERR_RETURN_NEG1(::getpeername(
264  socket, reinterpret_cast<struct ::sockaddr*>(&addr), &addrLen))
265 
266  setSocketNoDelay(socket);
267 
268  return std::make_tuple(
269  socket, sockaddrToString(reinterpret_cast<struct ::sockaddr*>(&addr)));
270 }
271 
272 } // namespace tcputil
273 } // namespace c10d
Definition: ddp.cpp:21