blob: 46642742d3072d7703d278de03bc8dbd4d24b6f1 [file] [log] [blame]
#include <c10d/FileStore.hpp>
#include <assert.h>
#include <fcntl.h>
#include <stdint.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <windows.h>
#include <fileapi.h>
#include <io.h>
#else
#include <sys/file.h>
#include <unistd.h>
#endif
#include <chrono>
#include <cstdio>
#include <functional>
#include <iostream>
#include <limits>
#include <sstream>
#include <system_error>
#include <thread>
#include <c10/util/Exception.h>
#define SYSASSERT(rv, ...) \
if ((rv) < 0) { \
throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
}
#ifdef _WIN32
#define LOCK_EX 0x00000001
#define LOCK_SH 0x00000010
#define LOCK_UN 0x00000100
int flock_(int fd, int op) {
HANDLE hdl = (HANDLE) _get_osfhandle(fd);
DWORD low = 1, high = 0;
OVERLAPPED offset = {0, 0, 0, 0, NULL};
if (hdl < 0)
return -1;
switch (op) {
case LOCK_EX:
if (LockFileEx(hdl, LOCKFILE_EXCLUSIVE_LOCK, 0, low, high, &offset))
return 0;
break;
case LOCK_SH:
if (LockFileEx(hdl, 0, 0, low, high, &offset))
return 0;
break;
case LOCK_UN:
if(UnlockFileEx(hdl, 0, low, high, &offset) != 0)
return 0;
break;
default:
break;
}
errno = EINVAL;
return -1;
}
#endif
namespace c10d {
namespace {
template <typename F>
typename std::result_of<F()>::type syscall(F fn) {
while (true) {
auto rv = fn();
if (rv == -1) {
if (errno == EINTR) {
continue;
}
}
return rv;
}
}
// For a comprehensive overview of file locking methods,
// see: https://gavv.github.io/blog/file-locks/.
// We stick to flock(2) here because we don't care about
// locking byte ranges and don't want locks to be process-wide.
// RAII wrapper around flock(2)
class Lock {
public:
explicit Lock(int fd, int operation) : fd_(fd) {
flock(operation);
}
~Lock() {
unlock();
}
Lock(const Lock& that) = delete;
Lock& operator=(Lock&& other) noexcept {
if (this != &other) {
fd_ = other.fd_;
other.fd_ = -1;
}
return *this;
}
Lock(Lock&& other) noexcept {
*this = std::move(other);
}
void unlock() {
if (fd_ >= 0) {
flock(LOCK_UN);
fd_ = -1;
}
}
protected:
int fd_{-1};
void flock(int operation) {
#ifdef _WIN32
auto rv = syscall(std::bind(::flock_, fd_, operation));
#else
auto rv = syscall(std::bind(::flock, fd_, operation));
#endif
SYSASSERT(rv, "flock");
}
};
class File {
public:
explicit File(
const std::string& path,
int flags,
std::chrono::milliseconds timeout) {
const auto start = std::chrono::steady_clock::now();
while (true) {
#ifdef _WIN32
fd_ = syscall(std::bind(::open, path.c_str(), flags | _O_BINARY, _S_IREAD | _S_IWRITE));
#else
fd_ = syscall(std::bind(::open, path.c_str(), flags, 0644));
#endif
// Only retry when the file doesn't exist, since we are waiting for the
// file to be created in this case to address the following issue:
// https://github.com/pytorch/pytorch/issues/13750
if (fd_ >= 0 || errno != ENOENT) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout != c10d::Store::kNoTimeout && elapsed > timeout) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
SYSASSERT(fd_, "open(" + path + ")");
}
~File() {
::close(fd_);
}
Lock lockShared() {
return Lock(fd_, LOCK_SH);
}
Lock lockExclusive() {
return Lock(fd_, LOCK_EX);
}
off_t seek(off_t offset, int whence) {
auto rv = syscall(std::bind(lseek, fd_, offset, whence));
SYSASSERT(rv, "lseek");
return rv;
}
off_t tell() {
auto rv = syscall(std::bind(lseek, fd_, 0, SEEK_CUR));
SYSASSERT(rv, "lseek");
return rv;
}
off_t size() {
auto pos = tell();
auto size = seek(0, SEEK_END);
seek(pos, SEEK_SET);
return size;
}
void write(const void* buf, size_t count) {
while (count > 0) {
auto rv = syscall(std::bind(::write, fd_, buf, count));
SYSASSERT(rv, "write");
buf = (uint8_t*)buf + rv;
count -= rv;
}
}
void read(void* buf, size_t count) {
while (count > 0) {
auto rv = syscall(std::bind(::read, fd_, buf, count));
SYSASSERT(rv, "read");
buf = (uint8_t*)buf + rv;
count -= rv;
}
}
void write(const std::string& str) {
uint32_t len = str.size();
assert(str.size() <= std::numeric_limits<decltype(len)>::max());
write(&len, sizeof(len));
write(str.c_str(), len);
}
void write(const std::vector<uint8_t>& data) {
uint32_t len = data.size();
assert(data.size() <= std::numeric_limits<decltype(len)>::max());
write(&len, sizeof(len));
write(data.data(), len);
}
void read(std::string& str) {
uint32_t len;
read(&len, sizeof(len));
std::vector<uint8_t> buf(len);
read(buf.data(), len);
str.assign(buf.begin(), buf.end());
}
void read(std::vector<uint8_t>& data) {
uint32_t len;
read(&len, sizeof(len));
data.resize(len);
read(data.data(), len);
}
protected:
int fd_;
};
off_t refresh(
File& file,
off_t pos,
std::unordered_map<std::string, std::vector<uint8_t>>& cache) {
auto size = file.size();
if (size != pos) {
std::string tmpKey;
std::vector<uint8_t> tmpValue;
file.seek(pos, SEEK_SET);
while (size > pos) {
file.read(tmpKey);
file.read(tmpValue);
cache[tmpKey] = std::move(tmpValue);
pos = file.tell();
}
}
file.seek(0, SEEK_SET);
return pos;
}
} // namespace
FileStore::FileStore(const std::string& path, int numWorkers)
: Store(),
path_(path),
pos_(0),
numWorkers_(numWorkers),
cleanupKey_("cleanup/"),
regularPrefix_("/") {
if (numWorkers_ < 1) {
throw std::runtime_error(
"Number of workers for FileStore should be greater than zero");
}
}
FileStore::~FileStore() {
// cleanup key will be different from all rest keys since all rest keys will
// have a regular prefix.
auto numFinishedWorker = addHelper(cleanupKey_, 1);
// The last worker cleans up the file
if (numFinishedWorker == numWorkers_) {
// Best effort removal without checking the return
std::remove(path_.c_str());
}
}
void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) {
std::string regKey = regularPrefix_ + key;
std::unique_lock<std::mutex> l(activeFileOpLock_);
File file(path_, O_RDWR | O_CREAT, timeout_);
auto lock = file.lockExclusive();
file.seek(0, SEEK_END);
file.write(regKey);
file.write(value);
}
std::vector<uint8_t> FileStore::get(const std::string& key) {
std::string regKey = regularPrefix_ + key;
const auto start = std::chrono::steady_clock::now();
while (true) {
std::unique_lock<std::mutex> l(activeFileOpLock_);
File file(path_, O_RDONLY, timeout_);
auto lock = file.lockShared();
auto size = file.size();
if (cache_.count(regKey) == 0 && size == pos_) {
// No new entries; release the shared lock and sleep for a bit
lock.unlock();
l.unlock();
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout_ != kNoTimeout && elapsed > timeout_) {
throw std::runtime_error("Timeout waiting for key: " + key);
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}
// Always refresh since even though the key exists in the cache,
// it might be outdated
pos_ = refresh(file, pos_, cache_);
if (cache_.count(regKey) != 0) {
return cache_[regKey];
}
}
}
int64_t FileStore::addHelper(const std::string& key, int64_t i) {
std::unique_lock<std::mutex> l(activeFileOpLock_);
File file(path_, O_RDWR | O_CREAT, timeout_);
auto lock = file.lockExclusive();
pos_ = refresh(file, pos_, cache_);
const auto& value = cache_[key];
int64_t ti = i;
if (!value.empty()) {
auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size();
ti += std::stoll(std::string(buf, len));
}
// Always seek to the end to write
file.seek(0, SEEK_END);
// File cursor is at the end of the file now, and we have an
// exclusive lock, so we can write the new value.
file.write(key);
file.write(std::to_string(ti));
return ti;
}
int64_t FileStore::add(const std::string& key, int64_t value) {
std::string regKey = regularPrefix_ + key;
return addHelper(regKey, value);
}
int64_t FileStore::getNumKeys() {
TORCH_CHECK(false, "getNumKeys not implemented for FileStore");
}
bool FileStore::deleteKey(const std::string& /* unused */) {
TORCH_CHECK(false, "deleteKey not implemented for FileStore");
}
bool FileStore::check(const std::vector<std::string>& keys) {
std::unique_lock<std::mutex> l(activeFileOpLock_);
File file(path_, O_RDONLY, timeout_);
auto lock = file.lockShared();
pos_ = refresh(file, pos_, cache_);
for (const auto& key : keys) {
std::string regKey = regularPrefix_ + key;
if (cache_.count(regKey) == 0) {
return false;
}
}
return true;
}
void FileStore::wait(const std::vector<std::string>& keys) {
wait(keys, timeout_);
}
void FileStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
// Not using inotify because it doesn't work on many
// shared filesystems (such as NFS).
const auto start = std::chrono::steady_clock::now();
while (!check(keys)) {
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout != kNoTimeout && elapsed > timeout) {
throw std::runtime_error("Wait timeout");
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
} // namespace c10d