| #ifdef _WIN32 |
| #include <c10d/WinSockUtils.hpp> |
| #else |
| #include <c10d/UnixSockUtils.hpp> |
| #include <netdb.h> |
| #include <sys/poll.h> |
| #include <arpa/inet.h> |
| #include <netinet/in.h> |
| #include <netinet/tcp.h> |
| #include <unistd.h> |
| #endif |
| |
| #include <algorithm> |
| #include <cstring> |
| #include <fcntl.h> |
| #include <memory> |
| #include <string> |
| #include <thread> |
| |
| namespace c10d { |
| namespace tcputil { |
| |
| namespace { |
| |
| constexpr int LISTEN_QUEUE_SIZE = 2048; |
| |
| void setSocketNoDelay(int socket) { |
| int flag = 1; |
| socklen_t optlen = sizeof(flag); |
| SYSCHECK_ERR_RETURN_NEG1( |
| setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, optlen)); |
| } |
| |
| PortType getSocketPort(int fd) { |
| PortType listenPort; |
| struct ::sockaddr_storage addrStorage; |
| socklen_t addrLen = sizeof(addrStorage); |
| SYSCHECK_ERR_RETURN_NEG1(getsockname( |
| fd, reinterpret_cast<struct ::sockaddr*>(&addrStorage), &addrLen)); |
| |
| if (addrStorage.ss_family == AF_INET) { |
| struct ::sockaddr_in* addr = |
| reinterpret_cast<struct ::sockaddr_in*>(&addrStorage); |
| listenPort = ntohs(addr->sin_port); |
| |
| } else if (addrStorage.ss_family == AF_INET6) { // AF_INET6 |
| struct ::sockaddr_in6* addr = |
| reinterpret_cast<struct ::sockaddr_in6*>(&addrStorage); |
| listenPort = ntohs(addr->sin6_port); |
| |
| } else { |
| throw std::runtime_error("unsupported protocol"); |
| } |
| return listenPort; |
| } |
| |
| } // namespace |
| |
| std::string getLocalSocketAddr(int fd) { |
| struct ::sockaddr_storage addrStorage; |
| socklen_t addrLen = sizeof(addrStorage); |
| SYSCHECK_ERR_RETURN_NEG1(getsockname( |
| fd, reinterpret_cast<struct ::sockaddr*>(&addrStorage), &addrLen)); |
| return sockaddrToString(reinterpret_cast<struct ::sockaddr*>(&addrStorage)); |
| } |
| |
| std::string sockaddrToString(struct ::sockaddr* addr) { |
| char address[INET6_ADDRSTRLEN + 1]; |
| if (addr->sa_family == AF_INET) { |
| struct ::sockaddr_in* s = reinterpret_cast<struct ::sockaddr_in*>(addr); |
| SYSCHECK( |
| ::inet_ntop(AF_INET, &(s->sin_addr), address, INET_ADDRSTRLEN), |
| __output != nullptr) |
| address[INET_ADDRSTRLEN] = '\0'; |
| } else if (addr->sa_family == AF_INET6) { |
| struct ::sockaddr_in6* s = reinterpret_cast<struct ::sockaddr_in6*>(addr); |
| SYSCHECK( |
| ::inet_ntop(AF_INET6, &(s->sin6_addr), address, INET6_ADDRSTRLEN), |
| __output != nullptr) |
| address[INET6_ADDRSTRLEN] = '\0'; |
| } else { |
| throw std::runtime_error("unsupported protocol"); |
| } |
| return address; |
| } |
| |
| // listen, connect and accept |
| std::pair<int, PortType> listen(PortType port) { |
| struct ::addrinfo hints, *res = NULL; |
| std::memset(&hints, 0x00, sizeof(hints)); |
| hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; |
| hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux |
| hints.ai_socktype = SOCK_STREAM; // TCP |
| |
| // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked |
| // by editing `/etc/gai.conf`. so there is no need to manual sorting |
| // or protocol preference. |
| int err = ::getaddrinfo(nullptr, std::to_string(port).data(), &hints, &res); |
| if (err != 0 || !res) { |
| throw std::invalid_argument( |
| "cannot find host to listen on: " + std::string(gai_strerror(err))); |
| } |
| |
| std::shared_ptr<struct ::addrinfo> addresses( |
| res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); }); |
| |
| struct ::addrinfo* nextAddr = addresses.get(); |
| int socket; |
| while (true) { |
| try { |
| SYSCHECK_ERR_RETURN_NEG1( |
| socket = ::socket( |
| nextAddr->ai_family, |
| nextAddr->ai_socktype, |
| nextAddr->ai_protocol)) |
| SYSCHECK_ERR_RETURN_NEG1(tcputil::setSocketAddrReUse(socket)) |
| SYSCHECK_ERR_RETURN_NEG1( |
| ::bind(socket, nextAddr->ai_addr, nextAddr->ai_addrlen)) |
| SYSCHECK_ERR_RETURN_NEG1(::listen(socket, LISTEN_QUEUE_SIZE)) |
| break; |
| |
| } catch (const std::system_error& e) { |
| tcputil::closeSocket(socket); |
| nextAddr = nextAddr->ai_next; |
| |
| // we have tried all addresses but could not start |
| // listening on any of them |
| if (!nextAddr) { |
| throw; |
| } |
| } |
| } |
| |
| // get listen port and address |
| return {socket, getSocketPort(socket)}; |
| } |
| |
| void handleConnectException( |
| struct ::addrinfo** nextAddr, |
| int error_code, |
| bool* anyRefused, |
| bool* anyReset, |
| bool wait, |
| std::chrono::time_point<std::chrono::high_resolution_clock> start, |
| std::shared_ptr<struct ::addrinfo> addresses, |
| std::chrono::milliseconds timeout) { |
| // ECONNREFUSED happens if the server is not yet listening. |
| if (error_code == ECONNREFUSED) { |
| *anyRefused = true; |
| } |
| // ECONNRESET happens if the server's listen backlog is exhausted. |
| if (error_code == ECONNRESET) { |
| *anyReset = true; |
| } |
| |
| // We need to move to the next address because this was not available |
| // to connect or to create a socket. |
| *nextAddr = (*nextAddr)->ai_next; |
| |
| // We have tried all addresses but could not connect to any of them. |
| if (!*nextAddr) { |
| if (!wait || (!anyRefused && !anyReset)) { |
| throw; |
| } |
| |
| // if a timeout is specified, check time elapsed to see if we need to |
| // timeout. A timeout is specified if timeout != kNoTimeout. |
| if (timeout != kNoTimeout) { |
| const auto elapsed = std::chrono::high_resolution_clock::now() - start; |
| if (elapsed > timeout) { |
| throw std::runtime_error(kConnectTimeoutMsg); |
| } |
| } |
| std::this_thread::sleep_for(std::chrono::seconds(1)); |
| *anyRefused = false; |
| *anyReset = false; |
| *nextAddr = addresses.get(); |
| } |
| } |
| |
| void handleConnectSystemError( |
| struct ::addrinfo** nextAddr, |
| std::system_error& e, |
| bool* anyRefused, |
| bool* anyReset, |
| bool wait, |
| std::chrono::time_point<std::chrono::high_resolution_clock> start, |
| std::shared_ptr<struct ::addrinfo> addresses, |
| std::chrono::milliseconds timeout) { |
| handleConnectException( |
| nextAddr, |
| e.code().value(), |
| anyRefused, |
| anyReset, |
| wait, |
| start, |
| addresses, |
| timeout); |
| } |
| |
| int connect( |
| const std::string& address, |
| PortType port, |
| bool wait, |
| const std::chrono::milliseconds& timeout) { |
| struct ::addrinfo hints, *res = NULL; |
| std::memset(&hints, 0x00, sizeof(hints)); |
| hints.ai_flags = AI_NUMERICSERV; // specifies that port (service) is numeric |
| hints.ai_family = AF_SELECTED; // IPv4 on Windows, IPv4/6 on Linux |
| hints.ai_socktype = SOCK_STREAM; // TCP |
| |
| // `getaddrinfo` will sort addresses according to RFC 3484 and can be tweeked |
| // by editing `/etc/gai.conf`. so there is no need to manual sorting |
| // or protcol preference. |
| int err = |
| ::getaddrinfo(address.data(), std::to_string(port).data(), &hints, &res); |
| if (err != 0 || !res) { |
| throw std::invalid_argument( |
| "host not found: " + std::string(gai_strerror(err))); |
| } |
| |
| std::shared_ptr<struct ::addrinfo> addresses( |
| res, [](struct ::addrinfo* p) { ::freeaddrinfo(p); }); |
| |
| struct ::addrinfo* nextAddr = addresses.get(); |
| int socket; |
| |
| // Loop over the addresses if at least one of them gave us ECONNREFUSED |
| // or ECONNRESET. This may happen if the server hasn't started listening |
| // yet, or is listening but has its listen backlog exhausted. |
| bool anyRefused = false; |
| bool anyReset = false; |
| const auto start = std::chrono::high_resolution_clock::now(); |
| while (true) { |
| try { |
| SYSCHECK_ERR_RETURN_NEG1( |
| socket = ::socket( |
| nextAddr->ai_family, |
| nextAddr->ai_socktype, |
| nextAddr->ai_protocol)) |
| |
| ResourceGuard socketGuard([socket]() { tcputil::closeSocket(socket); }); |
| |
| // We need to connect in non-blocking mode, so we can use a timeout |
| waitSocketConnected(socket, nextAddr, timeout, start); |
| |
| socketGuard.release(); |
| break; |
| |
| } catch (std::system_error& e) { |
| handleConnectSystemError( |
| &nextAddr, |
| e, |
| &anyRefused, |
| &anyReset, |
| wait, |
| start, |
| addresses, |
| timeout); |
| } catch (std::exception& e) { |
| handleConnectException( |
| &nextAddr, |
| errno, |
| &anyRefused, |
| &anyReset, |
| wait, |
| start, |
| addresses, |
| timeout); |
| } |
| } |
| |
| setSocketNoDelay(socket); |
| |
| return socket; |
| } |
| |
| std::tuple<int, std::string> accept( |
| int listenSocket, |
| const std::chrono::milliseconds& timeout) { |
| // poll on listen socket, it allows to make timeout |
| std::unique_ptr<struct ::pollfd[]> events(new struct ::pollfd[1]); |
| events[0] = tcputil::getPollfd(listenSocket, POLLIN); |
| |
| while (true) { |
| int res = tcputil::poll(events.get(), 1, timeout.count()); |
| if (res == 0) { |
| throw std::runtime_error( |
| "waiting for processes to " |
| "connect has timed out"); |
| } else if (res == -1) { |
| if (errno == EINTR) { |
| continue; |
| } |
| throw std::system_error(errno, std::system_category()); |
| } else { |
| if (!(events[0].revents & POLLIN)) |
| throw std::system_error(ECONNABORTED, std::system_category()); |
| break; |
| } |
| } |
| |
| int socket; |
| SYSCHECK_ERR_RETURN_NEG1(socket = ::accept(listenSocket, NULL, NULL)) |
| |
| // Get address of the connecting process |
| struct ::sockaddr_storage addr; |
| socklen_t addrLen = sizeof(addr); |
| SYSCHECK_ERR_RETURN_NEG1(::getpeername( |
| socket, reinterpret_cast<struct ::sockaddr*>(&addr), &addrLen)) |
| |
| setSocketNoDelay(socket); |
| |
| return std::make_tuple( |
| socket, sockaddrToString(reinterpret_cast<struct ::sockaddr*>(&addr))); |
| } |
| } // namespace tcputil |
| } // namespace c10d |