use timeout in connect function to prevent against (#26364)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26364
Per https://github.com/pytorch/pytorch/issues/25769, we sometimes get
an infinite loop when `TCPStore` calls `tcputil::connect`, and the server
continually returns `ECONNRESET` or `ECONNREFUSED`. If a proper timeout is passed
in, we guard against this by throwing an exception once the timeout has passed.
Testing: Tested with modifying `TCPStore` to connect to an invalid port, thus getting
`ECONNREFUSED`. If a valid timeout is passed in, the function correctly throws an
exception. Steps below:
1) in TCPStore.cpp's constructor, replace the `connect` call with this line:
`storeSocket_ = tcputil::connect(tcpStoreAddr_, 1, true, std::chrono::milliseconds(3000));`
2) Build the `TCPStoreTest` binary.
3) Run the binary. Expected output:
```
terminate called after throwing an instance of 'std::runtime_error'
what(): Connecting to TCP store timed out.
Aborted (core dumped)
```
ghstack-source-id: 90480086
Test Plan: See above.
Differential Revision: D17430164
fbshipit-source-id: 1482aca72fcc3ddb95ea25649ec057edda5d1934
diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp
index 6076012..ec0cd47 100644
--- a/torch/lib/c10d/Utils.cpp
+++ b/torch/lib/c10d/Utils.cpp
@@ -22,6 +22,7 @@
namespace {
constexpr int LISTEN_QUEUE_SIZE = 2048;
+const std::string kConnectTimeoutMsg = "connect() timed out.";
void setSocketNoDelay(int socket) {
int flag = 1;
@@ -162,6 +163,7 @@
// yet, or is listening but has its listen backlog exhausted.
bool anyRefused = false;
bool anyReset = false;
+ const auto start = std::chrono::high_resolution_clock::now();
while (true) {
try {
SYSCHECK_ERR_RETURN_NEG1(
@@ -185,12 +187,22 @@
pfd.fd = socket;
pfd.events = POLLOUT;
- int numReady = ::poll(&pfd, 1, timeout.count());
+ int64_t pollTimeout = -1;
+ if (timeout != kNoTimeout) {
+ // calculate remaining time and use that as timeout for poll()
+ const auto elapsed = std::chrono::high_resolution_clock::now() - start;
+ const auto remaining =
+ std::chrono::duration_cast<std::chrono::milliseconds>(timeout) -
+ std::chrono::duration_cast<std::chrono::milliseconds>(elapsed);
+ pollTimeout = std::max(
+ static_cast<int64_t>(0), static_cast<int64_t>(remaining.count()));
+ }
+ int numReady = ::poll(&pfd, 1, pollTimeout);
if (numReady < 0) {
throw std::system_error(errno, std::system_category());
} else if (numReady == 0) {
errno = 0;
- throw std::runtime_error("connect() timed out");
+ throw std::runtime_error(kConnectTimeoutMsg);
}
socklen_t errLen = sizeof(errno);
@@ -231,6 +243,16 @@
if (!wait || (!anyRefused && !anyReset)) {
throw;
}
+
+ // if a timeout is specified, check time elapsed to see if we need to
+ // timeout. A timeout is specified if timeout != kNoTimeout.
+ if (timeout != kNoTimeout) {
+ const auto elapsed =
+ std::chrono::high_resolution_clock::now() - start;
+ if (elapsed > timeout) {
+ throw std::runtime_error(kConnectTimeoutMsg);
+ }
+ }
std::this_thread::sleep_for(std::chrono::seconds(1));
anyRefused = false;
anyReset = false;