blob: 12647cccfa6ede97e734d5864106d09856336825 [file] [log] [blame]
#pragma once
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
namespace torch {
namespace distributed {
namespace rpc {
// Given an RPC message received as a request over the wire, deserialize it into
// the appropriate 'RpcCommandBase' type.
TORCH_API std::unique_ptr<RpcCommandBase> deserializeRequest(
const Message& request);
// Given an RPC message received as a response over the wire, deserialize it
// into the appropriate 'RpcCommandBase' type, if the response is
// FORWARD_AUTOGRAD_RESP type, unwrap it, attach recvBackward() functions
// to received tensors and set the wrappedMsgType to its wrapped message type.
TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse(
const Message& response,
MessageType& wrappedMsgType);
// Given an RPC message received as a response over the wire, deserialize it
// into the valid IValue if the message is for a script rpc result,
// otherwise deserialize it into dummy none ivalue that will never be used.
// In this deserialization, we also attach recv rpc backward functions if
// needed.
IValue deserializeResptoIValueInternal(
RpcCommandBase& rpc,
MessageType messageType);
TORCH_API IValue deserializeRespToIValue(const Message& message);
// Note: format is subject to change and intended for RPCs.
// For saving persistently to disk, use torch::save().
TORCH_API std::string wireSerialize(
const std::vector<char>& payload,
const std::vector<at::Tensor>& tensors);
TORCH_API std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
const void* data,
size_t data_size);
// Some Tensors are effectively views of larger Tensors, where only a small
// subset of the Storage data is referenced. This normally is good and avoids
// copies when kept locally, but if we naively push the whole Storage over the
// wire, we'll end up with excess network traffic. This change clones tensors if
// we'd save at least half the data, and over a minimum hurdle.
TORCH_API c10::List<at::Tensor> cloneSparseTensors(
const std::vector<at::Tensor>& tensors);
} // namespace rpc
} // namespace distributed
} // namespace torch