blob: f1c4d574be0409b3e3f9aab01b699e575c2b6311 [file] [log] [blame]
#pragma once
#include <torch/csrc/distributed/rpc/rpc_command_base.h>
namespace torch {
namespace distributed {
namespace rpc {
// Given an RPC message received as a request over the wire, deserialize it into
// the appropriate 'RpcCommandBase' type.
TORCH_API std::unique_ptr<RpcCommandBase> deserializeRequest(
const Message& request);
// Given an RPC message received as a response over the wire, deserialize it
// into the appropriate 'RpcCommandBase' type.
TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse(
const Message& response);
// Note: format is subject to change and intended for RPCs.
// For saving persistently to disk, use torch::save().
TORCH_API std::string wireSerialize(
const std::vector<char>& payload,
const std::vector<at::Tensor>& tensors);
TORCH_API std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
const void* data,
size_t data_size);
// Some Tensors are effectively views of larger Tensors, where only a small
// subset of the Storage data is referenced. This normally is good and avoids
// copies when kept locally, but if we naively push the whole Storage over the
// wire, we'll end up with excess network trafic. This change clones tensors if
// we'd save at least half the data, and over a minimum hurdle.
TORCH_API c10::List<at::Tensor> cloneSparseTensors(
const std::vector<at::Tensor>& tensors);
} // namespace rpc
} // namespace distributed
} // namespace torch