| #include <sys/mman.h> |
| #include <poll.h> |
| // NOLINTNEXTLINE(modernize-deprecated-headers) |
| #include <errno.h> |
| #include <unistd.h> |
| #include <fcntl.h> |
| #include <vector> |
| #include <set> |
| #include <algorithm> |
| #include <memory> |
| #include <unordered_map> |
| |
| #include <c10/util/tempfile.h> |
| |
| #include <libshm/err.h> |
| #include <libshm/socket.h> |
| |
| const int SHUTDOWN_TIMEOUT = 2000; // 2s |
| |
| #ifdef DEBUG_LOG |
| #define COLOR "\033[31;1m" |
| #define RESET "\033[0m" |
| #define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__); |
| #define DEBUG(...) __DEBUG(__VA_ARGS__, '\n') |
| #else |
| #define DEBUG(...) (void)0 |
| #endif |
| |
| struct ClientSession { |
| ClientSession(ManagerSocket s): socket(std::move(s)), pid(0) {} |
| |
| ManagerSocket socket; |
| pid_t pid; |
| }; |
| |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| std::vector<struct pollfd> pollfds; |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| std::unordered_map<int, ClientSession> client_sessions; |
| // TODO: check if objects have been freed from time to time |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| std::set<std::string> used_objects; |
| |
| |
| void register_fd(int fd) { |
| struct pollfd pfd = {0}; |
| pfd.fd = fd; |
| pfd.events = POLLIN; |
| pollfds.push_back(pfd); |
| } |
| |
| |
| void unregister_fd(int fd) { |
| pollfds.erase( |
| std::remove_if(pollfds.begin(), pollfds.end(), |
| [fd](const struct pollfd &pfd) { return pfd.fd == fd; }), |
| pollfds.end()); |
| client_sessions.erase(fd); |
| } |
| |
| |
| void print_init_message(const char *message) { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| size_t unused; |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| unused = write(1, message, strlen(message)); |
| // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) |
| unused = write(1, "\n", 1); |
| } |
| |
| bool object_exists(const char *name) { |
| int fd = shm_open(name, O_RDONLY, 0); |
| if (fd >= 0) { |
| close(fd); |
| return true; |
| } else { |
| return false; |
| } |
| } |
| |
| void free_used_object(const std::string &name) { |
| if (!object_exists(name.c_str())) { |
| DEBUG("object %s appears to have been freed", name.c_str()); |
| used_objects.erase(name); |
| } else { |
| DEBUG("object %s still exists", name.c_str()); |
| } |
| } |
| |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| int main(int argc, char *argv[]) { |
| setsid(); // Daemonize the process |
| |
| std::unique_ptr<ManagerServerSocket> srv_socket; |
| const auto tempfile = |
| c10::try_make_tempfile(/*name_prefix=*/"torch-shm-file-"); |
| try { |
| if (!tempfile.has_value()) { |
| throw std::runtime_error( |
| "could not generate a random filename for manager socket"); |
| } |
| // TODO: better strategy for generating tmp names |
| // TODO: retry on collisions - this can easily fail |
| // NOLINTNEXTLINE(modernize-make-unique) |
| srv_socket.reset(new ManagerServerSocket(tempfile->name)); |
| register_fd(srv_socket->socket_fd); |
| print_init_message(tempfile->name.c_str()); |
| DEBUG("opened socket %s", tempfile->name.c_str()); |
| } catch (...) { |
| print_init_message("ERROR"); |
| throw; |
| } |
| |
| int timeout = -1; |
| std::vector<int> to_add; |
| std::vector<int> to_remove; |
| for (;;) { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int nevents; |
| if (client_sessions.size() == 0) |
| timeout = SHUTDOWN_TIMEOUT; |
| SYSCHECK_ERR_RETURN_NEG1(nevents = poll(pollfds.data(), pollfds.size(), timeout)); |
| timeout = -1; |
| if (nevents == 0 && client_sessions.size() == 0) |
| break; |
| |
| for (auto &pfd: pollfds) { |
| if (pfd.revents & (POLLERR | POLLHUP)) { |
| // some process died |
| DEBUG("detaching process"); |
| auto &session = client_sessions.at(pfd.fd); |
| (void) session; |
| DEBUG("%d has died", session.pid); |
| to_remove.push_back(pfd.fd); |
| } else if (pfd.revents & POLLIN) { |
| if (pfd.fd == srv_socket->socket_fd) { |
| // someone is joining |
| DEBUG("registered new client"); |
| auto client = srv_socket->accept(); |
| int fd = client.socket_fd; |
| to_add.push_back(fd); |
| client_sessions.emplace(fd, std::move(client)); |
| } else { |
| // someone wants to register a segment |
| DEBUG("got alloc info"); |
| auto &session = client_sessions.at(pfd.fd); |
| AllocInfo info = session.socket.receive(); |
| session.pid = info.pid; |
| DEBUG("got alloc info: %d %d %s", (int)info.free, info.pid, info.filename); |
| if (info.free) { |
| free_used_object(info.filename); |
| } else { |
| used_objects.insert(info.filename); |
| DEBUG("registered object %s", info.filename); |
| session.socket.confirm(); |
| } |
| } |
| } |
| } |
| |
| for (int fd: to_add) |
| register_fd(fd); |
| to_add.clear(); |
| |
| for (int fd: to_remove) |
| unregister_fd(fd); |
| to_remove.clear(); |
| } |
| |
| for (auto &obj_name: used_objects) { |
| DEBUG("freeing %s", obj_name.c_str()); |
| shm_unlink(obj_name.c_str()); |
| } |
| |
| DEBUG("manager done"); |
| return 0; |
| } |