3 #include <unordered_map> 6 #include <libshm/err.h> 7 #include <libshm/socket.h> 8 #include <libshm/libshm.h> 10 std::unordered_map<std::string, ClientSocket> managers;
11 std::string manager_executable_path;
13 AllocInfo get_alloc_info(
const char* filename) {
17 size_t len = strlen(filename);
18 if (len >=
sizeof(info.filename)) {
19 throw std::runtime_error(
"THMapAllocatorContext_filename too long");
21 memcpy(info.filename, filename, len + 1);
25 void start_manager() {
27 SYSCHECK_ERR_RETURN_NEG1(pipe(pipe_ends));
30 SYSCHECK_ERR_RETURN_NEG1(pid = fork());
32 SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0]));
33 SYSCHECK_ERR_RETURN_NEG1(dup2(pipe_ends[1], 1));
34 SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[1]));
35 execl(manager_executable_path.c_str(),
"torch_shm_manager", NULL);
38 SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[1]));
44 SYSCHECK_ERR_RETURN_NEG1(bytes_read = read(pipe_ends[0], buffer,
sizeof(buffer)));
45 handle.append(buffer, bytes_read);
46 if (bytes_read == 0 || handle[handle.length() - 1] ==
'\n') {
50 SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0]));
51 if (handle.length() == 0) {
52 std::string msg(
"error executing torch_shm_manager at \"");
53 msg += manager_executable_path;
55 throw std::runtime_error(msg);
59 if (handle ==
"ERROR")
60 throw std::exception();
63 managers.emplace(std::move(handle), std::move(manager));
66 ClientSocket& get_manager_socket(
const std::string& manager_handle) {
67 auto it = managers.find(manager_handle);
68 if (it == managers.end()) {
70 auto result = managers.emplace(manager_handle, std::move(socket));
71 return result.first->second;
77 void libshm_init(
const char *manager_exec_path) {
78 manager_executable_path = std::string(manager_exec_path);
81 THManagedMapAllocatorInit::THManagedMapAllocatorInit(
const char* manager_handle,
const char* filename)
82 : manager_handle_(manager_handle ? manager_handle :
"") {
86 if (!manager_handle_.empty()) {
87 socket = &get_manager_socket(manager_handle_);
89 if (managers.size() == 0) {
92 const auto &manager = managers.begin();
93 manager_handle_ = manager->first;
94 socket = &manager->second;
96 AllocInfo info = get_alloc_info(filename);
97 socket->register_allocation(info);
98 }
catch(std::exception &e) {
103 THManagedMapAllocator::THManagedMapAllocator(
const char *manager_handle,
const char *filename,
int flags, ptrdiff_t size)
106 void THManagedMapAllocator::close() {
108 AllocInfo info = get_alloc_info(filename());
110 ClientSocket &socket = get_manager_socket(manager_handle_);
111 THRefcountedMapAllocator::close();
112 socket.register_deallocation(info);
115 static void deleteTHManagedMapAllocator(
void* ptr) {
116 delete static_cast<THManagedMapAllocator*
>(ptr);
119 at::DataPtr THManagedMapAllocator::makeDataPtr(
const char* manager_handle,
const char* filename,
int flags, ptrdiff_t size) {
120 auto* context =
new THManagedMapAllocator(manager_handle, filename, flags, size);
121 return {context->data(), context, &deleteTHManagedMapAllocator, at::DeviceType::CPU};
124 THManagedMapAllocator* THManagedMapAllocator::fromDataPtr(
const at::DataPtr& dptr) {
125 return dptr.cast_context<THManagedMapAllocator>(&deleteTHManagedMapAllocator);