10 #include <unordered_map> 12 #include <c10/util/tempfile.h> 14 #include <libshm/err.h> 15 #include <libshm/socket.h> 17 const int SHUTDOWN_TIMEOUT = 2000;
20 #define COLOR "\033[31;1m" 21 #define RESET "\033[0m" 22 #define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__); 23 #define DEBUG(...) __DEBUG(__VA_ARGS__, '\n') 25 #define DEBUG(...) (void)0 36 std::vector<struct pollfd> pollfds;
37 std::unordered_map<int, ClientSession> client_sessions;
39 std::set<std::string> used_objects;
42 void register_fd(
int fd) {
43 struct pollfd pfd = {0};
46 pollfds.push_back(pfd);
50 void unregister_fd(
int fd) {
52 std::remove_if(pollfds.begin(), pollfds.end(),
53 [fd](
const struct pollfd &pfd) {
return pfd.fd == fd; }),
55 client_sessions.erase(fd);
59 void print_init_message(
const char *message) {
61 unused = write(1, message, strlen(message));
62 unused = write(1,
"\n", 1);
65 bool object_exists(
const char *name) {
66 int fd = shm_open(name, O_RDONLY, 0);
75 void free_used_object(
const std::string &name) {
76 if (!object_exists(name.c_str())) {
77 DEBUG(
"object %s appears to have been freed", name.c_str());
78 used_objects.erase(name);
80 DEBUG(
"object %s still exists", name.c_str());
84 int main(
int argc,
char *argv[]) {
87 std::unique_ptr<ManagerServerSocket> srv_socket;
91 if (!tempfile.has_value()) {
92 throw std::runtime_error(
93 "could not generate a random filename for manager socket");
98 register_fd(srv_socket->socket_fd);
99 print_init_message(tempfile->name.c_str());
100 DEBUG(
"opened socket %s", tempfile->name.c_str());
102 print_init_message(
"ERROR");
107 std::vector<int> to_add;
108 std::vector<int> to_remove;
111 if (client_sessions.size() == 0)
112 timeout = SHUTDOWN_TIMEOUT;
113 SYSCHECK_ERR_RETURN_NEG1(nevents = poll(pollfds.data(), pollfds.size(), timeout));
115 if (nevents == 0 && client_sessions.size() == 0)
118 for (
auto &pfd: pollfds) {
119 if (pfd.revents & (POLLERR | POLLHUP)) {
121 DEBUG(
"detaching process");
122 auto &session = client_sessions.at(pfd.fd);
123 DEBUG(
"%d has died", session.pid);
124 to_remove.push_back(pfd.fd);
125 }
else if (pfd.revents & POLLIN) {
126 if (pfd.fd == srv_socket->socket_fd) {
128 DEBUG(
"registered new client");
129 auto client = srv_socket->accept();
130 int fd = client.socket_fd;
131 to_add.push_back(fd);
132 client_sessions.emplace(fd, std::move(client));
135 DEBUG(
"got alloc info");
136 auto &session = client_sessions.at(pfd.fd);
137 AllocInfo info = session.socket.receive();
138 session.pid = info.pid;
139 DEBUG(
"got alloc info: %d %d %s", (
int)info.free, info.pid, info.filename);
141 free_used_object(info.filename);
143 used_objects.insert(info.filename);
144 DEBUG(
"registered object %s", info.filename);
145 session.socket.confirm();
155 for (
int fd: to_remove)
160 for (
auto &obj_name: used_objects) {
161 DEBUG(
"freeing %s", obj_name.c_str());
162 shm_unlink(obj_name.c_str());
165 DEBUG(
"manager done");
c10::optional< TempFile > try_make_tempfile(std::string name_prefix="torch-file-")
Attempts to return a temporary file or returns nullopt if an error ocurred.