blob: 3a4dc7ead934c2fdc89ff3c896e29ee9b7ed19a2 [file] [log] [blame]
#pragma once
#include <ATen/core/ivalue.h>
namespace torch {
namespace utils {
// FutureError inherits from std::exception, it can return const char* or
// std::string error message
class TORCH_API FutureError final : public std::exception {
public:
FutureError(std::string errorMsg) : errorMsg_(std::move(errorMsg)) {}
FutureError() = default;
const char* what() const noexcept override {
return errorMsg_.c_str();
}
private:
std::string errorMsg_;
};
// This class holds a value of type T that will be ready in the future.
// Most implementation is copied from FutureMessage and
// c10::ivalue::Future
template <typename T>
class TORCH_API Future final {
public:
using Callback =
std::function<void(const T&, const c10::optional<FutureError>&)>;
Future() = default;
Future(T value) : completed_(true), value_(std::move(value)) {}
const T& wait() {
std::unique_lock<std::mutex> lock(mutex_);
finished_cv_.wait(lock, [this] { return completed_.load(); });
if (error_) {
throw *error_;
}
return value_;
}
const T& waitNoThrow() {
std::unique_lock<std::mutex> lock(mutex_);
finished_cv_.wait(lock, [this] { return completed_.load(); });
return value_;
}
T&& moveValue() && {
std::unique_lock<std::mutex> lock(mutex_);
return std::move(value_);
}
void markCompleted(T value) {
std::unique_lock<std::mutex> lock(mutex_);
TORCH_CHECK(!completed_);
// Set value first as completed_ is accessed without lock
value_ = std::move(value);
completed_ = true;
// Move callbacks to a vector on the stack so we can access it without
// holding a lock
std::vector<Callback> cbs;
cbs.swap(callbacks_);
lock.unlock();
finished_cv_.notify_all();
// There is no need to protect callbacks_ with the lock.
// Once completed_ is set to true, no one can add new callback to the
// list. pass value_, error_ for callback to easily check state.
for (auto& callback : cbs) {
callback(value_, error_);
}
}
// Sets error only if the future hasn't been marked completed already.
// Useful in avoiding races where multiple threads try to setError
// on a future.
void setErrorIfNeeded(std::string errorMsg) {
std::unique_lock<std::mutex> lock(mutex_);
if (completed_) {
return;
} else {
setErrorInternal(std::move(errorMsg), lock);
}
}
void setError(std::string errorMsg) {
std::unique_lock<std::mutex> lock(mutex_);
setErrorInternal(std::move(errorMsg), lock);
}
bool completed() const {
return completed_;
}
bool hasError() const {
std::unique_lock<std::mutex> lock(mutex_);
return error_ ? true : false;
}
c10::optional<FutureError> error() const {
std::unique_lock<std::mutex> lock(mutex_);
return error_;
}
// If completed() the callback will be invoked in-place.
void addCallback(const Callback& callback) {
std::unique_lock<std::mutex> lock(mutex_);
if (completed_) {
lock.unlock();
callback(value_, error_);
return;
}
callbacks_.push_back(callback);
}
private:
void setErrorInternal(
std::string errorMsg,
std::unique_lock<std::mutex>& lock) {
TORCH_CHECK(!completed_);
error_ = FutureError(std::move(errorMsg));
completed_ = true;
// Move callbacks to a vector on the stack so we can access it without
// holding a lock
std::vector<Callback> cbs(std::move(callbacks_));
lock.unlock();
finished_cv_.notify_all();
// There is no need to protect callbacks_ with the lock.
// Once completed_ is set to true, no one can add new callback to the
// list. pass value_, error_ for callback to easily check state.
for (auto& callback : cbs) {
callback(value_, error_);
}
}
mutable std::mutex mutex_;
std::atomic_bool completed_{false}; // is this future complete
std::condition_variable finished_cv_;
std::vector<Callback> callbacks_;
T value_;
c10::optional<FutureError> error_;
};
} // namespace utils
} // namespace torch