Add finishAndThrow function to ProcessGroup::Work, and use with Gloo (#40405)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40405
This adds a finishAndThrow function that completes the work object,
sets an exception if one is provided by the user, and throws an exception (if
it is already set or passed by the caller). This is now done by grabbing the
lock just once and simplifies the wait functions in ProcessGroupGloo.
ghstack-source-id: 106516114
Test Plan: CI
Differential Revision: D22174890
fbshipit-source-id: ea74702216c4328187c8d193bf39e1fea43847f6
diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp
index 0cf4b14..a701183 100644
--- a/torch/lib/c10d/ProcessGroup.cpp
+++ b/torch/lib/c10d/ProcessGroup.cpp
@@ -56,6 +56,15 @@
cv_.notify_all();
}
+void ProcessGroup::Work::finishAndThrow(std::exception_ptr exception) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ completed_ = true;
+ exception_ = exception;
+ if (exception_) {
+ std::rethrow_exception(exception_);
+ }
+}
+
ProcessGroup::ProcessGroup(int rank, int size) : rank_(rank), size_(size) {
C10_LOG_API_USAGE_ONCE("c10d.process_group");
}
diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp
index ba57aaf..72c66bd 100644
--- a/torch/lib/c10d/ProcessGroup.hpp
+++ b/torch/lib/c10d/ProcessGroup.hpp
@@ -92,6 +92,10 @@
// thread-safe manner. Notifies all waiting condition variables as well.
void finish(std::exception_ptr exception = nullptr);
+ // Similar to finish, but throws an exception if one is already set or
+ // provided by the user.
+ void finishAndThrow(std::exception_ptr exception);
+
mutable std::mutex mutex_;
std::condition_variable cv_;
bool completed_ = false;
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index 283f890..789c05d 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -341,14 +341,9 @@
} catch (...) {
exception = std::current_exception();
}
- // Lock to write completed_ and exception_, and throw if there is an
- // exception.
- std::lock_guard<std::mutex> lock(mutex_);
- completed_ = true;
- exception_ = exception;
- if (exception_) {
- std::rethrow_exception(exception_);
- }
+
+ // Completes the Work object and throws the exception.
+ finishAndThrow(exception);
return sendCompleted;
}
@@ -374,14 +369,9 @@
} catch (...) {
exception = std::current_exception();
}
- // Lock to write completed_ and exception_, and throw if there is an
- // exception.
- std::lock_guard<std::mutex> lock(mutex_);
- completed_ = true;
- exception_ = exception;
- if (exception_) {
- std::rethrow_exception(exception_);
- }
+
+ // Completes the Work object and throws the exception.
+ finishAndThrow(exception);
return recvCompleted;
}