blob: 48fb53408b47c4a49b65bc029b9a5ba8e4528cd7 [file] [log] [blame]
Edward Yang517c7c92018-12-08 19:32:01 -08001#include <c10d/TCPStore.hpp>
Teng Li0d27d262018-05-17 13:38:06 -07002
3#include <poll.h>
Pieter Noordhuis2ead3b02018-07-12 17:43:27 -07004
Teng Li0d27d262018-05-17 13:38:06 -07005#include <unistd.h>
Teng Li0d27d262018-05-17 13:38:06 -07006#include <algorithm>
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -07007#include <system_error>
Teng Li0d27d262018-05-17 13:38:06 -07008
9namespace c10d {
10
11namespace {
12
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070013enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT };
Teng Li0d27d262018-05-17 13:38:06 -070014
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070015enum class CheckResponseType : uint8_t { READY, NOT_READY };
Teng Li0d27d262018-05-17 13:38:06 -070016
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070017enum class WaitResponseType : uint8_t { STOP_WAITING };
Teng Li0d27d262018-05-17 13:38:06 -070018
19} // anonymous namespace
20
21// TCPStoreDaemon class methods
22// Simply start the daemon thread
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070023TCPStoreDaemon::TCPStoreDaemon(int storeListenSocket)
24 : storeListenSocket_(storeListenSocket) {
Pieter Noordhuisc3603302018-11-10 17:52:31 -080025 // Use control pipe to signal instance destruction to the daemon thread.
26 if (pipe(controlPipeFd_.data()) == -1) {
27 throw std::runtime_error(
28 "Failed to create the control pipe to start the "
29 "TCPStoreDaemon run");
30 }
Teng Li0d27d262018-05-17 13:38:06 -070031 daemonThread_ = std::thread(&TCPStoreDaemon::run, this);
32}
33
34TCPStoreDaemon::~TCPStoreDaemon() {
35 // Stop the run
36 stop();
37 // Join the thread
38 join();
39 // Close unclosed sockets
40 for (auto socket : sockets_) {
41 if (socket != -1) {
42 ::close(socket);
43 }
44 }
45 // Now close the rest control pipe
46 for (auto fd : controlPipeFd_) {
47 if (fd != -1) {
48 ::close(fd);
49 }
50 }
51}
52
53void TCPStoreDaemon::join() {
54 daemonThread_.join();
55}
56
57void TCPStoreDaemon::run() {
Teng Li0d27d262018-05-17 13:38:06 -070058 std::vector<struct pollfd> fds;
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070059 fds.push_back({.fd = storeListenSocket_, .events = POLLIN});
Teng Li0d27d262018-05-17 13:38:06 -070060 // Push the read end of the pipe to signal the stopping of the daemon run
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070061 fds.push_back({.fd = controlPipeFd_[0], .events = POLLHUP});
Teng Li0d27d262018-05-17 13:38:06 -070062
63 // receive the queries
64 bool finished = false;
65 while (!finished) {
66 for (size_t i = 0; i < sockets_.size(); i++) {
67 fds[i].revents = 0;
68 }
69
SsnL774705b2019-01-14 15:59:29 -080070 SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
Teng Li0d27d262018-05-17 13:38:06 -070071
72 // TCPStore's listening socket has an event and it should now be able to
73 // accept new connections.
74 if (fds[0].revents != 0) {
75 if (fds[0].revents ^ POLLIN) {
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070076 throw std::system_error(
77 ECONNABORTED,
78 std::system_category(),
Teng Li0d27d262018-05-17 13:38:06 -070079 "Unexpected poll revent on the master's listening socket: " +
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070080 std::to_string(fds[0].revents));
Teng Li0d27d262018-05-17 13:38:06 -070081 }
82 int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
83 sockets_.push_back(sockFd);
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070084 fds.push_back({.fd = sockFd, .events = POLLIN});
Teng Li0d27d262018-05-17 13:38:06 -070085 }
86 // The pipe receives an event which tells us to shutdown the daemon
87 if (fds[1].revents != 0) {
88 // Will be POLLUP when the pipe is closed
89 if (fds[1].revents ^ POLLHUP) {
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070090 throw std::system_error(
91 ECONNABORTED,
92 std::system_category(),
Teng Li0d27d262018-05-17 13:38:06 -070093 "Unexpected poll revent on the control pipe's reading fd: " +
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -070094 std::to_string(fds[1].revents));
Teng Li0d27d262018-05-17 13:38:06 -070095 }
96 finished = true;
97 break;
98 }
99 // Skipping the fds[0] and fds[1],
100 // fds[0] is master's listening socket
101 // fds[1] is control pipe's reading fd
102 for (size_t fdIdx = 2; fdIdx < fds.size(); ++fdIdx) {
103 if (fds[fdIdx].revents == 0) {
104 continue;
105 }
106
Teng Li0d27d262018-05-17 13:38:06 -0700107 // Now query the socket that has the event
108 try {
109 query(fds[fdIdx].fd);
110 } catch (...) {
111 // There was an error when processing query. Probably an exception
112 // occurred in recv/send what would indicate that socket on the other
113 // side has been closed. If the closing was due to normal exit, then
114 // the store should continue executing. Otherwise, if it was different
115 // exception, other connections will get an exception once they try to
116 // use the store. We will go ahead and close this connection whenever
117 // we hit an exception here.
118 ::close(fds[fdIdx].fd);
119
120 // Remove all the tracking state of the close FD
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700121 for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
122 for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
Teng Li0d27d262018-05-17 13:38:06 -0700123 if (*vecIt == fds[fdIdx].fd) {
124 vecIt = it->second.erase(vecIt);
125 } else {
126 ++vecIt;
127 }
128 }
129 if (it->second.size() == 0) {
130 it = waitingSockets_.erase(it);
131 } else {
132 ++it;
133 }
134 }
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700135 for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
Teng Li0d27d262018-05-17 13:38:06 -0700136 if (it->first == fds[fdIdx].fd) {
137 it = keysAwaited_.erase(it);
138 } else {
139 ++it;
140 }
141 }
142 fds.erase(fds.begin() + fdIdx);
143 sockets_.erase(sockets_.begin() + fdIdx - 2);
144 --fdIdx;
145 continue;
146 }
147 }
148 }
149}
150
151void TCPStoreDaemon::stop() {
152 if (controlPipeFd_[1] != -1) {
153 // close the write end of the pipe
154 ::close(controlPipeFd_[1]);
155 controlPipeFd_[1] = -1;
156 }
157}
158
159// query communicates with the worker. The format
160// of the query is as follows:
161// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
162// or, in the case of wait
163// type of query | number of args | size of arg1 | arg1 | ...
164void TCPStoreDaemon::query(int socket) {
165 QueryType qt;
166 tcputil::recvBytes<QueryType>(socket, &qt, 1);
167
168 if (qt == QueryType::SET) {
169 setHandler(socket);
170
171 } else if (qt == QueryType::ADD) {
172 addHandler(socket);
173
174 } else if (qt == QueryType::GET) {
175 getHandler(socket);
176
177 } else if (qt == QueryType::CHECK) {
178 checkHandler(socket);
179
180 } else if (qt == QueryType::WAIT) {
181 waitHandler(socket);
182
183 } else {
184 throw std::runtime_error("Unexpected query type");
185 }
186}
187
188void TCPStoreDaemon::wakeupWaitingClients(const std::string& key) {
189 auto socketsToWait = waitingSockets_.find(key);
190 if (socketsToWait != waitingSockets_.end()) {
191 for (int socket : socketsToWait->second) {
192 if (--keysAwaited_[socket] == 0) {
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700193 tcputil::sendValue<WaitResponseType>(
194 socket, WaitResponseType::STOP_WAITING);
Teng Li0d27d262018-05-17 13:38:06 -0700195 }
196 }
197 waitingSockets_.erase(socketsToWait);
198 }
199}
200
201void TCPStoreDaemon::setHandler(int socket) {
202 std::string key = tcputil::recvString(socket);
203 tcpStore_[key] = tcputil::recvVector<uint8_t>(socket);
204 // On "set", wake up all clients that have been waiting
205 wakeupWaitingClients(key);
206}
207
208void TCPStoreDaemon::addHandler(int socket) {
209 std::string key = tcputil::recvString(socket);
210 int64_t addVal = tcputil::recvValue<int64_t>(socket);
211
212 if (tcpStore_.find(key) != tcpStore_.end()) {
213 auto buf = reinterpret_cast<const char*>(tcpStore_[key].data());
214 auto len = tcpStore_[key].size();
215 addVal += std::stoll(std::string(buf, len));
216 }
217 auto addValStr = std::to_string(addVal);
218 tcpStore_[key] = std::vector<uint8_t>(addValStr.begin(), addValStr.end());
219 // Now send the new value
220 tcputil::sendValue<int64_t>(socket, addVal);
221 // On "add", wake up all clients that have been waiting
222 wakeupWaitingClients(key);
223}
224
225void TCPStoreDaemon::getHandler(int socket) const {
226 std::string key = tcputil::recvString(socket);
227 auto data = tcpStore_.at(key);
228 tcputil::sendVector<uint8_t>(socket, data);
229}
230
231void TCPStoreDaemon::checkHandler(int socket) const {
232 SizeType nargs;
233 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
234 std::vector<std::string> keys(nargs);
235 for (size_t i = 0; i < nargs; i++) {
236 keys[i] = tcputil::recvString(socket);
237 }
238 // Now we have received all the keys
239 if (checkKeys(keys)) {
240 tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
241 } else {
242 tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
243 }
244}
245
246void TCPStoreDaemon::waitHandler(int socket) {
247 SizeType nargs;
248 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
249 std::vector<std::string> keys(nargs);
250 for (size_t i = 0; i < nargs; i++) {
251 keys[i] = tcputil::recvString(socket);
252 }
253 if (checkKeys(keys)) {
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700254 tcputil::sendValue<WaitResponseType>(
255 socket, WaitResponseType::STOP_WAITING);
Teng Li0d27d262018-05-17 13:38:06 -0700256 } else {
257 for (auto& key : keys) {
258 waitingSockets_[key].push_back(socket);
259 }
260 keysAwaited_[socket] = keys.size();
261 }
262}
263
264bool TCPStoreDaemon::checkKeys(const std::vector<std::string>& keys) const {
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700265 return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
266 return tcpStore_.count(s) > 0;
267 });
Teng Li0d27d262018-05-17 13:38:06 -0700268}
269
270// TCPStore class methods
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700271TCPStore::TCPStore(
272 const std::string& masterAddr,
273 PortType masterPort,
Teng Lib4bc55b2019-01-18 02:23:51 -0800274 int numWorkers,
Rohan Varmaf57ecd52019-09-24 12:34:20 -0700275 bool isServer,
276 const std::chrono::milliseconds& timeout)
277 : Store(timeout),
278 isServer_(isServer),
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700279 tcpStoreAddr_(masterAddr),
Teng Lib4bc55b2019-01-18 02:23:51 -0800280 tcpStorePort_(masterPort),
281 numWorkers_(numWorkers),
282 initKey_("init/"),
283 regularPrefix_("/") {
Teng Li0d27d262018-05-17 13:38:06 -0700284 if (isServer_) {
285 // Opening up the listening socket
286 std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort);
287 // Now start the daemon
288 tcpStoreDaemon_ = std::unique_ptr<TCPStoreDaemon>(
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700289 new TCPStoreDaemon(masterListenSocket_));
Teng Li0d27d262018-05-17 13:38:06 -0700290 }
291 // Connect to the daemon
Rohan Varmaf57ecd52019-09-24 12:34:20 -0700292 storeSocket_ = tcputil::connect(
293 tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
Teng Lib4bc55b2019-01-18 02:23:51 -0800294
295 waitForWorkers_();
Teng Li0d27d262018-05-17 13:38:06 -0700296}
297
298TCPStore::~TCPStore() {
299 ::close(storeSocket_);
300 if (isServer_) {
301 // Store daemon should end because of closed connection.
302 // daemon destructor should join the thread
303 tcpStoreDaemon_.reset(nullptr);
304 ::close(masterListenSocket_);
305 }
306}
307
Teng Lib4bc55b2019-01-18 02:23:51 -0800308void TCPStore::waitForWorkers_() {
309 addHelper_(initKey_, 1);
310 // Let server block until all workers have completed, this ensures that
311 // the server daemon thread is always running until the very end
312 if (isServer_) {
313 const auto start = std::chrono::steady_clock::now();
314 while (true) {
315 std::vector<uint8_t> value = getHelper_(initKey_);
316 auto buf = reinterpret_cast<const char*>(value.data());
317 auto len = value.size();
318 int numWorkersCompleted = std::stoi(std::string(buf, len));
319 if (numWorkersCompleted >= numWorkers_) {
320 break;
321 }
322 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
323 std::chrono::steady_clock::now() - start);
324 if (timeout_ != kNoTimeout && elapsed > timeout_) {
325 break;
326 }
327 /* sleep override */
328 std::this_thread::sleep_for(std::chrono::milliseconds(10));
329 }
330 }
331}
332
Teng Li0d27d262018-05-17 13:38:06 -0700333void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
Teng Lib4bc55b2019-01-18 02:23:51 -0800334 std::string regKey = regularPrefix_ + key;
Teng Li0d27d262018-05-17 13:38:06 -0700335 tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
Teng Lib4bc55b2019-01-18 02:23:51 -0800336 tcputil::sendString(storeSocket_, regKey, true);
Teng Li0d27d262018-05-17 13:38:06 -0700337 tcputil::sendVector<uint8_t>(storeSocket_, data);
338}
339
340std::vector<uint8_t> TCPStore::get(const std::string& key) {
Teng Lib4bc55b2019-01-18 02:23:51 -0800341 std::string regKey = regularPrefix_ + key;
342 return getHelper_(regKey);
343}
344
345std::vector<uint8_t> TCPStore::getHelper_(const std::string& key) {
346 waitHelper_({key}, timeout_);
Teng Li0d27d262018-05-17 13:38:06 -0700347 tcputil::sendValue<QueryType>(storeSocket_, QueryType::GET);
348 tcputil::sendString(storeSocket_, key);
349 return tcputil::recvVector<uint8_t>(storeSocket_);
350}
351
352int64_t TCPStore::add(const std::string& key, int64_t value) {
Teng Lib4bc55b2019-01-18 02:23:51 -0800353 std::string regKey = regularPrefix_ + key;
354 return addHelper_(regKey, value);
355}
356
357int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
Teng Li0d27d262018-05-17 13:38:06 -0700358 tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
359 tcputil::sendString(storeSocket_, key, true);
360 tcputil::sendValue<int64_t>(storeSocket_, value);
361 return tcputil::recvValue<int64_t>(storeSocket_);
362}
363
364bool TCPStore::check(const std::vector<std::string>& keys) {
365 tcputil::sendValue<QueryType>(storeSocket_, QueryType::CHECK);
366 SizeType nkeys = keys.size();
367 tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
368 for (size_t i = 0; i < nkeys; i++) {
Teng Lib4bc55b2019-01-18 02:23:51 -0800369 std::string regKey = regularPrefix_ + keys[i];
370 tcputil::sendString(storeSocket_, regKey, (i != (nkeys - 1)));
Teng Li0d27d262018-05-17 13:38:06 -0700371 }
372 auto checkResponse = tcputil::recvValue<CheckResponseType>(storeSocket_);
373 if (checkResponse == CheckResponseType::READY) {
374 return true;
375 } else if (checkResponse == CheckResponseType::NOT_READY) {
376 return false;
377 } else {
378 throw std::runtime_error("ready or not_ready response expected");
379 }
380}
381
Teng Liec195122018-09-06 12:47:20 -0700382void TCPStore::wait(const std::vector<std::string>& keys) {
383 wait(keys, timeout_);
384}
385
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700386void TCPStore::wait(
387 const std::vector<std::string>& keys,
388 const std::chrono::milliseconds& timeout) {
Teng Lib4bc55b2019-01-18 02:23:51 -0800389 std::vector<std::string> regKeys;
390 regKeys.resize(keys.size());
391 for (size_t i = 0; i < keys.size(); ++i) {
392 regKeys[i] = regularPrefix_ + keys[i];
393 }
394 waitHelper_(regKeys, timeout);
395}
396
397void TCPStore::waitHelper_(
398 const std::vector<std::string>& keys,
399 const std::chrono::milliseconds& timeout) {
Teng Li0d27d262018-05-17 13:38:06 -0700400 // Set the socket timeout if there is a wait timeout
401 if (timeout != kNoTimeout) {
402 struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000,
403 .tv_usec = (timeout.count() % 1000) * 1000};
SsnL774705b2019-01-14 15:59:29 -0800404 SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
Pieter Noordhuis7d0de4f2018-05-23 11:26:35 -0700405 storeSocket_,
406 SOL_SOCKET,
407 SO_RCVTIMEO,
408 reinterpret_cast<char*>(&timeoutTV),
409 sizeof(timeoutTV)));
Teng Li0d27d262018-05-17 13:38:06 -0700410 }
411 tcputil::sendValue<QueryType>(storeSocket_, QueryType::WAIT);
412 SizeType nkeys = keys.size();
413 tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
414 for (size_t i = 0; i < nkeys; i++) {
415 tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1)));
416 }
417 auto waitResponse = tcputil::recvValue<WaitResponseType>(storeSocket_);
418 if (waitResponse != WaitResponseType::STOP_WAITING) {
419 throw std::runtime_error("Stop_waiting response is expected");
420 }
421}
422
423} // namespace c10d