| #include <torch/csrc/distributed/rpc/utils.h> | 
 |  | 
 | #include <fmt/format.h> | 
 | #include <torch/csrc/autograd/profiler.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h> | 
 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h> | 
 | #include <torch/csrc/distributed/autograd/utils.h> | 
 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> | 
 | #include <torch/csrc/distributed/rpc/python_call.h> | 
 | #include <torch/csrc/distributed/rpc/python_remote_call.h> | 
 | #include <torch/csrc/distributed/rpc/python_resp.h> | 
 | #include <torch/csrc/distributed/rpc/rref_proto.h> | 
 | #include <torch/csrc/distributed/rpc/script_call.h> | 
 | #include <torch/csrc/distributed/rpc/script_remote_call.h> | 
 | #include <torch/csrc/distributed/rpc/script_resp.h> | 
 | #include <torch/csrc/jit/serialization/pickler.h> | 
 | #include <torch/csrc/jit/serialization/unpickler.h> | 
 |  | 
 | #include <c10/util/irange.h> | 
 |  | 
 | using namespace torch::autograd::profiler; | 
 |  | 
 | namespace torch::distributed::rpc { | 
 | namespace { | 
 | void processRemoteProfiledEvents( | 
 |     autograd::RpcWithProfilingResp& rpcWithProfilingResp) { | 
 |   // Check if the profiler is enabled | 
 |   auto enabled = profilerEnabled(); | 
 |   TORCH_CHECK( | 
 |       enabled, | 
 |       "Profiler was expected to be enabled. This can happen in callback " | 
 |       " continuations that run in different threads, and the TLS of the " | 
 |       " profiler was not propagated."); | 
 |   std::vector<LegacyEvent> events = rpcWithProfilingResp.getProfiledEvents(); | 
 |   const auto& profilingId = rpcWithProfilingResp.getProfilingId(); | 
 |   auto& remoteProfilerManager = RemoteProfilerManager::getInstance(); | 
 |   auto key = remoteProfilerManager.retrieveRPCProfilingKey(profilingId); | 
 |   remoteProfilerManager.eraseKey(profilingId); | 
 |   auto keyPrefixStr = key + rpc::REMOTE_PROFILING_KEY_PREFIX; | 
 |   std::for_each( | 
 |       events.begin(), events.end(), [&keyPrefixStr](LegacyEvent& event) { | 
 |         std::string name = keyPrefixStr + std::string(event.name()); | 
 |         event.setName(at::StringView(name)); | 
 |       }); | 
 |   // Add event list to the thread local profiler. | 
 |   addEventList(std::move(events)); | 
 | } | 
 |  | 
 | } // namespace | 
 |  | 
 | const std::string kRPCErrorPrefix = std::string("RPCErr"); | 
 |  | 
 | RPCErrorType getRPCErrorType(const JitFuture& jitFuture) { | 
 |   TORCH_INTERNAL_ASSERT( | 
 |       jitFuture.hasError(), | 
 |       "JitFuture of Message passed to getRPCErrorType does not have an error."); | 
 |  | 
 |   // Attempt to parse for error string given by makeRPCError, otherwise return | 
 |   // unknown error. | 
 |   // Note that this function expects errors formatted with makeRPCError(). | 
 |   auto err = jitFuture.tryRetrieveErrorMessage(); | 
 |   size_t pos = err.find(kRPCErrorPrefix); | 
 |   if (pos != std::string::npos) { | 
 |     // Parse the RPCErrorType. | 
 |     auto errStartIdx = | 
 |         pos + torch::distributed::rpc::kRPCErrorPrefix.size() + 1; | 
 |     auto errEndIdx = err.find(':', errStartIdx); | 
 |     if (errEndIdx == std::string::npos) { | 
 |       // Indicates error was not formatted correctly. | 
 |       return RPCErrorType::UNKNOWN_ERROR; | 
 |     } | 
 |     auto errStr = err.substr(errStartIdx, errEndIdx - errStartIdx); | 
 |     auto errType = static_cast<RPCErrorType>(std::stoi(errStr)); | 
 |     return errType; | 
 |   } else { | 
 |     return RPCErrorType::UNKNOWN_ERROR; | 
 |   } | 
 | } | 
 |  | 
 | std::string makeRPCError( | 
 |     const std::string& rpcErrorStr, | 
 |     RPCErrorType errorType) { | 
 |   return fmt::format( | 
 |       "{}:{}:{}", | 
 |       torch::distributed::rpc::kRPCErrorPrefix, | 
 |       static_cast<int>(errorType), | 
 |       rpcErrorStr); | 
 | } | 
 |  | 
 | std::unique_ptr<RpcCommandBase> deserializeRequest(const Message& request) { | 
 |   switch (request.type()) { | 
 |     case MessageType::SCRIPT_CALL: { | 
 |       return ScriptCall::fromMessage(request); | 
 |     } | 
 |     case MessageType::PYTHON_CALL: { | 
 |       return PythonCall::fromMessage(request); | 
 |     } | 
 |     case MessageType::SCRIPT_REMOTE_CALL: { | 
 |       return ScriptRemoteCall::fromMessage(request); | 
 |     } | 
 |     case MessageType::PYTHON_REMOTE_CALL: { | 
 |       return PythonRemoteCall::fromMessage(request); | 
 |     } | 
 |     case MessageType::SCRIPT_RREF_FETCH_CALL: { | 
 |       return ScriptRRefFetchCall::fromMessage(request); | 
 |     } | 
 |     case MessageType::PYTHON_RREF_FETCH_CALL: { | 
 |       return PythonRRefFetchCall::fromMessage(request); | 
 |     } | 
 |     case MessageType::RREF_USER_DELETE: { | 
 |       return RRefUserDelete::fromMessage(request); | 
 |     } | 
 |     case MessageType::RREF_CHILD_ACCEPT: { | 
 |       return RRefChildAccept::fromMessage(request); | 
 |     } | 
 |     case MessageType::RREF_FORK_REQUEST: { | 
 |       return RRefForkRequest::fromMessage(request); | 
 |     } | 
 |     case MessageType::FORWARD_AUTOGRAD_REQ: { | 
 |       return autograd::RpcWithAutograd::fromMessage(request); | 
 |     } | 
 |     case MessageType::BACKWARD_AUTOGRAD_REQ: { | 
 |       return autograd::PropagateGradientsReq::fromMessage(request); | 
 |     } | 
 |     case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { | 
 |       return autograd::CleanupAutogradContextReq::fromMessage(request); | 
 |     } | 
 |     case MessageType::RUN_WITH_PROFILING_REQ: { | 
 |       return autograd::RpcWithProfilingReq::fromMessage(request); | 
 |     } | 
 |     case MessageType::RREF_BACKWARD_REQ: { | 
 |       return autograd::RRefBackwardReq::fromMessage(request); | 
 |     } | 
 |     default: { | 
 |       TORCH_INTERNAL_ASSERT( | 
 |           false, "Request type ", request.type(), " not supported."); | 
 |     } | 
 |   } | 
 | } | 
 |  | 
 | std::unique_ptr<RpcCommandBase> deserializeResponse( | 
 |     const Message& response, | 
 |     MessageType& wrappedMsgType) { | 
 |   switch (response.type()) { | 
 |     case MessageType::SCRIPT_RET: { | 
 |       return ScriptResp::fromMessage(response); | 
 |     } | 
 |     case MessageType::PYTHON_RET: { | 
 |       return PythonResp::fromMessage(response); | 
 |     } | 
 |     case MessageType::REMOTE_RET: { | 
 |       return RemoteRet::fromMessage(response); | 
 |     } | 
 |     case MessageType::SCRIPT_RREF_FETCH_RET: { | 
 |       return ScriptRRefFetchRet::fromMessage(response); | 
 |     } | 
 |     case MessageType::PYTHON_RREF_FETCH_RET: { | 
 |       return PythonRRefFetchRet::fromMessage(response); | 
 |     } | 
 |     case MessageType::RREF_ACK: { | 
 |       return RRefAck::fromMessage(response); | 
 |     } | 
 |     case MessageType::FORWARD_AUTOGRAD_RESP: { | 
 |       std::unique_ptr<RpcCommandBase> rpcPtr = | 
 |           autograd::RpcWithAutograd::fromMessage(response); | 
 |       RpcCommandBase& rpc = *rpcPtr; | 
 |       auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc); | 
 |  | 
 |       // Need to reverse the device map for the backward pass of distributed | 
 |       // autograd. | 
 |       DeviceMap reverseDeviceMap; | 
 |       for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { | 
 |         reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); | 
 |       } | 
 |  | 
 |       // Attach 'recv' autograd function. | 
 |       addRecvRpcBackward( | 
 |           rpcWithAutograd.autogradMetadata(), | 
 |           rpcWithAutograd.tensors(), | 
 |           rpcWithAutograd.fromWorkerId(), | 
 |           reverseDeviceMap); | 
 |  | 
 |       wrappedMsgType = rpcWithAutograd.wrappedMessageType(); | 
 |  | 
 |       return std::move(rpcWithAutograd).moveWrappedRpc(); | 
 |     } | 
 |     case MessageType::BACKWARD_AUTOGRAD_RESP: { | 
 |       return autograd::PropagateGradientsResp::fromMessage(response); | 
 |     } | 
 |     case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: { | 
 |       return autograd::CleanupAutogradContextResp::fromMessage(response); | 
 |     } | 
 |     case MessageType::RUN_WITH_PROFILING_RESP: { | 
 |       std::unique_ptr<RpcCommandBase> rpcPtr = | 
 |           autograd::RpcWithProfilingResp::fromMessage(response); | 
 |       RpcCommandBase& rpc = *rpcPtr; | 
 |       auto& rpcWithProfilingResp = | 
 |           static_cast<autograd::RpcWithProfilingResp&>(rpc); | 
 |       // Process remotely profiled events. | 
 |       processRemoteProfiledEvents(rpcWithProfilingResp); | 
 |  | 
 |       wrappedMsgType = rpcWithProfilingResp.wrappedMessageType(); | 
 |       auto wrappedRPC = std::move(rpcWithProfilingResp).moveWrappedRpc(); | 
 |       return wrappedRPC; | 
 |     } | 
 |     case MessageType::RREF_BACKWARD_RESP: { | 
 |       return autograd::RRefBackwardResp::fromMessage(response); | 
 |     } | 
 |     default: { | 
 |       TORCH_INTERNAL_ASSERT( | 
 |           false, "Response type ", response.type(), " not supported."); | 
 |     } | 
 |   } | 
 | } | 
 |  | 
 | IValue deserializeResptoIValueInternal( | 
 |     RpcCommandBase& rpc, | 
 |     MessageType messageType) { | 
 |   switch (messageType) { | 
 |     case MessageType::SCRIPT_RET: { | 
 |       auto& ret = static_cast<ScriptResp&>(rpc); | 
 |       return ret.value(); | 
 |     } | 
 |     default: { | 
 |       TORCH_INTERNAL_ASSERT( | 
 |           false, | 
 |           "Response type ", | 
 |           messageType, | 
 |           " is not supported to be deserialized to IValue."); | 
 |     } | 
 |   } | 
 | } | 
 |  | 
 | IValue deserializeRespToIValue(const Message& message) { | 
 |   MessageType msgType = message.type(); | 
 |   auto response = deserializeResponse(message, msgType); | 
 |   return deserializeResptoIValueInternal(*response, msgType); | 
 | } | 
 |  | 
 | namespace { | 
 |  | 
 | // Helper for wireDeserialize() below. | 
 | // | 
 | // The format we use below looks like: | 
 | //    section_name_1 size_1\n | 
 | //    section_name_2 size_2\n | 
 | //    .. | 
 | //    \n | 
 | //    [sections in order] | 
 | // | 
 | // Sections themselves include: | 
 | //    - "payload" - the payload bits | 
 | //    - "meta"    - metadata for the unpickler | 
 | //    - "0" ...   - tensor sections for the unpickler | 
 | // | 
 | // Note that per the header comments, the format is subject to change, | 
 | // and is best used for rpcs, rather than persistent disk storage. | 
 | std::unordered_map<std::string, std::pair<const char*, size_t>> | 
 | parseWireSections(const void* data, size_t data_size) { | 
 |   const char* ptr = static_cast<const char*>(data); | 
 |   const char* endp = ptr + data_size; | 
 |  | 
 |   std::vector<std::pair<std::string, size_t>> headerEnts; | 
 |   bool ok = false; | 
 |   while (ptr != endp) { | 
 |     if (*ptr == '\n') { | 
 |       ok = true; // The only "correct" exit point. | 
 |       ++ptr; | 
 |       break; | 
 |     } | 
 |     // Parse name | 
 |     const char* namePtr = ptr; | 
 |     while (ptr != endp && *ptr != ' ') { | 
 |       ptr++; | 
 |     } | 
 |     if (ptr == endp) { | 
 |       break; | 
 |     } | 
 |     std::string name(namePtr, ptr - namePtr); | 
 |     if (++ptr == endp) { | 
 |       break; // past the ' ' | 
 |     } | 
 |     // Parse size | 
 |     const char* sizePtr = ptr; | 
 |     while (ptr != endp && *ptr != '\n') { | 
 |       ptr++; | 
 |     } | 
 |     if (ptr == endp) { | 
 |       break; | 
 |     } | 
 |     size_t sz = std::stoll(std::string(sizePtr, ptr - sizePtr)); | 
 |     headerEnts.emplace_back(name, sz); | 
 |     ++ptr; // past the '\n' | 
 |   } | 
 |   if (!ok) { | 
 |     TORCH_CHECK(false, "failed parse"); | 
 |   } | 
 |  | 
 |   std::unordered_map<std::string, std::pair<const char*, size_t>> out; | 
 |   for (const auto& headerEnt : headerEnts) { | 
 |     out[headerEnt.first] = {ptr, headerEnt.second}; | 
 |     ptr += headerEnt.second; | 
 |   } | 
 |   if (ptr != endp) { | 
 |     TORCH_CHECK(false, "failed bounds"); | 
 |   } | 
 |   return out; | 
 | } | 
 |  | 
 | static const char* kMeta = "meta"; | 
 | static const char* kPayload = "payload"; | 
 | }; // namespace | 
 |  | 
 | c10::List<at::Tensor> cloneSparseTensors( | 
 |     const std::vector<at::Tensor>& tensors) { | 
 |   // Sanity-check: If the majority of bits don't need to go over the wire, | 
 |   // force a clone(). Some Tensors are effectively small views, only using | 
 |   // ~1% of the underlying Storage. | 
 |   auto worthRecopying = [](const at::Tensor& t) -> bool { | 
 |     if (!t.has_storage()) { | 
 |       return false; // avoid throwing below. | 
 |     } | 
 |     auto storageSize = t.storage().nbytes(); | 
 |     auto usefulSize = t.element_size() * t.numel(); | 
 |     constexpr size_t kMinMultiple = 2; | 
 |     constexpr size_t kMinRecopyBytes = 8 * 1024; | 
 |     return storageSize >= kMinRecopyBytes && | 
 |         storageSize >= usefulSize * kMinMultiple; | 
 |   }; | 
 |   c10::List<at::Tensor> pTensors; | 
 |   pTensors.reserve(tensors.size()); | 
 |   for (const auto& t : tensors) { | 
 |     pTensors.push_back(worthRecopying(t) ? t.clone() : t); | 
 |   } | 
 |   return pTensors; | 
 | } | 
 |  | 
 | std::string wireSerialize( | 
 |     const std::vector<char>& payload, | 
 |     const std::vector<at::Tensor>& tensors) { | 
 |   for (const auto& tensor : tensors) { | 
 |     TORCH_CHECK( | 
 |         tensor.device().is_cpu(), | 
 |         "ProcessGroup RPC backend only supports", | 
 |         " CPU tensors, please move your tensors to CPU before sending ", | 
 |         "them over RPC. Found tensor on device: ", | 
 |         tensor.device()); | 
 |   } | 
 |  | 
 |   struct Ent { | 
 |     std::string name; | 
 |     const char* data; | 
 |     size_t size; | 
 |   }; | 
 |   std::vector<Ent> entries; | 
 |   std::string metaEntry; | 
 |   std::vector<at::Tensor> tensorData; | 
 |  | 
 |   if (!payload.empty()) { | 
 |     entries.push_back({kPayload, payload.data(), payload.size()}); | 
 |   } | 
 |  | 
 |   if (!tensors.empty()) { | 
 |     torch::jit::Pickler pickler([&](const void* buf, size_t sz) -> size_t { | 
 |       metaEntry.append(static_cast<const char*>(buf), sz); | 
 |       return sz; | 
 |     }); | 
 |     pickler.protocol(); | 
 |     pickler.pushIValue(cloneSparseTensors(tensors)); | 
 |     pickler.stop(); | 
 |     tensorData = pickler.tensorData(); | 
 |     entries.push_back({kMeta, metaEntry.data(), metaEntry.size()}); | 
 |     for (const auto i : c10::irange(tensorData.size())) { | 
 |       // Construct WritableTensorData for each tensor in the pickler tensorData | 
 |       // Since tensorData is in function scope, and getWritableTensorData just | 
 |       // record the tensors, the data() pointers stay valid for CPU tensors | 
 |       // Note that RPC serde doesn't support CUDA tensors yet, if we should | 
 |       // support CUDA tensor, we need to be careful since getWritableTensorData | 
 |       // converts CUDA tensor to cpu and data() might get destructed as we go | 
 |       // out of scope of this loop. | 
 |       auto writeableTensorData = jit::getWriteableTensorData(tensorData[i]); | 
 |       entries.push_back( | 
 |           {std::to_string(i), | 
 |            writeableTensorData.data(), | 
 |            writeableTensorData.sizeInBytes()}); | 
 |     } | 
 |   } | 
 |  | 
 |   std::string header; | 
 |   size_t tot = 0; | 
 |   for (const auto& e : entries) { | 
 |     tot += e.size; | 
 |     header.append(e.name) | 
 |         .append(" ") | 
 |         .append(std::to_string(e.size)) | 
 |         .append("\n"); | 
 |   } | 
 |   header.push_back('\n'); | 
 |  | 
 |   std::string out; | 
 |   out.reserve(header.size() + tot); | 
 |   out.append(header); | 
 |   for (const auto& e : entries) { | 
 |     out.append(e.data, e.size); | 
 |   } | 
 |   return out; | 
 | } | 
 |  | 
 | std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize( | 
 |     const void* data, | 
 |     size_t data_size) { | 
 |   auto sections = parseWireSections(data, data_size); | 
 |  | 
 |   std::vector<char> payload; | 
 |   auto payloadIt = sections.find(kPayload); | 
 |   if (payloadIt != sections.end() && payloadIt->second.second != 0) { | 
 |     payload.assign( | 
 |         payloadIt->second.first, | 
 |         payloadIt->second.first + payloadIt->second.second); | 
 |   } | 
 |  | 
 |   std::vector<at::Tensor> tensors; | 
 |   auto metaIt = sections.find(kMeta); | 
 |   if (metaIt != sections.end()) { | 
 |     const auto& metaData = metaIt->second; | 
 |     size_t metaDataPos = 0; | 
 |     auto metaDataReadFunc = [&](char* buf, size_t n) -> size_t { | 
 |       if (metaDataPos >= metaData.second || n == 0) { | 
 |         return 0; | 
 |       } | 
 |       size_t toCopy = std::min(metaDataPos + n, metaData.second) - metaDataPos; | 
 |       memcpy(buf, metaData.first + metaDataPos, toCopy); | 
 |       metaDataPos += toCopy; | 
 |       return toCopy; | 
 |     }; | 
 |     auto sectionReadFunc = [&](const std::string& ename) -> at::DataPtr { | 
 |       auto it = sections.find(ename); | 
 |       if (it == sections.end()) { | 
 |         TORCH_CHECK(false, "Couldn't find entity " + ename); | 
 |       } | 
 |       const auto& idat = it->second; | 
 |       auto dptr = at::getCPUAllocator()->allocate(idat.second); | 
 |       if (idat.second != 0) { | 
 |         memcpy(dptr.get(), idat.first, idat.second); | 
 |       } | 
 |       return dptr; | 
 |     }; | 
 |  | 
 |     // No need to pass typeResolver here, as it always processes string and | 
 |     // tensors only | 
 |     torch::jit::Unpickler unpickler( | 
 |         metaDataReadFunc, nullptr, nullptr, sectionReadFunc, {}); | 
 |     auto ival = unpickler.parse_ivalue(); | 
 |     for (auto&& t : ival.toTensorList()) { | 
 |       tensors.emplace_back(std::move(t)); | 
 |     } | 
 |   } | 
 |   return {std::move(payload), std::move(tensors)}; | 
 | } | 
 |  | 
 | void writeWrappedPayload( | 
 |     std::vector<char>& originalPayload, | 
 |     std::vector<char>& additionalPayload) { | 
 |   originalPayload.insert( | 
 |       originalPayload.end(), | 
 |       additionalPayload.begin(), | 
 |       additionalPayload.end()); | 
 |  | 
 |   // Add size of the additional payload | 
 |   int64_t indexToWrite = originalPayload.size(); | 
 |   originalPayload.resize(originalPayload.size() + sizeof(int64_t)); | 
 |   const int64_t additionalPayloadSize = additionalPayload.size(); | 
 |   torch::utils::THP_encodeInt64Buffer( | 
 |       reinterpret_cast<uint8_t*>(originalPayload.data()) + indexToWrite, | 
 |       &additionalPayloadSize, | 
 |       torch::utils::THPByteOrder::THP_BIG_ENDIAN, | 
 |       1); | 
 | } | 
 |  | 
 | std::vector<at::IValue> readWrappedPayload( | 
 |     std::vector<char>& payload, | 
 |     const rpc::Message& message) { | 
 |   // Read the additional payload remove it from the payload. | 
 |   // NOLINTNEXTLINE(cppcoreguidelines-init-variables) | 
 |   int64_t additionalPayloadSize; | 
 |   TORCH_INTERNAL_ASSERT(payload.size() >= sizeof(int64_t)); | 
 |   size_t indexToRead = payload.size() - sizeof(int64_t); | 
 |   torch::utils::THP_decodeInt64Buffer( | 
 |       &additionalPayloadSize, | 
 |       reinterpret_cast<uint8_t*>(payload.data()) + indexToRead, | 
 |       torch::utils::THPByteOrder::THP_BIG_ENDIAN, | 
 |       1); | 
 |   payload.resize(indexToRead); | 
 |  | 
 |   TORCH_INTERNAL_ASSERT( | 
 |       additionalPayloadSize > 0 && | 
 |           static_cast<int64_t>(payload.size()) > additionalPayloadSize, | 
 |       "Wrong payload sizes: payload.size() is ", | 
 |       payload.size(), | 
 |       " but additional payload size is ", | 
 |       additionalPayloadSize); | 
 |   auto wrappedPayloadBegin = | 
 |       static_cast<const char*>(message.payload().data()) + payload.size() - | 
 |       additionalPayloadSize; | 
 |   std::vector<torch::Tensor> tensorTable; | 
 |   IValue tuple = jit::unpickle( | 
 |       wrappedPayloadBegin, | 
 |       additionalPayloadSize, | 
 |       *rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(), | 
 |       tensorTable); | 
 |   std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec(); | 
 |   payload.resize(payload.size() - additionalPayloadSize); | 
 |   return tupleElements; | 
 | } | 
 |  | 
 | void populateRemoteProfiledEvents( | 
 |     std::vector<LegacyEvent>& profiledEvents, | 
 |     const ProfilerConfig& profilingConfig, | 
 |     const std::vector<std::vector<LegacyEvent>>& eventLists) { | 
 |   // Gather all events into a vector | 
 |   for (auto& l : eventLists) { | 
 |     for (auto& e : l) { | 
 |       profiledEvents.push_back(e); | 
 |     } | 
 |   } | 
 |   // find __start_profile event | 
 |   bool cudaProfilingEnabled = profilingConfig.state == ProfilerState::CUDA; | 
 |   const LegacyEvent* profilerStart = nullptr; | 
 |  | 
 |   for (auto& e : profiledEvents) { | 
 |     if (std::string(e.name()) == "__start_profile") { | 
 |       profilerStart = &e; | 
 |       break; | 
 |     } | 
 |   } | 
 |   // We should always find __start_profile. | 
 |   TORCH_CHECK( | 
 |       profilerStart != nullptr, "Expected to find __start_profile event."); | 
 |  | 
 |   if (cudaProfilingEnabled) { | 
 |     // Deserialized events don't have the corresponding CUDA events, making it | 
 |     // impossible to use cudaEventElapsedTime the receiving end. To avoid this, | 
 |     // find all push/pop pairs of CUDA events and set the corresponding CUDA | 
 |     // time to zero for the push event and to the elapsed time for the pop | 
 |     // event, to be used later for the elapsed CUDA time computation. | 
 |     std::unordered_map<at::RecordFunctionHandle, const LegacyEvent*> | 
 |         startEvents; | 
 |     for (auto& e : profiledEvents) { | 
 |       if (e.hasCuda()) { | 
 |         if (e.kind() == EventKind::PushRange) { | 
 |           startEvents[e.handle()] = &e; | 
 |         } | 
 |       } | 
 |     } | 
 |     for (auto& e : profiledEvents) { | 
 |       if (e.hasCuda()) { | 
 |         if (e.kind() == EventKind::PopRange) { | 
 |           auto it = startEvents.find(e.handle()); | 
 |           if (it != startEvents.end()) { | 
 |             e.setCudaUs(it->second->cudaElapsedUs(e)); | 
 |           } else { | 
 |             TORCH_WARN("Found a pop event without a corresponding push event"); | 
 |             e.setCudaUs(0); | 
 |           } | 
 |         } else { | 
 |           e.setCudaUs(0); | 
 |         } | 
 |       } | 
 |     } | 
 |   } | 
 | } | 
 |  | 
 | } // namespace torch::distributed::rpc |