blob: db57485e56107dabf863058653ba5e51ba276042 [file] [log] [blame]
#include <ATen/ThreadLocalState.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/types.h>
namespace torch {
namespace distributed {
namespace autograd {
using torch::distributed::autograd::AutogradMetadata;
using torch::distributed::autograd::RpcWithAutograd;
using torch::distributed::rpc::JitFuture;
using torch::distributed::rpc::Message;
using torch::distributed::rpc::MessageType;
using torch::distributed::rpc::RpcAgent;
using torch::distributed::rpc::WorkerInfo;
void addSendRpcBackward(
const ContextPtr& autogradContext,
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors) {
// Attach autograd information only for tensors requiring grad.
std::vector<torch::Tensor> tensors_with_grad;
std::copy_if(
tensors.begin(),
tensors.end(),
std::back_inserter(tensors_with_grad),
[](const torch::Tensor& t) { return t.requires_grad(); });
// Attach the appropriate autograd edges.
auto grad_fn = std::make_shared<SendRpcBackward>();
grad_fn->set_next_edges(
torch::autograd::collect_next_edges(tensors_with_grad));
// Add the appropriate input metadata for the grad_fn.
for (const auto& tensor : tensors_with_grad) {
grad_fn->add_input_metadata(tensor);
}
// Record the send autograd function in our current context.
autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}
ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId,
const rpc::DeviceMap& deviceMap) {
// Initialize autograd context if necessary.
auto& autogradContainer = DistAutogradContainer::getInstance();
auto autogradContext =
autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
// Attach the tensors as inputs to the autograd function.
auto grad_fn = std::make_shared<RecvRpcBackward>(
autogradMetadata, autogradContext, fromWorkerId, deviceMap);
for (auto& tensor : tensors) {
if (tensor.requires_grad()) {
torch::autograd::set_history(tensor, grad_fn);
}
}
// Now update the autograd context with the necessary information.
autogradContext->addRecvFunction(
grad_fn, autogradMetadata.autogradMessageId);
}
return autogradContext;
}
static c10::intrusive_ptr<Message> getMessageWithProfiling(
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,
MessageType msgType,
torch::autograd::profiler::ProfilerConfig&& profilerConfig) {
auto& remoteProfilerManager =
torch::distributed::rpc::RemoteProfilerManager::getInstance();
auto key = remoteProfilerManager.getCurrentProfilingKey();
// generate a globally unique Id
auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId();
// Save a mapping of ID -> RPC profiling key and unset the current TLS key.
remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key);
remoteProfilerManager.unsetCurrentKey();
auto wrappedProfilingMsg = RpcWithProfilingReq(
msgType,
std::move(wrappedRpcMessage),
std::move(profilerConfig),
globallyUniqueProfilingId);
return std::move(wrappedProfilingMsg).toMessage();
}
c10::intrusive_ptr<Message> getMessageWithAutograd(
const rpc::worker_id_t dstId,
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
MessageType msgType,
bool forceGradRecording,
const rpc::DeviceMap& deviceMap) {
auto& autogradContainer = DistAutogradContainer::getInstance();
// If there is no valid context and no tensor requires grads, send original
// rpc message. otherwise, attach grad info and grad functions and send
// rpcWithAutograd message.
auto tensorsRequireGrad =
torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors());
if (!autogradContainer.hasValidContext() ||
(!forceGradRecording && !tensorsRequireGrad)) {
return wrappedRpcMsg;
}
// Retrieve the appropriate context to modify.
auto autogradContext = autogradContainer.currentContext();
// Wrap the original rpc with autograd information.
AutogradMetadata autogradMetadata(
autogradContext->contextId(), autogradContainer.newAutogradMessageId());
auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
msgType,
autogradMetadata,
std::move(wrappedRpcMsg),
deviceMap);
if (tensorsRequireGrad) {
// Record autograd information for 'send'.
addSendRpcBackward(
autogradContext, autogradMetadata, rpcWithAutograd->tensors());
}
// Record the workerID
autogradContext->addKnownWorkerId(dstId);
return std::move(*rpcWithAutograd).toMessage();
}
c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
RpcAgent& agent,
const WorkerInfo& dst,
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
bool forceGradRecording,
const float rpcTimeoutSeconds,
bool forceDisableProfiling) {
auto msg = getMessageWithAutograd(
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording,
agent.getDeviceMap(dst));
// If profiler is enabled, wrap this message with profiling metadata that will
// tell the remote end to process this request with the profiler enabled.
if (!forceDisableProfiling) {
switch (torch::profiler::impl::profilerType()) {
case torch::profiler::impl::ActiveProfilerType::LEGACY: {
auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
auto msgWithProfiling = getMessageWithProfiling(
std::move(msg),
rpc::MessageType::RUN_WITH_PROFILING_REQ,
std::move(profilerConfig));
return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
}
case torch::profiler::impl::ActiveProfilerType::KINETO:
TORCH_WARN_ONCE(
"Profiling a distributed call with the Kineto profiler will profile "
"the caller, but not the worker.");
break;
default:
break;
}
}
return agent.send(dst, std::move(msg), rpcTimeoutSeconds);
;
}
} // namespace autograd
} // namespace distributed
} // namespace torch