| #include <ATen/ThreadLocalState.h> |
| #include <c10/util/ThreadLocalDebugInfo.h> |
| #include <torch/csrc/autograd/functions/utils.h> |
| #include <torch/csrc/autograd/profiler.h> |
| #include <torch/csrc/distributed/autograd/context/container.h> |
| #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h> |
| #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h> |
| #include <torch/csrc/distributed/autograd/utils.h> |
| #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
| #include <torch/csrc/distributed/rpc/rpc_agent.h> |
| #include <torch/csrc/distributed/rpc/types.h> |
| |
| namespace torch { |
| namespace distributed { |
| namespace autograd { |
| |
| using torch::distributed::autograd::AutogradMetadata; |
| using torch::distributed::autograd::RpcWithAutograd; |
| using torch::distributed::rpc::JitFuture; |
| using torch::distributed::rpc::Message; |
| using torch::distributed::rpc::MessageType; |
| using torch::distributed::rpc::RpcAgent; |
| using torch::distributed::rpc::WorkerInfo; |
| |
| void addSendRpcBackward( |
| const ContextPtr& autogradContext, |
| const AutogradMetadata& autogradMetadata, |
| std::vector<torch::Tensor>& tensors) { |
| // Attach autograd information only for tensors requiring grad. |
| std::vector<torch::Tensor> tensors_with_grad; |
| std::copy_if( |
| tensors.begin(), |
| tensors.end(), |
| std::back_inserter(tensors_with_grad), |
| [](const torch::Tensor& t) { return t.requires_grad(); }); |
| |
| // Attach the appropriate autograd edges. |
| auto grad_fn = std::make_shared<SendRpcBackward>(); |
| grad_fn->set_next_edges( |
| torch::autograd::collect_next_edges(tensors_with_grad)); |
| |
| // Add the appropriate input metadata for the grad_fn. |
| for (const auto& tensor : tensors_with_grad) { |
| grad_fn->add_input_metadata(tensor); |
| } |
| |
| // Record the send autograd function in our current context. |
| autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId); |
| } |
| |
| ContextPtr addRecvRpcBackward( |
| const AutogradMetadata& autogradMetadata, |
| std::vector<torch::Tensor>& tensors, |
| rpc::worker_id_t fromWorkerId, |
| const rpc::DeviceMap& deviceMap) { |
| // Initialize autograd context if necessary. |
| auto& autogradContainer = DistAutogradContainer::getInstance(); |
| auto autogradContext = |
| autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId); |
| |
| if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) { |
| // Attach the tensors as inputs to the autograd function. |
| auto grad_fn = std::make_shared<RecvRpcBackward>( |
| autogradMetadata, autogradContext, fromWorkerId, deviceMap); |
| for (auto& tensor : tensors) { |
| if (tensor.requires_grad()) { |
| torch::autograd::set_history(tensor, grad_fn); |
| } |
| } |
| |
| // Now update the autograd context with the necessary information. |
| autogradContext->addRecvFunction( |
| grad_fn, autogradMetadata.autogradMessageId); |
| } |
| |
| return autogradContext; |
| } |
| |
| static c10::intrusive_ptr<Message> getMessageWithProfiling( |
| c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage, |
| MessageType msgType, |
| torch::autograd::profiler::ProfilerConfig&& profilerConfig) { |
| auto& remoteProfilerManager = |
| torch::distributed::rpc::RemoteProfilerManager::getInstance(); |
| |
| auto key = remoteProfilerManager.getCurrentProfilingKey(); |
| // generate a globally unique Id |
| auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId(); |
| // Save a mapping of ID -> RPC profiling key and unset the current TLS key. |
| remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key); |
| remoteProfilerManager.unsetCurrentKey(); |
| auto wrappedProfilingMsg = RpcWithProfilingReq( |
| msgType, |
| std::move(wrappedRpcMessage), |
| std::move(profilerConfig), |
| globallyUniqueProfilingId); |
| |
| return std::move(wrappedProfilingMsg).toMessage(); |
| } |
| |
| c10::intrusive_ptr<Message> getMessageWithAutograd( |
| const rpc::worker_id_t dstId, |
| c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg, |
| MessageType msgType, |
| bool forceGradRecording, |
| const rpc::DeviceMap& deviceMap) { |
| auto& autogradContainer = DistAutogradContainer::getInstance(); |
| |
| // If there is no valid context and no tensor requires grads, send original |
| // rpc message. otherwise, attach grad info and grad functions and send |
| // rpcWithAutograd message. |
| auto tensorsRequireGrad = |
| torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors()); |
| if (!autogradContainer.hasValidContext() || |
| (!forceGradRecording && !tensorsRequireGrad)) { |
| return wrappedRpcMsg; |
| } |
| |
| // Retrieve the appropriate context to modify. |
| auto autogradContext = autogradContainer.currentContext(); |
| |
| // Wrap the original rpc with autograd information. |
| AutogradMetadata autogradMetadata( |
| autogradContext->contextId(), autogradContainer.newAutogradMessageId()); |
| auto rpcWithAutograd = std::make_unique<RpcWithAutograd>( |
| RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_, |
| msgType, |
| autogradMetadata, |
| std::move(wrappedRpcMsg), |
| deviceMap); |
| |
| if (tensorsRequireGrad) { |
| // Record autograd information for 'send'. |
| addSendRpcBackward( |
| autogradContext, autogradMetadata, rpcWithAutograd->tensors()); |
| } |
| // Record the workerID |
| autogradContext->addKnownWorkerId(dstId); |
| |
| return std::move(*rpcWithAutograd).toMessage(); |
| } |
| |
| c10::intrusive_ptr<JitFuture> sendMessageWithAutograd( |
| RpcAgent& agent, |
| const WorkerInfo& dst, |
| c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg, |
| bool forceGradRecording, |
| const float rpcTimeoutSeconds, |
| bool forceDisableProfiling) { |
| auto msg = getMessageWithAutograd( |
| dst.id_, |
| std::move(wrappedRpcMsg), |
| MessageType::FORWARD_AUTOGRAD_REQ, |
| forceGradRecording, |
| agent.getDeviceMap(dst)); |
| |
| // If profiler is enabled, wrap this message with profiling metadata that will |
| // tell the remote end to process this request with the profiler enabled. |
| if (!forceDisableProfiling) { |
| switch (torch::profiler::impl::profilerType()) { |
| case torch::profiler::impl::ActiveProfilerType::LEGACY: { |
| auto profilerConfig = torch::autograd::profiler::getProfilerConfig(); |
| auto msgWithProfiling = getMessageWithProfiling( |
| std::move(msg), |
| rpc::MessageType::RUN_WITH_PROFILING_REQ, |
| std::move(profilerConfig)); |
| return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds); |
| } |
| case torch::profiler::impl::ActiveProfilerType::KINETO: |
| TORCH_WARN_ONCE( |
| "Profiling a distributed call with the Kineto profiler will profile " |
| "the caller, but not the worker."); |
| break; |
| default: |
| break; |
| } |
| } |
| |
| return agent.send(dst, std::move(msg), rpcTimeoutSeconds); |
| ; |
| } |
| |
| } // namespace autograd |
| } // namespace distributed |
| } // namespace torch |