| #include <torch/csrc/distributed/rpc/utils.h> |
| #include <torch/csrc/distributed/rpc/python_remote_call.h> |
| #include <torch/csrc/distributed/rpc/python_udf_call.h> |
| #include <torch/csrc/distributed/rpc/python_udf_resp.h> |
| #include <torch/csrc/distributed/rpc/rpc_with_autograd.h> |
| #include <torch/csrc/distributed/rpc/rref_proto.h> |
| #include <torch/csrc/distributed/rpc/script_call.h> |
| #include <torch/csrc/distributed/rpc/script_remote_call.h> |
| #include <torch/csrc/distributed/rpc/script_resp.h> |
| |
| namespace torch { |
| namespace distributed { |
| namespace rpc { |
| |
| std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) { |
| switch (request.type()) { |
| case MessageType::SCRIPT_CALL: { |
| return ScriptCall::fromMessage(request); |
| } |
| case MessageType::PYTHON_CALL: { |
| return PythonUDFCall::fromMessage(request); |
| } |
| case MessageType::SCRIPT_REMOTE_CALL: { |
| return ScriptRemoteCall::fromMessage(request); |
| } |
| case MessageType::PYTHON_REMOTE_CALL: { |
| return PythonRemoteCall::fromMessage(request); |
| } |
| case MessageType::SCRIPT_RREF_FETCH_CALL: { |
| return ScriptRRefFetchCall::fromMessage(request); |
| } |
| case MessageType::PYTHON_RREF_FETCH_CALL: { |
| return PythonRRefFetchCall::fromMessage(request); |
| } |
| case MessageType::RREF_USER_DELETE: { |
| return RRefUserDelete::fromMessage(request); |
| } |
| case MessageType::RREF_CHILD_ACCEPT: { |
| return RRefChildAccept::fromMessage(request); |
| } |
| case MessageType::RREF_FORK_REQUEST: { |
| return RRefForkRequest::fromMessage(request); |
| } |
| case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: { |
| return RpcWithAutograd::fromMessage(request); |
| } |
| default: { |
| TORCH_INTERNAL_ASSERT( |
| false, "Request type ", request.type(), " not supported."); |
| } |
| } |
| } |
| |
| std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) { |
| switch (response.type()) { |
| case MessageType::SCRIPT_RET: { |
| return ScriptResp::fromMessage(response); |
| } |
| case MessageType::PYTHON_RET: { |
| return PythonUDFResp::fromMessage(response); |
| } |
| case MessageType::REMOTE_RET: { |
| return RemoteRet::fromMessage(response); |
| } |
| case MessageType::RREF_FETCH_RET: { |
| return RRefFetchRet::fromMessage(response); |
| } |
| case MessageType::RREF_ACK: { |
| return RRefAck::fromMessage(response); |
| } |
| case MessageType::EXCEPTION: { |
| std::string err(response.payload().begin(), response.payload().end()); |
| throw std::runtime_error(err); |
| } |
| case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: { |
| return RpcWithAutograd::fromMessage(response); |
| } |
| default: { |
| TORCH_INTERNAL_ASSERT( |
| false, "Response type ", response.type(), " not supported."); |
| } |
| } |
| } |
| |
| } // namespace rpc |
| } // namespace distributed |
| } // namespace torch |