[C10D] Reimplement TCPStore wait timeout logic. (#100594)

Current TCPStore wait logic leaves the client socket in a bad state if waiting timesout.

This happens because all recv functions raise an exception on timeout and that's it.
The problem is that on timeout we need to unregister the wait.

We implement this with client side cancelation by adding a new CANCEL_WAIT instruction.

So, if no data arrives before the deadline, the client sends a CANCEL_WAIT command.
The server sends a WAIT_CANCELED response to that command, always.

This gets us down to the last issue, which is that there's a race between timeout'ing,
canceling the wait and the wait completing. The client needs to handle the server sending
a STOP_WAITING followed by a WAIT_CANCELED answer.

This ensures client and server state are synchronized regardless of whether the wait
timeouts or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100594
Approved by: https://github.com/H-Huang
diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py
index cf2a37a..67f0bc8 100644
--- a/test/distributed/test_store.py
+++ b/test/distributed/test_store.py
@@ -581,6 +581,26 @@
         time_diff = end - start
         self.assertGreater(test_store_timeout.seconds * 10, time_diff)
 
+    def test_tcp_store_timeout_doest_break_client(self):
+        url = self.create_tcp_url()
+        test_store_timeout = timedelta(seconds=10)
+        gen0 = dist.rendezvous(url + "&rank=0", timeout=test_store_timeout)
+        store0, rank0, size0 = next(gen0)
+        # this should time out in 10s. If the timeout passed into rendezvous was
+        # not respected, it will take much longer to timeout.
+        start = time.time()
+        with self.assertRaisesRegex(RuntimeError, "Timeout"):
+            store0.get("the_key")
+
+        store0.set("the_key", "x")
+
+        self.assertEqual(b"x", store0.get("the_key"))
+
+        end = time.time()
+        time_diff = end - start
+        self.assertGreater(test_store_timeout.seconds * 10, time_diff)
+
+
 class DummyStore(dist.Store):
     def __init__(self):
         self.appends = []
diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp
index f727d60..311178d 100644
--- a/torch/csrc/distributed/c10d/TCPStore.cpp
+++ b/torch/csrc/distributed/c10d/TCPStore.cpp
@@ -149,11 +149,12 @@
   APPEND,
   MULTI_GET,
   MULTI_SET,
+  CANCEL_WAIT,
 };
 
 enum class CheckResponseType : uint8_t { READY, NOT_READY };
 
-enum class WaitResponseType : uint8_t { STOP_WAITING };
+enum class WaitResponseType : uint8_t { STOP_WAITING, WAIT_CANCELED };
 
 enum class WatchResponseType : uint8_t {
   KEY_UPDATED,
@@ -174,6 +175,7 @@
   void run();
   void queryFds(std::vector<struct pollfd>& fds);
   void query(int socket);
+  void clearSocketWaitState(int socket);
 
   // The master runs on a single thread so only
   // one handler can be executed at a time
@@ -189,6 +191,7 @@
   void appendHandler(int socket);
   void multiGetHandler(int socket);
   void multiSetHandler(int socket);
+  void cancelWaitHandler(int socket);
 
   bool checkKeys(const std::vector<std::string>& keys) const;
   // Helper function to alerts waiting workers, used in setHandler, getHandler
@@ -241,29 +244,8 @@
       // exception, other connections will get an exception once they try to
       // use the store. We will go ahead and close this connection whenever
       // we hit an exception here.
+      clearSocketWaitState(fds[fdIdx].fd);
 
-      // Remove all the tracking state of the close FD
-      for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
-        for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
-          if (*vecIt == fds[fdIdx].fd) {
-            vecIt = it->second.erase(vecIt);
-          } else {
-            ++vecIt;
-          }
-        }
-        if (it->second.empty()) {
-          it = waitingSockets_.erase(it);
-        } else {
-          ++it;
-        }
-      }
-      for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
-        if (it->first == fds[fdIdx].fd) {
-          it = keysAwaited_.erase(it);
-        } else {
-          ++it;
-        }
-      }
       fds.erase(fds.begin() + fdIdx);
       sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
       --fdIdx;
@@ -272,6 +254,31 @@
   }
 }
 
+void TCPStoreMasterDaemon::clearSocketWaitState(int socket) {
+  // Remove all the tracking state of the close FD
+  for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
+    for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
+      if (*vecIt == socket) {
+        vecIt = it->second.erase(vecIt);
+      } else {
+        ++vecIt;
+      }
+    }
+    if (it->second.empty()) {
+      it = waitingSockets_.erase(it);
+    } else {
+      ++it;
+    }
+  }
+  for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
+    if (it->first == socket) {
+      it = keysAwaited_.erase(it);
+    } else {
+      ++it;
+    }
+  }
+}
+
 // query communicates with the worker. The format
 // of the query is as follows:
 // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
@@ -312,6 +319,8 @@
     multiGetHandler(socket);
   } else if (qt == QueryType::MULTI_SET) {
     multiSetHandler(socket);
+  } else if (qt == QueryType::CANCEL_WAIT) {
+    cancelWaitHandler(socket);
   } else {
     TORCH_CHECK(false, "Unexpected query type");
   }
@@ -544,6 +553,14 @@
   }
 }
 
+void TCPStoreMasterDaemon::cancelWaitHandler(int socket) {
+  clearSocketWaitState(socket);
+
+  // Send update to TCPStoreWorkerDaemon on client
+  tcputil::sendValue<WaitResponseType>(
+      socket, detail::WaitResponseType::WAIT_CANCELED);
+}
+
 bool TCPStoreMasterDaemon::checkKeys(
     const std::vector<std::string>& keys) const {
   return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
@@ -910,7 +927,13 @@
   T receiveValue() {
     return tcputil::recvValue<T>(socket_.handle());
   }
-
+  template <typename T>
+  bool receiveValueWithTimeout(T& t, std::chrono::milliseconds timeout) {
+    if (!socket_.waitForInput(timeout))
+      return false;
+    t = tcputil::recvValue<T>(socket_.handle());
+    return true;
+  }
   void setTimeout(std::chrono::milliseconds value);
 
   explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
@@ -1236,20 +1259,43 @@
 void TCPStore::doWait(
     c10::ArrayRef<std::string> keys,
     std::chrono::milliseconds timeout) {
-  // TODO: Should we revert to the original timeout at the end of the call?
-  client_->setTimeout(timeout);
-
-  detail::SendBuffer buffer(*client_, detail::QueryType::WAIT);
-  buffer.appendValue(keys.size());
-  for (const std::string& key : keys) {
-    buffer.appendString(key);
+  {
+    detail::SendBuffer buffer(*client_, detail::QueryType::WAIT);
+    buffer.appendValue(keys.size());
+    for (const std::string& key : keys) {
+      buffer.appendString(key);
+    }
+    buffer.flush();
   }
-  buffer.flush();
 
-  auto response = client_->receiveValue<detail::WaitResponseType>();
-  if (response != detail::WaitResponseType::STOP_WAITING) {
-    TORCH_CHECK(false, "Stop_waiting response is expected");
+  detail::WaitResponseType response;
+  if (client_->receiveValueWithTimeout<detail::WaitResponseType>(
+          response, timeout)) {
+    if (response != detail::WaitResponseType::STOP_WAITING) {
+      TORCH_CHECK(false, "Stop_waiting response is expected");
+    }
+    return;
   }
+  // this is the cancel wait timeout, once here we expect the server to respond
+  // in a timely fashion
+  {
+    detail::SendBuffer buffer(*client_, detail::QueryType::CANCEL_WAIT);
+    buffer.flush();
+  }
+
+  response = client_->receiveValue<detail::WaitResponseType>();
+  // this can happen if the server responds before we cancel, just ignore it
+  if (response != detail::WaitResponseType::WAIT_CANCELED) {
+    if (response != detail::WaitResponseType::STOP_WAITING) {
+      TORCH_CHECK(false, "Stop_waiting response is expected");
+    }
+
+    response = client_->receiveValue<detail::WaitResponseType>(); // ignore
+    if (response != detail::WaitResponseType::WAIT_CANCELED) {
+      TORCH_CHECK(false, "wait_canceled response is expected");
+    }
+  }
+  TORCH_CHECK(false, "Socket Timeout");
 }
 
 void TCPStore::append(
diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp
index e3cf537..98e02c5 100644
--- a/torch/csrc/distributed/c10d/socket.cpp
+++ b/torch/csrc/distributed/c10d/socket.cpp
@@ -174,6 +174,8 @@
     return hnd_;
   }
 
+  bool waitForInput(std::chrono::milliseconds timeout);
+
  private:
   bool setSocketFlag(int level, int optname, bool value) noexcept;
 
@@ -398,6 +400,14 @@
   return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0;
 }
 
+bool SocketImpl::waitForInput(std::chrono::milliseconds timeout) {
+  ::pollfd pfd{};
+  pfd.fd = hnd_;
+  pfd.events = POLLIN;
+
+  return pollFd(&pfd, 1, static_cast<int>(timeout.count())) > 0;
+}
+
 namespace {
 
 struct addrinfo_delete {
@@ -983,6 +993,10 @@
 Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept
     : impl_{std::move(impl)} {}
 
+bool Socket::waitForInput(std::chrono::milliseconds timeout) {
+  return impl_->waitForInput(timeout);
+}
+
 } // namespace detail
 
 SocketError::~SocketError() = default;
diff --git a/torch/csrc/distributed/c10d/socket.h b/torch/csrc/distributed/c10d/socket.h
index 7fceb6c..43f98bc 100644
--- a/torch/csrc/distributed/c10d/socket.h
+++ b/torch/csrc/distributed/c10d/socket.h
@@ -79,6 +79,8 @@
 
   std::uint16_t port() const;
 
+  bool waitForInput(std::chrono::milliseconds timeout);
+
  private:
   explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept;