Remove faulty process group code (#61907)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61907
Removing the code for faulty process group agent since it was replaced by faulty tensorpipe agent
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D29794666
Pulled By: H-Huang
fbshipit-source-id: 0b35191cc07220b6774ecacc8d004f25fd2e87f0
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index d41d2d7..4ca72d8 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -382,7 +382,6 @@
"torch/csrc/distributed/rpc/script_resp.cpp",
"torch/csrc/distributed/rpc/tensorpipe_agent.cpp",
"torch/csrc/distributed/rpc/tensorpipe_utils.cpp",
- "torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp",
"torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp",
"torch/csrc/distributed/rpc/torchscript_functions.cpp",
"torch/csrc/distributed/rpc/types.cpp",
diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp
deleted file mode 100644
index bb980ee..0000000
--- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.cpp
+++ /dev/null
@@ -1,152 +0,0 @@
-#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
-#include <torch/csrc/distributed/rpc/utils.h>
-
-namespace torch {
-namespace distributed {
-namespace rpc {
-
-std::string fromVec(const std::vector<char>& vec) {
- return std::string(vec.begin(), vec.end());
-}
-
-FaultyProcessGroupAgent::FaultyProcessGroupAgent(
- const c10::intrusive_ptr<::c10d::Store>& store,
- std::string workerName,
- c10::intrusive_ptr<::c10d::ProcessGroup> pg,
- int numSendRecvThreads,
- std::chrono::milliseconds rpcTimeout,
- std::unique_ptr<RequestCallback> cb,
- const std::vector<std::string>& messagesToFail,
- const std::unordered_map<std::string, float>& messageTypesToDelay,
- int failNumSends)
- : ProcessGroupAgent(
- store,
- std::move(workerName),
- std::move(pg),
- numSendRecvThreads,
- rpcTimeout,
- std::move(cb)),
- failNumSends_(failNumSends),
- messageTypesToFail_(parseMessagesToFailInput(messagesToFail)),
- messageTypesToDelay_(parseMessagesToDelay(messageTypesToDelay)) {}
-
-std::vector<MessageType> FaultyProcessGroupAgent::parseMessagesToFailInput(
- const std::vector<std::string>& messagesToFail) const {
- // Since we can only pass strings corresponding to the Message Types from the
- // python tests, we must parse the list of strings and resolve the actual
- // types. We will then check this list of types in the send function to
- // determine whether we should fail or not.
- std::vector<MessageType> messageTypesToFail;
- messageTypesToFail.reserve(messagesToFail.size());
- for (const auto& msgString : messagesToFail) {
- messageTypesToFail.push_back(messageStringToType(msgString));
- }
- return messageTypesToFail;
-}
-
-std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
- parseMessagesToDelay(const std::unordered_map<std::string, float>&
- messageTypesToDelay) const {
- std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
- for (const auto& messagePair : messageTypesToDelay) {
- float delay = messagePair.second;
- TORCH_CHECK(
- delay >= 0,
- "Delays passed to FaultyProcessGroupAgent must be non-negative.")
- delayMessages.insert({messageStringToType(messagePair.first), delay});
- }
- return delayMessages;
-}
-
-c10::intrusive_ptr<JitFuture> FaultyProcessGroupAgent::send(
- const WorkerInfo& to,
- c10::intrusive_ptr<Message> message,
- const float rpcTimeoutSeconds,
- const std::unordered_map<c10::Device, c10::Device>& /* unused */) {
- // We only fail control messages that have been specified by the test case.
- // For all other messages, we just send them without any failures.
- if (!shouldFailMessage(message->type())) {
- return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
- }
- // This send function checks the failMessageCountMap_ to check whether
- // we must fail the next send. If the send must be failed, we set an error
- // on the returned future immediately and increment the counter in the map,
- // otherwise we just call the ProcessGroupAgent send.
- const auto key = fromVec(message->payload());
- std::unique_lock<std::mutex> lock(failMapMutex_);
- auto it = failMessageCountMap_.find(key);
- if (it == failMessageCountMap_.end()) {
- failMessageCountMap_[key] = 0;
- }
- if (failMessageCountMap_[key] < failNumSends_) {
- failMessageCountMap_[key]++;
- lock.unlock();
- auto jitFuture = c10::make_intrusive<JitFuture>(at::AnyClassType::get());
- jitFuture->setError(std::make_exception_ptr(std::runtime_error(makeRPCError(
- c10::str("Send attempt failed intentionally for ", key),
- RPCErrorType::INTENTIONAL_FAILURE))));
- return jitFuture;
- } else {
- lock.unlock();
- return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
- }
-}
-
-void FaultyProcessGroupAgent::enqueueSend(SendWork work) {
- float msgDelay = getDelayForMessage(work.message_->type());
- if (msgDelay != 0) {
- // Sleep for the specified delay for the message.
- std::this_thread::sleep_for(std::chrono::milliseconds(
- static_cast<int>(msgDelay * kSecToMsConversion)));
- }
- ProcessGroupAgent::enqueueSend(std::move(work));
-}
-
-void FaultyProcessGroupAgent::sendToSelf(c10::intrusive_ptr<Message> message) {
- float msgDelay = getDelayForMessage(message->type());
- if (msgDelay != 0) {
- // Sleep for the specified delay for the message.
- std::this_thread::sleep_for(std::chrono::milliseconds(
- static_cast<int>(msgDelay * kSecToMsConversion)));
- }
- ProcessGroupAgent::sendToSelf(std::move(message));
-}
-
-bool FaultyProcessGroupAgent::shouldFailMessage(MessageType type) const {
- // Return true if the input message type is in the messageTypesToFail_ list
- return (
- std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
- messageTypesToFail_.end());
-}
-
-float FaultyProcessGroupAgent::getDelayForMessage(MessageType type) const {
- const auto& it = messageTypesToDelay_.find(type);
- return it == messageTypesToDelay_.end() ? 0 : it->second;
-}
-
-MessageType FaultyProcessGroupAgent::messageStringToType(
- const std::string& messageString) const {
- // Lazily constructed map that returns string to message type mapping
- static std::unordered_map<std::string, MessageType> msgMap = {
- {"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
- {"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
- {"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
- {"CLEANUP_AUTOGRAD_CONTEXT_REQ",
- MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
- {"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
- {"SCRIPT_REMOTE_CALL", MessageType::SCRIPT_REMOTE_CALL},
- {"PYTHON_CALL", MessageType::PYTHON_CALL},
- {"SCRIPT_CALL", MessageType::SCRIPT_CALL},
- {"PYTHON_RREF_FETCH_CALL", MessageType::PYTHON_RREF_FETCH_CALL},
- {"SCRIPT_RREF_FETCH_CALL", MessageType::SCRIPT_RREF_FETCH_CALL}};
- const auto& it = msgMap.find(messageString);
- TORCH_CHECK(
- it != msgMap.end(),
- "No mapping to rpc::MessageType exists for ",
- messageString);
- return it->second;
-}
-
-} // namespace rpc
-} // namespace distributed
-} // namespace torch
diff --git a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h b/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h
deleted file mode 100644
index d0bbb33..0000000
--- a/torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h
+++ /dev/null
@@ -1,99 +0,0 @@
-#pragma once
-
-#include <torch/csrc/distributed/rpc/message.h>
-#include <torch/csrc/distributed/rpc/process_group_agent.h>
-
-namespace torch {
-namespace distributed {
-namespace rpc {
-
-struct TORCH_API FaultyProcessGroupRpcBackendOptions
- : public ProcessGroupRpcBackendOptions {
- FaultyProcessGroupRpcBackendOptions(
- int num_send_recv_threads,
- float rpc_timeout,
- std::string init_method,
- std::vector<std::string> messages_to_fail,
- std::unordered_map<std::string, float> messages_to_delay,
- int num_fail_sends = 0)
- : ProcessGroupRpcBackendOptions(
- num_send_recv_threads,
- rpc_timeout,
- std::move(init_method)),
- messagesToFail(std::move(messages_to_fail)),
- messagesToDelay(std::move(messages_to_delay)),
- numFailSends(num_fail_sends) {
- TORCH_CHECK(numFailSends >= 0, "numFailSends should be non-negative");
- }
-
- std::vector<std::string> messagesToFail;
- std::unordered_map<std::string, float> messagesToDelay;
- int numFailSends;
-};
-
-class TORCH_API FaultyProcessGroupAgent : public ProcessGroupAgent {
- public:
- FaultyProcessGroupAgent(
- const c10::intrusive_ptr<::c10d::Store>& store,
- std::string workerName,
- c10::intrusive_ptr<c10d::ProcessGroup> pg,
- int numSendRecvThreads,
- std::chrono::milliseconds rpcTimeout,
- std::unique_ptr<RequestCallback> cb,
- const std::vector<std::string>& messagesToFail,
- const std::unordered_map<std::string, float>& messageTypesToDelay,
- int failNumSends = 0);
-
- // Faulty send function for this class.
- c10::intrusive_ptr<JitFuture> send(
- const WorkerInfo& to,
- c10::intrusive_ptr<Message> message,
- const float rpcTimeoutSeconds = torch::distributed::rpc::kUnsetRpcTimeout,
- const std::unordered_map<c10::Device, c10::Device>& deviceMap = {})
- override;
-
- protected:
- // This function checks the messageTypesToFail_ to determine whether to use
- // the faulty send or not.
- virtual bool shouldFailMessage(MessageType type) const;
-
- private:
- // Overrides ProcessGroupAgent's enqueueSend to inject delays.
- void enqueueSend(SendWork work) override;
- // Override ProcessGroupAgent's sendToSelf to inject delays.
- void sendToSelf(c10::intrusive_ptr<Message> message) override;
- // This function parses the list of strings passed in by the python tests and
- // resolves the Message Types that must use the faulty send.
- std::vector<MessageType> parseMessagesToFailInput(
- const std::vector<std::string>& messagesToFail) const;
-
- // Returns amount of time in seconds to delay sending of the given message
- // type.
- float getDelayForMessage(MessageType type) const;
-
- // Parse message types that we should inject arbitrary delays for.
- std::unordered_map<MessageType, float, std::hash<int>> parseMessagesToDelay(
- const std::unordered_map<std::string, float>& messageTypesToDelay) const;
-
- // Number of sends to intentionally fail before allowing one to succeed.
- const int failNumSends_;
-
- // Vector of the MessageTypes that we must use the faulty send for. This is
- // parsed based on a list of strings passed in by the python tests.
- const std::vector<MessageType> messageTypesToFail_;
-
- // Mapping of message types to amount we should delay send for in the ::send()
- // function.
- std::unordered_map<MessageType, float, std::hash<int>> messageTypesToDelay_;
-
- // Map to track the number of sends we've failed for each RPC.
- std::unordered_map<std::string, int> failMessageCountMap_;
-
- // Mutex to guard failMessageCountMap_
- std::mutex failMapMutex_;
-
- MessageType messageStringToType(const std::string& messageString) const;
-};
-} // namespace rpc
-} // namespace distributed
-} // namespace torch
diff --git a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h
index 5d60597..01d2b3f 100644
--- a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h
+++ b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h
@@ -67,7 +67,7 @@
protected:
// This function checks the messageTypesToFail_ to determine whether to use
// the faulty send or not.
- bool shouldFailMessage(MessageType type) const;
+ virtual bool shouldFailMessage(MessageType type) const;
private:
// This function parses the list of strings passed in by the python tests and
diff --git a/torch/csrc/distributed/rpc/testing/init.cpp b/torch/csrc/distributed/rpc/testing/init.cpp
index 4569cc3..714db36 100644
--- a/torch/csrc/distributed/rpc/testing/init.cpp
+++ b/torch/csrc/distributed/rpc/testing/init.cpp
@@ -1,10 +1,8 @@
#include <torch/csrc/python_headers.h>
-#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
-#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
#include <torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h>
#include <torch/csrc/utils/pybind.h>
@@ -21,8 +19,8 @@
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
- // Add the FaultyProcessGroupAgent / FaultyTensorPipeAgent and its backend
- // options object to the python module torch._C._distributed_rpc_testing
+ // Add the FaultyTensorPipeAgent and its backend options object
+ // to the python module torch._C._distributed_rpc_testing
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
@@ -37,36 +35,6 @@
// TensorPipeAgent
py::module rpc_module = py::module::import("torch.distributed.rpc");
- shared_ptr_class_<FaultyProcessGroupRpcBackendOptions>(
- module,
- "FaultyProcessGroupRpcBackendOptions",
- rpc_module.attr("ProcessGroupRpcBackendOptions"))
- .def(
- py::init<
- int,
- float,
- std::string,
- std::vector<std::string>,
- std::unordered_map<std::string, float>,
- int>(),
- py::arg("num_send_recv_threads"),
- py::arg("rpc_timeout"),
- py::arg("init_method"),
- py::arg("messages_to_fail"),
- py::arg("messages_to_delay"),
- py::arg("num_fail_sends"))
- .def_readwrite(
- "num_send_recv_threads",
- &ProcessGroupRpcBackendOptions::numSendRecvThreads)
- .def_readwrite(
- "messages_to_fail",
- &FaultyProcessGroupRpcBackendOptions::messagesToFail)
- .def_readwrite(
- "messages_to_delay",
- &FaultyProcessGroupRpcBackendOptions::messagesToDelay)
- .def_readwrite(
- "num_fail_sends", &FaultyProcessGroupRpcBackendOptions::numFailSends);
-
shared_ptr_class_<FaultyTensorPipeRpcBackendOptions>(
module,
"FaultyTensorPipeRpcBackendOptions",
@@ -96,69 +64,6 @@
.def_readwrite(
"num_fail_sends", &FaultyTensorPipeRpcBackendOptions::numFailSends);
- shared_ptr_class_<FaultyProcessGroupAgent>(
- module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
- .def(
- py::init([](const c10::intrusive_ptr<::c10d::Store> store,
- std::string name,
- c10::intrusive_ptr<::c10d::ProcessGroup> process_group,
- int num_send_recv_threads,
- std::chrono::milliseconds rpc_timeout,
- const std::vector<std::string>& messages_to_fail,
- const std::unordered_map<std::string, float>&
- messages_to_delay,
- int failNumSends) {
- return std::shared_ptr<FaultyProcessGroupAgent>(
- new FaultyProcessGroupAgent(
- store,
- std::move(name),
- process_group,
- num_send_recv_threads,
- rpc_timeout,
- std::make_unique<RequestCallbackImpl>(),
- messages_to_fail,
- messages_to_delay,
- failNumSends),
- impl::destroy_without_gil<FaultyProcessGroupAgent>);
- }),
- py::arg("store"),
- py::arg("name"),
- py::arg("process_group"),
- py::arg("num_send_recv_threads"),
- py::arg("rpc_timeout"),
- py::arg("messages_to_fail"),
- py::arg("messages_to_delay"),
- py::arg("failNumSends"))
- .def(
- "join",
- &ProcessGroupAgent::join,
- py::call_guard<py::gil_scoped_release>(),
- py::arg("shutdown") = false)
- .def(
- "shutdown",
- &ProcessGroupAgent::shutdown,
- py::call_guard<py::gil_scoped_release>())
- .def(
- "get_worker_info",
- (const WorkerInfo& (ProcessGroupAgent::*)(void) const) &
- RpcAgent::getWorkerInfo,
- py::call_guard<py::gil_scoped_release>())
- .def(
- "get_worker_info",
- (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&) const) &
- ProcessGroupAgent::getWorkerInfo,
- py::call_guard<py::gil_scoped_release>())
- .def(
- "get_worker_info",
- (const WorkerInfo& (ProcessGroupAgent::*)(worker_id_t id) const) &
- ProcessGroupAgent::getWorkerInfo,
- py::call_guard<py::gil_scoped_release>())
- .def(
- "get_worker_infos",
- (std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) &
- ProcessGroupAgent::getWorkerInfos,
- py::call_guard<py::gil_scoped_release>());
-
shared_ptr_class_<FaultyTensorPipeAgent>(
module, "FaultyTensorPipeAgent", rpc_module.attr("TensorPipeAgent"))
.def(
diff --git a/torch/distributed/rpc/_testing/__init__.py b/torch/distributed/rpc/_testing/__init__.py
index c763ee7..5755b99 100644
--- a/torch/distributed/rpc/_testing/__init__.py
+++ b/torch/distributed/rpc/_testing/__init__.py
@@ -10,11 +10,9 @@
raise RuntimeError("Failed to initialize torch.distributed.rpc._testing")
if is_available():
- # Registers FAULTY_PROCESS_GROUP and FAULTY_TENSORPIPE RPC backends.
+ # Registers FAULTY_TENSORPIPE RPC backend.
from . import faulty_agent_backend_registry
from torch._C._distributed_rpc_testing import (
- FaultyProcessGroupRpcBackendOptions,
- FaultyProcessGroupAgent,
FaultyTensorPipeRpcBackendOptions,
FaultyTensorPipeAgent,
)
diff --git a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py
index c9df38a..43c7f72 100644
--- a/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py
+++ b/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py
@@ -2,87 +2,8 @@
import torch.distributed as dist
import torch.distributed.rpc as rpc
-import torch.distributed.distributed_c10d as dc10d
from torch.distributed.rpc import constants as rpc_constants
-from datetime import timedelta
-
-
-def _faulty_process_group_construct_rpc_backend_options_handler(
- rpc_timeout,
- init_method,
- num_send_recv_threads,
- messages_to_fail,
- messages_to_delay,
- num_fail_sends,
- **kwargs
-):
- from . import FaultyProcessGroupRpcBackendOptions
-
- return FaultyProcessGroupRpcBackendOptions(
- rpc_timeout=rpc_timeout,
- init_method=init_method,
- num_send_recv_threads=num_send_recv_threads,
- messages_to_fail=messages_to_fail,
- messages_to_delay=messages_to_delay,
- num_fail_sends=num_fail_sends,
- )
-
-
-def _faulty_process_group_init_backend_handler(
- store, name, rank, world_size, rpc_backend_options
-):
- from . import FaultyProcessGroupAgent
-
- if dist.is_initialized():
- raise RuntimeError("Process group must not be initialized before init_rpc.")
-
- process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
-
- dist.init_process_group(
- backend=dist.Backend.GLOO,
- store=store,
- rank=rank,
- world_size=world_size,
- timeout=process_group_timeout,
- )
-
- try:
- group = dc10d._get_default_group()
- assert group is not None, "Failed to initialize default ProcessGroup."
-
- if (rank != -1) and (rank != group.rank()):
- raise RuntimeError(
- "rank argument {} doesn't match pg rank {}".format(rank, group.rank())
- )
- if (world_size != -1) and (world_size != group.size()):
- raise RuntimeError(
- "world_size argument {} doesn't match pg size {}".format(
- world_size, group.size()
- )
- )
-
- return FaultyProcessGroupAgent(
- store,
- name,
- group,
- rpc_backend_options.num_send_recv_threads,
- timedelta(seconds=rpc_backend_options.rpc_timeout),
- rpc_backend_options.messages_to_fail,
- rpc_backend_options.messages_to_delay,
- rpc_backend_options.num_fail_sends,
- )
- except Exception as ex:
- dist.destroy_process_group()
- raise ex
-
-
-rpc.backend_registry.register_backend(
- "FAULTY_PROCESS_GROUP",
- _faulty_process_group_construct_rpc_backend_options_handler,
- _faulty_process_group_init_backend_handler,
-)
-
def _init_process_group(store, rank, world_size):
# Initialize ProcessGroup.
process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT