blob: dfd33cfb77ca46432d596fe6ee572677fd099fcb [file] [log] [blame]
#include <c10d/TCPStore.hpp>
#include <poll.h>
#include <unistd.h>
#include <algorithm>
#include <system_error>
namespace c10d {
namespace {
enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT };
enum class CheckResponseType : uint8_t { READY, NOT_READY };
enum class WaitResponseType : uint8_t { STOP_WAITING };
} // anonymous namespace
// TCPStoreDaemon class methods
// Simply start the daemon thread
TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket)
: storeListenSocket_(storeListenSocket) {
// Use control pipe to signal instance destruction to the daemon thread.
if (pipe(controlPipeFd_.data()) == -1) {
throw std::runtime_error(
"Failed to create the control pipe to start the "
"TCPStoreDaemon run");
}
daemonThread_ = std::thread(&TCPStoreDaemon::run, this);
}
TCPStoreDaemon::~TCPStoreDaemon() {
// Stop the run
stop();
// Join the thread
join();
// Close unclosed sockets
for (auto socket : sockets_) {
if (socket != -1) {
::close(socket);
}
}
// Now close the rest control pipe
for (auto fd : controlPipeFd_) {
if (fd != -1) {
::close(fd);
}
}
}
void TCPStoreDaemon::join() {
daemonThread_.join();
}
void TCPStoreDaemon::run() {
std::vector<struct pollfd> fds;
fds.push_back({.fd = storeListenSocket_, .events = POLLIN});
// Push the read end of the pipe to signal the stopping of the daemon run
fds.push_back({.fd = controlPipeFd_[0], .events = POLLHUP});
// receive the queries
bool finished = false;
while (!finished) {
for (size_t i = 0; i < sockets_.size(); i++) {
fds[i].revents = 0;
}
SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
// TCPStore's listening socket has an event and it should now be able to
// accept new connections.
if (fds[0].revents != 0) {
if (fds[0].revents ^ POLLIN) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the master's listening socket: " +
std::to_string(fds[0].revents));
}
int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
sockets_.push_back(sockFd);
fds.push_back({.fd = sockFd, .events = POLLIN});
}
// The pipe receives an event which tells us to shutdown the daemon
if (fds[1].revents != 0) {
// Will be POLLUP when the pipe is closed
if (fds[1].revents ^ POLLHUP) {
throw std::system_error(
ECONNABORTED,
std::system_category(),
"Unexpected poll revent on the control pipe's reading fd: " +
std::to_string(fds[1].revents));
}
finished = true;
break;
}
// Skipping the fds[0] and fds[1],
// fds[0] is master's listening socket
// fds[1] is control pipe's reading fd
for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) {
if (fds[fdIdx].revents == 0) {
continue;
}
// Now query the socket that has the event
try {
query(fds[fdIdx].fd);
} catch (...) {
// There was an error when processing query. Probably an exception
// occurred in recv/send what would indicate that socket on the other
// side has been closed. If the closing was due to normal exit, then
// the store should continue executing. Otherwise, if it was different
// exception, other connections will get an exception once they try to
// use the store. We will go ahead and close this connection whenever
// we hit an exception here.
::close(fds[fdIdx].fd);
// Remove all the tracking state of the close FD
for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
if (*vecIt == fds[fdIdx].fd) {
vecIt = it->second.erase(vecIt);
} else {
++vecIt;
}
}
if (it->second.size() == 0) {
it = waitingSockets_.erase(it);
} else {
++it;
}
}
for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
if (it->first == fds[fdIdx].fd) {
it = keysAwaited_.erase(it);
} else {
++it;
}
}
fds.erase(fds.begin() + fdIdx);
sockets_.erase(sockets_.begin() + fdIdx - 2);
--fdIdx;
continue;
}
}
}
}
void TCPStoreDaemon::stop() {
if (controlPipeFd_[1] != -1) {
// close the write end of the pipe
::close(controlPipeFd_[1]);
controlPipeFd_[1] = -1;
}
}
// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreDaemon::query(int socket) {
QueryType qt;
tcputil::recvBytes<QueryType>(socket, &qt, 1);
if (qt == QueryType::SET) {
setHandler(socket);
} else if (qt == QueryType::ADD) {
addHandler(socket);
} else if (qt == QueryType::GET) {
getHandler(socket);
} else if (qt == QueryType::CHECK) {
checkHandler(socket);
} else if (qt == QueryType::WAIT) {
waitHandler(socket);
} else {
throw std::runtime_error("Unexpected query type");
}
}
void TCPStoreDaemon::wakeupWaitingClients(const std::string& key) {
auto socketsToWait = waitingSockets_.find(key);
if (socketsToWait != waitingSockets_.end()) {
for (int socket : socketsToWait->second) {
if (--keysAwaited_[socket] == 0) {
tcputil::sendValue<WaitResponseType>(
socket, WaitResponseType::STOP_WAITING);
}
}
waitingSockets_.erase(socketsToWait);
}
}
void TCPStoreDaemon::setHandler(int socket) {
std::string key = tcputil::recvString(socket);
tcpStore_[key] = tcputil::recvVector<uint8_t>(socket);
// On "set", wake up all clients that have been waiting
wakeupWaitingClients(key);
}
void TCPStoreDaemon::addHandler(int socket) {
std::string key = tcputil::recvString(socket);
int64_t addVal = tcputil::recvValue<int64_t>(socket);
if (tcpStore_.find(key) != tcpStore_.end()) {
auto buf = reinterpret_cast<const char*>(tcpStore_[key].data());
auto len = tcpStore_[key].size();
addVal += std::stoll(std::string(buf, len));
}
auto addValStr = std::to_string(addVal);
tcpStore_[key] = std::vector<uint8_t>(addValStr.begin(), addValStr.end());
// Now send the new value
tcputil::sendValue<int64_t>(socket, addVal);
// On "add", wake up all clients that have been waiting
wakeupWaitingClients(key);
}
void TCPStoreDaemon::getHandler(int socket) const {
std::string key = tcputil::recvString(socket);
auto data = tcpStore_.at(key);
tcputil::sendVector<uint8_t>(socket, data);
}
void TCPStoreDaemon::checkHandler(int socket) const {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (size_t i = 0; i < nargs; i++) {
keys[i] = tcputil::recvString(socket);
}
// Now we have received all the keys
if (checkKeys(keys)) {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
} else {
tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
}
}
void TCPStoreDaemon::waitHandler(int socket) {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
std::vector<std::string> keys(nargs);
for (size_t i = 0; i < nargs; i++) {
keys[i] = tcputil::recvString(socket);
}
if (checkKeys(keys)) {
tcputil::sendValue<WaitResponseType>(
socket, WaitResponseType::STOP_WAITING);
} else {
for (auto& key : keys) {
waitingSockets_[key].push_back(socket);
}
keysAwaited_[socket] = keys.size();
}
}
bool TCPStoreDaemon::checkKeys(const std::vector<std::string>& keys) const {
return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
return tcpStore_.count(s) > 0;
});
}
// TCPStore class methods
TCPStore::TCPStore(
const std::string& masterAddr,
PortType masterPort,
int numWorkers,
bool isServer,
const std::chrono::milliseconds& timeout,
bool waitWorkers)
: Store(timeout),
isServer_(isServer),
tcpStoreAddr_(masterAddr),
tcpStorePort_(masterPort),
numWorkers_(numWorkers),
initKey_("init/"),
regularPrefix_("/") {
if (isServer_) {
// Opening up the listening socket
std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort);
// Now start the daemon
tcpStoreDaemon_ = std::unique_ptr<TCPStoreDaemon>(
new TCPStoreDaemon(masterListenSocket_));
}
// Connect to the daemon
storeSocket_ = tcputil::connect(
tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
if (waitWorkers) {
waitForWorkers();
}
}
TCPStore::~TCPStore() {
::close(storeSocket_);
if (isServer_) {
// Store daemon should end because of closed connection.
// daemon destructor should join the thread
tcpStoreDaemon_.reset(nullptr);
::close(masterListenSocket_);
}
}
void TCPStore::waitForWorkers() {
addHelper_(initKey_, 1);
// Let server block until all workers have completed, this ensures that
// the server daemon thread is always running until the very end
if (isServer_) {
const auto start = std::chrono::steady_clock::now();
while (true) {
std::vector<uint8_t> value = getHelper_(initKey_);
auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size();
int numWorkersCompleted = std::stoi(std::string(buf, len));
if (numWorkersCompleted >= numWorkers_) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout_ != kNoTimeout && elapsed > timeout_) {
break;
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}
void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
std::string regKey = regularPrefix_ + key;
tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
tcputil::sendString(storeSocket_, regKey, true);
tcputil::sendVector<uint8_t>(storeSocket_, data);
}
std::vector<uint8_t> TCPStore::get(const std::string& key) {
std::string regKey = regularPrefix_ + key;
return getHelper_(regKey);
}
std::vector<uint8_t> TCPStore::getHelper_(const std::string& key) {
waitHelper_({key}, timeout_);
tcputil::sendValue<QueryType>(storeSocket_, QueryType::GET);
tcputil::sendString(storeSocket_, key);
return tcputil::recvVector<uint8_t>(storeSocket_);
}
int64_t TCPStore::add(const std::string& key, int64_t value) {
std::string regKey = regularPrefix_ + key;
return addHelper_(regKey, value);
}
int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
tcputil::sendString(storeSocket_, key, true);
tcputil::sendValue<int64_t>(storeSocket_, value);
return tcputil::recvValue<int64_t>(storeSocket_);
}
bool TCPStore::check(const std::vector<std::string>& keys) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::CHECK);
SizeType nkeys = keys.size();
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
for (size_t i = 0; i < nkeys; i++) {
std::string regKey = regularPrefix_ + keys[i];
tcputil::sendString(storeSocket_, regKey, (i != (nkeys - 1)));
}
auto checkResponse = tcputil::recvValue<CheckResponseType>(storeSocket_);
if (checkResponse == CheckResponseType::READY) {
return true;
} else if (checkResponse == CheckResponseType::NOT_READY) {
return false;
} else {
throw std::runtime_error("ready or not_ready response expected");
}
}
void TCPStore::wait(const std::vector<std::string>& keys) {
wait(keys, timeout_);
}
void TCPStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
std::vector<std::string> regKeys;
regKeys.resize(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
regKeys[i] = regularPrefix_ + keys[i];
}
waitHelper_(regKeys, timeout);
}
void TCPStore::waitHelper_(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
// Set the socket timeout if there is a wait timeout
if (timeout != kNoTimeout) {
struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000,
.tv_usec = (timeout.count() % 1000) * 1000};
SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
storeSocket_,
SOL_SOCKET,
SO_RCVTIMEO,
reinterpret_cast<char*>(&timeoutTV),
sizeof(timeoutTV)));
}
tcputil::sendValue<QueryType>(storeSocket_, QueryType::WAIT);
SizeType nkeys = keys.size();
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
for (size_t i = 0; i < nkeys; i++) {
tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1)));
}
auto waitResponse = tcputil::recvValue<WaitResponseType>(storeSocket_);
if (waitResponse != WaitResponseType::STOP_WAITING) {
throw std::runtime_error("Stop_waiting response is expected");
}
}
PortType TCPStore::getPort() {
return tcpStorePort_;
}
} // namespace c10d