Caffe2 - C++ API
A deep learning, cross platform ML framework
socket.h
1 #pragma once
2 
3 #include <sys/types.h>
4 #include <sys/socket.h>
5 #include <sys/un.h>
6 #include <unistd.h>
7 #include <poll.h>
8 #include <cstdio>
9 #include <string>
10 #include <sstream>
11 #include <iostream>
12 #include <cstring>
13 
14 #include <libshm/err.h>
15 #include <libshm/alloc_info.h>
16 
17 class Socket {
18 public:
19  int socket_fd;
20 
21 protected:
22  Socket() {
23  SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
24  }
25  Socket(const Socket& other) = delete;
26  Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { other.socket_fd = -1; };
27  explicit Socket(int fd) : socket_fd(fd) {}
28 
29  virtual ~Socket() {
30  if (socket_fd != -1)
31  close(socket_fd);
32  }
33 
34  struct sockaddr_un prepare_address(const char *path) {
35  struct sockaddr_un address;
36  address.sun_family = AF_UNIX;
37  strcpy(address.sun_path, path);
38  return address;
39  }
40 
41  size_t address_length(struct sockaddr_un address) {
42  return strlen(address.sun_path) + sizeof(address.sun_family);
43  }
44 
45  void recv(void *_buffer, size_t num_bytes) {
46  char *buffer = (char*)_buffer;
47  size_t bytes_received = 0;
48  ssize_t step_received;
49  struct pollfd pfd = {0};
50  pfd.fd = socket_fd;
51  pfd.events = POLLIN;
52  while (bytes_received < num_bytes) {
53  SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
54  if (pfd.revents & POLLIN) {
55  SYSCHECK_ERR_RETURN_NEG1(step_received = ::read(socket_fd, buffer, num_bytes - bytes_received));
56  if (step_received == 0)
57  throw std::runtime_error("Other end has closed the connection");
58  bytes_received += step_received;
59  buffer += step_received;
60  } else if (pfd.revents & (POLLERR | POLLHUP)) {
61  throw std::runtime_error("An error occurred while waiting for the data");
62  } else {
63  throw std::runtime_error("Shared memory manager connection has timed out");
64  }
65  }
66  }
67 
68  void send(const void *_buffer, size_t num_bytes) {
69  const char *buffer = (const char*)_buffer;
70  size_t bytes_sent = 0;
71  ssize_t step_sent;
72  while (bytes_sent < num_bytes) {
73  SYSCHECK_ERR_RETURN_NEG1(step_sent = ::write(socket_fd, buffer, num_bytes));
74  bytes_sent += step_sent;
75  buffer += step_sent;
76  }
77  }
78 
79 
80 };
81 
82 class ManagerSocket: public Socket {
83 public:
84  explicit ManagerSocket(int fd): Socket(fd) {}
85 
86  AllocInfo receive() {
87  AllocInfo info;
88  recv(&info, sizeof(info));
89  return info;
90  }
91 
92  void confirm() {
93  send("OK", 2);
94  }
95 
96 };
97 
98 
99 class ManagerServerSocket: public Socket {
100 public:
101  explicit ManagerServerSocket(const std::string &path) {
102  socket_path = path;
103  try {
104  struct sockaddr_un address = prepare_address(path.c_str());
105  size_t len = address_length(address);
106  SYSCHECK_ERR_RETURN_NEG1(bind(socket_fd, (struct sockaddr *)&address, len));
107  SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
108  } catch(std::exception &e) {
109  SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
110  throw;
111  }
112  }
113 
114  virtual ~ManagerServerSocket() {
115  unlink(socket_path.c_str());
116  }
117 
118  ManagerSocket accept() {
119  int client_fd;
120  struct sockaddr_un addr;
121  socklen_t addr_len = sizeof(addr);
122  SYSCHECK_ERR_RETURN_NEG1(client_fd = ::accept(socket_fd, (struct sockaddr *)&addr, &addr_len));
123  return ManagerSocket(client_fd);
124  }
125 
126  std::string socket_path;
127 };
128 
129 class ClientSocket: public Socket {
130 public:
131  explicit ClientSocket(const std::string &path) {
132  try {
133  struct sockaddr_un address = prepare_address(path.c_str());
134  size_t len = address_length(address);
135  SYSCHECK_ERR_RETURN_NEG1(connect(socket_fd, (struct sockaddr *)&address, len));
136  } catch(std::exception &e) {
137  SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
138  throw;
139  }
140  }
141 
142  void register_allocation(AllocInfo &info) {
143  char buffer[3] = {0, 0, 0};
144  ssize_t bytes_read;
145  send(&info, sizeof(info));
146  recv(buffer, 2);
147  if (strcmp(buffer, "OK") != 0)
148  throw std::runtime_error("Shared memory manager didn't respond with an OK");
149  }
150 
151  void register_deallocation(AllocInfo &info) {
152  send(&info, sizeof(info));
153  }
154 
155 };
Definition: socket.h:17