Caffe2 - C++ API
A deep learning, cross platform ML framework
manager.cpp
1 #include <sys/mman.h>
2 #include <poll.h>
3 #include <errno.h>
4 #include <unistd.h>
5 #include <fcntl.h>
6 #include <vector>
7 #include <set>
8 #include <algorithm>
9 #include <memory>
10 #include <unordered_map>
11 
12 #include <c10/util/tempfile.h>
13 
14 #include <libshm/err.h>
15 #include <libshm/socket.h>
16 
17 const int SHUTDOWN_TIMEOUT = 2000; // 2s
18 
19 #ifdef DEBUG_LOG
20 #define COLOR "\033[31;1m"
21 #define RESET "\033[0m"
22 #define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
23 #define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
24 #else
25 #define DEBUG(...) (void)0
26 #endif
27 
28 struct ClientSession {
29  ClientSession(ManagerSocket s): socket(std::move(s)), pid(0) {}
30 
31  ManagerSocket socket;
32  pid_t pid;
33 };
34 
35 
36 std::vector<struct pollfd> pollfds;
37 std::unordered_map<int, ClientSession> client_sessions;
38 // TODO: check if objects have been freed from time to time
39 std::set<std::string> used_objects;
40 
41 
42 void register_fd(int fd) {
43  struct pollfd pfd = {0};
44  pfd.fd = fd;
45  pfd.events = POLLIN;
46  pollfds.push_back(pfd);
47 }
48 
49 
50 void unregister_fd(int fd) {
51  pollfds.erase(
52  std::remove_if(pollfds.begin(), pollfds.end(),
53  [fd](const struct pollfd &pfd) { return pfd.fd == fd; }),
54  pollfds.end());
55  client_sessions.erase(fd);
56 }
57 
58 
59 void print_init_message(const char *message) {
60  size_t unused;
61  unused = write(1, message, strlen(message));
62  unused = write(1, "\n", 1);
63 }
64 
65 bool object_exists(const char *name) {
66  int fd = shm_open(name, O_RDONLY, 0);
67  if (fd >= 0) {
68  close(fd);
69  return true;
70  } else {
71  return false;
72  }
73 }
74 
75 void free_used_object(const std::string &name) {
76  if (!object_exists(name.c_str())) {
77  DEBUG("object %s appears to have been freed", name.c_str());
78  used_objects.erase(name);
79  } else {
80  DEBUG("object %s still exists", name.c_str());
81  }
82 }
83 
84 int main(int argc, char *argv[]) {
85  setsid(); // Daemonize the process
86 
87  std::unique_ptr<ManagerServerSocket> srv_socket;
88  const auto tempfile =
89  c10::try_make_tempfile(/*name_prefix=*/"torch-shm-file-");
90  try {
91  if (!tempfile.has_value()) {
92  throw std::runtime_error(
93  "could not generate a random filename for manager socket");
94  }
95  // TODO: better strategy for generating tmp names
96  // TODO: retry on collisions - this can easily fail
97  srv_socket.reset(new ManagerServerSocket(tempfile->name));
98  register_fd(srv_socket->socket_fd);
99  print_init_message(tempfile->name.c_str());
100  DEBUG("opened socket %s", tempfile->name.c_str());
101  } catch (...) {
102  print_init_message("ERROR");
103  throw;
104  }
105 
106  int timeout = -1;
107  std::vector<int> to_add;
108  std::vector<int> to_remove;
109  for (;;) {
110  int nevents;
111  if (client_sessions.size() == 0)
112  timeout = SHUTDOWN_TIMEOUT;
113  SYSCHECK_ERR_RETURN_NEG1(nevents = poll(pollfds.data(), pollfds.size(), timeout));
114  timeout = -1;
115  if (nevents == 0 && client_sessions.size() == 0)
116  break;
117 
118  for (auto &pfd: pollfds) {
119  if (pfd.revents & (POLLERR | POLLHUP)) {
120  // some process died
121  DEBUG("detaching process");
122  auto &session = client_sessions.at(pfd.fd);
123  DEBUG("%d has died", session.pid);
124  to_remove.push_back(pfd.fd);
125  } else if (pfd.revents & POLLIN) {
126  if (pfd.fd == srv_socket->socket_fd) {
127  // someone is joining
128  DEBUG("registered new client");
129  auto client = srv_socket->accept();
130  int fd = client.socket_fd;
131  to_add.push_back(fd);
132  client_sessions.emplace(fd, std::move(client));
133  } else {
134  // someone wants to register a segment
135  DEBUG("got alloc info");
136  auto &session = client_sessions.at(pfd.fd);
137  AllocInfo info = session.socket.receive();
138  session.pid = info.pid;
139  DEBUG("got alloc info: %d %d %s", (int)info.free, info.pid, info.filename);
140  if (info.free) {
141  free_used_object(info.filename);
142  } else {
143  used_objects.insert(info.filename);
144  DEBUG("registered object %s", info.filename);
145  session.socket.confirm();
146  }
147  }
148  }
149  }
150 
151  for (int fd: to_add)
152  register_fd(fd);
153  to_add.clear();
154 
155  for (int fd: to_remove)
156  unregister_fd(fd);
157  to_remove.clear();
158  }
159 
160  for (auto &obj_name: used_objects) {
161  DEBUG("freeing %s", obj_name.c_str());
162  shm_unlink(obj_name.c_str());
163  }
164 
165  DEBUG("manager done");
166  return 0;
167 }
c10::optional< TempFile > try_make_tempfile(std::string name_prefix="torch-file-")
Attempts to return a temporary file or returns nullopt if an error ocurred.
Definition: tempfile.h:81